server/scores+favorites: merge duplicate code
This commit is contained in:
parent
f140ae6176
commit
cd15cdff7a
6 changed files with 41 additions and 47 deletions
|
@ -22,3 +22,4 @@ from szurubooru.db.session import (
|
|||
session,
|
||||
reset_query_count,
|
||||
get_query_count)
|
||||
import szurubooru.db.util
|
||||
|
|
32
server/szurubooru/db/util.py
Normal file
32
server/szurubooru/db/util.py
Normal file
|
@ -0,0 +1,32 @@
|
|||
from sqlalchemy.inspection import inspect
|
||||
|
||||
def get_resource_info(entity):
|
||||
serializers = {
|
||||
'tag': lambda tag: tag.first_name,
|
||||
'tag_category': lambda category: category.name,
|
||||
'comment': lambda comment: comment.comment_id,
|
||||
'post': lambda post: post.post_id,
|
||||
}
|
||||
|
||||
resource_type = entity.__table__.name
|
||||
assert resource_type in serializers
|
||||
|
||||
primary_key = inspect(entity).identity
|
||||
assert primary_key is not None
|
||||
assert len(primary_key) == 1
|
||||
|
||||
resource_repr = serializers[resource_type](entity)
|
||||
assert resource_repr
|
||||
|
||||
resource_id = primary_key[0]
|
||||
assert resource_id
|
||||
|
||||
return (resource_type, resource_id, resource_repr)
|
||||
|
||||
def get_aux_entity(session, get_table_info, entity, user):
|
||||
table, get_column = get_table_info(entity)
|
||||
return session \
|
||||
.query(table) \
|
||||
.filter(get_column(table) == get_column(entity)) \
|
||||
.filter(table.user_id == user.user_id) \
|
||||
.one_or_none()
|
|
@ -1,21 +1,14 @@
|
|||
import datetime
|
||||
from szurubooru import db
|
||||
from szurubooru.func import util
|
||||
|
||||
def _get_table_info(entity):
|
||||
resource_type, _, _ = util.get_resource_info(entity)
|
||||
resource_type, _, _ = db.util.get_resource_info(entity)
|
||||
if resource_type == 'post':
|
||||
return db.PostFavorite, lambda table: table.post_id
|
||||
else:
|
||||
assert False
|
||||
assert False
|
||||
|
||||
def _get_fav_entity(entity, user):
|
||||
table, get_column = _get_table_info(entity)
|
||||
return db.session \
|
||||
.query(table) \
|
||||
.filter(get_column(table) == get_column(entity)) \
|
||||
.filter(table.user_id == user.user_id) \
|
||||
.one_or_none()
|
||||
return db.util.get_aux_entity(db.session, _get_table_info, entity, user)
|
||||
|
||||
def has_favorited(entity, user):
|
||||
return _get_fav_entity(entity, user) is not None
|
||||
|
|
|
@ -1,25 +1,18 @@
|
|||
import datetime
|
||||
from szurubooru import db, errors
|
||||
from szurubooru.func import util
|
||||
|
||||
class InvalidScoreError(errors.ValidationError): pass
|
||||
|
||||
def _get_table_info(entity):
|
||||
resource_type, _, _ = util.get_resource_info(entity)
|
||||
resource_type, _, _ = db.util.get_resource_info(entity)
|
||||
if resource_type == 'post':
|
||||
return db.PostScore, lambda table: table.post_id
|
||||
elif resource_type == 'comment':
|
||||
return db.CommentScore, lambda table: table.comment_id
|
||||
else:
|
||||
assert False
|
||||
assert False
|
||||
|
||||
def _get_score_entity(entity, user):
|
||||
table, get_column = _get_table_info(entity)
|
||||
return db.session \
|
||||
.query(table) \
|
||||
.filter(get_column(table) == get_column(entity)) \
|
||||
.filter(table.user_id == user.user_id) \
|
||||
.one_or_none()
|
||||
return db.util.get_aux_entity(db.session, _get_table_info, entity, user)
|
||||
|
||||
def delete_score(entity, user):
|
||||
score_entity = _get_score_entity(entity, user)
|
||||
|
|
|
@ -1,6 +1,5 @@
|
|||
import datetime
|
||||
from szurubooru import db
|
||||
from szurubooru.func import util
|
||||
|
||||
def get_tag_snapshot(tag):
|
||||
return {
|
||||
|
@ -49,7 +48,7 @@ def get_previous_snapshot(snapshot):
|
|||
.first()
|
||||
|
||||
def get_snapshots(entity):
|
||||
resource_type, resource_id, _ = util.get_resource_info(entity)
|
||||
resource_type, resource_id, _ = db.util.get_resource_info(entity)
|
||||
return db.session \
|
||||
.query(db.Snapshot) \
|
||||
.filter(db.Snapshot.resource_type == resource_type) \
|
||||
|
@ -81,7 +80,7 @@ def get_serialized_history(entity):
|
|||
return ret
|
||||
|
||||
def _save(operation, entity, auth_user):
|
||||
resource_type, resource_id, resource_repr = util.get_resource_info(entity)
|
||||
resource_type, resource_id, resource_repr = db.util.get_resource_info(entity)
|
||||
now = datetime.datetime.now()
|
||||
|
||||
snapshot = db.Snapshot()
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
import datetime
|
||||
import hashlib
|
||||
import re
|
||||
from sqlalchemy.inspection import inspect
|
||||
from szurubooru.errors import ValidationError
|
||||
|
||||
def unalias_dict(input_dict):
|
||||
|
@ -23,29 +22,6 @@ def get_md5(source):
|
|||
def flip(source):
|
||||
return {v: k for k, v in source.items()}
|
||||
|
||||
def get_resource_info(entity):
|
||||
serializers = {
|
||||
'tag': lambda tag: tag.first_name,
|
||||
'tag_category': lambda category: category.name,
|
||||
'comment': lambda comment: comment.comment_id,
|
||||
'post': lambda post: post.post_id,
|
||||
}
|
||||
|
||||
resource_type = entity.__table__.name
|
||||
assert resource_type in serializers
|
||||
|
||||
primary_key = inspect(entity).identity
|
||||
assert primary_key is not None
|
||||
assert len(primary_key) == 1
|
||||
|
||||
resource_repr = serializers[resource_type](entity)
|
||||
assert resource_repr
|
||||
|
||||
resource_id = primary_key[0]
|
||||
assert resource_id
|
||||
|
||||
return (resource_type, resource_id, resource_repr)
|
||||
|
||||
def is_valid_email(email):
|
||||
''' Return whether given email address is valid or empty. '''
|
||||
return not email or re.match(r'^[^@]*@[^@]*\.[^@]*$', email)
|
||||
|
|
Loading…
Reference in a new issue