server/general: be more pythonic
This commit is contained in:
parent
7f4708c696
commit
219ab7c2c3
36 changed files with 335 additions and 428 deletions
|
@ -1,5 +1,5 @@
|
|||
import hashlib
|
||||
from szurubooru import errors
|
||||
from szurubooru import config, errors
|
||||
from szurubooru.util import auth, mailer, users
|
||||
from szurubooru.api.base_api import BaseApi
|
||||
|
||||
MAIL_SUBJECT = 'Password reset for {name}'
|
||||
|
@ -9,45 +9,35 @@ MAIL_BODY = \
|
|||
'Otherwise, please ignore this email.'
|
||||
|
||||
class PasswordReminderApi(BaseApi):
|
||||
def __init__(self, config, mailer, user_service):
|
||||
super().__init__()
|
||||
self._config = config
|
||||
self._mailer = mailer
|
||||
self._user_service = user_service
|
||||
|
||||
def get(self, context, user_name):
|
||||
user = self._user_service.get_by_name(context.session, user_name)
|
||||
''' Send a mail with secure token to the correlated user. '''
|
||||
user = users.get_by_name(context.session, user_name)
|
||||
if not user:
|
||||
raise errors.NotFoundError('User %r not found.' % user_name)
|
||||
if not user.email:
|
||||
raise errors.ValidationError(
|
||||
'User %r hasn\'t supplied email. Cannot reset password.' % user_name)
|
||||
token = self._generate_authentication_token(user)
|
||||
token = auth.generate_authentication_token(user)
|
||||
url = '%s/password-reset/%s' % (
|
||||
self._config['basic']['base_url'].rstrip('/'), token)
|
||||
self._mailer.send(
|
||||
'noreply@%s' % self._config['basic']['name'],
|
||||
config.config['basic']['base_url'].rstrip('/'), token)
|
||||
mailer.send_mail(
|
||||
'noreply@%s' % config.config['basic']['name'],
|
||||
user.email,
|
||||
MAIL_SUBJECT.format(name=self._config['basic']['name']),
|
||||
MAIL_BODY.format(name=self._config['basic']['name'], url=url))
|
||||
MAIL_SUBJECT.format(name=config.config['basic']['name']),
|
||||
MAIL_BODY.format(name=config.config['basic']['name'], url=url))
|
||||
return {}
|
||||
|
||||
def post(self, context, user_name):
|
||||
user = self._user_service.get_by_name(context.session, user_name)
|
||||
''' Verify token from mail, generate a new password and return it. '''
|
||||
user = users.get_by_name(context.session, user_name)
|
||||
if not user:
|
||||
raise errors.NotFoundError('User %r not found.' % user_name)
|
||||
good_token = self._generate_authentication_token(user)
|
||||
good_token = auth.generate_authentication_token(user)
|
||||
if not 'token' in context.request:
|
||||
raise errors.ValidationError('Missing password reset token.')
|
||||
token = context.request['token']
|
||||
if token != good_token:
|
||||
raise errors.ValidationError('Invalid password reset token.')
|
||||
new_password = self._user_service.reset_password(user)
|
||||
new_password = users.reset_password(user)
|
||||
context.session.commit()
|
||||
return {'password': new_password}
|
||||
|
||||
def _generate_authentication_token(self, user):
|
||||
digest = hashlib.sha256()
|
||||
digest.update(self._config['basic']['secret'].encode('utf8'))
|
||||
digest.update(user.password_salt.encode('utf8'))
|
||||
return digest.hexdigest()
|
||||
|
|
|
@ -1,9 +1,7 @@
|
|||
import re
|
||||
import sqlalchemy
|
||||
from szurubooru import errors
|
||||
from szurubooru import util
|
||||
from szurubooru import errors, search
|
||||
from szurubooru.util import auth, users
|
||||
from szurubooru.api.base_api import BaseApi
|
||||
from szurubooru.services import search
|
||||
|
||||
def _serialize_user(authenticated_user, user):
|
||||
ret = {
|
||||
|
@ -21,29 +19,27 @@ def _serialize_user(authenticated_user, user):
|
|||
class UserListApi(BaseApi):
|
||||
''' API for lists of users. '''
|
||||
|
||||
def __init__(self, auth_service, user_service):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self._auth_service = auth_service
|
||||
self._user_service = user_service
|
||||
self._search_executor = search.SearchExecutor(search.UserSearchConfig())
|
||||
|
||||
def get(self, context):
|
||||
''' Retrieves a list of users. '''
|
||||
self._auth_service.verify_privilege(context.user, 'users:list')
|
||||
''' Retrieve a list of users. '''
|
||||
auth.verify_privilege(context.user, 'users:list')
|
||||
query = context.get_param_as_string('query')
|
||||
page = context.get_param_as_int('page', 1)
|
||||
count, users = self._search_executor.execute(context.session, query, page)
|
||||
count, user_list = self._search_executor.execute(context.session, query, page)
|
||||
return {
|
||||
'query': query,
|
||||
'page': page,
|
||||
'page_size': self._search_executor.page_size,
|
||||
'total': count,
|
||||
'users': [_serialize_user(context.user, user) for user in users],
|
||||
'users': [_serialize_user(context.user, user) for user in user_list],
|
||||
}
|
||||
|
||||
def post(self, context):
|
||||
''' Creates a new user. '''
|
||||
self._auth_service.verify_privilege(context.user, 'users:create')
|
||||
''' Create a new user. '''
|
||||
auth.verify_privilege(context.user, 'users:create')
|
||||
|
||||
try:
|
||||
name = context.request['name'].strip()
|
||||
|
@ -52,9 +48,9 @@ class UserListApi(BaseApi):
|
|||
except KeyError as ex:
|
||||
raise errors.ValidationError('Field %r not found.' % ex.args[0])
|
||||
|
||||
user = self._user_service.create_user(
|
||||
context.session, name, password, email)
|
||||
user = users.create_user(name, password, email)
|
||||
try:
|
||||
context.session.add(user)
|
||||
context.session.commit()
|
||||
except sqlalchemy.exc.IntegrityError:
|
||||
raise errors.IntegrityError('User %r already exists.' % name)
|
||||
|
@ -63,26 +59,17 @@ class UserListApi(BaseApi):
|
|||
class UserDetailApi(BaseApi):
|
||||
''' API for individual users. '''
|
||||
|
||||
def __init__(self, config, auth_service, password_service, user_service):
|
||||
super().__init__()
|
||||
self._available_access_ranks = config['service']['user_ranks']
|
||||
self._name_regex = config['service']['user_name_regex']
|
||||
self._password_regex = config['service']['password_regex']
|
||||
self._password_service = password_service
|
||||
self._auth_service = auth_service
|
||||
self._user_service = user_service
|
||||
|
||||
def get(self, context, user_name):
|
||||
''' Retrieves an user. '''
|
||||
self._auth_service.verify_privilege(context.user, 'users:view')
|
||||
user = self._user_service.get_by_name(context.session, user_name)
|
||||
''' Retrieve an user. '''
|
||||
auth.verify_privilege(context.user, 'users:view')
|
||||
user = users.get_by_name(context.session, user_name)
|
||||
if not user:
|
||||
raise errors.NotFoundError('User %r not found.' % user_name)
|
||||
return {'user': _serialize_user(context.user, user)}
|
||||
|
||||
def put(self, context, user_name):
|
||||
''' Updates an existing user. '''
|
||||
user = self._user_service.get_by_name(context.session, user_name)
|
||||
''' Update an existing user. '''
|
||||
user = users.get_by_name(context.session, user_name)
|
||||
if not user:
|
||||
raise errors.NotFoundError('User %r not found.' % user_name)
|
||||
|
||||
|
@ -92,53 +79,26 @@ class UserDetailApi(BaseApi):
|
|||
infix = 'any'
|
||||
|
||||
if 'name' in context.request:
|
||||
self._auth_service.verify_privilege(
|
||||
context.user, 'users:edit:%s:name' % infix)
|
||||
name = context.request['name'].strip()
|
||||
if not re.match(self._name_regex, name):
|
||||
raise errors.ValidationError(
|
||||
'Name must satisfy regex %r.' % self._name_regex)
|
||||
user.name = name
|
||||
auth.verify_privilege(context.user, 'users:edit:%s:name' % infix)
|
||||
users.update_name(user, context.request['name'])
|
||||
|
||||
if 'password' in context.request:
|
||||
password = context.request['password']
|
||||
self._auth_service.verify_privilege(
|
||||
context.user, 'users:edit:%s:pass' % infix)
|
||||
if not re.match(self._password_regex, password):
|
||||
raise errors.ValidationError(
|
||||
'Password must satisfy regex %r.' % self._password_regex)
|
||||
user.password_salt = self._password_service.create_password()
|
||||
user.password_hash = self._password_service.get_password_hash(
|
||||
user.password_salt, password)
|
||||
auth.verify_privilege(context.user, 'users:edit:%s:pass' % infix)
|
||||
users.update_password(user, context.request['password'])
|
||||
|
||||
if 'email' in context.request:
|
||||
self._auth_service.verify_privilege(
|
||||
context.user, 'users:edit:%s:email' % infix)
|
||||
email = context.request['email'].strip() or None
|
||||
if not util.is_valid_email(email):
|
||||
raise errors.ValidationError(
|
||||
'%r is not a vaild email address.' % email)
|
||||
user.email = email
|
||||
auth.verify_privilege(context.user, 'users:edit:%s:email' % infix)
|
||||
users.update_email(user, context.request['email'])
|
||||
|
||||
if 'accessRank' in context.request:
|
||||
self._auth_service.verify_privilege(
|
||||
context.user, 'users:edit:%s:rank' % infix)
|
||||
rank = context.request['accessRank'].strip()
|
||||
if not rank in self._available_access_ranks:
|
||||
raise errors.ValidationError(
|
||||
'Bad access rank. Valid access ranks: %r' \
|
||||
% self._available_access_ranks)
|
||||
if self._available_access_ranks.index(context.user.access_rank) \
|
||||
< self._available_access_ranks.index(rank):
|
||||
raise errors.AuthError(
|
||||
'Trying to set higher access rank than one has')
|
||||
user.access_rank = rank
|
||||
auth.verify_privilege(context.user, 'users:edit:%s:rank' % infix)
|
||||
users.update_rank(user, context.request['accessRank'], context.user)
|
||||
|
||||
# TODO: avatar
|
||||
|
||||
try:
|
||||
context.session.commit()
|
||||
except sqlalchemy.exc.IntegrityError:
|
||||
raise errors.IntegrityError('User %r already exists.' % name)
|
||||
raise errors.IntegrityError('User %r already exists.' % user.name)
|
||||
|
||||
return {'user': _serialize_user(context.user, user)}
|
||||
|
|
|
@ -3,16 +3,11 @@
|
|||
import falcon
|
||||
import sqlalchemy
|
||||
import sqlalchemy.orm
|
||||
import szurubooru.api
|
||||
import szurubooru.config
|
||||
import szurubooru.errors
|
||||
import szurubooru.middleware
|
||||
import szurubooru.services
|
||||
import szurubooru.services.search
|
||||
import szurubooru.util
|
||||
from szurubooru import api, config, errors, middleware
|
||||
from szurubooru.util import misc
|
||||
|
||||
class _CustomRequest(falcon.Request):
|
||||
context_type = szurubooru.util.dotdict
|
||||
context_type = misc.dotdict
|
||||
|
||||
def get_param_as_string(self, name, required=False, store=None, default=None):
|
||||
params = self._params
|
||||
|
@ -45,47 +40,37 @@ def _on_not_found_error(ex, _request, _response, _params):
|
|||
raise falcon.HTTPNotFound(title='Not found', description=str(ex))
|
||||
|
||||
def create_app():
|
||||
''' Creates a WSGI compatible App object. '''
|
||||
config = szurubooru.config.Config()
|
||||
|
||||
''' Create a WSGI compatible App object. '''
|
||||
engine = sqlalchemy.create_engine(
|
||||
'{schema}://{user}:{password}@{host}:{port}/{name}'.format(
|
||||
schema=config['database']['schema'],
|
||||
user=config['database']['user'],
|
||||
password=config['database']['pass'],
|
||||
host=config['database']['host'],
|
||||
port=config['database']['port'],
|
||||
name=config['database']['name']))
|
||||
schema=config.config['database']['schema'],
|
||||
user=config.config['database']['user'],
|
||||
password=config.config['database']['pass'],
|
||||
host=config.config['database']['host'],
|
||||
port=config.config['database']['port'],
|
||||
name=config.config['database']['name']))
|
||||
session_maker = sqlalchemy.orm.sessionmaker(bind=engine)
|
||||
scoped_session = sqlalchemy.orm.scoped_session(session_maker)
|
||||
|
||||
# TODO: is there a better way?
|
||||
mailer = szurubooru.services.Mailer(config)
|
||||
password_service = szurubooru.services.PasswordService(config)
|
||||
auth_service = szurubooru.services.AuthService(config, password_service)
|
||||
user_service = szurubooru.services.UserService(config, password_service)
|
||||
|
||||
user_list_api = szurubooru.api.UserListApi(auth_service, user_service)
|
||||
user_detail_api = szurubooru.api.UserDetailApi(
|
||||
config, auth_service, password_service, user_service)
|
||||
password_reminder_api = szurubooru.api.PasswordReminderApi(
|
||||
config, mailer, user_service)
|
||||
|
||||
app = falcon.API(
|
||||
request_type=_CustomRequest,
|
||||
middleware=[
|
||||
szurubooru.middleware.ImbueContext(),
|
||||
szurubooru.middleware.RequireJson(),
|
||||
szurubooru.middleware.JsonTranslator(),
|
||||
szurubooru.middleware.DbSession(scoped_session),
|
||||
szurubooru.middleware.Authenticator(auth_service, user_service),
|
||||
middleware.ImbueContext(),
|
||||
middleware.RequireJson(),
|
||||
middleware.JsonTranslator(),
|
||||
middleware.DbSession(scoped_session),
|
||||
middleware.Authenticator(),
|
||||
])
|
||||
|
||||
app.add_error_handler(szurubooru.errors.AuthError, _on_auth_error)
|
||||
app.add_error_handler(szurubooru.errors.IntegrityError, _on_integrity_error)
|
||||
app.add_error_handler(szurubooru.errors.ValidationError, _on_validation_error)
|
||||
app.add_error_handler(szurubooru.errors.SearchError, _on_search_error)
|
||||
app.add_error_handler(szurubooru.errors.NotFoundError, _on_not_found_error)
|
||||
user_list_api = api.UserListApi()
|
||||
user_detail_api = api.UserDetailApi()
|
||||
password_reminder_api = api.PasswordReminderApi()
|
||||
|
||||
app.add_error_handler(errors.AuthError, _on_auth_error)
|
||||
app.add_error_handler(errors.IntegrityError, _on_integrity_error)
|
||||
app.add_error_handler(errors.ValidationError, _on_validation_error)
|
||||
app.add_error_handler(errors.SearchError, _on_search_error)
|
||||
app.add_error_handler(errors.NotFoundError, _on_not_found_error)
|
||||
|
||||
app.add_route('/users/', user_list_api)
|
||||
app.add_route('/user/{user_name}', user_detail_api)
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
import os
|
||||
import configobj
|
||||
import szurubooru.errors
|
||||
from szurubooru import errors
|
||||
|
||||
class Config(object):
|
||||
''' INI config parser and container. '''
|
||||
|
@ -15,20 +15,22 @@ class Config(object):
|
|||
|
||||
def _validate(self):
|
||||
'''
|
||||
Checks whether config.ini doesn't contain errors that might prove
|
||||
Check whether config.ini doesn't contain errors that might prove
|
||||
lethal at runtime.
|
||||
'''
|
||||
all_ranks = self['service']['user_ranks']
|
||||
for privilege, rank in self['privileges'].items():
|
||||
if rank not in all_ranks:
|
||||
raise szurubooru.errors.ConfigError(
|
||||
raise errors.ConfigError(
|
||||
'Rank %r for privilege %r is missing from user_ranks' % (
|
||||
rank, privilege))
|
||||
for rank in ['anonymous', 'admin', 'nobody']:
|
||||
if rank not in all_ranks:
|
||||
raise szurubooru.errors.ConfigError(
|
||||
raise errors.ConfigError(
|
||||
'Fixed rank %r is missing from user_ranks' % rank)
|
||||
if self['service']['default_user_rank'] not in all_ranks:
|
||||
raise szurubooru.errors.ConfigError(
|
||||
raise errors.ConfigError(
|
||||
'Default rank %r is missing from user_ranks' % (
|
||||
self['service']['default_user_rank']))
|
||||
|
||||
config = Config() # pylint: disable=invalid-name
|
||||
|
|
4
server/szurubooru/db/__init__.py
Normal file
4
server/szurubooru/db/__init__.py
Normal file
|
@ -0,0 +1,4 @@
|
|||
''' Database layer. '''
|
||||
|
||||
from szurubooru.db.base import Base
|
||||
from szurubooru.db.user import User
|
|
@ -1,7 +1,5 @@
|
|||
# pylint: disable=too-many-instance-attributes,too-few-public-methods
|
||||
|
||||
import sqlalchemy as sa
|
||||
from szurubooru.model.base import Base
|
||||
from szurubooru.db.base import Base
|
||||
|
||||
class User(Base):
|
||||
__tablename__ = 'user'
|
|
@ -1,24 +1,20 @@
|
|||
import base64
|
||||
import falcon
|
||||
from szurubooru import errors
|
||||
from szurubooru import model
|
||||
from szurubooru import db, errors
|
||||
from szurubooru.util import auth, users
|
||||
|
||||
class Authenticator(object):
|
||||
'''
|
||||
Authenticates every request and puts information on active user in the
|
||||
Authenticates every request and put information on active user in the
|
||||
request context.
|
||||
'''
|
||||
|
||||
def __init__(self, auth_service, user_service):
|
||||
self._auth_service = auth_service
|
||||
self._user_service = user_service
|
||||
|
||||
def process_request(self, request, _response):
|
||||
''' Executed before passing the request to the API. '''
|
||||
''' Bind the user to request. Update last login time if needed. '''
|
||||
request.context.user = self._get_user(request)
|
||||
if request.get_param_as_bool('bump-login') \
|
||||
and request.context.user.user_id:
|
||||
self._user_service.bump_login_time(request.context.user)
|
||||
users.bump_login_time(request.context.user)
|
||||
request.context.session.commit()
|
||||
|
||||
def _get_user(self, request):
|
||||
|
@ -27,15 +23,12 @@ class Authenticator(object):
|
|||
|
||||
try:
|
||||
auth_type, user_and_password = request.auth.split(' ', 1)
|
||||
|
||||
if auth_type.lower() != 'basic':
|
||||
raise falcon.HTTPBadRequest(
|
||||
'Invalid authentication type',
|
||||
'Only basic authorization is supported.')
|
||||
|
||||
username, password = base64.decodebytes(
|
||||
user_and_password.encode('ascii')).decode('utf8').split(':')
|
||||
|
||||
return self._authenticate(
|
||||
request.context.session, username, password)
|
||||
except ValueError as err:
|
||||
|
@ -46,16 +39,16 @@ class Authenticator(object):
|
|||
msg.format(request.auth, str(err)))
|
||||
|
||||
def _authenticate(self, session, username, password):
|
||||
''' Tries to authenticate user. Throws AuthError for invalid users. '''
|
||||
user = self._user_service.get_by_name(session, username)
|
||||
''' Try to authenticate user. Throw AuthError for invalid users. '''
|
||||
user = users.get_by_name(session, username)
|
||||
if not user:
|
||||
raise errors.AuthError('No such user.')
|
||||
if not self._auth_service.is_valid_password(user, password):
|
||||
if not auth.is_valid_password(user, password):
|
||||
raise errors.AuthError('Invalid password.')
|
||||
return user
|
||||
|
||||
def _create_anonymous_user(self):
|
||||
user = model.User()
|
||||
user = db.User()
|
||||
user.name = None
|
||||
user.access_rank = 'anonymous'
|
||||
user.password = None
|
||||
|
|
|
@ -5,12 +5,8 @@ class DbSession(object):
|
|||
self._session_factory = session_factory
|
||||
|
||||
def process_request(self, request, _response):
|
||||
''' Executed before passing the request to the API. '''
|
||||
request.context.session = self._session_factory()
|
||||
|
||||
def process_response(self, request, _response, _resource):
|
||||
'''
|
||||
Executed before passing the response to falcon.
|
||||
Any commits to database need to happen explicitly in the API layer.
|
||||
'''
|
||||
# any commits need to happen explicitly in the API layer.
|
||||
request.context.session.close()
|
||||
|
|
|
@ -16,7 +16,6 @@ class JsonTranslator(object):
|
|||
'''
|
||||
|
||||
def process_request(self, request, _response):
|
||||
''' Executed before passing the request to the API. '''
|
||||
if request.content_length in (None, 0):
|
||||
return
|
||||
|
||||
|
@ -36,7 +35,6 @@ class JsonTranslator(object):
|
|||
'JSON was incorrect or not encoded as UTF-8.')
|
||||
|
||||
def process_response(self, request, response, _resource):
|
||||
''' Executed before passing the response to falcon. '''
|
||||
if 'result' not in request.context:
|
||||
return
|
||||
response.body = json.dumps(
|
||||
|
|
|
@ -4,7 +4,6 @@ class RequireJson(object):
|
|||
''' Sanitizes requests so that only JSON is accepted. '''
|
||||
|
||||
def process_request(self, request, _response):
|
||||
''' Executed before passing the request to the API. '''
|
||||
if not request.client_accepts_json:
|
||||
raise falcon.HTTPNotAcceptable(
|
||||
'This API only supports responses encoded as JSON.')
|
||||
|
|
|
@ -9,13 +9,13 @@ import logging.config
|
|||
dir_to_self = os.path.dirname(os.path.realpath(__file__))
|
||||
sys.path.append(os.path.join(dir_to_self, *[os.pardir] * 2))
|
||||
|
||||
import szurubooru.model.base
|
||||
import szurubooru.db.base
|
||||
import szurubooru.config
|
||||
|
||||
alembic_config = alembic.context.config
|
||||
logging.config.fileConfig(alembic_config.config_file_name)
|
||||
|
||||
szuru_config = szurubooru.config.Config()
|
||||
szuru_config = szurubooru.config.config
|
||||
alembic_config.set_main_option(
|
||||
'sqlalchemy.url',
|
||||
'{schema}://{user}:{password}@{host}:{port}/{name}'.format(
|
||||
|
@ -26,7 +26,7 @@ alembic_config.set_main_option(
|
|||
port=szuru_config['database']['port'],
|
||||
name=szuru_config['database']['name']))
|
||||
|
||||
target_metadata = szurubooru.model.Base.metadata
|
||||
target_metadata = szurubooru.db.Base.metadata
|
||||
|
||||
def run_migrations_offline():
|
||||
'''
|
||||
|
|
|
@ -1,6 +0,0 @@
|
|||
'''
|
||||
Database models.
|
||||
'''
|
||||
|
||||
from szurubooru.model.base import Base
|
||||
from szurubooru.model.user import User
|
4
server/szurubooru/search/__init__.py
Normal file
4
server/szurubooru/search/__init__.py
Normal file
|
@ -0,0 +1,4 @@
|
|||
''' Search parsers and services. '''
|
||||
|
||||
from szurubooru.search.search_executor import SearchExecutor
|
||||
from szurubooru.search.user_search_config import UserSearchConfig
|
|
@ -1,11 +1,11 @@
|
|||
import sqlalchemy
|
||||
import szurubooru.errors
|
||||
from szurubooru import util
|
||||
from szurubooru.services.search import criteria
|
||||
from szurubooru.util import misc
|
||||
from szurubooru.search import criteria
|
||||
|
||||
def _apply_criterion_to_column(
|
||||
column, query, criterion, allow_composite=True, allow_ranged=True):
|
||||
''' Decorates SQLAlchemy filter on given column using supplied criterion. '''
|
||||
''' Decorate SQLAlchemy filter on given column using supplied criterion. '''
|
||||
if isinstance(criterion, criteria.StringSearchCriterion):
|
||||
expr = column == criterion.value
|
||||
if criterion.negated:
|
||||
|
@ -32,11 +32,11 @@ def _apply_criterion_to_column(
|
|||
|
||||
def _apply_date_criterion_to_column(column, query, criterion):
|
||||
'''
|
||||
Decorates SQLAlchemy filter on given column using supplied criterion.
|
||||
Parses the datetime inside the criterion.
|
||||
Decorate SQLAlchemy filter on given column using supplied criterion.
|
||||
Parse the datetime inside the criterion.
|
||||
'''
|
||||
if isinstance(criterion, criteria.StringSearchCriterion):
|
||||
min_date, max_date = util.parse_time_range(criterion.value)
|
||||
min_date, max_date = misc.parse_time_range(criterion.value)
|
||||
expr = column.between(min_date, max_date)
|
||||
if criterion.negated:
|
||||
expr = ~expr
|
||||
|
@ -44,7 +44,7 @@ def _apply_date_criterion_to_column(column, query, criterion):
|
|||
elif isinstance(criterion, criteria.ArraySearchCriterion):
|
||||
expr = sqlalchemy.sql.false()
|
||||
for value in criterion.values:
|
||||
min_date, max_date = util.parse_time_range(value)
|
||||
min_date, max_date = misc.parse_time_range(value)
|
||||
expr = expr | column.between(min_date, max_date)
|
||||
if criterion.negated:
|
||||
expr = ~expr
|
||||
|
@ -52,14 +52,14 @@ def _apply_date_criterion_to_column(column, query, criterion):
|
|||
elif isinstance(criterion, criteria.RangedSearchCriterion):
|
||||
assert criterion.min_value or criterion.max_value
|
||||
if criterion.min_value and criterion.max_value:
|
||||
min_date = util.parse_time_range(criterion.min_value)[0]
|
||||
max_date = util.parse_time_range(criterion.max_value)[1]
|
||||
min_date = misc.parse_time_range(criterion.min_value)[0]
|
||||
max_date = misc.parse_time_range(criterion.max_value)[1]
|
||||
expr = column.between(min_date, max_date)
|
||||
elif criterion.min_value:
|
||||
min_date = util.parse_time_range(criterion.min_value)[0]
|
||||
min_date = misc.parse_time_range(criterion.min_value)[0]
|
||||
expr = column >= min_date
|
||||
elif criterion.max_value:
|
||||
max_date = util.parse_time_range(criterion.max_value)[1]
|
||||
max_date = misc.parse_time_range(criterion.max_value)[1]
|
||||
expr = column <= max_date
|
||||
if criterion.negated:
|
||||
expr = ~expr
|
|
@ -1,19 +1,17 @@
|
|||
''' Exports SearchExecutor. '''
|
||||
|
||||
import re
|
||||
import sqlalchemy
|
||||
from szurubooru import errors
|
||||
from szurubooru.services.search import criteria
|
||||
from szurubooru.search import criteria
|
||||
|
||||
class SearchExecutor(object):
|
||||
ORDER_DESC = 1
|
||||
ORDER_ASC = 2
|
||||
|
||||
'''
|
||||
Class for search parsing and execution. Handles plaintext parsing and
|
||||
delegates sqlalchemy filter decoration to SearchConfig instances.
|
||||
'''
|
||||
|
||||
ORDER_DESC = 1
|
||||
ORDER_ASC = 2
|
||||
|
||||
def __init__(self, search_config):
|
||||
self.page_size = 100
|
||||
self._search_config = search_config
|
42
server/szurubooru/search/user_search_config.py
Normal file
42
server/szurubooru/search/user_search_config.py
Normal file
|
@ -0,0 +1,42 @@
|
|||
from sqlalchemy.sql.expression import func
|
||||
from szurubooru import db
|
||||
from szurubooru.search.base_search_config import BaseSearchConfig
|
||||
|
||||
class UserSearchConfig(BaseSearchConfig):
|
||||
''' Executes searches related to the users. '''
|
||||
|
||||
def create_query(self, session):
|
||||
return session.query(db.User)
|
||||
|
||||
@property
|
||||
def anonymous_filter(self):
|
||||
return self._create_basic_filter(db.User.name, allow_ranged=False)
|
||||
|
||||
@property
|
||||
def special_filters(self):
|
||||
return {}
|
||||
|
||||
@property
|
||||
def named_filters(self):
|
||||
return {
|
||||
'name': self._create_basic_filter(db.User.name, allow_ranged=False),
|
||||
'creation_date': self._create_date_filter(db.User.creation_time),
|
||||
'creation_time': self._create_date_filter(db.User.creation_time),
|
||||
'last_login_date': self._create_date_filter(db.User.last_login_time),
|
||||
'last_login_time': self._create_date_filter(db.User.last_login_time),
|
||||
'login_date': self._create_date_filter(db.User.last_login_time),
|
||||
'login_time': self._create_date_filter(db.User.last_login_time),
|
||||
}
|
||||
|
||||
@property
|
||||
def order_columns(self):
|
||||
return {
|
||||
'random': func.random(),
|
||||
'name': db.User.name,
|
||||
'creation_date': db.User.creation_time,
|
||||
'creation_time': db.User.creation_time,
|
||||
'last_login_date': db.User.last_login_time,
|
||||
'last_login_time': db.User.last_login_time,
|
||||
'login_date': db.User.last_login_time,
|
||||
'login_time': db.User.last_login_time,
|
||||
}
|
|
@ -1,9 +0,0 @@
|
|||
'''
|
||||
Middle layer between REST API and database.
|
||||
All the business logic goes here.
|
||||
'''
|
||||
|
||||
from szurubooru.services.mailer import Mailer
|
||||
from szurubooru.services.auth_service import AuthService
|
||||
from szurubooru.services.user_service import UserService
|
||||
from szurubooru.services.password_service import PasswordService
|
|
@ -1,28 +0,0 @@
|
|||
from szurubooru import errors
|
||||
|
||||
class AuthService(object):
|
||||
def __init__(self, config, password_service):
|
||||
self._config = config
|
||||
self._password_service = password_service
|
||||
|
||||
def is_valid_password(self, user, password):
|
||||
''' Returns whether the given password for a given user is valid. '''
|
||||
salt, valid_hash = user.password_salt, user.password_hash
|
||||
possible_hashes = [
|
||||
self._password_service.get_password_hash(salt, password),
|
||||
self._password_service.get_legacy_password_hash(salt, password)
|
||||
]
|
||||
return valid_hash in possible_hashes
|
||||
|
||||
def verify_privilege(self, user, privilege_name):
|
||||
'''
|
||||
Throws an AuthError if the given user doesn't have given privilege.
|
||||
'''
|
||||
all_ranks = self._config['service']['user_ranks']
|
||||
|
||||
assert privilege_name in self._config['privileges']
|
||||
assert user.access_rank in all_ranks
|
||||
minimal_rank = self._config['privileges'][privilege_name]
|
||||
good_ranks = all_ranks[all_ranks.index(minimal_rank):]
|
||||
if user.access_rank not in good_ranks:
|
||||
raise errors.AuthError('Insufficient privileges to do this.')
|
|
@ -1,19 +0,0 @@
|
|||
import smtplib
|
||||
import email.mime.text
|
||||
|
||||
class Mailer(object):
|
||||
def __init__(self, config):
|
||||
self._config = config
|
||||
|
||||
def send(self, sender, recipient, subject, body):
|
||||
msg = email.mime.text.MIMEText(body)
|
||||
msg['Subject'] = subject
|
||||
msg['From'] = sender
|
||||
msg['To'] = recipient
|
||||
|
||||
smtp = smtplib.SMTP(
|
||||
self._config['smtp']['host'],
|
||||
int(self._config['smtp']['port']))
|
||||
smtp.login(self._config['smtp']['user'], self._config['smtp']['pass'])
|
||||
smtp.send_message(msg)
|
||||
smtp.quit()
|
|
@ -1,34 +0,0 @@
|
|||
import hashlib
|
||||
import random
|
||||
|
||||
class PasswordService(object):
|
||||
''' Stateless utilities for passwords '''
|
||||
|
||||
def __init__(self, config):
|
||||
self._config = config
|
||||
|
||||
def get_password_hash(self, salt, password):
|
||||
''' Retrieves new-style password hash. '''
|
||||
digest = hashlib.sha256()
|
||||
digest.update(self._config['basic']['secret'].encode('utf8'))
|
||||
digest.update(salt.encode('utf8'))
|
||||
digest.update(password.encode('utf8'))
|
||||
return digest.hexdigest()
|
||||
|
||||
def get_legacy_password_hash(self, salt, password):
|
||||
''' Retrieves old-style password hash. '''
|
||||
digest = hashlib.sha1()
|
||||
digest.update(b'1A2/$_4xVa')
|
||||
digest.update(salt.encode('utf8'))
|
||||
digest.update(password.encode('utf8'))
|
||||
return digest.hexdigest()
|
||||
|
||||
def create_password(self):
|
||||
''' Creates an easy-to-remember password. '''
|
||||
alphabet = {
|
||||
'c': list('bcdfghijklmnpqrstvwxyz'),
|
||||
'v': list('aeiou'),
|
||||
'n': list('0123456789'),
|
||||
}
|
||||
pattern = 'cvcvnncvcv'
|
||||
return ''.join(random.choice(alphabet[l]) for l in list(pattern))
|
|
@ -1,2 +0,0 @@
|
|||
from szurubooru.services.search.search_executor import SearchExecutor
|
||||
from szurubooru.services.search.user_search_config import UserSearchConfig
|
|
@ -1,44 +0,0 @@
|
|||
''' Exports UserSearchConfig. '''
|
||||
|
||||
from sqlalchemy.sql.expression import func
|
||||
from szurubooru.model import User
|
||||
from szurubooru.services.search.base_search_config import BaseSearchConfig
|
||||
|
||||
class UserSearchConfig(BaseSearchConfig):
|
||||
''' Executes searches related to the users. '''
|
||||
|
||||
def create_query(self, session):
|
||||
return session.query(User)
|
||||
|
||||
@property
|
||||
def anonymous_filter(self):
|
||||
return self._create_basic_filter(User.name, allow_ranged=False)
|
||||
|
||||
@property
|
||||
def special_filters(self):
|
||||
return {}
|
||||
|
||||
@property
|
||||
def named_filters(self):
|
||||
return {
|
||||
'name': self._create_basic_filter(User.name, allow_ranged=False),
|
||||
'creation_date': self._create_date_filter(User.creation_time),
|
||||
'creation_time': self._create_date_filter(User.creation_time),
|
||||
'last_login_date': self._create_date_filter(User.last_login_time),
|
||||
'last_login_time': self._create_date_filter(User.last_login_time),
|
||||
'login_date': self._create_date_filter(User.last_login_time),
|
||||
'login_time': self._create_date_filter(User.last_login_time),
|
||||
}
|
||||
|
||||
@property
|
||||
def order_columns(self):
|
||||
return {
|
||||
'random': func.random(),
|
||||
'name': User.name,
|
||||
'creation_date': User.creation_time,
|
||||
'creation_time': User.creation_time,
|
||||
'last_login_date': User.last_login_time,
|
||||
'last_login_time': User.last_login_time,
|
||||
'login_date': User.last_login_time,
|
||||
'login_time': User.last_login_time,
|
||||
}
|
|
@ -1,56 +0,0 @@
|
|||
import re
|
||||
from datetime import datetime
|
||||
from szurubooru import errors
|
||||
from szurubooru import model
|
||||
from szurubooru import util
|
||||
|
||||
class UserService(object):
|
||||
''' User management '''
|
||||
|
||||
def __init__(self, config, password_service):
|
||||
self._config = config
|
||||
self._password_service = password_service
|
||||
self._name_regex = self._config['service']['user_name_regex']
|
||||
self._password_regex = self._config['service']['password_regex']
|
||||
|
||||
def create_user(self, session, name, password, email):
|
||||
''' Creates an user with given parameters and returns it. '''
|
||||
|
||||
if not re.match(self._name_regex, name):
|
||||
raise errors.ValidationError(
|
||||
'Name must satisfy regex %r.' % self._name_regex)
|
||||
|
||||
if not re.match(self._password_regex, password):
|
||||
raise errors.ValidationError(
|
||||
'Password must satisfy regex %r.' % self._password_regex)
|
||||
|
||||
if not util.is_valid_email(email):
|
||||
raise errors.ValidationError(
|
||||
'%r is not a vaild email address.' % email)
|
||||
|
||||
user = model.User()
|
||||
user.name = name
|
||||
user.password_salt = self._password_service.create_password()
|
||||
user.password_hash = self._password_service.get_password_hash(
|
||||
user.password_salt, password)
|
||||
user.email = email or None
|
||||
user.access_rank = self._config['service']['default_user_rank']
|
||||
user.creation_time = datetime.now()
|
||||
user.avatar_style = model.User.AVATAR_GRAVATAR
|
||||
|
||||
session.add(user)
|
||||
return user
|
||||
|
||||
def bump_login_time(self, user):
|
||||
user.last_login_time = datetime.now()
|
||||
|
||||
def reset_password(self, user):
|
||||
password = self._password_service.create_password()
|
||||
user.password_salt = self._password_service.create_password()
|
||||
user.password_hash = self._password_service.get_password_hash(
|
||||
user.password_salt, password)
|
||||
return password
|
||||
|
||||
def get_by_name(self, session, name):
|
||||
''' Retrieves an user by its name. '''
|
||||
return session.query(model.User).filter_by(name=name).first()
|
|
@ -1,15 +1,12 @@
|
|||
from datetime import datetime
|
||||
import szurubooru.services
|
||||
from szurubooru.api.user_api import UserDetailApi
|
||||
from szurubooru.errors import AuthError, ValidationError
|
||||
from szurubooru.model.user import User
|
||||
from szurubooru import api, db, errors, config
|
||||
from szurubooru.util import auth, misc
|
||||
from szurubooru.tests.database_test_case import DatabaseTestCase
|
||||
from szurubooru.util import dotdict
|
||||
|
||||
class TestUserDetailApi(DatabaseTestCase):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
config = {
|
||||
config_mock = {
|
||||
'basic': {
|
||||
'secret': '',
|
||||
},
|
||||
|
@ -30,18 +27,18 @@ class TestUserDetailApi(DatabaseTestCase):
|
|||
'users:edit:any:rank': 'admin',
|
||||
}
|
||||
}
|
||||
password_service = szurubooru.services.PasswordService(config)
|
||||
auth_service = szurubooru.services.AuthService(config, password_service)
|
||||
user_service = szurubooru.services.UserService(config, password_service)
|
||||
self.auth_service = auth_service
|
||||
self.api = UserDetailApi(
|
||||
config, auth_service, password_service, user_service)
|
||||
self.context = dotdict()
|
||||
self.old_config = config.config
|
||||
config.config = config_mock
|
||||
self.api = api.UserDetailApi()
|
||||
self.context = misc.dotdict()
|
||||
self.context.session = self.session
|
||||
self.context.request = {}
|
||||
|
||||
def tearDown(self):
|
||||
config.config = self.old_config
|
||||
|
||||
def _create_user(self, name, rank='admin'):
|
||||
user = User()
|
||||
user = db.User()
|
||||
user.name = name
|
||||
user.password = 'dummy'
|
||||
user.password_salt = 'dummy'
|
||||
|
@ -49,7 +46,7 @@ class TestUserDetailApi(DatabaseTestCase):
|
|||
user.email = 'dummy'
|
||||
user.access_rank = rank
|
||||
user.creation_time = datetime.now()
|
||||
user.avatar_style = User.AVATAR_GRAVATAR
|
||||
user.avatar_style = db.User.AVATAR_GRAVATAR
|
||||
return user
|
||||
|
||||
def test_updating_nothing(self):
|
||||
|
@ -57,7 +54,7 @@ class TestUserDetailApi(DatabaseTestCase):
|
|||
self.session.add(admin_user)
|
||||
self.context.user = admin_user
|
||||
self.api.put(self.context, 'u1')
|
||||
admin_user = self.session.query(User).filter_by(name='u1').one()
|
||||
admin_user = self.session.query(db.User).filter_by(name='u1').one()
|
||||
self.assertEqual(admin_user.name, 'u1')
|
||||
self.assertEqual(admin_user.email, 'dummy')
|
||||
self.assertEqual(admin_user.access_rank, 'admin')
|
||||
|
@ -69,16 +66,16 @@ class TestUserDetailApi(DatabaseTestCase):
|
|||
self.context.request = {
|
||||
'name': 'chewie',
|
||||
'email': 'asd@asd.asd',
|
||||
'password': 'valid',
|
||||
'password': 'oks',
|
||||
'accessRank': 'mod',
|
||||
}
|
||||
self.api.put(self.context, 'u1')
|
||||
admin_user = self.session.query(User).filter_by(name='chewie').one()
|
||||
admin_user = self.session.query(db.User).filter_by(name='chewie').one()
|
||||
self.assertEqual(admin_user.name, 'chewie')
|
||||
self.assertEqual(admin_user.email, 'asd@asd.asd')
|
||||
self.assertEqual(admin_user.access_rank, 'mod')
|
||||
self.assertTrue(self.auth_service.is_valid_password(admin_user, 'valid'))
|
||||
self.assertFalse(self.auth_service.is_valid_password(admin_user, 'invalid'))
|
||||
self.assertTrue(auth.is_valid_password(admin_user, 'oks'))
|
||||
self.assertFalse(auth.is_valid_password(admin_user, 'invalid'))
|
||||
|
||||
def test_removing_email(self):
|
||||
admin_user = self._create_user('u1', 'admin')
|
||||
|
@ -86,7 +83,7 @@ class TestUserDetailApi(DatabaseTestCase):
|
|||
self.context.user = admin_user
|
||||
self.context.request = {'email': ''}
|
||||
self.api.put(self.context, 'u1')
|
||||
admin_user = self.session.query(User).filter_by(name='u1').one()
|
||||
admin_user = self.session.query(db.User).filter_by(name='u1').one()
|
||||
self.assertEqual(admin_user.email, None)
|
||||
|
||||
def test_invalid_inputs(self):
|
||||
|
@ -95,16 +92,16 @@ class TestUserDetailApi(DatabaseTestCase):
|
|||
self.context.user = admin_user
|
||||
self.context.request = {'name': '.'}
|
||||
self.assertRaises(
|
||||
ValidationError, self.api.put, self.context, 'u1')
|
||||
errors.ValidationError, self.api.put, self.context, 'u1')
|
||||
self.context.request = {'password': '.'}
|
||||
self.assertRaises(
|
||||
ValidationError, self.api.put, self.context, 'u1')
|
||||
errors.ValidationError, self.api.put, self.context, 'u1')
|
||||
self.context.request = {'accessRank': '.'}
|
||||
self.assertRaises(
|
||||
ValidationError, self.api.put, self.context, 'u1')
|
||||
errors.ValidationError, self.api.put, self.context, 'u1')
|
||||
self.context.request = {'email': '.'}
|
||||
self.assertRaises(
|
||||
ValidationError, self.api.put, self.context, 'u1')
|
||||
errors.ValidationError, self.api.put, self.context, 'u1')
|
||||
|
||||
def test_user_trying_to_update_someone_else(self):
|
||||
user1 = self._create_user('u1', 'regular_user')
|
||||
|
@ -118,7 +115,7 @@ class TestUserDetailApi(DatabaseTestCase):
|
|||
{'password': 'whatever'}]:
|
||||
self.context.request = request
|
||||
self.assertRaises(
|
||||
AuthError, self.api.put, self.context, user2.name)
|
||||
errors.AuthError, self.api.put, self.context, user2.name)
|
||||
|
||||
def test_user_trying_to_become_someone_else(self):
|
||||
user1 = self._create_user('u1', 'regular_user')
|
||||
|
@ -127,7 +124,7 @@ class TestUserDetailApi(DatabaseTestCase):
|
|||
self.context.user = user1
|
||||
self.context.request = {'name': 'u2'}
|
||||
self.assertRaises(
|
||||
ValidationError, self.api.put, self.context, 'u1')
|
||||
errors.ValidationError, self.api.put, self.context, 'u1')
|
||||
|
||||
def test_mods_trying_to_become_admin(self):
|
||||
user1 = self._create_user('u1', 'mod')
|
||||
|
@ -136,6 +133,6 @@ class TestUserDetailApi(DatabaseTestCase):
|
|||
self.context.user = user1
|
||||
self.context.request = {'accessRank': 'admin'}
|
||||
self.assertRaises(
|
||||
AuthError, self.api.put, self.context, user1.name)
|
||||
errors.AuthError, self.api.put, self.context, user1.name)
|
||||
self.assertRaises(
|
||||
AuthError, self.api.put, self.context, user2.name)
|
||||
errors.AuthError, self.api.put, self.context, user2.name)
|
||||
|
|
|
@ -1,11 +1,11 @@
|
|||
import unittest
|
||||
import sqlalchemy
|
||||
from szurubooru.model import Base
|
||||
from szurubooru import db
|
||||
|
||||
class DatabaseTestCase(unittest.TestCase):
|
||||
def setUp(self):
|
||||
engine = sqlalchemy.create_engine('sqlite:///:memory:')
|
||||
session_maker = sqlalchemy.orm.sessionmaker(bind=engine)
|
||||
self.session = sqlalchemy.orm.scoped_session(session_maker)
|
||||
Base.query = self.session.query_property()
|
||||
Base.metadata.create_all(bind=engine)
|
||||
db.Base.query = self.session.query_property()
|
||||
db.Base.metadata.create_all(bind=engine)
|
||||
|
|
|
@ -1,7 +1,5 @@
|
|||
from datetime import datetime
|
||||
from szurubooru import errors
|
||||
from szurubooru import model
|
||||
from szurubooru.services import search
|
||||
from szurubooru import db, errors, search
|
||||
from szurubooru.tests.database_test_case import DatabaseTestCase
|
||||
|
||||
class TestUserSearchExecutor(DatabaseTestCase):
|
||||
|
@ -11,7 +9,7 @@ class TestUserSearchExecutor(DatabaseTestCase):
|
|||
self.executor = search.SearchExecutor(self.search_config)
|
||||
|
||||
def _create_user(self, name):
|
||||
user = model.User()
|
||||
user = db.User()
|
||||
user.name = name
|
||||
user.password = 'dummy'
|
||||
user.password_salt = 'dummy'
|
||||
|
@ -19,7 +17,7 @@ class TestUserSearchExecutor(DatabaseTestCase):
|
|||
user.email = 'dummy'
|
||||
user.access_rank = 'dummy'
|
||||
user.creation_time = datetime.now()
|
||||
user.avatar_style = model.User.AVATAR_GRAVATAR
|
||||
user.avatar_style = db.User.AVATAR_GRAVATAR
|
||||
return user
|
||||
|
||||
def _test(self, query, page, expected_count, expected_user_names):
|
|
@ -1,8 +1,7 @@
|
|||
import unittest
|
||||
from datetime import datetime
|
||||
import szurubooru.util
|
||||
from szurubooru.util import parse_time_range
|
||||
from szurubooru.errors import ValidationError
|
||||
from szurubooru import errors
|
||||
from szurubooru.util import misc
|
||||
|
||||
class FakeDatetime(datetime):
|
||||
@staticmethod
|
||||
|
@ -11,33 +10,33 @@ class FakeDatetime(datetime):
|
|||
|
||||
class TestParseTime(unittest.TestCase):
|
||||
def test_empty(self):
|
||||
self.assertRaises(ValidationError, parse_time_range, '')
|
||||
self.assertRaises(errors.ValidationError, misc.parse_time_range, '')
|
||||
|
||||
def test_today(self):
|
||||
szurubooru.util.datetime.datetime = FakeDatetime
|
||||
date_min, date_max = parse_time_range('today')
|
||||
misc.datetime.datetime = FakeDatetime
|
||||
date_min, date_max = misc.parse_time_range('today')
|
||||
self.assertEqual(date_min, datetime(1997, 1, 2, 0, 0, 0))
|
||||
self.assertEqual(date_max, datetime(1997, 1, 2, 23, 59, 59))
|
||||
|
||||
def test_yesterday(self):
|
||||
szurubooru.util.datetime.datetime = FakeDatetime
|
||||
date_min, date_max = parse_time_range('yesterday')
|
||||
misc.datetime.datetime = FakeDatetime
|
||||
date_min, date_max = misc.parse_time_range('yesterday')
|
||||
self.assertEqual(date_min, datetime(1997, 1, 1, 0, 0, 0))
|
||||
self.assertEqual(date_max, datetime(1997, 1, 1, 23, 59, 59))
|
||||
|
||||
def test_year(self):
|
||||
date_min, date_max = parse_time_range('1999')
|
||||
date_min, date_max = misc.parse_time_range('1999')
|
||||
self.assertEqual(date_min, datetime(1999, 1, 1, 0, 0, 0))
|
||||
self.assertEqual(date_max, datetime(1999, 12, 31, 23, 59, 59))
|
||||
|
||||
def test_month(self):
|
||||
for text in ['1999-2', '1999-02']:
|
||||
date_min, date_max = parse_time_range(text)
|
||||
date_min, date_max = misc.parse_time_range(text)
|
||||
self.assertEqual(date_min, datetime(1999, 2, 1, 0, 0, 0))
|
||||
self.assertEqual(date_max, datetime(1999, 2, 28, 23, 59, 59))
|
||||
|
||||
def test_day(self):
|
||||
for text in ['1999-2-6', '1999-02-6', '1999-2-06', '1999-02-06']:
|
||||
date_min, date_max = parse_time_range(text)
|
||||
date_min, date_max = misc.parse_time_range(text)
|
||||
self.assertEqual(date_min, datetime(1999, 2, 6, 0, 0, 0))
|
||||
self.assertEqual(date_max, datetime(1999, 2, 6, 23, 59, 59))
|
1
server/szurubooru/util/__init__.py
Normal file
1
server/szurubooru/util/__init__.py
Normal file
|
@ -0,0 +1 @@
|
|||
''' Cool functions. '''
|
59
server/szurubooru/util/auth.py
Normal file
59
server/szurubooru/util/auth.py
Normal file
|
@ -0,0 +1,59 @@
|
|||
import hashlib
|
||||
import random
|
||||
from szurubooru import config
|
||||
from szurubooru import errors
|
||||
|
||||
def get_password_hash(salt, password):
|
||||
''' Retrieve new-style password hash. '''
|
||||
digest = hashlib.sha256()
|
||||
digest.update(config.config['basic']['secret'].encode('utf8'))
|
||||
digest.update(salt.encode('utf8'))
|
||||
digest.update(password.encode('utf8'))
|
||||
return digest.hexdigest()
|
||||
|
||||
def get_legacy_password_hash(salt, password):
|
||||
''' Retrieve old-style password hash. '''
|
||||
digest = hashlib.sha1()
|
||||
digest.update(b'1A2/$_4xVa')
|
||||
digest.update(salt.encode('utf8'))
|
||||
digest.update(password.encode('utf8'))
|
||||
return digest.hexdigest()
|
||||
|
||||
def create_password():
|
||||
''' Create an easy-to-remember password. '''
|
||||
alphabet = {
|
||||
'c': list('bcdfghijklmnpqrstvwxyz'),
|
||||
'v': list('aeiou'),
|
||||
'n': list('0123456789'),
|
||||
}
|
||||
pattern = 'cvcvnncvcv'
|
||||
return ''.join(random.choice(alphabet[l]) for l in list(pattern))
|
||||
|
||||
def is_valid_password(user, password):
|
||||
''' Return whether the given password for a given user is valid. '''
|
||||
salt, valid_hash = user.password_salt, user.password_hash
|
||||
possible_hashes = [
|
||||
get_password_hash(salt, password),
|
||||
get_legacy_password_hash(salt, password)
|
||||
]
|
||||
return valid_hash in possible_hashes
|
||||
|
||||
def verify_privilege(user, privilege_name):
|
||||
'''
|
||||
Throw an AuthError if the given user doesn't have given privilege.
|
||||
'''
|
||||
all_ranks = config.config['service']['user_ranks']
|
||||
|
||||
assert privilege_name in config.config['privileges']
|
||||
assert user.access_rank in all_ranks
|
||||
minimal_rank = config.config['privileges'][privilege_name]
|
||||
good_ranks = all_ranks[all_ranks.index(minimal_rank):]
|
||||
if user.access_rank not in good_ranks:
|
||||
raise errors.AuthError('Insufficient privileges to do this.')
|
||||
|
||||
def generate_authentication_token(user):
|
||||
''' Generate nonguessable challenge (e.g. links in password reminder). '''
|
||||
digest = hashlib.sha256()
|
||||
digest.update(config.config['basic']['secret'].encode('utf8'))
|
||||
digest.update(user.password_salt.encode('utf8'))
|
||||
return digest.hexdigest()
|
15
server/szurubooru/util/mailer.py
Normal file
15
server/szurubooru/util/mailer.py
Normal file
|
@ -0,0 +1,15 @@
|
|||
import smtplib
|
||||
import email.mime.text
|
||||
from szurubooru import config
|
||||
|
||||
def send_mail(sender, recipient, subject, body):
|
||||
msg = email.mime.text.MIMEText(body)
|
||||
msg['Subject'] = subject
|
||||
msg['From'] = sender
|
||||
msg['To'] = recipient
|
||||
|
||||
smtp = smtplib.SMTP(
|
||||
config.config['smtp']['host'], int(config.config['smtp']['port']))
|
||||
smtp.login(config.config['smtp']['user'], config.config['smtp']['pass'])
|
||||
smtp.send_message(msg)
|
||||
smtp.quit()
|
|
@ -3,7 +3,7 @@ import re
|
|||
from szurubooru.errors import ValidationError
|
||||
|
||||
def is_valid_email(email):
|
||||
''' Validates given email address. '''
|
||||
''' Return whether given email address is valid or empty. '''
|
||||
return not email or re.match(r'^[^@]*@[^@]*\.[^@]*$', email)
|
||||
|
||||
class dotdict(dict): # pylint: disable=invalid-name
|
||||
|
@ -14,7 +14,7 @@ class dotdict(dict): # pylint: disable=invalid-name
|
|||
__delattr__ = dict.__delitem__
|
||||
|
||||
def parse_time_range(value, timezone=datetime.timezone(datetime.timedelta())):
|
||||
''' Returns tuple containing min/max time for given text representation. '''
|
||||
''' Return tuple containing min/max time for given text representation. '''
|
||||
one_day = datetime.timedelta(days=1)
|
||||
one_second = datetime.timedelta(seconds=1)
|
||||
|
67
server/szurubooru/util/users.py
Normal file
67
server/szurubooru/util/users.py
Normal file
|
@ -0,0 +1,67 @@
|
|||
import re
|
||||
from datetime import datetime
|
||||
from szurubooru import config, db, errors
|
||||
from szurubooru.util import auth, misc
|
||||
|
||||
def create_user(name, password, email):
|
||||
''' Create an user with given parameters and returns it. '''
|
||||
user = db.User()
|
||||
update_name(user, name)
|
||||
update_password(user, password)
|
||||
update_email(user, email)
|
||||
user.access_rank = config.config['service']['default_user_rank']
|
||||
user.creation_time = datetime.now()
|
||||
user.avatar_style = db.User.AVATAR_GRAVATAR
|
||||
return user
|
||||
|
||||
def update_name(user, name):
|
||||
''' Validate and update user's name. '''
|
||||
name = name.strip()
|
||||
name_regex = config.config['service']['user_name_regex']
|
||||
if not re.match(name_regex, name):
|
||||
raise errors.ValidationError(
|
||||
'Name must satisfy regex %r.' % name_regex)
|
||||
user.name = name
|
||||
|
||||
def update_password(user, password):
|
||||
''' Validate and update user's password. '''
|
||||
password_regex = config.config['service']['password_regex']
|
||||
if not re.match(password_regex, password):
|
||||
raise errors.ValidationError(
|
||||
'Password must satisfy regex %r.' % password_regex)
|
||||
user.password_salt = auth.create_password()
|
||||
user.password_hash = auth.get_password_hash(user.password_salt, password)
|
||||
|
||||
def update_email(user, email):
|
||||
''' Validate and update user's email. '''
|
||||
email = email.strip() or None
|
||||
if not misc.is_valid_email(email):
|
||||
raise errors.ValidationError(
|
||||
'%r is not a vaild email address.' % email)
|
||||
user.email = email
|
||||
|
||||
def update_rank(user, rank, authenticated_user):
|
||||
rank = rank.strip()
|
||||
available_access_ranks = config.config['service']['user_ranks']
|
||||
if not rank in available_access_ranks:
|
||||
raise errors.ValidationError(
|
||||
'Bad access rank. Valid access ranks: %r' % available_access_ranks)
|
||||
if available_access_ranks.index(authenticated_user.access_rank) \
|
||||
< available_access_ranks.index(rank):
|
||||
raise errors.AuthError('Trying to set higher access rank than one has')
|
||||
user.access_rank = rank
|
||||
|
||||
def bump_login_time(user):
|
||||
''' Update user's login time to current date. '''
|
||||
user.last_login_time = datetime.now()
|
||||
|
||||
def reset_password(user):
|
||||
''' Reset password for an user. '''
|
||||
password = auth.create_password()
|
||||
user.password_salt = auth.create_password()
|
||||
user.password_hash = auth.get_password_hash(user.password_salt, password)
|
||||
return password
|
||||
|
||||
def get_by_name(session, name):
|
||||
''' Retrieve an user by its name. '''
|
||||
return session.query(db.User).filter_by(name=name).first()
|
Loading…
Reference in a new issue