server: lint

This commit is contained in:
rr- 2017-04-24 23:30:53 +02:00
parent fea9a94945
commit 4bc58a3c95
42 changed files with 192 additions and 169 deletions

View file

@ -27,7 +27,7 @@ def _serialize(
@rest.routes.get('/comments/?')
def get_comments(
ctx: rest.Context, _params: Dict[str, str]={}) -> rest.Response:
ctx: rest.Context, _params: Dict[str, str] = {}) -> rest.Response:
auth.verify_privilege(ctx.user, 'comments:list')
return _search_executor.execute_and_serialize(
ctx, lambda comment: _serialize(ctx, comment))
@ -35,7 +35,7 @@ def get_comments(
@rest.routes.post('/comments/?')
def create_comment(
ctx: rest.Context, _params: Dict[str, str]={}) -> rest.Response:
ctx: rest.Context, _params: Dict[str, str] = {}) -> rest.Response:
auth.verify_privilege(ctx.user, 'comments:create')
text = ctx.get_param_as_string('text')
post_id = ctx.get_param_as_int('postId')

View file

@ -28,7 +28,7 @@ def _get_disk_usage() -> int:
@rest.routes.get('/info/?')
def get_info(
ctx: rest.Context, _params: Dict[str, str]={}) -> rest.Response:
ctx: rest.Context, _params: Dict[str, str] = {}) -> rest.Response:
post_feature = posts.try_get_current_post_feature()
return {
'postCount': posts.get_post_count(),

View file

@ -5,10 +5,10 @@ from hashlib import md5
MAIL_SUBJECT = 'Password reset for {name}'
MAIL_BODY = \
'You (or someone else) requested to reset your password on {name}.\n' \
'If you wish to proceed, click this link: {url}\n' \
'Otherwise, please ignore this email.'
MAIL_BODY = (
'You (or someone else) requested to reset your password on {name}.\n'
'If you wish to proceed, click this link: {url}\n'
'Otherwise, please ignore this email.')
@rest.routes.get('/password-reset/(?P<user_name>[^/]+)/?')

View file

@ -31,7 +31,7 @@ def _serialize_post(
@rest.routes.get('/posts/?')
def get_posts(
ctx: rest.Context, _params: Dict[str, str]={}) -> rest.Response:
ctx: rest.Context, _params: Dict[str, str] = {}) -> rest.Response:
auth.verify_privilege(ctx.user, 'posts:list')
_search_executor_config.user = ctx.user
return _search_executor.execute_and_serialize(
@ -40,7 +40,7 @@ def get_posts(
@rest.routes.post('/posts/?')
def create_post(
ctx: rest.Context, _params: Dict[str, str]={}) -> rest.Response:
ctx: rest.Context, _params: Dict[str, str] = {}) -> rest.Response:
anonymous = ctx.get_param_as_bool('anonymous', default=False)
if anonymous:
auth.verify_privilege(ctx.user, 'posts:create:anonymous')
@ -144,7 +144,7 @@ def delete_post(ctx: rest.Context, params: Dict[str, str]) -> rest.Response:
@rest.routes.post('/post-merge/?')
def merge_posts(
ctx: rest.Context, _params: Dict[str, str]={}) -> rest.Response:
ctx: rest.Context, _params: Dict[str, str] = {}) -> rest.Response:
source_post_id = ctx.get_param_as_int('remove')
target_post_id = ctx.get_param_as_int('mergeTo')
source_post = posts.get_post_by_id(source_post_id)
@ -162,14 +162,14 @@ def merge_posts(
@rest.routes.get('/featured-post/?')
def get_featured_post(
ctx: rest.Context, _params: Dict[str, str]={}) -> rest.Response:
ctx: rest.Context, _params: Dict[str, str] = {}) -> rest.Response:
post = posts.try_get_featured_post()
return _serialize_post(ctx, post)
@rest.routes.post('/featured-post/?')
def set_featured_post(
ctx: rest.Context, _params: Dict[str, str]={}) -> rest.Response:
ctx: rest.Context, _params: Dict[str, str] = {}) -> rest.Response:
auth.verify_privilege(ctx.user, 'posts:feature')
post_id = ctx.get_param_as_int('id')
post = posts.get_post_by_id(post_id)
@ -235,7 +235,7 @@ def get_posts_around(
@rest.routes.post('/posts/reverse-search/?')
def get_posts_by_image(
ctx: rest.Context, _params: Dict[str, str]={}) -> rest.Response:
ctx: rest.Context, _params: Dict[str, str] = {}) -> rest.Response:
auth.verify_privilege(ctx.user, 'posts:reverse_search')
content = ctx.get_file('content')

