server: lint
This commit is contained in:
parent
fea9a94945
commit
4bc58a3c95
42 changed files with 192 additions and 169 deletions
|
@ -27,7 +27,7 @@ def _serialize(
|
|||
|
||||
@rest.routes.get('/comments/?')
|
||||
def get_comments(
|
||||
ctx: rest.Context, _params: Dict[str, str]={}) -> rest.Response:
|
||||
ctx: rest.Context, _params: Dict[str, str] = {}) -> rest.Response:
|
||||
auth.verify_privilege(ctx.user, 'comments:list')
|
||||
return _search_executor.execute_and_serialize(
|
||||
ctx, lambda comment: _serialize(ctx, comment))
|
||||
|
@ -35,7 +35,7 @@ def get_comments(
|
|||
|
||||
@rest.routes.post('/comments/?')
|
||||
def create_comment(
|
||||
ctx: rest.Context, _params: Dict[str, str]={}) -> rest.Response:
|
||||
ctx: rest.Context, _params: Dict[str, str] = {}) -> rest.Response:
|
||||
auth.verify_privilege(ctx.user, 'comments:create')
|
||||
text = ctx.get_param_as_string('text')
|
||||
post_id = ctx.get_param_as_int('postId')
|
||||
|
|
|
@ -28,7 +28,7 @@ def _get_disk_usage() -> int:
|
|||
|
||||
@rest.routes.get('/info/?')
|
||||
def get_info(
|
||||
ctx: rest.Context, _params: Dict[str, str]={}) -> rest.Response:
|
||||
ctx: rest.Context, _params: Dict[str, str] = {}) -> rest.Response:
|
||||
post_feature = posts.try_get_current_post_feature()
|
||||
return {
|
||||
'postCount': posts.get_post_count(),
|
||||
|
|
|
@ -5,10 +5,10 @@ from hashlib import md5
|
|||
|
||||
|
||||
MAIL_SUBJECT = 'Password reset for {name}'
|
||||
MAIL_BODY = \
|
||||
'You (or someone else) requested to reset your password on {name}.\n' \
|
||||
'If you wish to proceed, click this link: {url}\n' \
|
||||
'Otherwise, please ignore this email.'
|
||||
MAIL_BODY = (
|
||||
'You (or someone else) requested to reset your password on {name}.\n'
|
||||
'If you wish to proceed, click this link: {url}\n'
|
||||
'Otherwise, please ignore this email.')
|
||||
|
||||
|
||||
@rest.routes.get('/password-reset/(?P<user_name>[^/]+)/?')
|
||||
|
|
|
@ -31,7 +31,7 @@ def _serialize_post(
|
|||
|
||||
@rest.routes.get('/posts/?')
|
||||
def get_posts(
|
||||
ctx: rest.Context, _params: Dict[str, str]={}) -> rest.Response:
|
||||
ctx: rest.Context, _params: Dict[str, str] = {}) -> rest.Response:
|
||||
auth.verify_privilege(ctx.user, 'posts:list')
|
||||
_search_executor_config.user = ctx.user
|
||||
return _search_executor.execute_and_serialize(
|
||||
|
@ -40,7 +40,7 @@ def get_posts(
|
|||
|
||||
@rest.routes.post('/posts/?')
|
||||
def create_post(
|
||||
ctx: rest.Context, _params: Dict[str, str]={}) -> rest.Response:
|
||||
ctx: rest.Context, _params: Dict[str, str] = {}) -> rest.Response:
|
||||
anonymous = ctx.get_param_as_bool('anonymous', default=False)
|
||||
if anonymous:
|
||||
auth.verify_privilege(ctx.user, 'posts:create:anonymous')
|
||||
|
@ -144,7 +144,7 @@ def delete_post(ctx: rest.Context, params: Dict[str, str]) -> rest.Response:
|
|||
|
||||
@rest.routes.post('/post-merge/?')
|
||||
def merge_posts(
|
||||
ctx: rest.Context, _params: Dict[str, str]={}) -> rest.Response:
|
||||
ctx: rest.Context, _params: Dict[str, str] = {}) -> rest.Response:
|
||||
source_post_id = ctx.get_param_as_int('remove')
|
||||
target_post_id = ctx.get_param_as_int('mergeTo')
|
||||
source_post = posts.get_post_by_id(source_post_id)
|
||||
|
@ -162,14 +162,14 @@ def merge_posts(
|
|||
|
||||
@rest.routes.get('/featured-post/?')
|
||||
def get_featured_post(
|
||||
ctx: rest.Context, _params: Dict[str, str]={}) -> rest.Response:
|
||||
ctx: rest.Context, _params: Dict[str, str] = {}) -> rest.Response:
|
||||
post = posts.try_get_featured_post()
|
||||
return _serialize_post(ctx, post)
|
||||
|
||||
|
||||
@rest.routes.post('/featured-post/?')
|
||||
def set_featured_post(
|
||||
ctx: rest.Context, _params: Dict[str, str]={}) -> rest.Response:
|
||||
ctx: rest.Context, _params: Dict[str, str] = {}) -> rest.Response:
|
||||
auth.verify_privilege(ctx.user, 'posts:feature')
|
||||
post_id = ctx.get_param_as_int('id')
|
||||
post = posts.get_post_by_id(post_id)
|
||||
|
@ -235,7 +235,7 @@ def get_posts_around(
|
|||
|
||||
@rest.routes.post('/posts/reverse-search/?')
|
||||
def get_posts_by_image(
|
||||
ctx: rest.Context, _params: Dict[str, str]={}) -> rest.Response:
|
||||
ctx: rest.Context, _params: Dict[str, str] = {}) -> rest.Response:
|
||||
auth.verify_privilege(ctx.user, 'posts:reverse_search')
|
||||
content = ctx.get_file('content')
|
||||
|
||||
|
|
|
@ -8,7 +8,7 @@ _search_executor = search.Executor(search.configs.SnapshotSearchConfig())
|
|||
|
||||
@rest.routes.get('/snapshots/?')
|
||||
def get_snapshots(
|
||||
ctx: rest.Context, _params: Dict[str, str]={}) -> rest.Response:
|
||||
ctx: rest.Context, _params: Dict[str, str] = {}) -> rest.Response:
|
||||
auth.verify_privilege(ctx.user, 'snapshots:list')
|
||||
return _search_executor.execute_and_serialize(
|
||||
ctx, lambda snapshot: snapshots.serialize_snapshot(snapshot, ctx.user))
|
||||
|
|
|
@ -28,14 +28,15 @@ def _create_if_needed(tag_names: List[str], user: model.User) -> None:
|
|||
|
||||
|
||||
@rest.routes.get('/tags/?')
|
||||
def get_tags(ctx: rest.Context, _params: Dict[str, str]={}) -> rest.Response:
|
||||
def get_tags(ctx: rest.Context, _params: Dict[str, str] = {}) -> rest.Response:
|
||||
auth.verify_privilege(ctx.user, 'tags:list')
|
||||
return _search_executor.execute_and_serialize(
|
||||
ctx, lambda tag: _serialize(ctx, tag))
|
||||
|
||||
|
||||
@rest.routes.post('/tags/?')
|
||||
def create_tag(ctx: rest.Context, _params: Dict[str, str]={}) -> rest.Response:
|
||||
def create_tag(
|
||||
ctx: rest.Context, _params: Dict[str, str] = {}) -> rest.Response:
|
||||
auth.verify_privilege(ctx.user, 'tags:create')
|
||||
|
||||
names = ctx.get_param_as_string_list('names')
|
||||
|
@ -112,7 +113,7 @@ def delete_tag(ctx: rest.Context, params: Dict[str, str]) -> rest.Response:
|
|||
|
||||
@rest.routes.post('/tag-merge/?')
|
||||
def merge_tags(
|
||||
ctx: rest.Context, _params: Dict[str, str]={}) -> rest.Response:
|
||||
ctx: rest.Context, _params: Dict[str, str] = {}) -> rest.Response:
|
||||
source_tag_name = ctx.get_param_as_string('remove')
|
||||
target_tag_name = ctx.get_param_as_string('mergeTo')
|
||||
source_tag = tags.get_tag_by_name(source_tag_name)
|
||||
|
|
|
@ -12,7 +12,7 @@ def _serialize(
|
|||
|
||||
@rest.routes.get('/tag-categories/?')
|
||||
def get_tag_categories(
|
||||
ctx: rest.Context, _params: Dict[str, str]={}) -> rest.Response:
|
||||
ctx: rest.Context, _params: Dict[str, str] = {}) -> rest.Response:
|
||||
auth.verify_privilege(ctx.user, 'tag_categories:list')
|
||||
categories = tag_categories.get_all_categories()
|
||||
return {
|
||||
|
@ -22,7 +22,7 @@ def get_tag_categories(
|
|||
|
||||
@rest.routes.post('/tag-categories/?')
|
||||
def create_tag_category(
|
||||
ctx: rest.Context, _params: Dict[str, str]={}) -> rest.Response:
|
||||
ctx: rest.Context, _params: Dict[str, str] = {}) -> rest.Response:
|
||||
auth.verify_privilege(ctx.user, 'tag_categories:create')
|
||||
name = ctx.get_param_as_string('name')
|
||||
color = ctx.get_param_as_string('color')
|
||||
|
|
|
@ -5,7 +5,7 @@ from szurubooru.func import auth, file_uploads
|
|||
|
||||
@rest.routes.post('/uploads/?')
|
||||
def create_temporary_file(
|
||||
ctx: rest.Context, _params: Dict[str, str]={}) -> rest.Response:
|
||||
ctx: rest.Context, _params: Dict[str, str] = {}) -> rest.Response:
|
||||
auth.verify_privilege(ctx.user, 'uploads:create')
|
||||
content = ctx.get_file('content', allow_tokens=False)
|
||||
token = file_uploads.save(content)
|
||||
|
|
|
@ -16,7 +16,8 @@ def _serialize(
|
|||
|
||||
|
||||
@rest.routes.get('/users/?')
|
||||
def get_users(ctx: rest.Context, _params: Dict[str, str]={}) -> rest.Response:
|
||||
def get_users(
|
||||
ctx: rest.Context, _params: Dict[str, str] = {}) -> rest.Response:
|
||||
auth.verify_privilege(ctx.user, 'users:list')
|
||||
return _search_executor.execute_and_serialize(
|
||||
ctx, lambda user: _serialize(ctx, user))
|
||||
|
@ -24,7 +25,7 @@ def get_users(ctx: rest.Context, _params: Dict[str, str]={}) -> rest.Response:
|
|||
|
||||
@rest.routes.post('/users/?')
|
||||
def create_user(
|
||||
ctx: rest.Context, _params: Dict[str, str]={}) -> rest.Response:
|
||||
ctx: rest.Context, _params: Dict[str, str] = {}) -> rest.Response:
|
||||
auth.verify_privilege(ctx.user, 'users:create')
|
||||
name = ctx.get_param_as_string('name')
|
||||
password = ctx.get_param_as_string('password')
|
||||
|
|
|
@ -4,8 +4,8 @@ from typing import Dict
|
|||
class BaseError(RuntimeError):
|
||||
def __init__(
|
||||
self,
|
||||
message: str='Unknown error',
|
||||
extra_fields: Dict[str, str]=None) -> None:
|
||||
message: str = 'Unknown error',
|
||||
extra_fields: Dict[str, str] = None) -> None:
|
||||
super().__init__(message)
|
||||
self.extra_fields = extra_fields
|
||||
|
||||
|
|
|
@ -21,9 +21,9 @@ class LruCache:
|
|||
i
|
||||
for i, v in enumerate(self.item_list)
|
||||
if v.key == item.key)
|
||||
self.item_list[:] \
|
||||
= self.item_list[:item_index] \
|
||||
+ self.item_list[item_index + 1:]
|
||||
self.item_list[:] = (
|
||||
self.item_list[:item_index] +
|
||||
self.item_list[item_index + 1:])
|
||||
self.item_list.insert(0, item)
|
||||
else:
|
||||
if len(self.item_list) > self.length:
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
from datetime import datetime
|
||||
from typing import Any, Optional, List, Dict, Callable
|
||||
from szurubooru import db, model, errors, rest
|
||||
from szurubooru.func import users, scores, util, serialization
|
||||
from szurubooru.func import users, scores, serialization
|
||||
|
||||
|
||||
class InvalidCommentIdError(errors.ValidationError):
|
||||
|
@ -65,7 +65,7 @@ class CommentSerializer(serialization.BaseSerializer):
|
|||
def serialize_comment(
|
||||
comment: model.Comment,
|
||||
auth_user: model.User,
|
||||
options: List[str]=[]) -> rest.Response:
|
||||
options: List[str] = []) -> rest.Response:
|
||||
if comment is None:
|
||||
return None
|
||||
return CommentSerializer(comment, auth_user).serialize(options)
|
||||
|
@ -73,10 +73,11 @@ def serialize_comment(
|
|||
|
||||
def try_get_comment_by_id(comment_id: int) -> Optional[model.Comment]:
|
||||
comment_id = int(comment_id)
|
||||
return db.session \
|
||||
.query(model.Comment) \
|
||||
.filter(model.Comment.comment_id == comment_id) \
|
||||
.one_or_none()
|
||||
return (
|
||||
db.session
|
||||
.query(model.Comment)
|
||||
.filter(model.Comment.comment_id == comment_id)
|
||||
.one_or_none())
|
||||
|
||||
|
||||
def get_comment_by_id(comment_id: int) -> model.Comment:
|
||||
|
|
|
@ -99,7 +99,7 @@ def _normalize_and_threshold(
|
|||
def _compute_grid_points(
|
||||
image: NpMatrix,
|
||||
n: float,
|
||||
window: Window=None) -> Tuple[NpMatrix, NpMatrix]:
|
||||
window: Window = None) -> Tuple[NpMatrix, NpMatrix]:
|
||||
if window is None:
|
||||
window = ((0, image.shape[0]), (0, image.shape[1]))
|
||||
x_coords = np.linspace(window[0][0], window[0][1], n + 2, dtype=int)[1:-1]
|
||||
|
@ -219,7 +219,7 @@ def _max_contrast(array: NpMatrix) -> None:
|
|||
def _normalized_distance(
|
||||
target_array: NpMatrix,
|
||||
vec: NpMatrix,
|
||||
nan_value: float=1.0) -> List[float]:
|
||||
nan_value: float = 1.0) -> List[float]:
|
||||
target_array = target_array.astype(int)
|
||||
vec = vec.astype(int)
|
||||
topvec = np.linalg.norm(vec - target_array, axis=1)
|
||||
|
|
|
@ -11,8 +11,8 @@ from szurubooru.func import mime, util
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
_SCALE_FIT_FMT = \
|
||||
r'scale=iw*max({width}/iw\,{height}/ih):ih*max({width}/iw\,{height}/ih)'
|
||||
_SCALE_FIT_FMT = (
|
||||
r'scale=iw*max({width}/iw\,{height}/ih):ih*max({width}/iw\,{height}/ih)')
|
||||
|
||||
|
||||
class Image:
|
||||
|
@ -77,7 +77,7 @@ class Image:
|
|||
'-',
|
||||
])
|
||||
|
||||
def _execute(self, cli: List[str], program: str='ffmpeg') -> bytes:
|
||||
def _execute(self, cli: List[str], program: str = 'ffmpeg') -> bytes:
|
||||
extension = mime.get_extension(mime.get_mime_type(self.content))
|
||||
assert extension
|
||||
with util.create_temp_file(suffix='.' + extension) as handle:
|
||||
|
|
|
@ -7,10 +7,10 @@ from szurubooru.func import (
|
|||
mime, images, files, image_hash, serialization)
|
||||
|
||||
|
||||
EMPTY_PIXEL = \
|
||||
b'\x47\x49\x46\x38\x39\x61\x01\x00\x01\x00\x80\x01\x00\x00\x00\x00' \
|
||||
b'\xff\xff\xff\x21\xf9\x04\x01\x00\x00\x01\x00\x2c\x00\x00\x00\x00' \
|
||||
b'\x01\x00\x01\x00\x00\x02\x02\x4c\x01\x00\x3b'
|
||||
EMPTY_PIXEL = (
|
||||
b'\x47\x49\x46\x38\x39\x61\x01\x00\x01\x00\x80\x01\x00\x00\x00\x00'
|
||||
b'\xff\xff\xff\x21\xf9\x04\x01\x00\x00\x01\x00\x2c\x00\x00\x00\x00'
|
||||
b'\x01\x00\x01\x00\x00\x02\x02\x4c\x01\x00\x3b')
|
||||
|
||||
|
||||
class PostNotFoundError(errors.NotFoundError):
|
||||
|
@ -283,7 +283,7 @@ class PostSerializer(serialization.BaseSerializer):
|
|||
def serialize_post(
|
||||
post: Optional[model.Post],
|
||||
auth_user: model.User,
|
||||
options: List[str]=[]) -> Optional[rest.Response]:
|
||||
options: List[str] = []) -> Optional[rest.Response]:
|
||||
if not post:
|
||||
return None
|
||||
return PostSerializer(post, auth_user).serialize(options)
|
||||
|
@ -300,10 +300,11 @@ def get_post_count() -> int:
|
|||
|
||||
|
||||
def try_get_post_by_id(post_id: int) -> Optional[model.Post]:
|
||||
return db.session \
|
||||
.query(model.Post) \
|
||||
.filter(model.Post.post_id == post_id) \
|
||||
.one_or_none()
|
||||
return (
|
||||
db.session
|
||||
.query(model.Post)
|
||||
.filter(model.Post.post_id == post_id)
|
||||
.one_or_none())
|
||||
|
||||
|
||||
def get_post_by_id(post_id: int) -> model.Post:
|
||||
|
@ -314,10 +315,11 @@ def get_post_by_id(post_id: int) -> model.Post:
|
|||
|
||||
|
||||
def try_get_current_post_feature() -> Optional[model.PostFeature]:
|
||||
return db.session \
|
||||
.query(model.PostFeature) \
|
||||
.order_by(model.PostFeature.time.desc()) \
|
||||
.first()
|
||||
return (
|
||||
db.session
|
||||
.query(model.PostFeature)
|
||||
.order_by(model.PostFeature.time.desc())
|
||||
.first())
|
||||
|
||||
|
||||
def try_get_featured_post() -> Optional[model.Post]:
|
||||
|
@ -426,11 +428,12 @@ def update_post_content(post: model.Post, content: Optional[bytes]) -> None:
|
|||
'Unhandled file type: %r' % post.mime_type)
|
||||
|
||||
post.checksum = util.get_sha1(content)
|
||||
other_post = db.session \
|
||||
.query(model.Post) \
|
||||
.filter(model.Post.checksum == post.checksum) \
|
||||
.filter(model.Post.post_id != post.post_id) \
|
||||
.one_or_none()
|
||||
other_post = (
|
||||
db.session
|
||||
.query(model.Post)
|
||||
.filter(model.Post.checksum == post.checksum)
|
||||
.filter(model.Post.post_id != post.post_id)
|
||||
.one_or_none())
|
||||
if other_post \
|
||||
and other_post.post_id \
|
||||
and other_post.post_id != post.post_id:
|
||||
|
@ -452,7 +455,7 @@ def update_post_content(post: model.Post, content: Optional[bytes]) -> None:
|
|||
|
||||
|
||||
def update_post_thumbnail(
|
||||
post: model.Post, content: Optional[bytes]=None) -> None:
|
||||
post: model.Post, content: Optional[bytes] = None) -> None:
|
||||
assert post
|
||||
setattr(post, '__thumbnail', content)
|
||||
|
||||
|
@ -492,10 +495,11 @@ def update_post_relations(post: model.Post, new_post_ids: List[int]) -> None:
|
|||
old_posts = post.relations
|
||||
old_post_ids = [int(p.post_id) for p in old_posts]
|
||||
if new_post_ids:
|
||||
new_posts = db.session \
|
||||
.query(model.Post) \
|
||||
.filter(model.Post.post_id.in_(new_post_ids)) \
|
||||
.all()
|
||||
new_posts = (
|
||||
db.session
|
||||
.query(model.Post)
|
||||
.filter(model.Post.post_id.in_(new_post_ids))
|
||||
.all())
|
||||
else:
|
||||
new_posts = []
|
||||
if len(new_posts) != len(new_post_ids):
|
||||
|
@ -673,10 +677,11 @@ def merge_posts(
|
|||
|
||||
def search_by_image_exact(image_content: bytes) -> Optional[model.Post]:
|
||||
checksum = util.get_sha1(image_content)
|
||||
return db.session \
|
||||
.query(model.Post) \
|
||||
.filter(model.Post.checksum == checksum) \
|
||||
.one_or_none()
|
||||
return (
|
||||
db.session
|
||||
.query(model.Post)
|
||||
.filter(model.Post.checksum == checksum)
|
||||
.one_or_none())
|
||||
|
||||
|
||||
def search_by_image(image_content: bytes) -> List[PostLookalike]:
|
||||
|
|
|
@ -39,11 +39,12 @@ def get_score(entity: model.Base, user: model.User) -> int:
|
|||
assert entity
|
||||
assert user
|
||||
table, get_column = _get_table_info(entity)
|
||||
row = db.session \
|
||||
.query(table.score) \
|
||||
.filter(get_column(table) == get_column(entity)) \
|
||||
.filter(table.user_id == user.user_id) \
|
||||
.one_or_none()
|
||||
row = (
|
||||
db.session
|
||||
.query(table.score)
|
||||
.filter(get_column(table) == get_column(entity))
|
||||
.filter(table.user_id == user.user_id)
|
||||
.one_or_none())
|
||||
return row[0] if row else 0
|
||||
|
||||
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
from typing import Any, Optional, List, Dict, Callable
|
||||
from szurubooru import db, model, rest, errors
|
||||
from typing import Any, List, Dict, Callable
|
||||
from szurubooru import model, rest, errors
|
||||
|
||||
|
||||
def get_serialization_options(ctx: rest.Context) -> List[str]:
|
||||
|
|
|
@ -66,7 +66,7 @@ class TagCategorySerializer(serialization.BaseSerializer):
|
|||
|
||||
def serialize_category(
|
||||
category: Optional[model.TagCategory],
|
||||
options: List[str]=[]) -> Optional[rest.Response]:
|
||||
options: List[str] = []) -> Optional[rest.Response]:
|
||||
if not category:
|
||||
return None
|
||||
return TagCategorySerializer(category).serialize(options)
|
||||
|
@ -113,16 +113,17 @@ def update_category_color(category: model.TagCategory, color: str) -> None:
|
|||
|
||||
|
||||
def try_get_category_by_name(
|
||||
name: str, lock: bool=False) -> Optional[model.TagCategory]:
|
||||
query = db.session \
|
||||
.query(model.TagCategory) \
|
||||
.filter(sa.func.lower(model.TagCategory.name) == name.lower())
|
||||
name: str, lock: bool = False) -> Optional[model.TagCategory]:
|
||||
query = (
|
||||
db.session
|
||||
.query(model.TagCategory)
|
||||
.filter(sa.func.lower(model.TagCategory.name) == name.lower()))
|
||||
if lock:
|
||||
query = query.with_lockmode('update')
|
||||
return query.one_or_none()
|
||||
|
||||
|
||||
def get_category_by_name(name: str, lock: bool=False) -> model.TagCategory:
|
||||
def get_category_by_name(name: str, lock: bool = False) -> model.TagCategory:
|
||||
category = try_get_category_by_name(name, lock)
|
||||
if not category:
|
||||
raise TagCategoryNotFoundError('Tag category %r not found.' % name)
|
||||
|
@ -137,26 +138,29 @@ def get_all_categories() -> List[model.TagCategory]:
|
|||
return db.session.query(model.TagCategory).all()
|
||||
|
||||
|
||||
def try_get_default_category(lock: bool=False) -> Optional[model.TagCategory]:
|
||||
query = db.session \
|
||||
.query(model.TagCategory) \
|
||||
.filter(model.TagCategory.default)
|
||||
def try_get_default_category(
|
||||
lock: bool = False) -> Optional[model.TagCategory]:
|
||||
query = (
|
||||
db.session
|
||||
.query(model.TagCategory)
|
||||
.filter(model.TagCategory.default))
|
||||
if lock:
|
||||
query = query.with_lockmode('update')
|
||||
category = query.first()
|
||||
# if for some reason (e.g. as a result of migration) there's no default
|
||||
# category, get the first record available.
|
||||
if not category:
|
||||
query = db.session \
|
||||
.query(model.TagCategory) \
|
||||
.order_by(model.TagCategory.tag_category_id.asc())
|
||||
query = (
|
||||
db.session
|
||||
.query(model.TagCategory)
|
||||
.order_by(model.TagCategory.tag_category_id.asc()))
|
||||
if lock:
|
||||
query = query.with_lockmode('update')
|
||||
category = query.first()
|
||||
return category
|
||||
|
||||
|
||||
def get_default_category(lock: bool=False) -> model.TagCategory:
|
||||
def get_default_category(lock: bool = False) -> model.TagCategory:
|
||||
category = try_get_default_category(lock)
|
||||
if not category:
|
||||
raise TagCategoryNotFoundError('No tag category created yet.')
|
||||
|
|
|
@ -122,7 +122,7 @@ class TagSerializer(serialization.BaseSerializer):
|
|||
|
||||
|
||||
def serialize_tag(
|
||||
tag: model.Tag, options: List[str]=[]) -> Optional[rest.Response]:
|
||||
tag: model.Tag, options: List[str] = []) -> Optional[rest.Response]:
|
||||
if not tag:
|
||||
return None
|
||||
return TagSerializer(tag).serialize(options)
|
||||
|
@ -209,7 +209,8 @@ def get_tags_by_names(names: List[str]) -> List[model.Tag]:
|
|||
names = util.icase_unique(names)
|
||||
if len(names) == 0:
|
||||
return []
|
||||
return (db.session.query(model.Tag)
|
||||
return (
|
||||
db.session.query(model.Tag)
|
||||
.join(model.TagName)
|
||||
.filter(
|
||||
sa.sql.or_(
|
||||
|
|
|
@ -86,7 +86,7 @@ class UserSerializer(serialization.BaseSerializer):
|
|||
self,
|
||||
user: model.User,
|
||||
auth_user: model.User,
|
||||
force_show_email: bool=False) -> None:
|
||||
force_show_email: bool = False) -> None:
|
||||
self.user = user
|
||||
self.auth_user = auth_user
|
||||
self.force_show_email = force_show_email
|
||||
|
@ -151,8 +151,8 @@ class UserSerializer(serialization.BaseSerializer):
|
|||
def serialize_user(
|
||||
user: Optional[model.User],
|
||||
auth_user: model.User,
|
||||
options: List[str]=[],
|
||||
force_show_email: bool=False) -> Optional[rest.Response]:
|
||||
options: List[str] = [],
|
||||
force_show_email: bool = False) -> Optional[rest.Response]:
|
||||
if not user:
|
||||
return None
|
||||
return UserSerializer(user, auth_user, force_show_email).serialize(options)
|
||||
|
@ -170,10 +170,11 @@ def get_user_count() -> int:
|
|||
|
||||
|
||||
def try_get_user_by_name(name: str) -> Optional[model.User]:
|
||||
return db.session \
|
||||
.query(model.User) \
|
||||
.filter(sa.func.lower(model.User.name) == sa.func.lower(name)) \
|
||||
.one_or_none()
|
||||
return (
|
||||
db.session
|
||||
.query(model.User)
|
||||
.filter(sa.func.lower(model.User.name) == sa.func.lower(name))
|
||||
.one_or_none())
|
||||
|
||||
|
||||
def get_user_by_name(name: str) -> model.User:
|
||||
|
@ -276,7 +277,7 @@ def update_user_rank(
|
|||
def update_user_avatar(
|
||||
user: model.User,
|
||||
avatar_style: str,
|
||||
avatar_content: Optional[bytes]=None) -> None:
|
||||
avatar_content: Optional[bytes] = None) -> None:
|
||||
assert user
|
||||
if avatar_style == 'gravatar':
|
||||
user.avatar_style = user.AVATAR_GRAVATAR
|
||||
|
|
|
@ -2,8 +2,7 @@ import os
|
|||
import hashlib
|
||||
import re
|
||||
import tempfile
|
||||
from typing import (
|
||||
Any, Optional, Union, Tuple, List, Dict, Generator, Callable, TypeVar)
|
||||
from typing import Any, Optional, Union, Tuple, List, Dict, Generator, TypeVar
|
||||
from datetime import datetime, timedelta
|
||||
from contextlib import contextmanager
|
||||
from szurubooru import errors
|
||||
|
|
|
@ -4,7 +4,7 @@ from szurubooru import errors, rest, model
|
|||
def verify_version(
|
||||
entity: model.Base,
|
||||
context: rest.Context,
|
||||
field_name: str='version') -> None:
|
||||
field_name: str = 'version') -> None:
|
||||
actual_version = context.get_param_as_int(field_name)
|
||||
expected_version = entity.version
|
||||
if actual_version != expected_version:
|
||||
|
|
|
@ -27,8 +27,9 @@ def _get_user(ctx: rest.Context) -> Optional[model.User]:
|
|||
credentials.encode('ascii')).decode('utf8').split(':')
|
||||
return _authenticate(username, password)
|
||||
except ValueError as err:
|
||||
msg = 'Basic authentication header value are not properly formed. ' \
|
||||
+ 'Supplied header {0}. Got error: {1}'
|
||||
msg = (
|
||||
'Basic authentication header value are not properly formed. '
|
||||
'Supplied header {0}. Got error: {1}')
|
||||
raise HttpBadRequest(
|
||||
'ValidationError',
|
||||
msg.format(ctx.get_header('Authorization'), str(err)))
|
||||
|
|
|
@ -50,13 +50,14 @@ def upgrade():
|
|||
|
||||
def downgrade():
|
||||
session = sa.orm.session.Session(bind=op.get_bind())
|
||||
default_category = session \
|
||||
.query(TagCategory) \
|
||||
.filter(TagCategory.name == 'default') \
|
||||
.filter(TagCategory.color == 'default') \
|
||||
.filter(TagCategory.version == 1) \
|
||||
.filter(TagCategory.default == True) \
|
||||
.one_or_none()
|
||||
default_category = (
|
||||
session
|
||||
.query(TagCategory)
|
||||
.filter(TagCategory.name == 'default')
|
||||
.filter(TagCategory.color == 'default')
|
||||
.filter(TagCategory.version == 1)
|
||||
.filter(TagCategory.default == 1)
|
||||
.one_or_none())
|
||||
if default_category:
|
||||
session.delete(default_category)
|
||||
session.commit()
|
||||
|
|
|
@ -211,10 +211,11 @@ class Post(Base):
|
|||
|
||||
@property
|
||||
def is_featured(self) -> bool:
|
||||
featured_post = sa.orm.object_session(self) \
|
||||
.query(PostFeature) \
|
||||
.order_by(PostFeature.time.desc()) \
|
||||
.first()
|
||||
featured_post = (
|
||||
sa.orm.object_session(self)
|
||||
.query(PostFeature)
|
||||
.order_by(PostFeature.time.desc())
|
||||
.first())
|
||||
return featured_post and featured_post.post_id == self.post_id
|
||||
|
||||
score = sa.orm.column_property(
|
||||
|
|
|
@ -14,7 +14,7 @@ class TagCategory(Base):
|
|||
'color', sa.Unicode(32), nullable=False, default='#000000')
|
||||
default = sa.Column('default', sa.Boolean, nullable=False, default=False)
|
||||
|
||||
def __init__(self, name: Optional[str]=None) -> None:
|
||||
def __init__(self, name: Optional[str] = None) -> None:
|
||||
self.name = name
|
||||
|
||||
tag_count = sa.orm.column_property(
|
||||
|
|
|
@ -13,9 +13,9 @@ class Context:
|
|||
self,
|
||||
method: str,
|
||||
url: str,
|
||||
headers: Dict[str, str]=None,
|
||||
params: Request=None,
|
||||
files: Dict[str, bytes]=None) -> None:
|
||||
headers: Dict[str, str] = None,
|
||||
params: Request = None,
|
||||
files: Dict[str, bytes] = None) -> None:
|
||||
self.method = method
|
||||
self.url = url
|
||||
self._headers = headers or {}
|
||||
|
@ -34,7 +34,7 @@ class Context:
|
|||
def get_header(self, name: str) -> str:
|
||||
return self._headers.get(name, '')
|
||||
|
||||
def has_file(self, name: str, allow_tokens: bool=True) -> bool:
|
||||
def has_file(self, name: str, allow_tokens: bool = True) -> bool:
|
||||
return (
|
||||
name in self._files or
|
||||
name + 'Url' in self._params or
|
||||
|
@ -43,8 +43,8 @@ class Context:
|
|||
def get_file(
|
||||
self,
|
||||
name: str,
|
||||
default: Union[object, bytes]=MISSING,
|
||||
allow_tokens: bool=True) -> bytes:
|
||||
default: Union[object, bytes] = MISSING,
|
||||
allow_tokens: bool = True) -> bytes:
|
||||
if name in self._files and self._files[name]:
|
||||
return self._files[name]
|
||||
|
||||
|
@ -70,7 +70,7 @@ class Context:
|
|||
def get_param_as_list(
|
||||
self,
|
||||
name: str,
|
||||
default: Union[object, List[Any]]=MISSING) -> List[Any]:
|
||||
default: Union[object, List[Any]] = MISSING) -> List[Any]:
|
||||
if name not in self._params:
|
||||
if default is not MISSING:
|
||||
return cast(List[Any], default)
|
||||
|
@ -89,7 +89,7 @@ class Context:
|
|||
def get_param_as_int_list(
|
||||
self,
|
||||
name: str,
|
||||
default: Union[object, List[int]]=MISSING) -> List[int]:
|
||||
default: Union[object, List[int]] = MISSING) -> List[int]:
|
||||
ret = self.get_param_as_list(name, default)
|
||||
for item in ret:
|
||||
if type(item) is not int:
|
||||
|
@ -100,7 +100,7 @@ class Context:
|
|||
def get_param_as_string_list(
|
||||
self,
|
||||
name: str,
|
||||
default: Union[object, List[str]]=MISSING) -> List[str]:
|
||||
default: Union[object, List[str]] = MISSING) -> List[str]:
|
||||
ret = self.get_param_as_list(name, default)
|
||||
for item in ret:
|
||||
if type(item) is not str:
|
||||
|
@ -111,7 +111,7 @@ class Context:
|
|||
def get_param_as_string(
|
||||
self,
|
||||
name: str,
|
||||
default: Union[object, str]=MISSING) -> str:
|
||||
default: Union[object, str] = MISSING) -> str:
|
||||
if name not in self._params:
|
||||
if default is not MISSING:
|
||||
return cast(str, default)
|
||||
|
@ -135,9 +135,9 @@ class Context:
|
|||
def get_param_as_int(
|
||||
self,
|
||||
name: str,
|
||||
default: Union[object, int]=MISSING,
|
||||
min: Optional[int]=None,
|
||||
max: Optional[int]=None) -> int:
|
||||
default: Union[object, int] = MISSING,
|
||||
min: Optional[int] = None,
|
||||
max: Optional[int] = None) -> int:
|
||||
if name not in self._params:
|
||||
if default is not MISSING:
|
||||
return cast(int, default)
|
||||
|
@ -161,7 +161,7 @@ class Context:
|
|||
def get_param_as_bool(
|
||||
self,
|
||||
name: str,
|
||||
default: Union[object, bool]=MISSING) -> bool:
|
||||
default: Union[object, bool] = MISSING) -> bool:
|
||||
if name not in self._params:
|
||||
if default is not MISSING:
|
||||
return cast(bool, default)
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
from typing import Callable, Type, Dict
|
||||
from typing import Optional, Callable, Type, Dict
|
||||
|
||||
|
||||
error_handlers = {} # pylint: disable=invalid-name
|
||||
|
@ -12,8 +12,8 @@ class BaseHttpError(RuntimeError):
|
|||
self,
|
||||
name: str,
|
||||
description: str,
|
||||
title: str=None,
|
||||
extra_fields: Dict[str, str]=None) -> None:
|
||||
title: Optional[str] = None,
|
||||
extra_fields: Optional[Dict[str, str]] = None) -> None:
|
||||
super().__init__()
|
||||
# error name for programmers
|
||||
self.name = name
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
from typing import Callable
|
||||
from typing import List, Callable
|
||||
from szurubooru.rest.context import Context
|
||||
|
||||
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
from typing import Callable, Dict, Any
|
||||
from typing import Callable, Dict
|
||||
from collections import defaultdict
|
||||
from szurubooru.rest.context import Context, Response
|
||||
|
||||
|
|
|
@ -52,10 +52,11 @@ def _create_score_filter(score: int) -> Filter:
|
|||
user_alias.name, criterion)
|
||||
if negated:
|
||||
expr = ~expr
|
||||
ret = query \
|
||||
.join(score_alias, score_alias.post_id == model.Post.post_id) \
|
||||
.join(user_alias, user_alias.user_id == score_alias.user_id) \
|
||||
.filter(expr)
|
||||
ret = (
|
||||
query
|
||||
.join(score_alias, score_alias.post_id == model.Post.post_id)
|
||||
.join(user_alias, user_alias.user_id == score_alias.user_id)
|
||||
.filter(expr))
|
||||
return ret
|
||||
return wrapper
|
||||
|
||||
|
@ -124,7 +125,8 @@ class PostSearchConfig(BaseSearchConfig):
|
|||
sa.orm.lazyload
|
||||
if disable_eager_loads
|
||||
else sa.orm.subqueryload)
|
||||
return db.session.query(model.Post) \
|
||||
return (
|
||||
db.session.query(model.Post)
|
||||
.options(
|
||||
sa.orm.lazyload('*'),
|
||||
# use config optimized for official client
|
||||
|
@ -141,7 +143,7 @@ class PostSearchConfig(BaseSearchConfig):
|
|||
strategy(model.Post.tags).subqueryload(model.Tag.names),
|
||||
strategy(model.Post.tags).defer(model.Tag.post_count),
|
||||
strategy(model.Post.tags).lazyload(model.Tag.implications),
|
||||
strategy(model.Post.tags).lazyload(model.Tag.suggestions))
|
||||
strategy(model.Post.tags).lazyload(model.Tag.suggestions)))
|
||||
|
||||
def create_count_query(self, _disable_eager_loads: bool) -> SaQuery:
|
||||
return db.session.query(model.Post)
|
||||
|
|
|
@ -14,8 +14,9 @@ class TagSearchConfig(BaseSearchConfig):
|
|||
sa.orm.lazyload
|
||||
if _disable_eager_loads
|
||||
else sa.orm.subqueryload)
|
||||
return db.session.query(model.Tag) \
|
||||
.join(model.TagCategory) \
|
||||
return (
|
||||
db.session.query(model.Tag)
|
||||
.join(model.TagCategory)
|
||||
.options(
|
||||
sa.orm.defer(model.Tag.first_name),
|
||||
sa.orm.defer(model.Tag.suggestion_count),
|
||||
|
@ -23,7 +24,7 @@ class TagSearchConfig(BaseSearchConfig):
|
|||
sa.orm.defer(model.Tag.post_count),
|
||||
strategy(model.Tag.names),
|
||||
strategy(model.Tag.suggestions).joinedload(model.Tag.names),
|
||||
strategy(model.Tag.implications).joinedload(model.Tag.names))
|
||||
strategy(model.Tag.implications).joinedload(model.Tag.names)))
|
||||
|
||||
def create_count_query(self, _disable_eager_loads: bool) -> SaQuery:
|
||||
return db.session.query(model.Tag)
|
||||
|
|
|
@ -69,7 +69,7 @@ def float_transformer(value: str) -> float:
|
|||
def apply_num_criterion_to_column(
|
||||
column: Any,
|
||||
criterion: criteria.BaseCriterion,
|
||||
transformer: Callable[[str], Number]=integer_transformer) -> SaQuery:
|
||||
transformer: Callable[[str], Number] = integer_transformer) -> SaQuery:
|
||||
try:
|
||||
if isinstance(criterion, criteria.PlainCriterion):
|
||||
expr = column == transformer(criterion.value)
|
||||
|
@ -95,7 +95,7 @@ def apply_num_criterion_to_column(
|
|||
|
||||
def create_num_filter(
|
||||
column: Any,
|
||||
transformer: Callable[[str], Number]=integer_transformer) -> SaQuery:
|
||||
transformer: Callable[[str], Number] = integer_transformer) -> SaQuery:
|
||||
def wrapper(
|
||||
query: SaQuery,
|
||||
criterion: Optional[criteria.BaseCriterion],
|
||||
|
@ -111,7 +111,7 @@ def create_num_filter(
|
|||
def apply_str_criterion_to_column(
|
||||
column: SaColumn,
|
||||
criterion: criteria.BaseCriterion,
|
||||
transformer: Callable[[str], str]=wildcard_transformer) -> SaQuery:
|
||||
transformer: Callable[[str], str] = wildcard_transformer) -> SaQuery:
|
||||
if isinstance(criterion, criteria.PlainCriterion):
|
||||
expr = column.ilike(transformer(criterion.value))
|
||||
elif isinstance(criterion, criteria.ArrayCriterion):
|
||||
|
@ -128,8 +128,8 @@ def apply_str_criterion_to_column(
|
|||
|
||||
|
||||
def create_str_filter(
|
||||
column: SaColumn, transformer: Callable[[str], str]=wildcard_transformer
|
||||
) -> Filter:
|
||||
column: SaColumn,
|
||||
transformer: Callable[[str], str] = wildcard_transformer) -> Filter:
|
||||
def wrapper(
|
||||
query: SaQuery,
|
||||
criterion: Optional[criteria.BaseCriterion],
|
||||
|
@ -187,7 +187,7 @@ def create_subquery_filter(
|
|||
right_id_column: SaColumn,
|
||||
filter_column: SaColumn,
|
||||
filter_factory: SaColumn,
|
||||
subquery_decorator: Callable[[SaQuery], None]=None) -> Filter:
|
||||
subquery_decorator: Callable[[SaQuery], None] = None) -> Filter:
|
||||
filter_func = filter_factory(filter_column)
|
||||
|
||||
def wrapper(
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
from typing import Optional, List, Callable
|
||||
from typing import Optional, List
|
||||
from szurubooru.search.typing import SaQuery
|
||||
|
||||
|
||||
|
|
|
@ -100,18 +100,20 @@ class Executor:
|
|||
filter_query = self.config.create_filter_query(disable_eager_loads)
|
||||
filter_query = filter_query.options(sa.orm.lazyload('*'))
|
||||
filter_query = self._prepare_db_query(filter_query, search_query, True)
|
||||
entities = filter_query \
|
||||
.offset(offset) \
|
||||
.limit(limit) \
|
||||
.all()
|
||||
entities = (
|
||||
filter_query
|
||||
.offset(offset)
|
||||
.limit(limit)
|
||||
.all())
|
||||
|
||||
count_query = self.config.create_count_query(disable_eager_loads)
|
||||
count_query = count_query.options(sa.orm.lazyload('*'))
|
||||
count_query = self._prepare_db_query(count_query, search_query, False)
|
||||
count_statement = count_query \
|
||||
.statement \
|
||||
.with_only_columns([sa.func.count()]) \
|
||||
.order_by(None)
|
||||
count_statement = (
|
||||
count_query
|
||||
.statement
|
||||
.with_only_columns([sa.func.count()])
|
||||
.order_by(None))
|
||||
count = db.session.execute(count_statement).scalar()
|
||||
|
||||
ret = (count, entities)
|
||||
|
|
|
@ -1,5 +1,4 @@
|
|||
import re
|
||||
from typing import Match, List
|
||||
from szurubooru import errors
|
||||
from szurubooru.search import criteria, tokens
|
||||
from szurubooru.search.query import SearchQuery
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
from szurubooru.search import tokens
|
||||
from typing import List
|
||||
|
||||
|
||||
class SearchQuery:
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
from unittest.mock import patch
|
||||
import pytest
|
||||
from szurubooru import api, db, model, errors
|
||||
from szurubooru import api, model, errors
|
||||
from szurubooru.func import tags, snapshots
|
||||
|
||||
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
from unittest.mock import patch
|
||||
import pytest
|
||||
from szurubooru import api, db, model, errors
|
||||
from szurubooru import api, model, errors
|
||||
from szurubooru.func import users
|
||||
|
||||
|
||||
|
|
|
@ -5,10 +5,10 @@ from szurubooru import db, model, errors
|
|||
from szurubooru.func import auth, users, files, util
|
||||
|
||||
|
||||
EMPTY_PIXEL = \
|
||||
b'\x47\x49\x46\x38\x39\x61\x01\x00\x01\x00\x80\x01\x00\x00\x00\x00' \
|
||||
b'\xff\xff\xff\x21\xf9\x04\x01\x00\x00\x01\x00\x2c\x00\x00\x00\x00' \
|
||||
b'\x01\x00\x01\x00\x00\x02\x02\x4c\x01\x00\x3b'
|
||||
EMPTY_PIXEL = (
|
||||
b'\x47\x49\x46\x38\x39\x61\x01\x00\x01\x00\x80\x01\x00\x00\x00\x00'
|
||||
b'\xff\xff\xff\x21\xf9\x04\x01\x00\x00\x01\x00\x2c\x00\x00\x00\x00'
|
||||
b'\x01\x00\x01\x00\x00\x02\x02\x4c\x01\x00\x3b')
|
||||
|
||||
|
||||
@pytest.mark.parametrize('user_name', ['test', 'TEST'])
|
||||
|
|
|
@ -28,11 +28,12 @@ def test_saving_tag(tag_factory):
|
|||
tag.implications.append(imp2)
|
||||
db.session.commit()
|
||||
|
||||
tag = db.session \
|
||||
.query(model.Tag) \
|
||||
.join(model.TagName) \
|
||||
.filter(model.TagName.name == 'alias1') \
|
||||
.one()
|
||||
tag = (
|
||||
db.session
|
||||
.query(model.Tag)
|
||||
.join(model.TagName)
|
||||
.filter(model.TagName.name == 'alias1')
|
||||
.one())
|
||||
assert [tag_name.name for tag_name in tag.names] == ['alias1', 'alias2']
|
||||
assert tag.category.name == 'category'
|
||||
assert tag.creation_time == datetime(1997, 1, 1)
|
||||
|
|
|
@ -300,7 +300,7 @@ def test_filter_by_note_count(
|
|||
('note-text:text3*', [3]),
|
||||
('note-text:text3a,text2', [2, 3]),
|
||||
])
|
||||
def test_filter_by_note_count(
|
||||
def test_filter_by_note_text(
|
||||
verify_unpaged, post_factory, note_factory, input, expected_post_ids):
|
||||
post1 = post_factory(id=1)
|
||||
post2 = post_factory(id=2)
|
||||
|
|
Loading…
Reference in a new issue