server/general: be more pythonic

This commit is contained in:
rr- 2016-04-03 22:03:58 +02:00
parent 7f4708c696
commit 219ab7c2c3
36 changed files with 335 additions and 428 deletions

View file

@ -1,5 +1,5 @@
import hashlib from szurubooru import config, errors
from szurubooru import errors from szurubooru.util import auth, mailer, users
from szurubooru.api.base_api import BaseApi from szurubooru.api.base_api import BaseApi
MAIL_SUBJECT = 'Password reset for {name}' MAIL_SUBJECT = 'Password reset for {name}'
@ -9,45 +9,35 @@ MAIL_BODY = \
'Otherwise, please ignore this email.' 'Otherwise, please ignore this email.'
class PasswordReminderApi(BaseApi): class PasswordReminderApi(BaseApi):
def __init__(self, config, mailer, user_service):
super().__init__()
self._config = config
self._mailer = mailer
self._user_service = user_service
def get(self, context, user_name): def get(self, context, user_name):
user = self._user_service.get_by_name(context.session, user_name) ''' Send a mail with secure token to the correlated user. '''
user = users.get_by_name(context.session, user_name)
if not user: if not user:
raise errors.NotFoundError('User %r not found.' % user_name) raise errors.NotFoundError('User %r not found.' % user_name)
if not user.email: if not user.email:
raise errors.ValidationError( raise errors.ValidationError(
'User %r hasn\'t supplied email. Cannot reset password.' % user_name) 'User %r hasn\'t supplied email. Cannot reset password.' % user_name)
token = self._generate_authentication_token(user) token = auth.generate_authentication_token(user)
url = '%s/password-reset/%s' % ( url = '%s/password-reset/%s' % (
self._config['basic']['base_url'].rstrip('/'), token) config.config['basic']['base_url'].rstrip('/'), token)
self._mailer.send( mailer.send_mail(
'noreply@%s' % self._config['basic']['name'], 'noreply@%s' % config.config['basic']['name'],
user.email, user.email,
MAIL_SUBJECT.format(name=self._config['basic']['name']), MAIL_SUBJECT.format(name=config.config['basic']['name']),
MAIL_BODY.format(name=self._config['basic']['name'], url=url)) MAIL_BODY.format(name=config.config['basic']['name'], url=url))
return {} return {}
def post(self, context, user_name): def post(self, context, user_name):
user = self._user_service.get_by_name(context.session, user_name) ''' Verify token from mail, generate a new password and return it. '''
user = users.get_by_name(context.session, user_name)
if not user: if not user:
raise errors.NotFoundError('User %r not found.' % user_name) raise errors.NotFoundError('User %r not found.' % user_name)
good_token = self._generate_authentication_token(user) good_token = auth.generate_authentication_token(user)
if not 'token' in context.request: if not 'token' in context.request:
raise errors.ValidationError('Missing password reset token.') raise errors.ValidationError('Missing password reset token.')
token = context.request['token'] token = context.request['token']
if token != good_token: if token != good_token:
raise errors.ValidationError('Invalid password reset token.') raise errors.ValidationError('Invalid password reset token.')
new_password = self._user_service.reset_password(user) new_password = users.reset_password(user)
context.session.commit() context.session.commit()
return {'password': new_password} return {'password': new_password}
def _generate_authentication_token(self, user):
digest = hashlib.sha256()
digest.update(self._config['basic']['secret'].encode('utf8'))
digest.update(user.password_salt.encode('utf8'))
return digest.hexdigest()

View file