View file

@ -8,7 +8,7 @@ _search_executor = search.Executor(search.configs.SnapshotSearchConfig())
@rest.routes.get('/snapshots/?')
def get_snapshots(
ctx: rest.Context, _params: Dict[str, str]={}) -> rest.Response:
ctx: rest.Context, _params: Dict[str, str] = {}) -> rest.Response:
auth.verify_privilege(ctx.user, 'snapshots:list')
return _search_executor.execute_and_serialize(
ctx, lambda snapshot: snapshots.serialize_snapshot(snapshot, ctx.user))

View file

@ -28,14 +28,15 @@ def _create_if_needed(tag_names: List[str], user: model.User) -> None:
@rest.routes.get('/tags/?')
def get_tags(ctx: rest.Context, _params: Dict[str, str]={}) -> rest.Response:
def get_tags(ctx: rest.Context, _params: Dict[str, str] = {}) -> rest.Response:
auth.verify_privilege(ctx.user, 'tags:list')
return _search_executor.execute_and_serialize(
ctx, lambda tag: _serialize(ctx, tag))
@rest.routes.post('/tags/?')
def create_tag(ctx: rest.Context, _params: Dict[str, str]={}) -> rest.Response:
def create_tag(
ctx: rest.Context, _params: Dict[str, str] = {}) -> rest.Response:
auth.verify_privilege(ctx.user, 'tags:create')
names = ctx.get_param_as_string_list('names')
@ -112,7 +113,7 @@ def delete_tag(ctx: rest.Context, params: Dict[str, str]) -> rest.Response:
@rest.routes.post('/tag-merge/?')
def merge_tags(
ctx: rest.Context, _params: Dict[str, str]={}) -> rest.Response:
ctx: rest.Context, _params: Dict[str, str] = {}) -> rest.Response:
source_tag_name = ctx.get_param_as_string('remove')
target_tag_name = ctx.get_param_as_string('mergeTo')
source_tag = tags.get_tag_by_name(source_tag_name)

View file

@ -12,7 +12,7 @@ def _serialize(
@rest.routes.get('/tag-categories/?')
def get_tag_categories(
ctx: rest.Context, _params: Dict[str, str]={}) -> rest.Response:
ctx: rest.Context, _params: Dict[str, str] = {}) -> rest.Response:
auth.verify_privilege(ctx.user, 'tag_categories:list')
categories = tag_categories.get_all_categories()
return {
@ -22,7 +22,7 @@ def get_tag_categories(
@rest.routes.post('/tag-categories/?')
def create_tag_category(
ctx: rest.Context, _params: Dict[str, str]={}) -> rest.Response:
ctx: rest.Context, _params: Dict[str, str] = {}) -> rest.Response:
auth.verify_privilege(ctx.user, 'tag_categories:create')
name = ctx.get_param_as_string('name')
color = ctx.get_param_as_string('color')

View file

@ -5,7 +5,7 @@ from szurubooru.func import auth, file_uploads
@rest.routes.post('/uploads/?')
def create_temporary_file(
ctx: rest.Context, _params: Dict[str, str]={}) -> rest.Response:
ctx: rest.Context, _params: Dict[str, str] = {}) -> rest.Response:
auth.verify_privilege(ctx.user, 'uploads:create')
content = ctx.get_file('content', allow_tokens=False)
token = file_uploads.save(content)

View file

@ -16,7 +16,8 @@ def _serialize(
@rest.routes.get('/users/?')
def get_users(ctx: rest.Context, _params: Dict[str, str]={}) -> rest.Response:
def get_users(
ctx: rest.Context, _params: Dict[str, str] = {}) -> rest.Response:
auth.verify_privilege(ctx.user, 'users:list')
return _search_executor.execute_and_serialize(
ctx, lambda user: _serialize(ctx, user))
@ -24,7 +25,7 @@ def get_users(ctx: rest.Context, _params: Dict[str, str]={}) -> rest.Response:
@rest.routes.post('/users/?')
def create_user(
ctx: rest.Context, _params: Dict[str, str]={}) -> rest.Response:
ctx: rest.Context, _params: Dict[str, str] = {}) -> rest.Response:
auth.verify_privilege(ctx.user, 'users:create')
name = ctx.get_param_as_string('name')
password = ctx.get_param_as_string('password')

View file

@ -4,8 +4,8 @@ from typing import Dict
class BaseError(RuntimeError):
def __init__(
self,
message: str='Unknown error',
extra_fields: Dict[str, str]=None) -> None:
message: str = 'Unknown error',
extra_fields: Dict[str, str] = None) -> None:
super().__init__(message)
self.extra_fields = extra_fields

