From 4117f63375c681a800068b56beb201699807801a Mon Sep 17 00:00:00 2001 From: Shyam Sunder Date: Mon, 22 Apr 2019 18:44:02 -0400 Subject: [PATCH] server/model/posts: Make post flags a hybrid attribute in model This should (hopefully) fix #250 and #252 --- server/szurubooru/func/posts.py | 8 ++++---- server/szurubooru/func/snapshots.py | 2 +- server/szurubooru/model/post.py | 12 +++++++++++- .../szurubooru/search/configs/post_search_config.py | 2 +- 4 files changed, 17 insertions(+), 7 deletions(-) diff --git a/server/szurubooru/func/posts.py b/server/szurubooru/func/posts.py index 9a5307b..532bd23 100644 --- a/server/szurubooru/func/posts.py +++ b/server/szurubooru/func/posts.py @@ -222,7 +222,7 @@ class PostSerializer(serialization.BaseSerializer): return get_post_thumbnail_url(self.post) def serialize_flags(self) -> Any: - return [x for x in self.post.flags.split(',') if x] + return self.post.flags def serialize_tags(self) -> Any: return [ @@ -356,7 +356,7 @@ def create_post( post.safety = model.Post.SAFETY_SAFE post.user = user post.creation_time = datetime.utcnow() - post.flags = '' + post.flags = [] post.type = '' post.checksum = '' @@ -477,7 +477,7 @@ def test_sound(post: model.Post, content: bytes) -> None: assert content if mime.is_video(mime.get_mime_type(content)): if images.Image(content).check_for_sound(): - flags = [x for x in post.flags.split(',') if x] + flags = post.flags if model.Post.FLAG_SOUND not in flags: flags.append(model.Post.FLAG_SOUND) update_post_flags(post, flags) @@ -637,7 +637,7 @@ def update_post_flags(post: model.Post, flags: List[str]) -> None: raise InvalidPostFlagError( 'Flag must be one of %r.' % list(FLAG_MAP.values())) target_flags.append(flag) - post.flags = ','.join(target_flags) + post.flags = target_flags def feature_post(post: model.Post, user: Optional[model.User]) -> None: diff --git a/server/szurubooru/func/snapshots.py b/server/szurubooru/func/snapshots.py index b6a13e0..240c3bc 100644 --- a/server/szurubooru/func/snapshots.py +++ b/server/szurubooru/func/snapshots.py @@ -29,7 +29,7 @@ def get_post_snapshot(post: model.Post) -> Dict[str, Any]: 'source': post.source, 'safety': post.safety, 'checksum': post.checksum, - 'flags': sorted(post.flags.split(',')), + 'flags': post.flags, 'featured': post.is_featured, 'tags': sorted([tag.first_name for tag in post.tags]), 'relations': sorted([rel.post_id for rel in post.relations]), diff --git a/server/szurubooru/model/post.py b/server/szurubooru/model/post.py index c0e3b13..bba4c54 100644 --- a/server/szurubooru/model/post.py +++ b/server/szurubooru/model/post.py @@ -1,6 +1,8 @@ +from typing import List import sqlalchemy as sa from szurubooru.model.base import Base from szurubooru.model.comment import Comment +from sqlalchemy.ext.hybrid import hybrid_property class PostFeature(Base): @@ -169,7 +171,7 @@ class Post(Base): last_edit_time = sa.Column('last_edit_time', sa.DateTime) safety = sa.Column('safety', sa.Unicode(32), nullable=False) source = sa.Column('source', sa.Unicode(200)) - flags = sa.Column('flags', sa.Unicode(200), default='') + flags_string = sa.Column('flags', sa.Unicode(200), default='') # content description type = sa.Column('type', sa.Unicode(32), nullable=False) @@ -219,6 +221,14 @@ class Post(Base): .first()) return featured_post and featured_post.post_id == self.post_id + @hybrid_property + def flags(self) -> List[str]: + return sorted([x for x in self.flags_string.split(',') if x]) + + @flags.setter + def flags(self, data: List[str]) -> None: + self.flags_string = ','.join([x for x in data if x]) + score = sa.orm.column_property( sa.sql.expression.select( [sa.sql.expression.func.coalesce( diff --git a/server/szurubooru/search/configs/post_search_config.py b/server/szurubooru/search/configs/post_search_config.py index ec9bcae..9c8de2e 100644 --- a/server/szurubooru/search/configs/post_search_config.py +++ b/server/szurubooru/search/configs/post_search_config.py @@ -348,7 +348,7 @@ class PostSearchConfig(BaseSearchConfig): ( ['flag'], search_util.create_str_filter( - model.Post.flags, _flag_transformer) + model.Post.flags_string, _flag_transformer) ), ])