@ -1,9 +1,7 @@
import re
import sqlalchemy import sqlalchemy
from szurubooru import errors from szurubooru import errors, search
from szurubooru import util from szurubooru.util import auth, users
from szurubooru.api.base_api import BaseApi from szurubooru.api.base_api import BaseApi
from szurubooru.services import search
def _serialize_user(authenticated_user, user): def _serialize_user(authenticated_user, user):
ret = { ret = {
@ -21,29 +19,27 @@ def _serialize_user(authenticated_user, user):
class UserListApi(BaseApi): class UserListApi(BaseApi):
''' API for lists of users. ''' ''' API for lists of users. '''
def __init__(self, auth_service, user_service): def __init__(self):
super().__init__() super().__init__()
self._auth_service = auth_service
self._user_service = user_service
self._search_executor = search.SearchExecutor(search.UserSearchConfig()) self._search_executor = search.SearchExecutor(search.UserSearchConfig())
def get(self, context): def get(self, context):
''' Retrieves a list of users. ''' ''' Retrieve a list of users. '''
self._auth_service.verify_privilege(context.user, 'users:list') auth.verify_privilege(context.user, 'users:list')
query = context.get_param_as_string('query') query = context.get_param_as_string('query')
page = context.get_param_as_int('page', 1) page = context.get_param_as_int('page', 1)
count, users = self._search_executor.execute(context.session, query, page) count, user_list = self._search_executor.execute(context.session, query, page)
return { return {
'query': query, 'query': query,
'page': page, 'page': page,
'page_size': self._search_executor.page_size, 'page_size': self._search_executor.page_size,
'total': count, 'total': count,
'users': [_serialize_user(context.user, user) for user in users], 'users': [_serialize_user(context.user, user) for user in user_list],
} }
def post(self, context): def post(self, context):
''' Creates a new user. ''' ''' Create a new user. '''
self._auth_service.verify_privilege(context.user, 'users:create') auth.verify_privilege(context.user, 'users:create')
try: try:
name = context.request['name'].strip() name = context.request['name'].strip()
@ -52,9 +48,9 @@ class UserListApi(BaseApi):
except KeyError as ex: except KeyError as ex:
raise errors.ValidationError('Field %r not found.' % ex.args[0]) raise errors.ValidationError('Field %r not found.' % ex.args[0])
user = self._user_service.create_user( user = users.create_user(name, password, email)
context.session, name, password, email)
try: try:
context.session.add(user)
context.session.commit() context.session.commit()
except sqlalchemy.exc.IntegrityError: except sqlalchemy.exc.IntegrityError:
raise errors.IntegrityError('User %r already exists.' % name) raise errors.IntegrityError('User %r already exists.' % name)
@ -63,26 +59,17 @@ class UserListApi(BaseApi):
class UserDetailApi(BaseApi): class UserDetailApi(BaseApi):
''' API for individual users. ''' ''' API for individual users. '''
def __init__(self, config, auth_service, password_service, user_service):
super().__init__()
self._available_access_ranks = config['service']['user_ranks']
self._name_regex = config['service']['user_name_regex']
self._password_regex = config['service']['password_regex']
self._password_service = password_service
self._auth_service = auth_service
self._user_service = user_service
def get(self, context, user_name): def get(self, context, user_name):
''' Retrieves an user. ''' ''' Retrieve an user. '''
self._auth_service.verify_privilege(context.user, 'users:view') auth.verify_privilege(context.user, 'users:view')
user = self._user_service.get_by_name(context.session, user_name) user = users.get_by_name(context.session, user_name)
if not user: if not user:
raise errors.NotFoundError('User %r not found.' % user_name) raise errors.NotFoundError('User %r not found.' % user_name)
return {'user': _serialize_user(context.user, user)} return {'user': _serialize_user(context.user, user)}
def put(self, context, user_name): def put(self, context, user_name):
''' Updates an existing user. ''' ''' Update an existing user. '''
user = self._user_service.get_by_name(context.session, user_name) user = users.get_by_name(context.session, user_name)
if not user: if not user:
raise errors.NotFoundError('User %r not found.' % user_name) raise errors.NotFoundError('User %r not found.' % user_name)
@ -92,53 +79,26 @@ class UserDetailApi(BaseApi):
infix = 'any' infix = 'any'
if 'name' in context.request: if 'name' in context.request:
self._auth_service.verify_privilege( auth.verify_privilege(context.user, 'users:edit:%s:name' % infix)
context.user, 'users:edit:%s:name' % infix) users.update_name(user, context.request['name'])
name = context.request['name'].strip()
if not re.match(self._name_regex, name):
raise errors.ValidationError(
'Name must satisfy regex %r.' % self._name_regex)
user.name = name
if 'password' in context.request: if 'password' in context.request:
password = context.request['password'] auth.verify_privilege(context.user, 'users:edit:%s:pass' % infix)
self._auth_service.verify_privilege( users.update_password(user, context.request['password'])
context.user, 'users:edit:%s:pass' % infix)
if not re.match(self._password_regex, password):
raise errors.ValidationError(
'Password must satisfy regex %r.' % self._password_regex)
user.password_salt = self._password_service.create_password()
user.password_hash = self._password_service.get_password_hash(
user.password_salt, password)
if 'email' in context.request: if 'email' in context.request:
self._auth_service.verify_privilege( auth.verify_privilege(context.user, 'users:edit:%s:email' % infix)
context.user, 'users:edit:%s:email' % infix) users.update_email(user, context.request['email'])
email = context.request['email'].strip() or None
if not util.is_valid_email(email):
raise errors.ValidationError(
'%r is not a vaild email address.' % email)
user.email = email
if 'accessRank' in context.request: if 'accessRank' in context.request:
self._auth_service.verify_privilege( auth.verify_privilege(context.user, 'users:edit:%s:rank' % infix)
context.user, 'users:edit:%s:rank' % infix) users.update_rank(user, context.request['accessRank'], context.user)
rank = context.request['accessRank'].strip()
if not rank in self._available_access_ranks:
raise errors.ValidationError(
'Bad access rank. Valid access ranks: %r' \
% self._available_access_ranks)
if self._available_access_ranks.index(context.user.access_rank) \
< self._available_access_ranks.index(rank):
raise errors.AuthError(
'Trying to set higher access rank than one has')
user.access_rank = rank
# TODO: avatar # TODO: avatar
try: try:
context.session.commit() context.session.commit()
except sqlalchemy.exc.IntegrityError: except sqlalchemy.exc.IntegrityError:
raise errors.IntegrityError('User %r already exists.' % name) raise errors.IntegrityError('User %r already exists.' % user.name)
return {'user': _serialize_user(context.user, user)} return {'user': _serialize_user(context.user, user)}

View file

@ -3,16 +3,11 @@
import falcon import falcon
import sqlalchemy import sqlalchemy
import sqlalchemy.orm import sqlalchemy.orm
import szurubooru.api from szurubooru import api, config, errors, middleware
import szurubooru.config from szurubooru.util import misc
import szurubooru.errors
import szurubooru.middleware
import szurubooru.services
import szurubooru.services.search
import szurubooru.util
class _CustomRequest(falcon.Request): class _CustomRequest(falcon.Request):
context_type = szurubooru.util.dotdict context_type = misc.dotdict
def get_param_as_string(self, name, required=False, store=None, default=None): def get_param_as_string(self, name, required=False, store=None, default=None):
params = self._params params = self._params
@ -45,47 +40,37 @@ def _on_not_found_error(ex, _request, _response, _params):
raise falcon.HTTPNotFound(title='Not found', description=str(ex)) raise falcon.HTTPNotFound(title='Not found', description=str(ex))
def create_app(): def create_app():
''' Creates a WSGI compatible App object. ''' ''' Create a WSGI compatible App object. '''
config = szurubooru.config.Config()
engine = sqlalchemy.create_engine( engine = sqlalchemy.create_engine(
'{schema}://{user}:{password}@{host}:{port}/{name}'.format( '{schema}://{user}:{password}@{host}:{port}/{name}'.format(
schema=config['database']['schema'], schema=config.config['database']['schema'],
user=config['database']['user'], user=config.config['database']['user'],
password=config['database']['pass'], password=config.config['database']['pass'],
host=config['database']['host'], host=config.config['database']['host'],
port=config['database']['port'], port=config.config['database']['port'],
name=config['database']['name'])) name=config.config['database']['name']))
session_maker = sqlalchemy.orm.sessionmaker(bind=engine) session_maker = sqlalchemy.orm.sessionmaker(bind=engine)
scoped_session = sqlalchemy.orm.scoped_session(session_maker) scoped_session = sqlalchemy.orm.scoped_session(session_maker)
# TODO: is there a better way?
mailer = szurubooru.services.Mailer(config)
password_service = szurubooru.services.PasswordService(config)
auth_service = szurubooru.services.AuthService(config, password_service)
user_service = szurubooru.services.UserService(config, password_service)
user_list_api = szurubooru.api.UserListApi(auth_service, user_service)
user_detail_api = szurubooru.api.UserDetailApi(
config, auth_service, password_service, user_service)
password_reminder_api = szurubooru.api.PasswordReminderApi(
config, mailer, user_service)
app = falcon.API( app = falcon.API(
request_type=_CustomRequest, request_type=_CustomRequest,
middleware=[ middleware=[
szurubooru.middleware.ImbueContext(), middleware.ImbueContext(),
szurubooru.middleware.RequireJson(), middleware.RequireJson(),
szurubooru.middleware.JsonTranslator(), middleware.JsonTranslator(),
szurubooru.middleware.DbSession(scoped_session), middleware.DbSession(scoped_session),
szurubooru.middleware.Authenticator(auth_service, user_service), middleware.Authenticator(),
]) ])
app.add_error_handler(szurubooru.errors.AuthError, _on_auth_error) user_list_api = api.UserListApi()
app.add_error_handler(szurubooru.errors.IntegrityError, _on_integrity_error) user_detail_api = api.UserDetailApi()
app.add_error_handler(szurubooru.errors.ValidationError, _on_validation_error) password_reminder_api = api.PasswordReminderApi()
app.add_error_handler(szurubooru.errors.SearchError, _on_search_error)
app.add_error_handler(szurubooru.errors.NotFoundError, _on_not_found_error) app.add_error_handler(errors.AuthError, _on_auth_error)
app.add_error_handler(errors.IntegrityError, _on_integrity_error)
app.add_error_handler(errors.ValidationError, _on_validation_error)
app.add_error_handler(errors.SearchError, _on_search_error)
app.add_error_handler(errors.NotFoundError, _on_not_found_error)
app.add_route('/users/', user_list_api) app.add_route('/users/', user_list_api)
app.add_route('/user/{user_name}', user_detail_api) app.add_route('/user/{user_name}', user_detail_api)

View file

@ -1,6 +1,6 @@
import os import os
import configobj import configobj
import szurubooru.errors from szurubooru import errors
class Config(object): class Config(object):
''' INI config parser and container. ''' ''' INI config parser and container. '''
@ -15,20 +15,22 @@ class Config(object):
def _validate(self): def _validate(self):
''' '''
Checks whether config.ini doesn't contain errors that might prove Check whether config.ini doesn't contain errors that might prove
lethal at runtime. lethal at runtime.
''' '''
all_ranks = self['service']['user_ranks'] all_ranks = self['service']['user_ranks']
for privilege, rank in self['privileges'].items(): for privilege, rank in self['privileges'].items():
if rank not in all_ranks: if rank not in all_ranks:
raise szurubooru.errors.ConfigError( raise errors.ConfigError(
'Rank %r for privilege %r is missing from user_ranks' % ( 'Rank %r for privilege %r is missing from user_ranks' % (
rank, privilege)) rank, privilege))
for rank in ['anonymous', 'admin', 'nobody']: for rank in ['anonymous', 'admin', 'nobody']:
if rank not in all_ranks: if rank not in all_ranks:
raise szurubooru.errors.ConfigError( raise errors.ConfigError(
'Fixed rank %r is missing from user_ranks' % rank) 'Fixed rank %r is missing from user_ranks' % rank)
if self['service']['default_user_rank'] not in all_ranks: if self['service']['default_user_rank'] not in all_ranks:
raise szurubooru.errors.ConfigError( raise errors.ConfigError(
'Default rank %r is missing from user_ranks' % ( 'Default rank %r is missing from user_ranks' % (
self['service']['default_user_rank'])) self['service']['default_user_rank']))
config = Config() # pylint: disable=invalid-name

View file

@ -0,0 +1,4 @@
''' Database layer. '''
from szurubooru.db.base import Base
from szurubooru.db.user import User

View file

@ -1,7 +1,5 @@
# pylint: disable=too-many-instance-attributes,too-few-public-methods
import sqlalchemy as sa import sqlalchemy as sa
from szurubooru.model.base import Base from szurubooru.db.base import Base
class User(Base): class User(Base):
__tablename__ = 'user' __tablename__ = 'user'

View file

@ -1,24 +1,20 @@
import base64 import base64
import falcon import falcon
from szurubooru import errors from szurubooru import db, errors
from szurubooru import model from szurubooru.util import auth, users
class Authenticator(object): class Authenticator(object):
''' '''
Authenticates every request and puts information on active user in the Authenticates every request and put information on active user in the
request context. request context.
''' '''
def __init__(self, auth_service, user_service):
self._auth_service = auth_service
self._user_service = user_service
def process_request(self, request, _response): def process_request(self, request, _response):
''' Executed before passing the request to the API. ''' ''' Bind the user to request. Update last login time if needed. '''
request.context.user = self._get_user(request) request.context.user = self._get_user(request)
if request.get_param_as_bool('bump-login') \ if request.get_param_as_bool('bump-login') \
and request.context.user.user_id: and request.context.user.user_id:
self._user_service.bump_login_time(request.context.user) users.bump_login_time(request.context.user)
request.context.session.commit() request.context.session.commit()
def _get_user(self, request): def _get_user(self, request):
@ -27,15 +23,12 @@ class Authenticator(object):
try: try:
auth_type, user_and_password = request.auth.split(' ', 1) auth_type, user_and_password = request.auth.split(' ', 1)
if auth_type.lower() != 'basic': if auth_type.lower() != 'basic':
raise falcon.HTTPBadRequest( raise falcon.HTTPBadRequest(
'Invalid authentication type', 'Invalid authentication type',
'Only basic authorization is supported.') 'Only basic authorization is supported.')
username, password = base64.decodebytes( username, password = base64.decodebytes(
user_and_password.encode('ascii')).decode('utf8').split(':') user_and_password.encode('ascii')).decode('utf8').split(':')
return self._authenticate( return self._authenticate(
request.context.session, username, password) request.context.session, username, password)
except ValueError as err: except ValueError as err:
@ -46,16 +39,16 @@ class Authenticator(object):
msg.format(request.auth, str(err))) msg.format(request.auth, str(err)))
def _authenticate(self, session, username, password): def _authenticate(self, session, username, password):
''' Tries to authenticate user. Throws AuthError for invalid users. ''' ''' Try to authenticate user. Throw AuthError for invalid users. '''
user = self._user_service.get_by_name(session, username) user = users.get_by_name(session, username)
if not user: if not user:
raise errors.AuthError('No such user.') raise errors.AuthError('No such user.')
if not self._auth_service.is_valid_password(user, password): if not auth.is_valid_password(user, password):
raise errors.AuthError('Invalid password.') raise errors.AuthError('Invalid password.')
return user return user
def _create_anonymous_user(self): def _create_anonymous_user(self):
user = model.User() user = db.User()
user.name = None user.name = None
user.access_rank = 'anonymous' user.access_rank = 'anonymous'
user.password = None user.password = None