View file

@ -21,9 +21,9 @@ class LruCache:
i
for i, v in enumerate(self.item_list)
if v.key == item.key)
self.item_list[:] \
= self.item_list[:item_index] \
+ self.item_list[item_index + 1:]
self.item_list[:] = (
self.item_list[:item_index] +
self.item_list[item_index + 1:])
self.item_list.insert(0, item)
else:
if len(self.item_list) > self.length:

View file

@ -1,7 +1,7 @@
from datetime import datetime
from typing import Any, Optional, List, Dict, Callable
from szurubooru import db, model, errors, rest
from szurubooru.func import users, scores, util, serialization
from szurubooru.func import users, scores, serialization
class InvalidCommentIdError(errors.ValidationError):
@ -65,7 +65,7 @@ class CommentSerializer(serialization.BaseSerializer):
def serialize_comment(
comment: model.Comment,
auth_user: model.User,
options: List[str]=[]) -> rest.Response:
options: List[str] = []) -> rest.Response:
if comment is None:
return None
return CommentSerializer(comment, auth_user).serialize(options)
@ -73,10 +73,11 @@ def serialize_comment(
def try_get_comment_by_id(comment_id: int) -> Optional[model.Comment]:
comment_id = int(comment_id)
return db.session \
.query(model.Comment) \
.filter(model.Comment.comment_id == comment_id) \
.one_or_none()
return (
db.session
.query(model.Comment)
.filter(model.Comment.comment_id == comment_id)
.one_or_none())
def get_comment_by_id(comment_id: int) -> model.Comment:

View file

@ -99,7 +99,7 @@ def _normalize_and_threshold(
def _compute_grid_points(
image: NpMatrix,
n: float,
window: Window=None) -> Tuple[NpMatrix, NpMatrix]:
window: Window = None) -> Tuple[NpMatrix, NpMatrix]:
if window is None:
window = ((0, image.shape[0]), (0, image.shape[1]))
x_coords = np.linspace(window[0][0], window[0][1], n + 2, dtype=int)[1:-1]
@ -219,7 +219,7 @@ def _max_contrast(array: NpMatrix) -> None:
def _normalized_distance(
target_array: NpMatrix,
vec: NpMatrix,
nan_value: float=1.0) -> List[float]:
nan_value: float = 1.0) -> List[float]:
target_array = target_array.astype(int)
vec = vec.astype(int)
topvec = np.linalg.norm(vec - target_array, axis=1)

View file

@ -11,8 +11,8 @@ from szurubooru.func import mime, util
logger = logging.getLogger(__name__)
_SCALE_FIT_FMT = \
r'scale=iw*max({width}/iw\,{height}/ih):ih*max({width}/iw\,{height}/ih)'
_SCALE_FIT_FMT = (
r'scale=iw*max({width}/iw\,{height}/ih):ih*max({width}/iw\,{height}/ih)')
class Image:
@ -77,7 +77,7 @@ class Image:
'-',
])
def _execute(self, cli: List[str], program: str='ffmpeg') -> bytes:
def _execute(self, cli: List[str], program: str = 'ffmpeg') -> bytes:
extension = mime.get_extension(mime.get_mime_type(self.content))
assert extension
with util.create_temp_file(suffix='.' + extension) as handle:

View file

