server/posts: add sketch of post table

This commit is contained in:
rr- 2016-04-17 12:55:07 +02:00
parent 9ac70dbed4
commit bc15fb6675
9 changed files with 226 additions and 14 deletions

View file

@ -1,3 +1,4 @@
from szurubooru.db.base import Base from szurubooru.db.base import Base
from szurubooru.db.user import User from szurubooru.db.user import User
from szurubooru.db.tag import Tag, TagName, TagSuggestion, TagImplication from szurubooru.db.tag import Tag, TagName, TagSuggestion, TagImplication
from szurubooru.db.post import Post, PostTag, PostRelation

View file

@ -0,0 +1,82 @@
from sqlalchemy import Column, Integer, DateTime, String, ForeignKey
from sqlalchemy.orm import relationship, column_property
from sqlalchemy.sql.expression import func, select
from szurubooru.db.base import Base
class PostRelation(Base):
__tablename__ = 'post_relation'
parent_id = Column('parent_id', Integer, ForeignKey('post.id'), primary_key=True)
child_id = Column('child_id', Integer, ForeignKey('post.id'), primary_key=True)
def __init__(self, parent_id, child_id):
self.parent_id = parent_id
self.child_id = child_id
class PostTag(Base):
__tablename__ = 'post_tag'
post_id = Column('post_id', Integer, ForeignKey('post.id'), primary_key=True)
tag_id = Column('tag_id', Integer, ForeignKey('tag.id'), primary_key=True)
def __init__(self, tag_id, post_id):
self.tag_id = tag_id
self.post_id = post_id
class Post(Base):
__tablename__ = 'post'
SAFETY_SAFE = 'safe'
SAFETY_SKETCHY = 'sketchy'
SAFETY_UNSAFE = 'unsafe'
TYPE_IMAGE = 'anim'
TYPE_ANIMATION = 'anim'
TYPE_FLASH = 'flash'
TYPE_VIDEO = 'video'
TYPE_YOUTUBE = 'youtube'
FLAG_LOOP_VIDEO = 1
post_id = Column('id', Integer, primary_key=True)
user_id = Column('user_id', Integer, ForeignKey('user.id'))
creation_time = Column('creation_time', DateTime, nullable=False)
last_edit_time = Column('last_edit_time', DateTime)
safety = Column('safety', String(32), nullable=False)
type = Column('type', String(32), nullable=False)
checksum = Column('checksum', String(64), nullable=False)
source = Column('source', String(200))
file_size = Column('file_size', Integer)
image_width = Column('image_width', Integer)
image_height = Column('image_height', Integer)
flags = Column('flags', Integer, nullable=False, default=0)
user = relationship('User')
tags = relationship('Tag', backref='posts', secondary='post_tag')
relations = relationship(
'Post',
secondary='post_relation',
primaryjoin=post_id == PostRelation.parent_id,
secondaryjoin=post_id == PostRelation.child_id)
tag_count = column_property(
select(
[func.count('1')],
PostTag.post_id == post_id
) \
.correlate('Post') \
.label('tag_count')
)
# TODO: wire these
fav_count = Column('auto_fav_count', Integer, nullable=False, default=0)
score = Column('auto_score', Integer, nullable=False, default=0)
feature_count = Column('auto_feature_count', Integer, nullable=False, default=0)
comment_count = Column('auto_comment_count', Integer, nullable=False, default=0)
note_count = Column('auto_note_count', Integer, nullable=False, default=0)
last_fav_time = Column(
'auto_fav_time', Integer, nullable=False, default=0)
last_feature_time = Column(
'auto_feature_time', Integer, nullable=False, default=0)
last_comment_edit_time = Column(
'auto_comment_creation_time', Integer, nullable=False, default=0)
last_comment_creation_time = Column(
'auto_comment_edit_time', Integer, nullable=False, default=0)

View file

@ -1,6 +1,8 @@
from sqlalchemy import Column, Integer, DateTime, String, ForeignKey from sqlalchemy import Column, Integer, DateTime, String, ForeignKey
from sqlalchemy.orm import relationship from sqlalchemy.orm import relationship, column_property
from sqlalchemy.sql.expression import func, select
from szurubooru.db.base import Base from szurubooru.db.base import Base
from szurubooru.db.post import PostTag
class TagSuggestion(Base): class TagSuggestion(Base):
__tablename__ = 'tag_suggestion' __tablename__ = 'tag_suggestion'
@ -52,5 +54,11 @@ class Tag(Base):
primaryjoin=tag_id == TagImplication.parent_id, primaryjoin=tag_id == TagImplication.parent_id,
secondaryjoin=tag_id == TagImplication.child_id) secondaryjoin=tag_id == TagImplication.child_id)
# TODO: wire this post_count = column_property(
post_count = Column('auto_post_count', Integer, nullable=False, default=0) select(
[func.count('Post.post_id')],
PostTag.tag_id == tag_id
) \
.correlate('Tag') \
.label('post_count')
)