View file

@ -5,12 +5,8 @@ class DbSession(object):
self._session_factory = session_factory self._session_factory = session_factory
def process_request(self, request, _response): def process_request(self, request, _response):
''' Executed before passing the request to the API. '''
request.context.session = self._session_factory() request.context.session = self._session_factory()
def process_response(self, request, _response, _resource): def process_response(self, request, _response, _resource):
''' # any commits need to happen explicitly in the API layer.
Executed before passing the response to falcon.
Any commits to database need to happen explicitly in the API layer.
'''
request.context.session.close() request.context.session.close()

View file

@ -16,7 +16,6 @@ class JsonTranslator(object):
''' '''
def process_request(self, request, _response): def process_request(self, request, _response):
''' Executed before passing the request to the API. '''
if request.content_length in (None, 0): if request.content_length in (None, 0):
return return
@ -36,7 +35,6 @@ class JsonTranslator(object):
'JSON was incorrect or not encoded as UTF-8.') 'JSON was incorrect or not encoded as UTF-8.')
def process_response(self, request, response, _resource): def process_response(self, request, response, _resource):
''' Executed before passing the response to falcon. '''
if 'result' not in request.context: if 'result' not in request.context:
return return
response.body = json.dumps( response.body = json.dumps(

View file

@ -4,7 +4,6 @@ class RequireJson(object):
''' Sanitizes requests so that only JSON is accepted. ''' ''' Sanitizes requests so that only JSON is accepted. '''
def process_request(self, request, _response): def process_request(self, request, _response):
''' Executed before passing the request to the API. '''
if not request.client_accepts_json: if not request.client_accepts_json:
raise falcon.HTTPNotAcceptable( raise falcon.HTTPNotAcceptable(
'This API only supports responses encoded as JSON.') 'This API only supports responses encoded as JSON.')

View file

@ -9,13 +9,13 @@ import logging.config
dir_to_self = os.path.dirname(os.path.realpath(__file__)) dir_to_self = os.path.dirname(os.path.realpath(__file__))
sys.path.append(os.path.join(dir_to_self, *[os.pardir] * 2)) sys.path.append(os.path.join(dir_to_self, *[os.pardir] * 2))
import szurubooru.model.base import szurubooru.db.base
import szurubooru.config import szurubooru.config
alembic_config = alembic.context.config alembic_config = alembic.context.config
logging.config.fileConfig(alembic_config.config_file_name) logging.config.fileConfig(alembic_config.config_file_name)
szuru_config = szurubooru.config.Config() szuru_config = szurubooru.config.config
alembic_config.set_main_option( alembic_config.set_main_option(
'sqlalchemy.url', 'sqlalchemy.url',
'{schema}://{user}:{password}@{host}:{port}/{name}'.format( '{schema}://{user}:{password}@{host}:{port}/{name}'.format(
@ -26,7 +26,7 @@ alembic_config.set_main_option(
port=szuru_config['database']['port'], port=szuru_config['database']['port'],
name=szuru_config['database']['name'])) name=szuru_config['database']['name']))
target_metadata = szurubooru.model.Base.metadata target_metadata = szurubooru.db.Base.metadata
def run_migrations_offline(): def run_migrations_offline():
''' '''

View file

@ -1,6 +0,0 @@
'''
Database models.
'''
from szurubooru.model.base import Base
from szurubooru.model.user import User

View file

@ -0,0 +1,4 @@
''' Search parsers and services. '''
from szurubooru.search.search_executor import SearchExecutor
from szurubooru.search.user_search_config import UserSearchConfig

View file

@ -1,11 +1,11 @@
import sqlalchemy import sqlalchemy
import szurubooru.errors import szurubooru.errors
from szurubooru import util from szurubooru.util import misc
from szurubooru.services.search import criteria from szurubooru.search import criteria
def _apply_criterion_to_column( def _apply_criterion_to_column(
column, query, criterion, allow_composite=True, allow_ranged=True): column, query, criterion, allow_composite=True, allow_ranged=True):
''' Decorates SQLAlchemy filter on given column using supplied criterion. ''' ''' Decorate SQLAlchemy filter on given column using supplied criterion. '''
if isinstance(criterion, criteria.StringSearchCriterion): if isinstance(criterion, criteria.StringSearchCriterion):
expr = column == criterion.value expr = column == criterion.value
if criterion.negated: if criterion.negated:
@ -32,11 +32,11 @@ def _apply_criterion_to_column(
def _apply_date_criterion_to_column(column, query, criterion): def _apply_date_criterion_to_column(column, query, criterion):
''' '''
Decorates SQLAlchemy filter on given column using supplied criterion. Decorate SQLAlchemy filter on given column using supplied criterion.
Parses the datetime inside the criterion. Parse the datetime inside the criterion.
''' '''
if isinstance(criterion, criteria.StringSearchCriterion): if isinstance(criterion, criteria.StringSearchCriterion):
min_date, max_date = util.parse_time_range(criterion.value) min_date, max_date = misc.parse_time_range(criterion.value)
expr = column.between(min_date, max_date) expr = column.between(min_date, max_date)
if criterion.negated: if criterion.negated:
expr = ~expr expr = ~expr
@ -44,7 +44,7 @@ def _apply_date_criterion_to_column(column, query, criterion):
elif isinstance(criterion, criteria.ArraySearchCriterion): elif isinstance(criterion, criteria.ArraySearchCriterion):
expr = sqlalchemy.sql.false() expr = sqlalchemy.sql.false()
for value in criterion.values: for value in criterion.values:
min_date, max_date = util.parse_time_range(value) min_date, max_date = misc.parse_time_range(value)
expr = expr | column.between(min_date, max_date) expr = expr | column.between(min_date, max_date)
if criterion.negated: if criterion.negated:
expr = ~expr expr = ~expr
@ -52,14 +52,14 @@ def _apply_date_criterion_to_column(column, query, criterion):
elif isinstance(criterion, criteria.RangedSearchCriterion): elif isinstance(criterion, criteria.RangedSearchCriterion):
assert criterion.min_value or criterion.max_value assert criterion.min_value or criterion.max_value
if criterion.min_value and criterion.max_value: if criterion.min_value and criterion.max_value:
min_date = util.parse_time_range(criterion.min_value)[0] min_date = misc.parse_time_range(criterion.min_value)[0]
max_date = util.parse_time_range(criterion.max_value)[1] max_date = misc.parse_time_range(criterion.max_value)[1]
expr = column.between(min_date, max_date) expr = column.between(min_date, max_date)
elif criterion.min_value: elif criterion.min_value:
min_date = util.parse_time_range(criterion.min_value)[0] min_date = misc.parse_time_range(criterion.min_value)[0]
expr = column >= min_date expr = column >= min_date
elif criterion.max_value: elif criterion.max_value:
max_date = util.parse_time_range(criterion.max_value)[1] max_date = misc.parse_time_range(criterion.max_value)[1]
expr = column <= max_date expr = column <= max_date
if criterion.negated: if criterion.negated:
expr = ~expr expr = ~expr

View file

@ -1,19 +1,17 @@
''' Exports SearchExecutor. '''
import re import re
import sqlalchemy import sqlalchemy
from szurubooru import errors from szurubooru import errors
from szurubooru.services.search import criteria from szurubooru.search import criteria
class SearchExecutor(object): class SearchExecutor(object):
ORDER_DESC = 1
ORDER_ASC = 2
''' '''
Class for search parsing and execution. Handles plaintext parsing and Class for search parsing and execution. Handles plaintext parsing and
delegates sqlalchemy filter decoration to SearchConfig instances. delegates sqlalchemy filter decoration to SearchConfig instances.
''' '''
ORDER_DESC = 1
ORDER_ASC = 2
def __init__(self, search_config): def __init__(self, search_config):
self.page_size = 100 self.page_size = 100
self._search_config = search_config self._search_config = search_config

View file

@ -0,0 +1,42 @@
from sqlalchemy.sql.expression import func
from szurubooru import db
from szurubooru.search.base_search_config import BaseSearchConfig
class UserSearchConfig(BaseSearchConfig):
''' Executes searches related to the users. '''
def create_query(self, session):
return session.query(db.User)
@property
def anonymous_filter(self):
return self._create_basic_filter(db.User.name, allow_ranged=False)
@property
def special_filters(self):
return {}
@property
def named_filters(self):
return {
'name': self._create_basic_filter(db.User.name, allow_ranged=False),
'creation_date': self._create_date_filter(db.User.creation_time),
'creation_time': self._create_date_filter(db.User.creation_time),
'last_login_date': self._create_date_filter(db.User.last_login_time),
'last_login_time': self._create_date_filter(db.User.last_login_time),
'login_date': self._create_date_filter(db.User.last_login_time),
'login_time': self._create_date_filter(db.User.last_login_time),
}
@property
def order_columns(self):
return {
'random': func.random(),
'name': db.User.name,
'creation_date': db.User.creation_time,
'creation_time': db.User.creation_time,
'last_login_date': db.User.last_login_time,
'last_login_time': db.User.last_login_time,
'login_date': db.User.last_login_time,
'login_time': db.User.last_login_time,
}

View file

@ -1,9 +0,0 @@
'''
Middle layer between REST API and database.
All the business logic goes here.
'''
from szurubooru.services.mailer import Mailer
from szurubooru.services.auth_service import AuthService
from szurubooru.services.user_service import UserService
from szurubooru.services.password_service import PasswordService

View file

@ -1,28 +0,0 @@
from szurubooru import errors
class AuthService(object):
def __init__(self, config, password_service):
self._config = config
self._password_service = password_service
def is_valid_password(self, user, password):
''' Returns whether the given password for a given user is valid. '''
salt, valid_hash = user.password_salt, user.password_hash
possible_hashes = [
self._password_service.get_password_hash(salt, password),
self._password_service.get_legacy_password_hash(salt, password)
]
return valid_hash in possible_hashes
def verify_privilege(self, user, privilege_name):
'''
Throws an AuthError if the given user doesn't have given privilege.
'''
all_ranks = self._config['service']['user_ranks']
assert privilege_name in self._config['privileges']
assert user.access_rank in all_ranks
minimal_rank = self._config['privileges'][privilege_name]
good_ranks = all_ranks[all_ranks.index(minimal_rank):]
if user.access_rank not in good_ranks:
raise errors.AuthError('Insufficient privileges to do this.')

View file

@ -1,19 +0,0 @@
import smtplib
import email.mime.text
class Mailer(object):
def __init__(self, config):
self._config = config
def send(self, sender, recipient, subject, body):
msg = email.mime.text.MIMEText(body)
msg['Subject'] = subject
msg['From'] = sender
msg['To'] = recipient
smtp = smtplib.SMTP(
self._config['smtp']['host'],
int(self._config['smtp']['port']))
smtp.login(self._config['smtp']['user'], self._config['smtp']['pass'])
smtp.send_message(msg)
smtp.quit()

View file

@ -1,34 +0,0 @@
import hashlib
import random
class PasswordService(object):
''' Stateless utilities for passwords '''
def __init__(self, config):
self._config = config
def get_password_hash(self, salt, password):
''' Retrieves new-style password hash. '''
digest = hashlib.sha256()
digest.update(self._config['basic']['secret'].encode('utf8'))
digest.update(salt.encode('utf8'))
digest.update(password.encode('utf8'))
return digest.hexdigest()
def get_legacy_password_hash(self, salt, password):
''' Retrieves old-style password hash. '''
digest = hashlib.sha1()
digest.update(b'1A2/$_4xVa')
digest.update(salt.encode('utf8'))
digest.update(password.encode('utf8'))
return digest.hexdigest()
def create_password(self):
''' Creates an easy-to-remember password. '''
alphabet = {
'c': list('bcdfghijklmnpqrstvwxyz'),
'v': list('aeiou'),
'n': list('0123456789'),
}
pattern = 'cvcvnncvcv'
return ''.join(random.choice(alphabet[l]) for l in list(pattern))

View file

@ -1,2 +0,0 @@
from szurubooru.services.search.search_executor import SearchExecutor
from szurubooru.services.search.user_search_config import UserSearchConfig

View file

@ -1,44 +0,0 @@
''' Exports UserSearchConfig. '''
from sqlalchemy.sql.expression import func
from szurubooru.model import User
from szurubooru.services.search.base_search_config import BaseSearchConfig
class UserSearchConfig(BaseSearchConfig):
''' Executes searches related to the users. '''
def create_query(self, session):
return session.query(User)
@property
def anonymous_filter(self):
return self._create_basic_filter(User.name, allow_ranged=False)
@property
def special_filters(self):
return {}
@property
def named_filters(self):
return {
'name': self._create_basic_filter(User.name, allow_ranged=False),
'creation_date': self._create_date_filter(User.creation_time),
'creation_time': self._create_date_filter(User.creation_time),
'last_login_date': self._create_date_filter(User.last_login_time),
'last_login_time': self._create_date_filter(User.last_login_time),
'login_date': self._create_date_filter(User.last_login_time),
'login_time': self._create_date_filter(User.last_login_time),
}
@property
def order_columns(self):
return {
'random': func.random(),
'name': User.name,
'creation_date': User.creation_time,
'creation_time': User.creation_time,
'last_login_date': User.last_login_time,
'last_login_time': User.last_login_time,
'login_date': User.last_login_time,
'login_time': User.last_login_time,
}

View file

@ -1,56 +0,0 @@
import re
from datetime import datetime
from szurubooru import errors
from szurubooru import model
from szurubooru import util
class UserService(object):
''' User management '''
def __init__(self, config, password_service):
self._config = config
self._password_service = password_service
self._name_regex = self._config['service']['user_name_regex']
self._password_regex = self._config['service']['password_regex']
def create_user(self, session, name, password, email):
''' Creates an user with given parameters and returns it. '''
if not re.match(self._name_regex, name):
raise errors.ValidationError(
'Name must satisfy regex %r.' % self._name_regex)
if not re.match(self._password_regex, password):
raise errors.ValidationError(
'Password must satisfy regex %r.' % self._password_regex)
if not util.is_valid_email(email):
raise errors.ValidationError(
'%r is not a vaild email address.' % email)
user = model.User()
user.name = name
user.password_salt = self._password_service.create_password()
user.password_hash = self._password_service.get_password_hash(
user.password_salt, password)
user.email = email or None
user.access_rank = self._config['service']['default_user_rank']
user.creation_time = datetime.now()
user.avatar_style = model.User.AVATAR_GRAVATAR
session.add(user)
return user
def bump_login_time(self, user):
user.last_login_time = datetime.now()
def reset_password(self, user):
password = self._password_service.create_password()
user.password_salt = self._password_service.create_password()
user.password_hash = self._password_service.get_password_hash(
user.password_salt, password)
return password
def get_by_name(self, session, name):
''' Retrieves an user by its name. '''
return session.query(model.User).filter_by(name=name).first()

View file

@ -1,15 +1,12 @@
from datetime import datetime from datetime import datetime
import szurubooru.services from szurubooru import api, db, errors, config
from szurubooru.api.user_api import UserDetailApi from szurubooru.util import auth, misc
from szurubooru.errors import AuthError, ValidationError
from szurubooru.model.user import User
from szurubooru.tests.database_test_case import DatabaseTestCase from szurubooru.tests.database_test_case import DatabaseTestCase
from szurubooru.util import dotdict
class TestUserDetailApi(DatabaseTestCase): class TestUserDetailApi(DatabaseTestCase):
def setUp(self): def setUp(self):
super().setUp() super().setUp()
config = { config_mock = {
'basic': { 'basic': {
'secret': '', 'secret': '',
}, },
@ -30,18 +27,18 @@ class TestUserDetailApi(DatabaseTestCase):
'users:edit:any:rank': 'admin', 'users:edit:any:rank': 'admin',
} }
} }
password_service = szurubooru.services.PasswordService(config) self.old_config = config.config
auth_service = szurubooru.services.AuthService(config, password_service) config.config = config_mock
user_service = szurubooru.services.UserService(config, password_service) self.api = api.UserDetailApi()
self.auth_service = auth_service self.context = misc.dotdict()
self.api = UserDetailApi(
config, auth_service, password_service, user_service)
self.context = dotdict()
self.context.session = self.session self.context.session = self.session
self.context.request = {} self.context.request = {}
def tearDown(self):
config.config = self.old_config
def _create_user(self, name, rank='admin'): def _create_user(self, name, rank='admin'):
user = User() user = db.User()
user.name = name user.name = name
user.password = 'dummy' user.password = 'dummy'
user.password_salt = 'dummy' user.password_salt = 'dummy'
@ -49,7 +46,7 @@ class TestUserDetailApi(DatabaseTestCase):
user.email = 'dummy' user.email = 'dummy'
user.access_rank = rank user.access_rank = rank
user.creation_time = datetime.now() user.creation_time = datetime.now()
user.avatar_style = User.AVATAR_GRAVATAR user.avatar_style = db.User.AVATAR_GRAVATAR
return user return user
def test_updating_nothing(self): def test_updating_nothing(self):
@ -57,7 +54,7 @@ class TestUserDetailApi(DatabaseTestCase):
self.session.add(admin_user) self.session.add(admin_user)
self.context.user = admin_user self.context.user = admin_user
self.api.put(self.context, 'u1') self.api.put(self.context, 'u1')
admin_user = self.session.query(User).filter_by(name='u1').one() admin_user = self.session.query(db.User).filter_by(name='u1').one()
self.assertEqual(admin_user.name, 'u1') self.assertEqual(admin_user.name, 'u1')
self.assertEqual(admin_user.email, 'dummy') self.assertEqual(admin_user.email, 'dummy')
self.assertEqual(admin_user.access_rank, 'admin') self.assertEqual(admin_user.access_rank, 'admin')
@ -69,16 +66,16 @@ class TestUserDetailApi(DatabaseTestCase):
self.context.request = { self.context.request = {
'name': 'chewie', 'name': 'chewie',
'email': 'asd@asd.asd', 'email': 'asd@asd.asd',
'password': 'valid', 'password': 'oks',
'accessRank': 'mod', 'accessRank': 'mod',
} }
self.api.put(self.context, 'u1') self.api.put(self.context, 'u1')
admin_user = self.session.query(User).filter_by(name='chewie').one() admin_user = self.session.query(db.User).filter_by(name='chewie').one()
self.assertEqual(admin_user.name, 'chewie') self.assertEqual(admin_user.name, 'chewie')
self.assertEqual(admin_user.email, 'asd@asd.asd') self.assertEqual(admin_user.email, 'asd@asd.asd')
self.assertEqual(admin_user.access_rank, 'mod') self.assertEqual(admin_user.access_rank, 'mod')
self.assertTrue(self.auth_service.is_valid_password(admin_user, 'valid')) self.assertTrue(auth.is_valid_password(admin_user, 'oks'))
self.assertFalse(self.auth_service.is_valid_password(admin_user, 'invalid')) self.assertFalse(auth.is_valid_password(admin_user, 'invalid'))
def test_removing_email(self): def test_removing_email(self):
admin_user = self._create_user('u1', 'admin') admin_user = self._create_user('u1', 'admin')
@ -86,7 +83,7 @@ class TestUserDetailApi(DatabaseTestCase):
self.context.user = admin_user self.context.user = admin_user
self.context.request = {'email': ''} self.context.request = {'email': ''}
self.api.put(self.context, 'u1') self.api.put(self.context, 'u1')
admin_user = self.session.query(User).filter_by(name='u1').one() admin_user = self.session.query(db.User).filter_by(name='u1').one()
self.assertEqual(admin_user.email, None) self.assertEqual(admin_user.email, None)
def test_invalid_inputs(self): def test_invalid_inputs(self):
@ -95,16 +92,16 @@ class TestUserDetailApi(DatabaseTestCase):
self.context.user = admin_user self.context.user = admin_user
self.context.request = {'name': '.'} self.context.request = {'name': '.'}
self.assertRaises( self.assertRaises(
ValidationError, self.api.put, self.context, 'u1') errors.ValidationError, self.api.put, self.context, 'u1')
self.context.request = {'password': '.'} self.context.request = {'password': '.'}
self.assertRaises( self.assertRaises(
ValidationError, self.api.put, self.context, 'u1') errors.ValidationError, self.api.put, self.context, 'u1')
self.context.request = {'accessRank': '.'} self.context.request = {'accessRank': '.'}
self.assertRaises( self.assertRaises(
ValidationError, self.api.put, self.context, 'u1') errors.ValidationError, self.api.put, self.context, 'u1')
self.context.request = {'email': '.'} self.context.request = {'email': '.'}
self.assertRaises( self.assertRaises(
ValidationError, self.api.put, self.context, 'u1') errors.ValidationError, self.api.put, self.context, 'u1')
def test_user_trying_to_update_someone_else(self): def test_user_trying_to_update_someone_else(self):
user1 = self._create_user('u1', 'regular_user') user1 = self._create_user('u1', 'regular_user')
@ -118,7 +115,7 @@ class TestUserDetailApi(DatabaseTestCase):
{'password': 'whatever'}]: {'password': 'whatever'}]:
self.context.request = request self.context.request = request
self.assertRaises( self.assertRaises(
AuthError, self.api.put, self.context, user2.name) errors.AuthError, self.api.put, self.context, user2.name)
def test_user_trying_to_become_someone_else(self): def test_user_trying_to_become_someone_else(self):
user1 = self._create_user('u1', 'regular_user') user1 = self._create_user('u1', 'regular_user')
@ -127,7 +124,7 @@ class TestUserDetailApi(DatabaseTestCase):
self.context.user = user1 self.context.user = user1
self.context.request = {'name': 'u2'} self.context.request = {'name': 'u2'}
self.assertRaises( self.assertRaises(
ValidationError, self.api.put, self.context, 'u1') errors.ValidationError, self.api.put, self.context, 'u1')
def test_mods_trying_to_become_admin(self): def test_mods_trying_to_become_admin(self):
user1 = self._create_user('u1', 'mod') user1 = self._create_user('u1', 'mod')
@ -136,6 +133,6 @@ class TestUserDetailApi(DatabaseTestCase):
self.context.user = user1 self.context.user = user1
self.context.request = {'accessRank': 'admin'} self.context.request = {'accessRank': 'admin'}
self.assertRaises( self.assertRaises(
AuthError, self.api.put, self.context, user1.name) errors.AuthError, self.api.put, self.context, user1.name)
self.assertRaises( self.assertRaises(
AuthError, self.api.put, self.context, user2.name) errors.AuthError, self.api.put, self.context, user2.name)