@ -7,10 +7,10 @@ from szurubooru.func import (
mime, images, files, image_hash, serialization)
EMPTY_PIXEL = \
b'\x47\x49\x46\x38\x39\x61\x01\x00\x01\x00\x80\x01\x00\x00\x00\x00' \
b'\xff\xff\xff\x21\xf9\x04\x01\x00\x00\x01\x00\x2c\x00\x00\x00\x00' \
b'\x01\x00\x01\x00\x00\x02\x02\x4c\x01\x00\x3b'
EMPTY_PIXEL = (
b'\x47\x49\x46\x38\x39\x61\x01\x00\x01\x00\x80\x01\x00\x00\x00\x00'
b'\xff\xff\xff\x21\xf9\x04\x01\x00\x00\x01\x00\x2c\x00\x00\x00\x00'
b'\x01\x00\x01\x00\x00\x02\x02\x4c\x01\x00\x3b')
class PostNotFoundError(errors.NotFoundError):
@ -283,7 +283,7 @@ class PostSerializer(serialization.BaseSerializer):
def serialize_post(
post: Optional[model.Post],
auth_user: model.User,
options: List[str]=[]) -> Optional[rest.Response]:
options: List[str] = []) -> Optional[rest.Response]:
if not post:
return None
return PostSerializer(post, auth_user).serialize(options)
@ -300,10 +300,11 @@ def get_post_count() -> int:
def try_get_post_by_id(post_id: int) -> Optional[model.Post]:
return db.session \
.query(model.Post) \
.filter(model.Post.post_id == post_id) \
.one_or_none()
return (
db.session
.query(model.Post)
.filter(model.Post.post_id == post_id)
.one_or_none())
def get_post_by_id(post_id: int) -> model.Post:
@ -314,10 +315,11 @@ def get_post_by_id(post_id: int) -> model.Post:
def try_get_current_post_feature() -> Optional[model.PostFeature]:
return db.session \
.query(model.PostFeature) \
.order_by(model.PostFeature.time.desc()) \
.first()
return (
db.session
.query(model.PostFeature)
.order_by(model.PostFeature.time.desc())
.first())
def try_get_featured_post() -> Optional[model.Post]:
@ -426,11 +428,12 @@ def update_post_content(post: model.Post, content: Optional[bytes]) -> None:
'Unhandled file type: %r' % post.mime_type)
post.checksum = util.get_sha1(content)
other_post = db.session \
.query(model.Post) \
.filter(model.Post.checksum == post.checksum) \
.filter(model.Post.post_id != post.post_id) \
.one_or_none()
other_post = (
db.session
.query(model.Post)
.filter(model.Post.checksum == post.checksum)
.filter(model.Post.post_id != post.post_id)
.one_or_none())
if other_post \
and other_post.post_id \
and other_post.post_id != post.post_id:
@ -452,7 +455,7 @@ def update_post_content(post: model.Post, content: Optional[bytes]) -> None:
def update_post_thumbnail(
post: model.Post, content: Optional[bytes]=None) -> None:
post: model.Post, content: Optional[bytes] = None) -> None:
assert post
setattr(post, '__thumbnail', content)
@ -492,10 +495,11 @@ def update_post_relations(post: model.Post, new_post_ids: List[int]) -> None:
old_posts = post.relations
old_post_ids = [int(p.post_id) for p in old_posts]
if new_post_ids:
new_posts = db.session \
.query(model.Post) \
.filter(model.Post.post_id.in_(new_post_ids)) \
.all()
new_posts = (
db.session
.query(model.Post)
.filter(model.Post.post_id.in_(new_post_ids))
.all())
else:
new_posts = []
if len(new_posts) != len(new_post_ids):
@ -673,10 +677,11 @@ def merge_posts(
def search_by_image_exact(image_content: bytes) -> Optional[model.Post]:
checksum = util.get_sha1(image_content)
return db.session \
.query(model.Post) \
.filter(model.Post.checksum == checksum) \
.one_or_none()
return (
db.session
.query(model.Post)
.filter(model.Post.checksum == checksum)
.one_or_none())
def search_by_image(image_content: bytes) -> List[PostLookalike]:

View file

@ -39,11 +39,12 @@ def get_score(entity: model.Base, user: model.User) -> int:
assert entity
assert user
table, get_column = _get_table_info(entity)
row = db.session \
.query(table.score) \
.filter(get_column(table) == get_column(entity)) \
.filter(table.user_id == user.user_id) \
.one_or_none()
row = (
db.session
.query(table.score)
.filter(get_column(table) == get_column(entity))
.filter(table.user_id == user.user_id)
.one_or_none())
return row[0] if row else 0

View file

@ -1,5 +1,5 @@
from typing import Any, Optional, List, Dict, Callable
from szurubooru import db, model, rest, errors
from typing import Any, List, Dict, Callable
from szurubooru import model, rest, errors
def get_serialization_options(ctx: rest.Context) -> List[str]:

View file

@ -66,7 +66,7 @@ class TagCategorySerializer(serialization.BaseSerializer):
def serialize_category(
category: Optional[model.TagCategory],
options: List[str]=[]) -> Optional[rest.Response]:
options: List[str] = []) -> Optional[rest.Response]:
if not category:
return None
return TagCategorySerializer(category).serialize(options)
@ -113,16 +113,17 @@ def update_category_color(category: model.TagCategory, color: str) -> None:
def try_get_category_by_name(
name: str, lock: bool=False) -> Optional[model.TagCategory]:
query = db.session \
.query(model.TagCategory) \
.filter(sa.func.lower(model.TagCategory.name) == name.lower())
name: str, lock: bool = False) -> Optional[model.TagCategory]:
query = (
db.session
.query(model.TagCategory)
.filter(sa.func.lower(model.TagCategory.name) == name.lower()))
if lock:
query = query.with_lockmode('update')
return query.one_or_none()
def get_category_by_name(name: str, lock: bool=False) -> model.TagCategory:
def get_category_by_name(name: str, lock: bool = False) -> model.TagCategory:
category = try_get_category_by_name(name, lock)
if not category:
raise TagCategoryNotFoundError('Tag category %r not found.' % name)
@ -137,26 +138,29 @@ def get_all_categories() -> List[model.TagCategory]:
return db.session.query(model.TagCategory).all()
def try_get_default_category(lock: bool=False) -> Optional[model.TagCategory]:
query = db.session \
.query(model.TagCategory) \
.filter(model.TagCategory.default)
def try_get_default_category(
lock: bool = False) -> Optional[model.TagCategory]:
query = (
db.session
.query(model.TagCategory)
.filter(model.TagCategory.default))
if lock:
query = query.with_lockmode('update')
category = query.first()
# if for some reason (e.g. as a result of migration) there's no default
# category, get the first record available.
if not category:
query = db.session \
.query(model.TagCategory) \
.order_by(model.TagCategory.tag_category_id.asc())
query = (
db.session
.query(model.TagCategory)
.order_by(model.TagCategory.tag_category_id.asc()))
if lock:
query = query.with_lockmode('update')
category = query.first()
return category
def get_default_category(lock: bool=False) -> model.TagCategory:
def get_default_category(lock: bool = False) -> model.TagCategory:
category = try_get_default_category(lock)
if not category:
raise TagCategoryNotFoundError('No tag category created yet.')