View file

@ -20,7 +20,6 @@ def upgrade():
sa.Column('category', sa.String(length=32), nullable=False), sa.Column('category', sa.String(length=32), nullable=False),
sa.Column('creation_time', sa.DateTime(), nullable=False), sa.Column('creation_time', sa.DateTime(), nullable=False),
sa.Column('last_edit_time', sa.DateTime(), nullable=True), sa.Column('last_edit_time', sa.DateTime(), nullable=True),
sa.Column('auto_post_count', sa.Integer(), nullable=False),
sa.PrimaryKeyConstraint('id')) sa.PrimaryKeyConstraint('id'))
op.create_table( op.create_table(

View file

@ -4,11 +4,7 @@ from szurubooru import db
from szurubooru.search.base_search_config import BaseSearchConfig from szurubooru.search.base_search_config import BaseSearchConfig
class TagSearchConfig(BaseSearchConfig): class TagSearchConfig(BaseSearchConfig):
def __init__(self):
self._session = None
def create_query(self, session): def create_query(self, session):
self._session = session
return session.query(db.Tag) return session.query(db.Tag)
def finalize_query(self, query): def finalize_query(self, query):
@ -65,7 +61,7 @@ class TagSearchConfig(BaseSearchConfig):
str_filter = self._create_str_filter(db.TagName.name) str_filter = self._create_str_filter(db.TagName.name)
return query.filter( return query.filter(
db.Tag.tag_id.in_( db.Tag.tag_id.in_(
str_filter(self._session.query(db.TagName.tag_id), criterion))) str_filter(query.session.query(db.TagName.tag_id), criterion)))
def _suggestion_count_filter(self, query, criterion): def _suggestion_count_filter(self, query, criterion):
return query.filter( return query.filter(

View file

@ -48,10 +48,11 @@ def test_removing_tags_without_privileges(test_ctx):
'tag') 'tag')
assert test_ctx.session.query(db.Tag).count() == 1 assert test_ctx.session.query(db.Tag).count() == 1
def test_removing_tags_with_usages(test_ctx): def test_removing_tags_with_usages(test_ctx, post_factory):
tag = test_ctx.tag_factory(names=['tag']) tag = test_ctx.tag_factory(names=['tag'])
tag.post_count = 5 post = post_factory()
test_ctx.session.add(tag) post.tags.append(tag)
test_ctx.session.add_all([tag, post])
test_ctx.session.commit() test_ctx.session.commit()
with pytest.raises(tags.TagIsInUseError): with pytest.raises(tags.TagIsInUseError):
test_ctx.api.delete( test_ctx.api.delete(

View file

@ -71,3 +71,18 @@ def tag_factory():
tag.creation_time = datetime.datetime(1996, 1, 1) tag.creation_time = datetime.datetime(1996, 1, 1)
return tag return tag
return factory return factory
@pytest.fixture
def post_factory():
def factory(
safety=db.Post.SAFETY_SAFE,
type=db.Post.TYPE_IMAGE,
checksum='...'):
post = db.Post()
post.safety = safety
post.type = type
post.checksum = checksum
post.flags = 0
post.creation_time = datetime.datetime(1996, 1, 1)
return post
return factory

View file

@ -0,0 +1,91 @@
from datetime import datetime
from szurubooru import db
def test_saving_post(session, post_factory, user_factory, tag_factory):
user = user_factory()
tag1 = tag_factory()
tag2 = tag_factory()
related_post1 = post_factory()
related_post2 = post_factory()
post = db.Post()
post.safety = 'safety'
post.type = 'type'
post.checksum = 'deadbeef'
post.creation_time = datetime(1997, 1, 1)
post.last_edit_time = datetime(1998, 1, 1)
session.add_all([user, tag1, tag2, related_post1, related_post2, post])
post.user = user
post.tags.append(tag1)
post.tags.append(tag2)
post.relations.append(related_post1)
post.relations.append(related_post2)
session.commit()
post = session.query(db.Post).filter(db.Post.post_id == post.post_id).one()
assert not session.dirty
assert post.user.user_id is not None
assert post.safety == 'safety'
assert post.type == 'type'
assert post.checksum == 'deadbeef'
assert post.creation_time == datetime(1997, 1, 1)
assert post.last_edit_time == datetime(1998, 1, 1)
assert len(post.relations) == 2
def test_cascade_deletions(session, post_factory, user_factory, tag_factory):
user = user_factory()
tag1 = tag_factory()
tag2 = tag_factory()
related_post1 = post_factory()
related_post2 = post_factory()
post = post_factory()
session.add_all([user, tag1, tag2, post, related_post1, related_post2])
session.flush()
post.user = user
post.tags.append(tag1)
post.tags.append(tag2)
post.relations.append(related_post1)
post.relations.append(related_post2)
session.flush()
assert not session.dirty
assert post.user.user_id is not None
assert len(post.relations) == 2
assert session.query(db.User).count() == 1
assert session.query(db.Tag).count() == 2
assert session.query(db.Post).count() == 3
assert session.query(db.PostTag).count() == 2
assert session.query(db.PostRelation).count() == 2
session.delete(post)
session.commit()
assert not session.dirty
assert session.query(db.User).count() == 1
assert session.query(db.Tag).count() == 2
assert session.query(db.Post).count() == 2
assert session.query(db.PostTag).count() == 0
assert session.query(db.PostRelation).count() == 0
def test_tracking_tag_count(session, post_factory, tag_factory):
post = post_factory()
tag1 = tag_factory()
tag2 = tag_factory()
session.add_all([tag1, tag2, post])
session.flush()
post.tags.append(tag1)
post.tags.append(tag2)
session.commit()
assert len(post.tags) == 2
assert post.tag_count == 2
session.delete(tag1)
session.commit()
session.refresh(post)
assert len(post.tags) == 1
assert post.tag_count == 1
session.delete(tag2)
session.commit()
session.refresh(post)
assert len(post.tags) == 0
assert post.tag_count == 0

View file

@ -13,7 +13,6 @@ def test_saving_tag(session, tag_factory):
tag.category = 'category' tag.category = 'category'
tag.creation_time = datetime(1997, 1, 1) tag.creation_time = datetime(1997, 1, 1)
tag.last_edit_time = datetime(1998, 1, 1) tag.last_edit_time = datetime(1998, 1, 1)
tag.post_count = 1
session.add_all([ session.add_all([
tag, suggested_tag1, suggested_tag2, implied_tag1, implied_tag2]) tag, suggested_tag1, suggested_tag2, implied_tag1, implied_tag2])
session.commit() session.commit()
@ -37,7 +36,6 @@ def test_saving_tag(session, tag_factory):
assert tag.category == 'category' assert tag.category == 'category'
assert tag.creation_time == datetime(1997, 1, 1) assert tag.creation_time == datetime(1997, 1, 1)
assert tag.last_edit_time == datetime(1998, 1, 1) assert tag.last_edit_time == datetime(1998, 1, 1)
assert tag.post_count == 1
assert [relation.names[0].name for relation in tag.suggestions] \ assert [relation.names[0].name for relation in tag.suggestions] \
== ['suggested1', 'suggested2'] == ['suggested1', 'suggested2']
assert [relation.names[0].name for relation in tag.implications] \ assert [relation.names[0].name for relation in tag.implications] \
@ -77,3 +75,24 @@ def test_cascade_deletions(session, tag_factory):
assert session.query(db.TagName).count() == 4 assert session.query(db.TagName).count() == 4
assert session.query(db.TagImplication).count() == 0 assert session.query(db.TagImplication).count() == 0
assert session.query(db.TagSuggestion).count() == 0 assert session.query(db.TagSuggestion).count() == 0
def test_tracking_post_count(session, post_factory, tag_factory):
tag = tag_factory()
post1 = post_factory()
post2 = post_factory()
session.add_all([tag, post1, post2])
session.flush()
post1.tags.append(tag)
post2.tags.append(tag)
session.commit()
assert len(post1.tags) == 1
assert len(post2.tags) == 1
assert tag.post_count == 2
session.delete(post1)
session.commit()
session.refresh(tag)
assert tag.post_count == 1
session.delete(post2)
session.commit()
session.refresh(tag)
assert tag.post_count == 0