View file

@ -1,11 +1,11 @@
import unittest import unittest
import sqlalchemy import sqlalchemy
from szurubooru.model import Base from szurubooru import db
class DatabaseTestCase(unittest.TestCase): class DatabaseTestCase(unittest.TestCase):
def setUp(self): def setUp(self):
engine = sqlalchemy.create_engine('sqlite:///:memory:') engine = sqlalchemy.create_engine('sqlite:///:memory:')
session_maker = sqlalchemy.orm.sessionmaker(bind=engine) session_maker = sqlalchemy.orm.sessionmaker(bind=engine)
self.session = sqlalchemy.orm.scoped_session(session_maker) self.session = sqlalchemy.orm.scoped_session(session_maker)
Base.query = self.session.query_property() db.Base.query = self.session.query_property()
Base.metadata.create_all(bind=engine) db.Base.metadata.create_all(bind=engine)

View file

@ -1,7 +1,5 @@
from datetime import datetime from datetime import datetime
from szurubooru import errors from szurubooru import db, errors, search
from szurubooru import model
from szurubooru.services import search
from szurubooru.tests.database_test_case import DatabaseTestCase from szurubooru.tests.database_test_case import DatabaseTestCase
class TestUserSearchExecutor(DatabaseTestCase): class TestUserSearchExecutor(DatabaseTestCase):
@ -11,7 +9,7 @@ class TestUserSearchExecutor(DatabaseTestCase):
self.executor = search.SearchExecutor(self.search_config) self.executor = search.SearchExecutor(self.search_config)
def _create_user(self, name): def _create_user(self, name):
user = model.User() user = db.User()
user.name = name user.name = name
user.password = 'dummy' user.password = 'dummy'
user.password_salt = 'dummy' user.password_salt = 'dummy'
@ -19,7 +17,7 @@ class TestUserSearchExecutor(DatabaseTestCase):
user.email = 'dummy' user.email = 'dummy'
user.access_rank = 'dummy' user.access_rank = 'dummy'
user.creation_time = datetime.now() user.creation_time = datetime.now()
user.avatar_style = model.User.AVATAR_GRAVATAR user.avatar_style = db.User.AVATAR_GRAVATAR
return user return user
def _test(self, query, page, expected_count, expected_user_names): def _test(self, query, page, expected_count, expected_user_names):