View file

@ -122,7 +122,7 @@ class TagSerializer(serialization.BaseSerializer):
def serialize_tag(
tag: model.Tag, options: List[str]=[]) -> Optional[rest.Response]:
tag: model.Tag, options: List[str] = []) -> Optional[rest.Response]:
if not tag:
return None
return TagSerializer(tag).serialize(options)
@ -209,7 +209,8 @@ def get_tags_by_names(names: List[str]) -> List[model.Tag]:
names = util.icase_unique(names)
if len(names) == 0:
return []
return (db.session.query(model.Tag)
return (
db.session.query(model.Tag)
.join(model.TagName)
.filter(
sa.sql.or_(

View file

@ -86,7 +86,7 @@ class UserSerializer(serialization.BaseSerializer):
self,
user: model.User,
auth_user: model.User,
force_show_email: bool=False) -> None:
force_show_email: bool = False) -> None:
self.user = user
self.auth_user = auth_user
self.force_show_email = force_show_email
@ -151,8 +151,8 @@ class UserSerializer(serialization.BaseSerializer):
def serialize_user(
user: Optional[model.User],
auth_user: model.User,
options: List[str]=[],
force_show_email: bool=False) -> Optional[rest.Response]:
options: List[str] = [],
force_show_email: bool = False) -> Optional[rest.Response]:
if not user:
return None
return UserSerializer(user, auth_user, force_show_email).serialize(options)
@ -170,10 +170,11 @@ def get_user_count() -> int:
def try_get_user_by_name(name: str) -> Optional[model.User]:
return db.session \
.query(model.User) \
.filter(sa.func.lower(model.User.name) == sa.func.lower(name)) \
.one_or_none()
return (
db.session
.query(model.User)
.filter(sa.func.lower(model.User.name) == sa.func.lower(name))
.one_or_none())
def get_user_by_name(name: str) -> model.User:
@ -276,7 +277,7 @@ def update_user_rank(
def update_user_avatar(
user: model.User,
avatar_style: str,
avatar_content: Optional[bytes]=None) -> None:
avatar_content: Optional[bytes] = None) -> None:
assert user
if avatar_style == 'gravatar':
user.avatar_style = user.AVATAR_GRAVATAR

View file

@ -2,8 +2,7 @@ import os
import hashlib
import re
import tempfile
from typing import (
Any, Optional, Union, Tuple, List, Dict, Generator, Callable, TypeVar)
from typing import Any, Optional, Union, Tuple, List, Dict, Generator, TypeVar
from datetime import datetime, timedelta
from contextlib import contextmanager
from szurubooru import errors

View file

@ -4,7 +4,7 @@ from szurubooru import errors, rest, model
def verify_version(
entity: model.Base,
context: rest.Context,
field_name: str='version') -> None:
field_name: str = 'version') -> None:
actual_version = context.get_param_as_int(field_name)
expected_version = entity.version
if actual_version != expected_version:

View file

