server/favorites: favoriting sets score to 1
This commit is contained in:
parent
519f606a39
commit
16d4d3ca68
5 changed files with 23 additions and 10 deletions
|
@ -1,11 +1,14 @@
|
|||
import datetime
|
||||
from szurubooru import db
|
||||
from szurubooru import db, errors
|
||||
from szurubooru.func import scores
|
||||
|
||||
class InvalidFavoriteTargetError(errors.ValidationError): pass
|
||||
|
||||
def _get_table_info(entity):
|
||||
resource_type, _, _ = db.util.get_resource_info(entity)
|
||||
if resource_type == 'post':
|
||||
return db.PostFavorite, lambda table: table.post_id
|
||||
assert False
|
||||
raise InvalidFavoriteTargetError()
|
||||
|
||||
def _get_fav_entity(entity, user):
|
||||
return db.util.get_aux_entity(db.session, _get_table_info, entity, user)
|
||||
|
@ -19,6 +22,10 @@ def unset_favorite(entity, user):
|
|||
db.session.delete(fav_entity)
|
||||
|
||||
def set_favorite(entity, user):
|
||||
try:
|
||||
scores.set_score(entity, user, 1)
|
||||
except scores.InvalidScoreTargetError:
|
||||
pass
|
||||
fav_entity = _get_fav_entity(entity, user)
|
||||
if not fav_entity:
|
||||
table, get_column = _get_table_info(entity)
|
||||
|
|
|
@ -1,7 +1,8 @@
|
|||
import datetime
|
||||
from szurubooru import db, errors
|
||||
|
||||
class InvalidScoreError(errors.ValidationError): pass
|
||||
class InvalidScoreTargetError(errors.ValidationError): pass
|
||||
class InvalidScoreValueError(errors.ValidationError): pass
|
||||
|
||||
def _get_table_info(entity):
|
||||
resource_type, _, _ = db.util.get_resource_info(entity)
|
||||
|
@ -9,7 +10,7 @@ def _get_table_info(entity):
|
|||
return db.PostScore, lambda table: table.post_id
|
||||
elif resource_type == 'comment':
|
||||
return db.CommentScore, lambda table: table.comment_id
|
||||
assert False
|
||||
raise InvalidScoreTargetError()
|
||||
|
||||
def _get_score_entity(entity, user):
|
||||
return db.util.get_aux_entity(db.session, _get_table_info, entity, user)
|
||||
|
@ -31,7 +32,7 @@ def set_score(entity, user, score):
|
|||
delete_score(entity, user)
|
||||
return
|
||||
if score not in (-1, 1):
|
||||
raise InvalidScoreError(
|
||||
raise InvalidScoreValueError(
|
||||
'Score %r is invalid. Valid scores: %r.' % (score, (-1, 1)))
|
||||
score_entity = _get_score_entity(entity, user)
|
||||
if score_entity:
|
||||
|
|
|
@ -110,8 +110,8 @@ def test_ratings_from_multiple_users(test_ctx, fake_datetime):
|
|||
@pytest.mark.parametrize('input,expected_exception', [
|
||||
({'score': None}, errors.ValidationError),
|
||||
({'score': ''}, errors.ValidationError),
|
||||
({'score': -2}, scores.InvalidScoreError),
|
||||
({'score': 2}, scores.InvalidScoreError),
|
||||
({'score': -2}, scores.InvalidScoreValueError),
|
||||
({'score': 2}, scores.InvalidScoreValueError),
|
||||
({'score': [1]}, errors.ValidationError),
|
||||
])
|
||||
def test_trying_to_pass_invalid_input(test_ctx, input, expected_exception):
|
||||
|
|
|
@ -23,10 +23,11 @@ def test_ctx(
|
|||
ret.api = api.PostFavoriteApi()
|
||||
return ret
|
||||
|
||||
def test_simple_rating(test_ctx, fake_datetime):
|
||||
def test_adding_to_favorites(test_ctx, fake_datetime):
|
||||
post = test_ctx.post_factory()
|
||||
db.session.add(post)
|
||||
db.session.commit()
|
||||
assert post.score == 0
|
||||
with fake_datetime('1997-12-01'):
|
||||
result = test_ctx.api.post(
|
||||
test_ctx.context_factory(user=test_ctx.user_factory()),
|
||||
|
@ -37,21 +38,25 @@ def test_simple_rating(test_ctx, fake_datetime):
|
|||
assert db.session.query(db.PostFavorite).count() == 1
|
||||
assert post is not None
|
||||
assert post.favorite_count == 1
|
||||
assert post.score == 1
|
||||
|
||||
def test_removing_from_favorites(test_ctx, fake_datetime):
|
||||
user = test_ctx.user_factory()
|
||||
post = test_ctx.post_factory()
|
||||
db.session.add(post)
|
||||
db.session.commit()
|
||||
assert post.score == 0
|
||||
with fake_datetime('1997-12-01'):
|
||||
result = test_ctx.api.post(
|
||||
test_ctx.context_factory(user=user),
|
||||
post.post_id)
|
||||
assert post.score == 1
|
||||
with fake_datetime('1997-12-02'):
|
||||
result = test_ctx.api.delete(
|
||||
test_ctx.context_factory(user=user),
|
||||
post.post_id)
|
||||
post = db.session.query(db.Post).one()
|
||||
assert post.score == 1
|
||||
assert db.session.query(db.PostFavorite).count() == 0
|
||||
assert post.favorite_count == 0
|
||||
|
||||
|
|
|
@ -106,8 +106,8 @@ def test_ratings_from_multiple_users(test_ctx, fake_datetime):
|
|||
@pytest.mark.parametrize('input,expected_exception', [
|
||||
({'score': None}, errors.ValidationError),
|
||||
({'score': ''}, errors.ValidationError),
|
||||
({'score': -2}, scores.InvalidScoreError),
|
||||
({'score': 2}, scores.InvalidScoreError),
|
||||
({'score': -2}, scores.InvalidScoreValueError),
|
||||
({'score': 2}, scores.InvalidScoreValueError),
|
||||
({'score': [1]}, errors.ValidationError),
|
||||
])
|
||||
def test_trying_to_pass_invalid_input(test_ctx, input, expected_exception):
|
||||
|
|
Loading…
Reference in a new issue