View file

@ -1,8 +1,7 @@
import unittest import unittest
from datetime import datetime from datetime import datetime
import szurubooru.util from szurubooru import errors
from szurubooru.util import parse_time_range from szurubooru.util import misc
from szurubooru.errors import ValidationError
class FakeDatetime(datetime): class FakeDatetime(datetime):
@staticmethod @staticmethod
@ -11,33 +10,33 @@ class FakeDatetime(datetime):
class TestParseTime(unittest.TestCase): class TestParseTime(unittest.TestCase):
def test_empty(self): def test_empty(self):
self.assertRaises(ValidationError, parse_time_range, '') self.assertRaises(errors.ValidationError, misc.parse_time_range, '')
def test_today(self): def test_today(self):
szurubooru.util.datetime.datetime = FakeDatetime misc.datetime.datetime = FakeDatetime
date_min, date_max = parse_time_range('today') date_min, date_max = misc.parse_time_range('today')
self.assertEqual(date_min, datetime(1997, 1, 2, 0, 0, 0)) self.assertEqual(date_min, datetime(1997, 1, 2, 0, 0, 0))
self.assertEqual(date_max, datetime(1997, 1, 2, 23, 59, 59)) self.assertEqual(date_max, datetime(1997, 1, 2, 23, 59, 59))
def test_yesterday(self): def test_yesterday(self):
szurubooru.util.datetime.datetime = FakeDatetime misc.datetime.datetime = FakeDatetime
date_min, date_max = parse_time_range('yesterday') date_min, date_max = misc.parse_time_range('yesterday')
self.assertEqual(date_min, datetime(1997, 1, 1, 0, 0, 0)) self.assertEqual(date_min, datetime(1997, 1, 1, 0, 0, 0))
self.assertEqual(date_max, datetime(1997, 1, 1, 23, 59, 59)) self.assertEqual(date_max, datetime(1997, 1, 1, 23, 59, 59))
def test_year(self): def test_year(self):
date_min, date_max = parse_time_range('1999') date_min, date_max = misc.parse_time_range('1999')
self.assertEqual(date_min, datetime(1999, 1, 1, 0, 0, 0)) self.assertEqual(date_min, datetime(1999, 1, 1, 0, 0, 0))
self.assertEqual(date_max, datetime(1999, 12, 31, 23, 59, 59)) self.assertEqual(date_max, datetime(1999, 12, 31, 23, 59, 59))
def test_month(self): def test_month(self):
for text in ['1999-2', '1999-02']: for text in ['1999-2', '1999-02']:
date_min, date_max = parse_time_range(text) date_min, date_max = misc.parse_time_range(text)
self.assertEqual(date_min, datetime(1999, 2, 1, 0, 0, 0)) self.assertEqual(date_min, datetime(1999, 2, 1, 0, 0, 0))
self.assertEqual(date_max, datetime(1999, 2, 28, 23, 59, 59)) self.assertEqual(date_max, datetime(1999, 2, 28, 23, 59, 59))
def test_day(self): def test_day(self):
for text in ['1999-2-6', '1999-02-6', '1999-2-06', '1999-02-06']: for text in ['1999-2-6', '1999-02-6', '1999-2-06', '1999-02-06']:
date_min, date_max = parse_time_range(text) date_min, date_max = misc.parse_time_range(text)
self.assertEqual(date_min, datetime(1999, 2, 6, 0, 0, 0)) self.assertEqual(date_min, datetime(1999, 2, 6, 0, 0, 0))
self.assertEqual(date_max, datetime(1999, 2, 6, 23, 59, 59)) self.assertEqual(date_max, datetime(1999, 2, 6, 23, 59, 59))