@ -27,8 +27,9 @@ def _get_user(ctx: rest.Context) -> Optional[model.User]:
credentials.encode('ascii')).decode('utf8').split(':')
return _authenticate(username, password)
except ValueError as err:
msg = 'Basic authentication header value are not properly formed. ' \
+ 'Supplied header {0}. Got error: {1}'
msg = (
'Basic authentication header value are not properly formed. '
'Supplied header {0}. Got error: {1}')
raise HttpBadRequest(
'ValidationError',
msg.format(ctx.get_header('Authorization'), str(err)))

View file

@ -50,13 +50,14 @@ def upgrade():
def downgrade():
session = sa.orm.session.Session(bind=op.get_bind())
default_category = session \
.query(TagCategory) \
.filter(TagCategory.name == 'default') \
.filter(TagCategory.color == 'default') \
.filter(TagCategory.version == 1) \
.filter(TagCategory.default == True) \
.one_or_none()
default_category = (
session
.query(TagCategory)
.filter(TagCategory.name == 'default')
.filter(TagCategory.color == 'default')
.filter(TagCategory.version == 1)
.filter(TagCategory.default == 1)
.one_or_none())
if default_category:
session.delete(default_category)
session.commit()

View file

@ -211,10 +211,11 @@ class Post(Base):
@property
def is_featured(self) -> bool:
featured_post = sa.orm.object_session(self) \
.query(PostFeature) \
.order_by(PostFeature.time.desc()) \
.first()
featured_post = (
sa.orm.object_session(self)
.query(PostFeature)
.order_by(PostFeature.time.desc())
.first())
return featured_post and featured_post.post_id == self.post_id
score = sa.orm.column_property(

View file

@ -14,7 +14,7 @@ class TagCategory(Base):
'color', sa.Unicode(32), nullable=False, default='#000000')
default = sa.Column('default', sa.Boolean, nullable=False, default=False)
def __init__(self, name: Optional[str]=None) -> None:
def __init__(self, name: Optional[str] = None) -> None:
self.name = name
tag_count = sa.orm.column_property(

View file

@ -13,9 +13,9 @@ class Context:
self,
method: str,
url: str,
headers: Dict[str, str]=None,
params: Request=None,
files: Dict[str, bytes]=None) -> None:
headers: Dict[str, str] = None,
params: Request = None,
files: Dict[str, bytes] = None) -> None:
self.method = method
self.url = url
self._headers = headers or {}
@ -34,7 +34,7 @@ class Context:
def get_header(self, name: str) -> str:
return self._headers.get(name, '')
def has_file(self, name: str, allow_tokens: bool=True) -> bool:
def has_file(self, name: str, allow_tokens: bool = True) -> bool:
return (
name in self._files or
name + 'Url' in self._params or
@ -43,8 +43,8 @@ class Context:
def get_file(
self,
name: str,
default: Union[object, bytes]=MISSING,
allow_tokens: bool=True) -> bytes:
default: Union[object, bytes] = MISSING,
allow_tokens: bool = True) -> bytes:
if name in self._files and self._files[name]:
return self._files[name]
@ -70,7 +70,7 @@ class Context:
def get_param_as_list(
self,
name: str,
default: Union[object, List[Any]]=MISSING) -> List[Any]:
default: Union[object, List[Any]] = MISSING) -> List[Any]:
if name not in self._params:
if default is not MISSING:
return cast(List[Any], default)
@ -89,7 +89,7 @@ class Context:
def get_param_as_int_list(
self,
name: str,
default: Union[object, List[int]]=MISSING) -> List[int]:
default: Union[object, List[int]] = MISSING) -> List[int]:
ret = self.get_param_as_list(name, default)
for item in ret:
if type(item) is not int:
@ -100,7 +100,7 @@ class Context:
def get_param_as_string_list(
self,
name: str,
default: Union[object, List[str]]=MISSING) -> List[str]:
default: Union[object, List[str]] = MISSING) -> List[str]:
ret = self.get_param_as_list(name, default)
for item in ret:
if type(item) is not str:
@ -111,7 +111,7 @@ class Context:
def get_param_as_string(
self,
name: str,
default: Union[object, str]=MISSING) -> str:
default: Union[object, str] = MISSING) -> str:
if name not in self._params:
if default is not MISSING:
return cast(str, default)
@ -135,9 +135,9 @@ class Context:
def get_param_as_int(
self,
name: str,
default: Union[object, int]=MISSING,
min: Optional[int]=None,
max: Optional[int]=None) -> int:
default: Union[object, int] = MISSING,
min: Optional[int] = None,
max: Optional[int] = None) -> int:
if name not in self._params:
if default is not MISSING:
return cast(int, default)
@ -161,7 +161,7 @@ class Context:
def get_param_as_bool(
self,
name: str,
default: Union[object, bool]=MISSING) -> bool:
default: Union[object, bool] = MISSING) -> bool:
if name not in self._params:
if default is not MISSING:
return cast(bool, default)