View file

@ -0,0 +1 @@
''' Cool functions. '''

View file

@ -0,0 +1,59 @@
import hashlib
import random
from szurubooru import config
from szurubooru import errors
def get_password_hash(salt, password):
''' Retrieve new-style password hash. '''
digest = hashlib.sha256()
digest.update(config.config['basic']['secret'].encode('utf8'))
digest.update(salt.encode('utf8'))
digest.update(password.encode('utf8'))
return digest.hexdigest()
def get_legacy_password_hash(salt, password):
''' Retrieve old-style password hash. '''
digest = hashlib.sha1()
digest.update(b'1A2/$_4xVa')
digest.update(salt.encode('utf8'))
digest.update(password.encode('utf8'))
return digest.hexdigest()
def create_password():
''' Create an easy-to-remember password. '''
alphabet = {
'c': list('bcdfghijklmnpqrstvwxyz'),
'v': list('aeiou'),
'n': list('0123456789'),
}
pattern = 'cvcvnncvcv'
return ''.join(random.choice(alphabet[l]) for l in list(pattern))
def is_valid_password(user, password):
''' Return whether the given password for a given user is valid. '''
salt, valid_hash = user.password_salt, user.password_hash
possible_hashes = [
get_password_hash(salt, password),
get_legacy_password_hash(salt, password)
]
return valid_hash in possible_hashes
def verify_privilege(user, privilege_name):
'''
Throw an AuthError if the given user doesn't have given privilege.
'''
all_ranks = config.config['service']['user_ranks']
assert privilege_name in config.config['privileges']
assert user.access_rank in all_ranks
minimal_rank = config.config['privileges'][privilege_name]
good_ranks = all_ranks[all_ranks.index(minimal_rank):]
if user.access_rank not in good_ranks:
raise errors.AuthError('Insufficient privileges to do this.')
def generate_authentication_token(user):
''' Generate nonguessable challenge (e.g. links in password reminder). '''
digest = hashlib.sha256()
digest.update(config.config['basic']['secret'].encode('utf8'))
digest.update(user.password_salt.encode('utf8'))
return digest.hexdigest()