View file

@ -1,4 +1,4 @@
from typing import Callable, Type, Dict
from typing import Optional, Callable, Type, Dict
error_handlers = {} # pylint: disable=invalid-name
@ -12,8 +12,8 @@ class BaseHttpError(RuntimeError):
self,
name: str,
description: str,
title: str=None,
extra_fields: Dict[str, str]=None) -> None:
title: Optional[str] = None,
extra_fields: Optional[Dict[str, str]] = None) -> None:
super().__init__()
# error name for programmers
self.name = name

View file

@ -1,4 +1,4 @@
from typing import Callable
from typing import List, Callable
from szurubooru.rest.context import Context

View file

@ -1,4 +1,4 @@
from typing import Callable, Dict, Any
from typing import Callable, Dict
from collections import defaultdict
from szurubooru.rest.context import Context, Response

View file

@ -52,10 +52,11 @@ def _create_score_filter(score: int) -> Filter:
user_alias.name, criterion)
if negated:
expr = ~expr
ret = query \
.join(score_alias, score_alias.post_id == model.Post.post_id) \
.join(user_alias, user_alias.user_id == score_alias.user_id) \
.filter(expr)
ret = (
query
.join(score_alias, score_alias.post_id == model.Post.post_id)
.join(user_alias, user_alias.user_id == score_alias.user_id)
.filter(expr))
return ret
return wrapper
@ -124,7 +125,8 @@ class PostSearchConfig(BaseSearchConfig):
sa.orm.lazyload
if disable_eager_loads
else sa.orm.subqueryload)
return db.session.query(model.Post) \
return (
db.session.query(model.Post)
.options(
sa.orm.lazyload('*'),
# use config optimized for official client
@ -141,7 +143,7 @@ class PostSearchConfig(BaseSearchConfig):
strategy(model.Post.tags).subqueryload(model.Tag.names),
strategy(model.Post.tags).defer(model.Tag.post_count),
strategy(model.Post.tags).lazyload(model.Tag.implications),
strategy(model.Post.tags).lazyload(model.Tag.suggestions))
strategy(model.Post.tags).lazyload(model.Tag.suggestions)))
def create_count_query(self, _disable_eager_loads: bool) -> SaQuery:
return db.session.query(model.Post)

View file

@ -14,8 +14,9 @@ class TagSearchConfig(BaseSearchConfig):
sa.orm.lazyload
if _disable_eager_loads
else sa.orm.subqueryload)
return db.session.query(model.Tag) \
.join(model.TagCategory) \
return (
db.session.query(model.Tag)
.join(model.TagCategory)
.options(
sa.orm.defer(model.Tag.first_name),
sa.orm.defer(model.Tag.suggestion_count),
@ -23,7 +24,7 @@ class TagSearchConfig(BaseSearchConfig):
sa.orm.defer(model.Tag.post_count),
strategy(model.Tag.names),
strategy(model.Tag.suggestions).joinedload(model.Tag.names),
strategy(model.Tag.implications).joinedload(model.Tag.names))
strategy(model.Tag.implications).joinedload(model.Tag.names)))
def create_count_query(self, _disable_eager_loads: bool) -> SaQuery:
return db.session.query(model.Tag)

View file