View file

@ -0,0 +1,15 @@
import smtplib
import email.mime.text
from szurubooru import config
def send_mail(sender, recipient, subject, body):
msg = email.mime.text.MIMEText(body)
msg['Subject'] = subject
msg['From'] = sender
msg['To'] = recipient
smtp = smtplib.SMTP(
config.config['smtp']['host'], int(config.config['smtp']['port']))
smtp.login(config.config['smtp']['user'], config.config['smtp']['pass'])
smtp.send_message(msg)
smtp.quit()

View file

@ -3,7 +3,7 @@ import re
from szurubooru.errors import ValidationError from szurubooru.errors import ValidationError
def is_valid_email(email): def is_valid_email(email):
''' Validates given email address. ''' ''' Return whether given email address is valid or empty. '''
return not email or re.match(r'^[^@]*@[^@]*\.[^@]*$', email) return not email or re.match(r'^[^@]*@[^@]*\.[^@]*$', email)
class dotdict(dict): # pylint: disable=invalid-name class dotdict(dict): # pylint: disable=invalid-name
@ -14,7 +14,7 @@ class dotdict(dict): # pylint: disable=invalid-name
__delattr__ = dict.__delitem__ __delattr__ = dict.__delitem__
def parse_time_range(value, timezone=datetime.timezone(datetime.timedelta())): def parse_time_range(value, timezone=datetime.timezone(datetime.timedelta())):
''' Returns tuple containing min/max time for given text representation. ''' ''' Return tuple containing min/max time for given text representation. '''
one_day = datetime.timedelta(days=1) one_day = datetime.timedelta(days=1)
one_second = datetime.timedelta(seconds=1) one_second = datetime.timedelta(seconds=1)