@ -69,7 +69,7 @@ def float_transformer(value: str) -> float:
def apply_num_criterion_to_column(
column: Any,
criterion: criteria.BaseCriterion,
transformer: Callable[[str], Number]=integer_transformer) -> SaQuery:
transformer: Callable[[str], Number] = integer_transformer) -> SaQuery:
try:
if isinstance(criterion, criteria.PlainCriterion):
expr = column == transformer(criterion.value)
@ -95,7 +95,7 @@ def apply_num_criterion_to_column(
def create_num_filter(
column: Any,
transformer: Callable[[str], Number]=integer_transformer) -> SaQuery:
transformer: Callable[[str], Number] = integer_transformer) -> SaQuery:
def wrapper(
query: SaQuery,
criterion: Optional[criteria.BaseCriterion],
@ -111,7 +111,7 @@ def create_num_filter(
def apply_str_criterion_to_column(
column: SaColumn,
criterion: criteria.BaseCriterion,
transformer: Callable[[str], str]=wildcard_transformer) -> SaQuery:
transformer: Callable[[str], str] = wildcard_transformer) -> SaQuery:
if isinstance(criterion, criteria.PlainCriterion):
expr = column.ilike(transformer(criterion.value))
elif isinstance(criterion, criteria.ArrayCriterion):
@ -128,8 +128,8 @@ def apply_str_criterion_to_column(
def create_str_filter(
column: SaColumn, transformer: Callable[[str], str]=wildcard_transformer
) -> Filter:
column: SaColumn,
transformer: Callable[[str], str] = wildcard_transformer) -> Filter:
def wrapper(
query: SaQuery,
criterion: Optional[criteria.BaseCriterion],
@ -187,7 +187,7 @@ def create_subquery_filter(
right_id_column: SaColumn,
filter_column: SaColumn,
filter_factory: SaColumn,
subquery_decorator: Callable[[SaQuery], None]=None) -> Filter:
subquery_decorator: Callable[[SaQuery], None] = None) -> Filter:
filter_func = filter_factory(filter_column)
def wrapper(

View file

@ -1,4 +1,4 @@
from typing import Optional, List, Callable
from typing import Optional, List
from szurubooru.search.typing import SaQuery

View file

@ -100,18 +100,20 @@ class Executor:
filter_query = self.config.create_filter_query(disable_eager_loads)
filter_query = filter_query.options(sa.orm.lazyload('*'))
filter_query = self._prepare_db_query(filter_query, search_query, True)
entities = filter_query \
.offset(offset) \
.limit(limit) \
.all()
entities = (
filter_query
.offset(offset)
.limit(limit)
.all())
count_query = self.config.create_count_query(disable_eager_loads)
count_query = count_query.options(sa.orm.lazyload('*'))
count_query = self._prepare_db_query(count_query, search_query, False)
count_statement = count_query \
.statement \
.with_only_columns([sa.func.count()]) \
.order_by(None)
count_statement = (
count_query
.statement
.with_only_columns([sa.func.count()])
.order_by(None))
count = db.session.execute(count_statement).scalar()
ret = (count, entities)

View file

@ -1,5 +1,4 @@
import re
from typing import Match, List
from szurubooru import errors
from szurubooru.search import criteria, tokens
from szurubooru.search.query import SearchQuery

View file

@ -1,4 +1,5 @@
from szurubooru.search import tokens
from typing import List
class SearchQuery:

View file

@ -1,6 +1,6 @@
from unittest.mock import patch
import pytest
from szurubooru import api, db, model, errors
from szurubooru import api, model, errors
from szurubooru.func import tags, snapshots

View file

@ -1,6 +1,6 @@
from unittest.mock import patch
import pytest
from szurubooru import api, db, model, errors
from szurubooru import api, model, errors
from szurubooru.func import users

View file

@ -5,10 +5,10 @@ from szurubooru import db, model, errors
from szurubooru.func import auth, users, files, util
EMPTY_PIXEL = \
b'\x47\x49\x46\x38\x39\x61\x01\x00\x01\x00\x80\x01\x00\x00\x00\x00' \
b'\xff\xff\xff\x21\xf9\x04\x01\x00\x00\x01\x00\x2c\x00\x00\x00\x00' \
b'\x01\x00\x01\x00\x00\x02\x02\x4c\x01\x00\x3b'
EMPTY_PIXEL = (
b'\x47\x49\x46\x38\x39\x61\x01\x00\x01\x00\x80\x01\x00\x00\x00\x00'
b'\xff\xff\xff\x21\xf9\x04\x01\x00\x00\x01\x00\x2c\x00\x00\x00\x00'
b'\x01\x00\x01\x00\x00\x02\x02\x4c\x01\x00\x3b')
@pytest.mark.parametrize('user_name', ['test', 'TEST'])

View file

@ -28,11 +28,12 @@ def test_saving_tag(tag_factory):
tag.implications.append(imp2)
db.session.commit()
tag = db.session \
.query(model.Tag) \
.join(model.TagName) \
.filter(model.TagName.name == 'alias1') \
.one()
tag = (
db.session
.query(model.Tag)
.join(model.TagName)
.filter(model.TagName.name == 'alias1')
.one())
assert [tag_name.name for tag_name in tag.names] == ['alias1', 'alias2']
assert tag.category.name == 'category'
assert tag.creation_time == datetime(1997, 1, 1)

View file

@ -300,7 +300,7 @@ def test_filter_by_note_count(
('note-text:text3*', [3]),
('note-text:text3a,text2', [2, 3]),
])
def test_filter_by_note_count(
def test_filter_by_note_text(
verify_unpaged, post_factory, note_factory, input, expected_post_ids):
post1 = post_factory(id=1)
post2 = post_factory(id=2)