View file

@ -0,0 +1,67 @@
import re
from datetime import datetime
from szurubooru import config, db, errors
from szurubooru.util import auth, misc
def create_user(name, password, email):
''' Create an user with given parameters and returns it. '''
user = db.User()
update_name(user, name)
update_password(user, password)
update_email(user, email)
user.access_rank = config.config['service']['default_user_rank']
user.creation_time = datetime.now()
user.avatar_style = db.User.AVATAR_GRAVATAR
return user
def update_name(user, name):
''' Validate and update user's name. '''
name = name.strip()
name_regex = config.config['service']['user_name_regex']
if not re.match(name_regex, name):
raise errors.ValidationError(
'Name must satisfy regex %r.' % name_regex)
user.name = name
def update_password(user, password):
''' Validate and update user's password. '''
password_regex = config.config['service']['password_regex']
if not re.match(password_regex, password):
raise errors.ValidationError(
'Password must satisfy regex %r.' % password_regex)
user.password_salt = auth.create_password()
user.password_hash = auth.get_password_hash(user.password_salt, password)
def update_email(user, email):
''' Validate and update user's email. '''
email = email.strip() or None
if not misc.is_valid_email(email):
raise errors.ValidationError(
'%r is not a vaild email address.' % email)
user.email = email
def update_rank(user, rank, authenticated_user):
rank = rank.strip()
available_access_ranks = config.config['service']['user_ranks']
if not rank in available_access_ranks:
raise errors.ValidationError(
'Bad access rank. Valid access ranks: %r' % available_access_ranks)
if available_access_ranks.index(authenticated_user.access_rank) \
< available_access_ranks.index(rank):
raise errors.AuthError('Trying to set higher access rank than one has')
user.access_rank = rank
def bump_login_time(user):
''' Update user's login time to current date. '''
user.last_login_time = datetime.now()
def reset_password(user):
''' Reset password for an user. '''
password = auth.create_password()
user.password_salt = auth.create_password()
user.password_hash = auth.get_password_hash(user.password_salt, password)
return password
def get_by_name(session, name):
''' Retrieve an user by its name. '''
return session.query(db.User).filter_by(name=name).first()