server/general: be more pythonic
This commit is contained in:
parent
7f4708c696
commit
219ab7c2c3
36 changed files with 335 additions and 428 deletions
|
@ -1,5 +1,5 @@
|
||||||
import hashlib
|
from szurubooru import config, errors
|
||||||
from szurubooru import errors
|
from szurubooru.util import auth, mailer, users
|
||||||
from szurubooru.api.base_api import BaseApi
|
from szurubooru.api.base_api import BaseApi
|
||||||
|
|
||||||
MAIL_SUBJECT = 'Password reset for {name}'
|
MAIL_SUBJECT = 'Password reset for {name}'
|
||||||
|
@ -9,45 +9,35 @@ MAIL_BODY = \
|
||||||
'Otherwise, please ignore this email.'
|
'Otherwise, please ignore this email.'
|
||||||
|
|
||||||
class PasswordReminderApi(BaseApi):
|
class PasswordReminderApi(BaseApi):
|
||||||
def __init__(self, config, mailer, user_service):
|
|
||||||
super().__init__()
|
|
||||||
self._config = config
|
|
||||||
self._mailer = mailer
|
|
||||||
self._user_service = user_service
|
|
||||||
|
|
||||||
def get(self, context, user_name):
|
def get(self, context, user_name):
|
||||||
user = self._user_service.get_by_name(context.session, user_name)
|
''' Send a mail with secure token to the correlated user. '''
|
||||||
|
user = users.get_by_name(context.session, user_name)
|
||||||
if not user:
|
if not user:
|
||||||
raise errors.NotFoundError('User %r not found.' % user_name)
|
raise errors.NotFoundError('User %r not found.' % user_name)
|
||||||
if not user.email:
|
if not user.email:
|
||||||
raise errors.ValidationError(
|
raise errors.ValidationError(
|
||||||
'User %r hasn\'t supplied email. Cannot reset password.' % user_name)
|
'User %r hasn\'t supplied email. Cannot reset password.' % user_name)
|
||||||
token = self._generate_authentication_token(user)
|
token = auth.generate_authentication_token(user)
|
||||||
url = '%s/password-reset/%s' % (
|
url = '%s/password-reset/%s' % (
|
||||||
self._config['basic']['base_url'].rstrip('/'), token)
|
config.config['basic']['base_url'].rstrip('/'), token)
|
||||||
self._mailer.send(
|
mailer.send_mail(
|
||||||
'noreply@%s' % self._config['basic']['name'],
|
'noreply@%s' % config.config['basic']['name'],
|
||||||
user.email,
|
user.email,
|
||||||
MAIL_SUBJECT.format(name=self._config['basic']['name']),
|
MAIL_SUBJECT.format(name=config.config['basic']['name']),
|
||||||
MAIL_BODY.format(name=self._config['basic']['name'], url=url))
|
MAIL_BODY.format(name=config.config['basic']['name'], url=url))
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
def post(self, context, user_name):
|
def post(self, context, user_name):
|
||||||
user = self._user_service.get_by_name(context.session, user_name)
|
''' Verify token from mail, generate a new password and return it. '''
|
||||||
|
user = users.get_by_name(context.session, user_name)
|
||||||
if not user:
|
if not user:
|
||||||
raise errors.NotFoundError('User %r not found.' % user_name)
|
raise errors.NotFoundError('User %r not found.' % user_name)
|
||||||
good_token = self._generate_authentication_token(user)
|
good_token = auth.generate_authentication_token(user)
|
||||||
if not 'token' in context.request:
|
if not 'token' in context.request:
|
||||||
raise errors.ValidationError('Missing password reset token.')
|
raise errors.ValidationError('Missing password reset token.')
|
||||||
token = context.request['token']
|
token = context.request['token']
|
||||||
if token != good_token:
|
if token != good_token:
|
||||||
raise errors.ValidationError('Invalid password reset token.')
|
raise errors.ValidationError('Invalid password reset token.')
|
||||||
new_password = self._user_service.reset_password(user)
|
new_password = users.reset_password(user)
|
||||||
context.session.commit()
|
context.session.commit()
|
||||||
return {'password': new_password}
|
return {'password': new_password}
|
||||||
|
|
||||||
def _generate_authentication_token(self, user):
|
|
||||||
digest = hashlib.sha256()
|
|
||||||
digest.update(self._config['basic']['secret'].encode('utf8'))
|
|
||||||
digest.update(user.password_salt.encode('utf8'))
|
|
||||||
return digest.hexdigest()
|
|
||||||
|
|
|
@ -1,9 +1,7 @@
|
||||||
import re
|
|
||||||
import sqlalchemy
|
import sqlalchemy
|
||||||
from szurubooru import errors
|
from szurubooru import errors, search
|
||||||
from szurubooru import util
|
from szurubooru.util import auth, users
|
||||||
from szurubooru.api.base_api import BaseApi
|
from szurubooru.api.base_api import BaseApi
|
||||||
from szurubooru.services import search
|
|
||||||
|
|
||||||
def _serialize_user(authenticated_user, user):
|
def _serialize_user(authenticated_user, user):
|
||||||
ret = {
|
ret = {
|
||||||
|
@ -21,29 +19,27 @@ def _serialize_user(authenticated_user, user):
|
||||||
class UserListApi(BaseApi):
|
class UserListApi(BaseApi):
|
||||||
''' API for lists of users. '''
|
''' API for lists of users. '''
|
||||||
|
|
||||||
def __init__(self, auth_service, user_service):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self._auth_service = auth_service
|
|
||||||
self._user_service = user_service
|
|
||||||
self._search_executor = search.SearchExecutor(search.UserSearchConfig())
|
self._search_executor = search.SearchExecutor(search.UserSearchConfig())
|
||||||
|
|
||||||
def get(self, context):
|
def get(self, context):
|
||||||
''' Retrieves a list of users. '''
|
''' Retrieve a list of users. '''
|
||||||
self._auth_service.verify_privilege(context.user, 'users:list')
|
auth.verify_privilege(context.user, 'users:list')
|
||||||
query = context.get_param_as_string('query')
|
query = context.get_param_as_string('query')
|
||||||
page = context.get_param_as_int('page', 1)
|
page = context.get_param_as_int('page', 1)
|
||||||
count, users = self._search_executor.execute(context.session, query, page)
|
count, user_list = self._search_executor.execute(context.session, query, page)
|
||||||
return {
|
return {
|
||||||
'query': query,
|
'query': query,
|
||||||
'page': page,
|
'page': page,
|
||||||
'page_size': self._search_executor.page_size,
|
'page_size': self._search_executor.page_size,
|
||||||
'total': count,
|
'total': count,
|
||||||
'users': [_serialize_user(context.user, user) for user in users],
|
'users': [_serialize_user(context.user, user) for user in user_list],
|
||||||
}
|
}
|
||||||
|
|
||||||
def post(self, context):
|
def post(self, context):
|
||||||
''' Creates a new user. '''
|
''' Create a new user. '''
|
||||||
self._auth_service.verify_privilege(context.user, 'users:create')
|
auth.verify_privilege(context.user, 'users:create')
|
||||||
|
|
||||||
try:
|
try:
|
||||||
name = context.request['name'].strip()
|
name = context.request['name'].strip()
|
||||||
|
@ -52,9 +48,9 @@ class UserListApi(BaseApi):
|
||||||
except KeyError as ex:
|
except KeyError as ex:
|
||||||
raise errors.ValidationError('Field %r not found.' % ex.args[0])
|
raise errors.ValidationError('Field %r not found.' % ex.args[0])
|
||||||
|
|
||||||
user = self._user_service.create_user(
|
user = users.create_user(name, password, email)
|
||||||
context.session, name, password, email)
|
|
||||||
try:
|
try:
|
||||||
|
context.session.add(user)
|
||||||
context.session.commit()
|
context.session.commit()
|
||||||
except sqlalchemy.exc.IntegrityError:
|
except sqlalchemy.exc.IntegrityError:
|
||||||
raise errors.IntegrityError('User %r already exists.' % name)
|
raise errors.IntegrityError('User %r already exists.' % name)
|
||||||
|
@ -63,26 +59,17 @@ class UserListApi(BaseApi):
|
||||||
class UserDetailApi(BaseApi):
|
class UserDetailApi(BaseApi):
|
||||||
''' API for individual users. '''
|
''' API for individual users. '''
|
||||||
|
|
||||||
def __init__(self, config, auth_service, password_service, user_service):
|
|
||||||
super().__init__()
|
|
||||||
self._available_access_ranks = config['service']['user_ranks']
|
|
||||||
self._name_regex = config['service']['user_name_regex']
|
|
||||||
self._password_regex = config['service']['password_regex']
|
|
||||||
self._password_service = password_service
|
|
||||||
self._auth_service = auth_service
|
|
||||||
self._user_service = user_service
|
|
||||||
|
|
||||||
def get(self, context, user_name):
|
def get(self, context, user_name):
|
||||||
''' Retrieves an user. '''
|
''' Retrieve an user. '''
|
||||||
self._auth_service.verify_privilege(context.user, 'users:view')
|
auth.verify_privilege(context.user, 'users:view')
|
||||||
user = self._user_service.get_by_name(context.session, user_name)
|
user = users.get_by_name(context.session, user_name)
|
||||||
if not user:
|
if not user:
|
||||||
raise errors.NotFoundError('User %r not found.' % user_name)
|
raise errors.NotFoundError('User %r not found.' % user_name)
|
||||||
return {'user': _serialize_user(context.user, user)}
|
return {'user': _serialize_user(context.user, user)}
|
||||||
|
|
||||||
def put(self, context, user_name):
|
def put(self, context, user_name):
|
||||||
''' Updates an existing user. '''
|
''' Update an existing user. '''
|
||||||
user = self._user_service.get_by_name(context.session, user_name)
|
user = users.get_by_name(context.session, user_name)
|
||||||
if not user:
|
if not user:
|
||||||
raise errors.NotFoundError('User %r not found.' % user_name)
|
raise errors.NotFoundError('User %r not found.' % user_name)
|
||||||
|
|
||||||
|
@ -92,53 +79,26 @@ class UserDetailApi(BaseApi):
|
||||||
infix = 'any'
|
infix = 'any'
|
||||||
|
|
||||||
if 'name' in context.request:
|
if 'name' in context.request:
|
||||||
self._auth_service.verify_privilege(
|
auth.verify_privilege(context.user, 'users:edit:%s:name' % infix)
|
||||||
context.user, 'users:edit:%s:name' % infix)
|
users.update_name(user, context.request['name'])
|
||||||
name = context.request['name'].strip()
|
|
||||||
if not re.match(self._name_regex, name):
|
|
||||||
raise errors.ValidationError(
|
|
||||||
'Name must satisfy regex %r.' % self._name_regex)
|
|
||||||
user.name = name
|
|
||||||
|
|
||||||
if 'password' in context.request:
|
if 'password' in context.request:
|
||||||
password = context.request['password']
|
auth.verify_privilege(context.user, 'users:edit:%s:pass' % infix)
|
||||||
self._auth_service.verify_privilege(
|
users.update_password(user, context.request['password'])
|
||||||
context.user, 'users:edit:%s:pass' % infix)
|
|
||||||
if not re.match(self._password_regex, password):
|
|
||||||
raise errors.ValidationError(
|
|
||||||
'Password must satisfy regex %r.' % self._password_regex)
|
|
||||||
user.password_salt = self._password_service.create_password()
|
|
||||||
user.password_hash = self._password_service.get_password_hash(
|
|
||||||
user.password_salt, password)
|
|
||||||
|
|
||||||
if 'email' in context.request:
|
if 'email' in context.request:
|
||||||
self._auth_service.verify_privilege(
|
auth.verify_privilege(context.user, 'users:edit:%s:email' % infix)
|
||||||
context.user, 'users:edit:%s:email' % infix)
|
users.update_email(user, context.request['email'])
|
||||||
email = context.request['email'].strip() or None
|
|
||||||
if not util.is_valid_email(email):
|
|
||||||
raise errors.ValidationError(
|
|
||||||
'%r is not a vaild email address.' % email)
|
|
||||||
user.email = email
|
|
||||||
|
|
||||||
if 'accessRank' in context.request:
|
if 'accessRank' in context.request:
|
||||||
self._auth_service.verify_privilege(
|
auth.verify_privilege(context.user, 'users:edit:%s:rank' % infix)
|
||||||
context.user, 'users:edit:%s:rank' % infix)
|
users.update_rank(user, context.request['accessRank'], context.user)
|
||||||
rank = context.request['accessRank'].strip()
|
|
||||||
if not rank in self._available_access_ranks:
|
|
||||||
raise errors.ValidationError(
|
|
||||||
'Bad access rank. Valid access ranks: %r' \
|
|
||||||
% self._available_access_ranks)
|
|
||||||
if self._available_access_ranks.index(context.user.access_rank) \
|
|
||||||
< self._available_access_ranks.index(rank):
|
|
||||||
raise errors.AuthError(
|
|
||||||
'Trying to set higher access rank than one has')
|
|
||||||
user.access_rank = rank
|
|
||||||
|
|
||||||
# TODO: avatar
|
# TODO: avatar
|
||||||
|
|
||||||
try:
|
try:
|
||||||
context.session.commit()
|
context.session.commit()
|
||||||
except sqlalchemy.exc.IntegrityError:
|
except sqlalchemy.exc.IntegrityError:
|
||||||
raise errors.IntegrityError('User %r already exists.' % name)
|
raise errors.IntegrityError('User %r already exists.' % user.name)
|
||||||
|
|
||||||
return {'user': _serialize_user(context.user, user)}
|
return {'user': _serialize_user(context.user, user)}
|
||||||
|
|
|
@ -3,16 +3,11 @@
|
||||||
import falcon
|
import falcon
|
||||||
import sqlalchemy
|
import sqlalchemy
|
||||||
import sqlalchemy.orm
|
import sqlalchemy.orm
|
||||||
import szurubooru.api
|
from szurubooru import api, config, errors, middleware
|
||||||
import szurubooru.config
|
from szurubooru.util import misc
|
||||||
import szurubooru.errors
|
|
||||||
import szurubooru.middleware
|
|
||||||
import szurubooru.services
|
|
||||||
import szurubooru.services.search
|
|
||||||
import szurubooru.util
|
|
||||||
|
|
||||||
class _CustomRequest(falcon.Request):
|
class _CustomRequest(falcon.Request):
|
||||||
context_type = szurubooru.util.dotdict
|
context_type = misc.dotdict
|
||||||
|
|
||||||
def get_param_as_string(self, name, required=False, store=None, default=None):
|
def get_param_as_string(self, name, required=False, store=None, default=None):
|
||||||
params = self._params
|
params = self._params
|
||||||
|
@ -45,47 +40,37 @@ def _on_not_found_error(ex, _request, _response, _params):
|
||||||
raise falcon.HTTPNotFound(title='Not found', description=str(ex))
|
raise falcon.HTTPNotFound(title='Not found', description=str(ex))
|
||||||
|
|
||||||
def create_app():
|
def create_app():
|
||||||
''' Creates a WSGI compatible App object. '''
|
''' Create a WSGI compatible App object. '''
|
||||||
config = szurubooru.config.Config()
|
|
||||||
|
|
||||||
engine = sqlalchemy.create_engine(
|
engine = sqlalchemy.create_engine(
|
||||||
'{schema}://{user}:{password}@{host}:{port}/{name}'.format(
|
'{schema}://{user}:{password}@{host}:{port}/{name}'.format(
|
||||||
schema=config['database']['schema'],
|
schema=config.config['database']['schema'],
|
||||||
user=config['database']['user'],
|
user=config.config['database']['user'],
|
||||||
password=config['database']['pass'],
|
password=config.config['database']['pass'],
|
||||||
host=config['database']['host'],
|
host=config.config['database']['host'],
|
||||||
port=config['database']['port'],
|
port=config.config['database']['port'],
|
||||||
name=config['database']['name']))
|
name=config.config['database']['name']))
|
||||||
session_maker = sqlalchemy.orm.sessionmaker(bind=engine)
|
session_maker = sqlalchemy.orm.sessionmaker(bind=engine)
|
||||||
scoped_session = sqlalchemy.orm.scoped_session(session_maker)
|
scoped_session = sqlalchemy.orm.scoped_session(session_maker)
|
||||||
|
|
||||||
# TODO: is there a better way?
|
|
||||||
mailer = szurubooru.services.Mailer(config)
|
|
||||||
password_service = szurubooru.services.PasswordService(config)
|
|
||||||
auth_service = szurubooru.services.AuthService(config, password_service)
|
|
||||||
user_service = szurubooru.services.UserService(config, password_service)
|
|
||||||
|
|
||||||
user_list_api = szurubooru.api.UserListApi(auth_service, user_service)
|
|
||||||
user_detail_api = szurubooru.api.UserDetailApi(
|
|
||||||
config, auth_service, password_service, user_service)
|
|
||||||
password_reminder_api = szurubooru.api.PasswordReminderApi(
|
|
||||||
config, mailer, user_service)
|
|
||||||
|
|
||||||
app = falcon.API(
|
app = falcon.API(
|
||||||
request_type=_CustomRequest,
|
request_type=_CustomRequest,
|
||||||
middleware=[
|
middleware=[
|
||||||
szurubooru.middleware.ImbueContext(),
|
middleware.ImbueContext(),
|
||||||
szurubooru.middleware.RequireJson(),
|
middleware.RequireJson(),
|
||||||
szurubooru.middleware.JsonTranslator(),
|
middleware.JsonTranslator(),
|
||||||
szurubooru.middleware.DbSession(scoped_session),
|
middleware.DbSession(scoped_session),
|
||||||
szurubooru.middleware.Authenticator(auth_service, user_service),
|
middleware.Authenticator(),
|
||||||
])
|
])
|
||||||
|
|
||||||
app.add_error_handler(szurubooru.errors.AuthError, _on_auth_error)
|
user_list_api = api.UserListApi()
|
||||||
app.add_error_handler(szurubooru.errors.IntegrityError, _on_integrity_error)
|
user_detail_api = api.UserDetailApi()
|
||||||
app.add_error_handler(szurubooru.errors.ValidationError, _on_validation_error)
|
password_reminder_api = api.PasswordReminderApi()
|
||||||
app.add_error_handler(szurubooru.errors.SearchError, _on_search_error)
|
|
||||||
app.add_error_handler(szurubooru.errors.NotFoundError, _on_not_found_error)
|
app.add_error_handler(errors.AuthError, _on_auth_error)
|
||||||
|
app.add_error_handler(errors.IntegrityError, _on_integrity_error)
|
||||||
|
app.add_error_handler(errors.ValidationError, _on_validation_error)
|
||||||
|
app.add_error_handler(errors.SearchError, _on_search_error)
|
||||||
|
app.add_error_handler(errors.NotFoundError, _on_not_found_error)
|
||||||
|
|
||||||
app.add_route('/users/', user_list_api)
|
app.add_route('/users/', user_list_api)
|
||||||
app.add_route('/user/{user_name}', user_detail_api)
|
app.add_route('/user/{user_name}', user_detail_api)
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
import os
|
import os
|
||||||
import configobj
|
import configobj
|
||||||
import szurubooru.errors
|
from szurubooru import errors
|
||||||
|
|
||||||
class Config(object):
|
class Config(object):
|
||||||
''' INI config parser and container. '''
|
''' INI config parser and container. '''
|
||||||
|
@ -15,20 +15,22 @@ class Config(object):
|
||||||
|
|
||||||
def _validate(self):
|
def _validate(self):
|
||||||
'''
|
'''
|
||||||
Checks whether config.ini doesn't contain errors that might prove
|
Check whether config.ini doesn't contain errors that might prove
|
||||||
lethal at runtime.
|
lethal at runtime.
|
||||||
'''
|
'''
|
||||||
all_ranks = self['service']['user_ranks']
|
all_ranks = self['service']['user_ranks']
|
||||||
for privilege, rank in self['privileges'].items():
|
for privilege, rank in self['privileges'].items():
|
||||||
if rank not in all_ranks:
|
if rank not in all_ranks:
|
||||||
raise szurubooru.errors.ConfigError(
|
raise errors.ConfigError(
|
||||||
'Rank %r for privilege %r is missing from user_ranks' % (
|
'Rank %r for privilege %r is missing from user_ranks' % (
|
||||||
rank, privilege))
|
rank, privilege))
|
||||||
for rank in ['anonymous', 'admin', 'nobody']:
|
for rank in ['anonymous', 'admin', 'nobody']:
|
||||||
if rank not in all_ranks:
|
if rank not in all_ranks:
|
||||||
raise szurubooru.errors.ConfigError(
|
raise errors.ConfigError(
|
||||||
'Fixed rank %r is missing from user_ranks' % rank)
|
'Fixed rank %r is missing from user_ranks' % rank)
|
||||||
if self['service']['default_user_rank'] not in all_ranks:
|
if self['service']['default_user_rank'] not in all_ranks:
|
||||||
raise szurubooru.errors.ConfigError(
|
raise errors.ConfigError(
|
||||||
'Default rank %r is missing from user_ranks' % (
|
'Default rank %r is missing from user_ranks' % (
|
||||||
self['service']['default_user_rank']))
|
self['service']['default_user_rank']))
|
||||||
|
|
||||||
|
config = Config() # pylint: disable=invalid-name
|
||||||
|
|
4
server/szurubooru/db/__init__.py
Normal file
4
server/szurubooru/db/__init__.py
Normal file
|
@ -0,0 +1,4 @@
|
||||||
|
''' Database layer. '''
|
||||||
|
|
||||||
|
from szurubooru.db.base import Base
|
||||||
|
from szurubooru.db.user import User
|
|
@ -1,7 +1,5 @@
|
||||||
# pylint: disable=too-many-instance-attributes,too-few-public-methods
|
|
||||||
|
|
||||||
import sqlalchemy as sa
|
import sqlalchemy as sa
|
||||||
from szurubooru.model.base import Base
|
from szurubooru.db.base import Base
|
||||||
|
|
||||||
class User(Base):
|
class User(Base):
|
||||||
__tablename__ = 'user'
|
__tablename__ = 'user'
|
|
@ -1,24 +1,20 @@
|
||||||
import base64
|
import base64
|
||||||
import falcon
|
import falcon
|
||||||
from szurubooru import errors
|
from szurubooru import db, errors
|
||||||
from szurubooru import model
|
from szurubooru.util import auth, users
|
||||||
|
|
||||||
class Authenticator(object):
|
class Authenticator(object):
|
||||||
'''
|
'''
|
||||||
Authenticates every request and puts information on active user in the
|
Authenticates every request and put information on active user in the
|
||||||
request context.
|
request context.
|
||||||
'''
|
'''
|
||||||
|
|
||||||
def __init__(self, auth_service, user_service):
|
|
||||||
self._auth_service = auth_service
|
|
||||||
self._user_service = user_service
|
|
||||||
|
|
||||||
def process_request(self, request, _response):
|
def process_request(self, request, _response):
|
||||||
''' Executed before passing the request to the API. '''
|
''' Bind the user to request. Update last login time if needed. '''
|
||||||
request.context.user = self._get_user(request)
|
request.context.user = self._get_user(request)
|
||||||
if request.get_param_as_bool('bump-login') \
|
if request.get_param_as_bool('bump-login') \
|
||||||
and request.context.user.user_id:
|
and request.context.user.user_id:
|
||||||
self._user_service.bump_login_time(request.context.user)
|
users.bump_login_time(request.context.user)
|
||||||
request.context.session.commit()
|
request.context.session.commit()
|
||||||
|
|
||||||
def _get_user(self, request):
|
def _get_user(self, request):
|
||||||
|
@ -27,15 +23,12 @@ class Authenticator(object):
|
||||||
|
|
||||||
try:
|
try:
|
||||||
auth_type, user_and_password = request.auth.split(' ', 1)
|
auth_type, user_and_password = request.auth.split(' ', 1)
|
||||||
|
|
||||||
if auth_type.lower() != 'basic':
|
if auth_type.lower() != 'basic':
|
||||||
raise falcon.HTTPBadRequest(
|
raise falcon.HTTPBadRequest(
|
||||||
'Invalid authentication type',
|
'Invalid authentication type',
|
||||||
'Only basic authorization is supported.')
|
'Only basic authorization is supported.')
|
||||||
|
|
||||||
username, password = base64.decodebytes(
|
username, password = base64.decodebytes(
|
||||||
user_and_password.encode('ascii')).decode('utf8').split(':')
|
user_and_password.encode('ascii')).decode('utf8').split(':')
|
||||||
|
|
||||||
return self._authenticate(
|
return self._authenticate(
|
||||||
request.context.session, username, password)
|
request.context.session, username, password)
|
||||||
except ValueError as err:
|
except ValueError as err:
|
||||||
|
@ -46,16 +39,16 @@ class Authenticator(object):
|
||||||
msg.format(request.auth, str(err)))
|
msg.format(request.auth, str(err)))
|
||||||
|
|
||||||
def _authenticate(self, session, username, password):
|
def _authenticate(self, session, username, password):
|
||||||
''' Tries to authenticate user. Throws AuthError for invalid users. '''
|
''' Try to authenticate user. Throw AuthError for invalid users. '''
|
||||||
user = self._user_service.get_by_name(session, username)
|
user = users.get_by_name(session, username)
|
||||||
if not user:
|
if not user:
|
||||||
raise errors.AuthError('No such user.')
|
raise errors.AuthError('No such user.')
|
||||||
if not self._auth_service.is_valid_password(user, password):
|
if not auth.is_valid_password(user, password):
|
||||||
raise errors.AuthError('Invalid password.')
|
raise errors.AuthError('Invalid password.')
|
||||||
return user
|
return user
|
||||||
|
|
||||||
def _create_anonymous_user(self):
|
def _create_anonymous_user(self):
|
||||||
user = model.User()
|
user = db.User()
|
||||||
user.name = None
|
user.name = None
|
||||||
user.access_rank = 'anonymous'
|
user.access_rank = 'anonymous'
|
||||||
user.password = None
|
user.password = None
|
||||||
|
|
|
@ -5,12 +5,8 @@ class DbSession(object):
|
||||||
self._session_factory = session_factory
|
self._session_factory = session_factory
|
||||||
|
|
||||||
def process_request(self, request, _response):
|
def process_request(self, request, _response):
|
||||||
''' Executed before passing the request to the API. '''
|
|
||||||
request.context.session = self._session_factory()
|
request.context.session = self._session_factory()
|
||||||
|
|
||||||
def process_response(self, request, _response, _resource):
|
def process_response(self, request, _response, _resource):
|
||||||
'''
|
# any commits need to happen explicitly in the API layer.
|
||||||
Executed before passing the response to falcon.
|
|
||||||
Any commits to database need to happen explicitly in the API layer.
|
|
||||||
'''
|
|
||||||
request.context.session.close()
|
request.context.session.close()
|
||||||
|
|
|
@ -16,7 +16,6 @@ class JsonTranslator(object):
|
||||||
'''
|
'''
|
||||||
|
|
||||||
def process_request(self, request, _response):
|
def process_request(self, request, _response):
|
||||||
''' Executed before passing the request to the API. '''
|
|
||||||
if request.content_length in (None, 0):
|
if request.content_length in (None, 0):
|
||||||
return
|
return
|
||||||
|
|
||||||
|
@ -36,7 +35,6 @@ class JsonTranslator(object):
|
||||||
'JSON was incorrect or not encoded as UTF-8.')
|
'JSON was incorrect or not encoded as UTF-8.')
|
||||||
|
|
||||||
def process_response(self, request, response, _resource):
|
def process_response(self, request, response, _resource):
|
||||||
''' Executed before passing the response to falcon. '''
|
|
||||||
if 'result' not in request.context:
|
if 'result' not in request.context:
|
||||||
return
|
return
|
||||||
response.body = json.dumps(
|
response.body = json.dumps(
|
||||||
|
|
|
@ -4,7 +4,6 @@ class RequireJson(object):
|
||||||
''' Sanitizes requests so that only JSON is accepted. '''
|
''' Sanitizes requests so that only JSON is accepted. '''
|
||||||
|
|
||||||
def process_request(self, request, _response):
|
def process_request(self, request, _response):
|
||||||
''' Executed before passing the request to the API. '''
|
|
||||||
if not request.client_accepts_json:
|
if not request.client_accepts_json:
|
||||||
raise falcon.HTTPNotAcceptable(
|
raise falcon.HTTPNotAcceptable(
|
||||||
'This API only supports responses encoded as JSON.')
|
'This API only supports responses encoded as JSON.')
|
||||||
|
|
|
@ -9,13 +9,13 @@ import logging.config
|
||||||
dir_to_self = os.path.dirname(os.path.realpath(__file__))
|
dir_to_self = os.path.dirname(os.path.realpath(__file__))
|
||||||
sys.path.append(os.path.join(dir_to_self, *[os.pardir] * 2))
|
sys.path.append(os.path.join(dir_to_self, *[os.pardir] * 2))
|
||||||
|
|
||||||
import szurubooru.model.base
|
import szurubooru.db.base
|
||||||
import szurubooru.config
|
import szurubooru.config
|
||||||
|
|
||||||
alembic_config = alembic.context.config
|
alembic_config = alembic.context.config
|
||||||
logging.config.fileConfig(alembic_config.config_file_name)
|
logging.config.fileConfig(alembic_config.config_file_name)
|
||||||
|
|
||||||
szuru_config = szurubooru.config.Config()
|
szuru_config = szurubooru.config.config
|
||||||
alembic_config.set_main_option(
|
alembic_config.set_main_option(
|
||||||
'sqlalchemy.url',
|
'sqlalchemy.url',
|
||||||
'{schema}://{user}:{password}@{host}:{port}/{name}'.format(
|
'{schema}://{user}:{password}@{host}:{port}/{name}'.format(
|
||||||
|
@ -26,7 +26,7 @@ alembic_config.set_main_option(
|
||||||
port=szuru_config['database']['port'],
|
port=szuru_config['database']['port'],
|
||||||
name=szuru_config['database']['name']))
|
name=szuru_config['database']['name']))
|
||||||
|
|
||||||
target_metadata = szurubooru.model.Base.metadata
|
target_metadata = szurubooru.db.Base.metadata
|
||||||
|
|
||||||
def run_migrations_offline():
|
def run_migrations_offline():
|
||||||
'''
|
'''
|
||||||
|
|
|
@ -1,6 +0,0 @@
|
||||||
'''
|
|
||||||
Database models.
|
|
||||||
'''
|
|
||||||
|
|
||||||
from szurubooru.model.base import Base
|
|
||||||
from szurubooru.model.user import User
|
|
4
server/szurubooru/search/__init__.py
Normal file
4
server/szurubooru/search/__init__.py
Normal file
|
@ -0,0 +1,4 @@
|
||||||
|
''' Search parsers and services. '''
|
||||||
|
|
||||||
|
from szurubooru.search.search_executor import SearchExecutor
|
||||||
|
from szurubooru.search.user_search_config import UserSearchConfig
|
|
@ -1,11 +1,11 @@
|
||||||
import sqlalchemy
|
import sqlalchemy
|
||||||
import szurubooru.errors
|
import szurubooru.errors
|
||||||
from szurubooru import util
|
from szurubooru.util import misc
|
||||||
from szurubooru.services.search import criteria
|
from szurubooru.search import criteria
|
||||||
|
|
||||||
def _apply_criterion_to_column(
|
def _apply_criterion_to_column(
|
||||||
column, query, criterion, allow_composite=True, allow_ranged=True):
|
column, query, criterion, allow_composite=True, allow_ranged=True):
|
||||||
''' Decorates SQLAlchemy filter on given column using supplied criterion. '''
|
''' Decorate SQLAlchemy filter on given column using supplied criterion. '''
|
||||||
if isinstance(criterion, criteria.StringSearchCriterion):
|
if isinstance(criterion, criteria.StringSearchCriterion):
|
||||||
expr = column == criterion.value
|
expr = column == criterion.value
|
||||||
if criterion.negated:
|
if criterion.negated:
|
||||||
|
@ -32,11 +32,11 @@ def _apply_criterion_to_column(
|
||||||
|
|
||||||
def _apply_date_criterion_to_column(column, query, criterion):
|
def _apply_date_criterion_to_column(column, query, criterion):
|
||||||
'''
|
'''
|
||||||
Decorates SQLAlchemy filter on given column using supplied criterion.
|
Decorate SQLAlchemy filter on given column using supplied criterion.
|
||||||
Parses the datetime inside the criterion.
|
Parse the datetime inside the criterion.
|
||||||
'''
|
'''
|
||||||
if isinstance(criterion, criteria.StringSearchCriterion):
|
if isinstance(criterion, criteria.StringSearchCriterion):
|
||||||
min_date, max_date = util.parse_time_range(criterion.value)
|
min_date, max_date = misc.parse_time_range(criterion.value)
|
||||||
expr = column.between(min_date, max_date)
|
expr = column.between(min_date, max_date)
|
||||||
if criterion.negated:
|
if criterion.negated:
|
||||||
expr = ~expr
|
expr = ~expr
|
||||||
|
@ -44,7 +44,7 @@ def _apply_date_criterion_to_column(column, query, criterion):
|
||||||
elif isinstance(criterion, criteria.ArraySearchCriterion):
|
elif isinstance(criterion, criteria.ArraySearchCriterion):
|
||||||
expr = sqlalchemy.sql.false()
|
expr = sqlalchemy.sql.false()
|
||||||
for value in criterion.values:
|
for value in criterion.values:
|
||||||
min_date, max_date = util.parse_time_range(value)
|
min_date, max_date = misc.parse_time_range(value)
|
||||||
expr = expr | column.between(min_date, max_date)
|
expr = expr | column.between(min_date, max_date)
|
||||||
if criterion.negated:
|
if criterion.negated:
|
||||||
expr = ~expr
|
expr = ~expr
|
||||||
|
@ -52,14 +52,14 @@ def _apply_date_criterion_to_column(column, query, criterion):
|
||||||
elif isinstance(criterion, criteria.RangedSearchCriterion):
|
elif isinstance(criterion, criteria.RangedSearchCriterion):
|
||||||
assert criterion.min_value or criterion.max_value
|
assert criterion.min_value or criterion.max_value
|
||||||
if criterion.min_value and criterion.max_value:
|
if criterion.min_value and criterion.max_value:
|
||||||
min_date = util.parse_time_range(criterion.min_value)[0]
|
min_date = misc.parse_time_range(criterion.min_value)[0]
|
||||||
max_date = util.parse_time_range(criterion.max_value)[1]
|
max_date = misc.parse_time_range(criterion.max_value)[1]
|
||||||
expr = column.between(min_date, max_date)
|
expr = column.between(min_date, max_date)
|
||||||
elif criterion.min_value:
|
elif criterion.min_value:
|
||||||
min_date = util.parse_time_range(criterion.min_value)[0]
|
min_date = misc.parse_time_range(criterion.min_value)[0]
|
||||||
expr = column >= min_date
|
expr = column >= min_date
|
||||||
elif criterion.max_value:
|
elif criterion.max_value:
|
||||||
max_date = util.parse_time_range(criterion.max_value)[1]
|
max_date = misc.parse_time_range(criterion.max_value)[1]
|
||||||
expr = column <= max_date
|
expr = column <= max_date
|
||||||
if criterion.negated:
|
if criterion.negated:
|
||||||
expr = ~expr
|
expr = ~expr
|
|
@ -1,19 +1,17 @@
|
||||||
''' Exports SearchExecutor. '''
|
|
||||||
|
|
||||||
import re
|
import re
|
||||||
import sqlalchemy
|
import sqlalchemy
|
||||||
from szurubooru import errors
|
from szurubooru import errors
|
||||||
from szurubooru.services.search import criteria
|
from szurubooru.search import criteria
|
||||||
|
|
||||||
class SearchExecutor(object):
|
class SearchExecutor(object):
|
||||||
ORDER_DESC = 1
|
|
||||||
ORDER_ASC = 2
|
|
||||||
|
|
||||||
'''
|
'''
|
||||||
Class for search parsing and execution. Handles plaintext parsing and
|
Class for search parsing and execution. Handles plaintext parsing and
|
||||||
delegates sqlalchemy filter decoration to SearchConfig instances.
|
delegates sqlalchemy filter decoration to SearchConfig instances.
|
||||||
'''
|
'''
|
||||||
|
|
||||||
|
ORDER_DESC = 1
|
||||||
|
ORDER_ASC = 2
|
||||||
|
|
||||||
def __init__(self, search_config):
|
def __init__(self, search_config):
|
||||||
self.page_size = 100
|
self.page_size = 100
|
||||||
self._search_config = search_config
|
self._search_config = search_config
|
42
server/szurubooru/search/user_search_config.py
Normal file
42
server/szurubooru/search/user_search_config.py
Normal file
|
@ -0,0 +1,42 @@
|
||||||
|
from sqlalchemy.sql.expression import func
|
||||||
|
from szurubooru import db
|
||||||
|
from szurubooru.search.base_search_config import BaseSearchConfig
|
||||||
|
|
||||||
|
class UserSearchConfig(BaseSearchConfig):
|
||||||
|
''' Executes searches related to the users. '''
|
||||||
|
|
||||||
|
def create_query(self, session):
|
||||||
|
return session.query(db.User)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def anonymous_filter(self):
|
||||||
|
return self._create_basic_filter(db.User.name, allow_ranged=False)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def special_filters(self):
|
||||||
|
return {}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def named_filters(self):
|
||||||
|
return {
|
||||||
|
'name': self._create_basic_filter(db.User.name, allow_ranged=False),
|
||||||
|
'creation_date': self._create_date_filter(db.User.creation_time),
|
||||||
|
'creation_time': self._create_date_filter(db.User.creation_time),
|
||||||
|
'last_login_date': self._create_date_filter(db.User.last_login_time),
|
||||||
|
'last_login_time': self._create_date_filter(db.User.last_login_time),
|
||||||
|
'login_date': self._create_date_filter(db.User.last_login_time),
|
||||||
|
'login_time': self._create_date_filter(db.User.last_login_time),
|
||||||
|
}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def order_columns(self):
|
||||||
|
return {
|
||||||
|
'random': func.random(),
|
||||||
|
'name': db.User.name,
|
||||||
|
'creation_date': db.User.creation_time,
|
||||||
|
'creation_time': db.User.creation_time,
|
||||||
|
'last_login_date': db.User.last_login_time,
|
||||||
|
'last_login_time': db.User.last_login_time,
|
||||||
|
'login_date': db.User.last_login_time,
|
||||||
|
'login_time': db.User.last_login_time,
|
||||||
|
}
|
|
@ -1,9 +0,0 @@
|
||||||
'''
|
|
||||||
Middle layer between REST API and database.
|
|
||||||
All the business logic goes here.
|
|
||||||
'''
|
|
||||||
|
|
||||||
from szurubooru.services.mailer import Mailer
|
|
||||||
from szurubooru.services.auth_service import AuthService
|
|
||||||
from szurubooru.services.user_service import UserService
|
|
||||||
from szurubooru.services.password_service import PasswordService
|
|
|
@ -1,28 +0,0 @@
|
||||||
from szurubooru import errors
|
|
||||||
|
|
||||||
class AuthService(object):
|
|
||||||
def __init__(self, config, password_service):
|
|
||||||
self._config = config
|
|
||||||
self._password_service = password_service
|
|
||||||
|
|
||||||
def is_valid_password(self, user, password):
|
|
||||||
''' Returns whether the given password for a given user is valid. '''
|
|
||||||
salt, valid_hash = user.password_salt, user.password_hash
|
|
||||||
possible_hashes = [
|
|
||||||
self._password_service.get_password_hash(salt, password),
|
|
||||||
self._password_service.get_legacy_password_hash(salt, password)
|
|
||||||
]
|
|
||||||
return valid_hash in possible_hashes
|
|
||||||
|
|
||||||
def verify_privilege(self, user, privilege_name):
|
|
||||||
'''
|
|
||||||
Throws an AuthError if the given user doesn't have given privilege.
|
|
||||||
'''
|
|
||||||
all_ranks = self._config['service']['user_ranks']
|
|
||||||
|
|
||||||
assert privilege_name in self._config['privileges']
|
|
||||||
assert user.access_rank in all_ranks
|
|
||||||
minimal_rank = self._config['privileges'][privilege_name]
|
|
||||||
good_ranks = all_ranks[all_ranks.index(minimal_rank):]
|
|
||||||
if user.access_rank not in good_ranks:
|
|
||||||
raise errors.AuthError('Insufficient privileges to do this.')
|
|
|
@ -1,19 +0,0 @@
|
||||||
import smtplib
|
|
||||||
import email.mime.text
|
|
||||||
|
|
||||||
class Mailer(object):
|
|
||||||
def __init__(self, config):
|
|
||||||
self._config = config
|
|
||||||
|
|
||||||
def send(self, sender, recipient, subject, body):
|
|
||||||
msg = email.mime.text.MIMEText(body)
|
|
||||||
msg['Subject'] = subject
|
|
||||||
msg['From'] = sender
|
|
||||||
msg['To'] = recipient
|
|
||||||
|
|
||||||
smtp = smtplib.SMTP(
|
|
||||||
self._config['smtp']['host'],
|
|
||||||
int(self._config['smtp']['port']))
|
|
||||||
smtp.login(self._config['smtp']['user'], self._config['smtp']['pass'])
|
|
||||||
smtp.send_message(msg)
|
|
||||||
smtp.quit()
|
|
|
@ -1,34 +0,0 @@
|
||||||
import hashlib
|
|
||||||
import random
|
|
||||||
|
|
||||||
class PasswordService(object):
|
|
||||||
''' Stateless utilities for passwords '''
|
|
||||||
|
|
||||||
def __init__(self, config):
|
|
||||||
self._config = config
|
|
||||||
|
|
||||||
def get_password_hash(self, salt, password):
|
|
||||||
''' Retrieves new-style password hash. '''
|
|
||||||
digest = hashlib.sha256()
|
|
||||||
digest.update(self._config['basic']['secret'].encode('utf8'))
|
|
||||||
digest.update(salt.encode('utf8'))
|
|
||||||
digest.update(password.encode('utf8'))
|
|
||||||
return digest.hexdigest()
|
|
||||||
|
|
||||||
def get_legacy_password_hash(self, salt, password):
|
|
||||||
''' Retrieves old-style password hash. '''
|
|
||||||
digest = hashlib.sha1()
|
|
||||||
digest.update(b'1A2/$_4xVa')
|
|
||||||
digest.update(salt.encode('utf8'))
|
|
||||||
digest.update(password.encode('utf8'))
|
|
||||||
return digest.hexdigest()
|
|
||||||
|
|
||||||
def create_password(self):
|
|
||||||
''' Creates an easy-to-remember password. '''
|
|
||||||
alphabet = {
|
|
||||||
'c': list('bcdfghijklmnpqrstvwxyz'),
|
|
||||||
'v': list('aeiou'),
|
|
||||||
'n': list('0123456789'),
|
|
||||||
}
|
|
||||||
pattern = 'cvcvnncvcv'
|
|
||||||
return ''.join(random.choice(alphabet[l]) for l in list(pattern))
|
|
|
@ -1,2 +0,0 @@
|
||||||
from szurubooru.services.search.search_executor import SearchExecutor
|
|
||||||
from szurubooru.services.search.user_search_config import UserSearchConfig
|
|
|
@ -1,44 +0,0 @@
|
||||||
''' Exports UserSearchConfig. '''
|
|
||||||
|
|
||||||
from sqlalchemy.sql.expression import func
|
|
||||||
from szurubooru.model import User
|
|
||||||
from szurubooru.services.search.base_search_config import BaseSearchConfig
|
|
||||||
|
|
||||||
class UserSearchConfig(BaseSearchConfig):
|
|
||||||
''' Executes searches related to the users. '''
|
|
||||||
|
|
||||||
def create_query(self, session):
|
|
||||||
return session.query(User)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def anonymous_filter(self):
|
|
||||||
return self._create_basic_filter(User.name, allow_ranged=False)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def special_filters(self):
|
|
||||||
return {}
|
|
||||||
|
|
||||||
@property
|
|
||||||
def named_filters(self):
|
|
||||||
return {
|
|
||||||
'name': self._create_basic_filter(User.name, allow_ranged=False),
|
|
||||||
'creation_date': self._create_date_filter(User.creation_time),
|
|
||||||
'creation_time': self._create_date_filter(User.creation_time),
|
|
||||||
'last_login_date': self._create_date_filter(User.last_login_time),
|
|
||||||
'last_login_time': self._create_date_filter(User.last_login_time),
|
|
||||||
'login_date': self._create_date_filter(User.last_login_time),
|
|
||||||
'login_time': self._create_date_filter(User.last_login_time),
|
|
||||||
}
|
|
||||||
|
|
||||||
@property
|
|
||||||
def order_columns(self):
|
|
||||||
return {
|
|
||||||
'random': func.random(),
|
|
||||||
'name': User.name,
|
|
||||||
'creation_date': User.creation_time,
|
|
||||||
'creation_time': User.creation_time,
|
|
||||||
'last_login_date': User.last_login_time,
|
|
||||||
'last_login_time': User.last_login_time,
|
|
||||||
'login_date': User.last_login_time,
|
|
||||||
'login_time': User.last_login_time,
|
|
||||||
}
|
|
|
@ -1,56 +0,0 @@
|
||||||
import re
|
|
||||||
from datetime import datetime
|
|
||||||
from szurubooru import errors
|
|
||||||
from szurubooru import model
|
|
||||||
from szurubooru import util
|
|
||||||
|
|
||||||
class UserService(object):
|
|
||||||
''' User management '''
|
|
||||||
|
|
||||||
def __init__(self, config, password_service):
|
|
||||||
self._config = config
|
|
||||||
self._password_service = password_service
|
|
||||||
self._name_regex = self._config['service']['user_name_regex']
|
|
||||||
self._password_regex = self._config['service']['password_regex']
|
|
||||||
|
|
||||||
def create_user(self, session, name, password, email):
|
|
||||||
''' Creates an user with given parameters and returns it. '''
|
|
||||||
|
|
||||||
if not re.match(self._name_regex, name):
|
|
||||||
raise errors.ValidationError(
|
|
||||||
'Name must satisfy regex %r.' % self._name_regex)
|
|
||||||
|
|
||||||
if not re.match(self._password_regex, password):
|
|
||||||
raise errors.ValidationError(
|
|
||||||
'Password must satisfy regex %r.' % self._password_regex)
|
|
||||||
|
|
||||||
if not util.is_valid_email(email):
|
|
||||||
raise errors.ValidationError(
|
|
||||||
'%r is not a vaild email address.' % email)
|
|
||||||
|
|
||||||
user = model.User()
|
|
||||||
user.name = name
|
|
||||||
user.password_salt = self._password_service.create_password()
|
|
||||||
user.password_hash = self._password_service.get_password_hash(
|
|
||||||
user.password_salt, password)
|
|
||||||
user.email = email or None
|
|
||||||
user.access_rank = self._config['service']['default_user_rank']
|
|
||||||
user.creation_time = datetime.now()
|
|
||||||
user.avatar_style = model.User.AVATAR_GRAVATAR
|
|
||||||
|
|
||||||
session.add(user)
|
|
||||||
return user
|
|
||||||
|
|
||||||
def bump_login_time(self, user):
|
|
||||||
user.last_login_time = datetime.now()
|
|
||||||
|
|
||||||
def reset_password(self, user):
|
|
||||||
password = self._password_service.create_password()
|
|
||||||
user.password_salt = self._password_service.create_password()
|
|
||||||
user.password_hash = self._password_service.get_password_hash(
|
|
||||||
user.password_salt, password)
|
|
||||||
return password
|
|
||||||
|
|
||||||
def get_by_name(self, session, name):
|
|
||||||
''' Retrieves an user by its name. '''
|
|
||||||
return session.query(model.User).filter_by(name=name).first()
|
|
|
@ -1,15 +1,12 @@
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
import szurubooru.services
|
from szurubooru import api, db, errors, config
|
||||||
from szurubooru.api.user_api import UserDetailApi
|
from szurubooru.util import auth, misc
|
||||||
from szurubooru.errors import AuthError, ValidationError
|
|
||||||
from szurubooru.model.user import User
|
|
||||||
from szurubooru.tests.database_test_case import DatabaseTestCase
|
from szurubooru.tests.database_test_case import DatabaseTestCase
|
||||||
from szurubooru.util import dotdict
|
|
||||||
|
|
||||||
class TestUserDetailApi(DatabaseTestCase):
|
class TestUserDetailApi(DatabaseTestCase):
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
super().setUp()
|
super().setUp()
|
||||||
config = {
|
config_mock = {
|
||||||
'basic': {
|
'basic': {
|
||||||
'secret': '',
|
'secret': '',
|
||||||
},
|
},
|
||||||
|
@ -30,18 +27,18 @@ class TestUserDetailApi(DatabaseTestCase):
|
||||||
'users:edit:any:rank': 'admin',
|
'users:edit:any:rank': 'admin',
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
password_service = szurubooru.services.PasswordService(config)
|
self.old_config = config.config
|
||||||
auth_service = szurubooru.services.AuthService(config, password_service)
|
config.config = config_mock
|
||||||
user_service = szurubooru.services.UserService(config, password_service)
|
self.api = api.UserDetailApi()
|
||||||
self.auth_service = auth_service
|
self.context = misc.dotdict()
|
||||||
self.api = UserDetailApi(
|
|
||||||
config, auth_service, password_service, user_service)
|
|
||||||
self.context = dotdict()
|
|
||||||
self.context.session = self.session
|
self.context.session = self.session
|
||||||
self.context.request = {}
|
self.context.request = {}
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
config.config = self.old_config
|
||||||
|
|
||||||
def _create_user(self, name, rank='admin'):
|
def _create_user(self, name, rank='admin'):
|
||||||
user = User()
|
user = db.User()
|
||||||
user.name = name
|
user.name = name
|
||||||
user.password = 'dummy'
|
user.password = 'dummy'
|
||||||
user.password_salt = 'dummy'
|
user.password_salt = 'dummy'
|
||||||
|
@ -49,7 +46,7 @@ class TestUserDetailApi(DatabaseTestCase):
|
||||||
user.email = 'dummy'
|
user.email = 'dummy'
|
||||||
user.access_rank = rank
|
user.access_rank = rank
|
||||||
user.creation_time = datetime.now()
|
user.creation_time = datetime.now()
|
||||||
user.avatar_style = User.AVATAR_GRAVATAR
|
user.avatar_style = db.User.AVATAR_GRAVATAR
|
||||||
return user
|
return user
|
||||||
|
|
||||||
def test_updating_nothing(self):
|
def test_updating_nothing(self):
|
||||||
|
@ -57,7 +54,7 @@ class TestUserDetailApi(DatabaseTestCase):
|
||||||
self.session.add(admin_user)
|
self.session.add(admin_user)
|
||||||
self.context.user = admin_user
|
self.context.user = admin_user
|
||||||
self.api.put(self.context, 'u1')
|
self.api.put(self.context, 'u1')
|
||||||
admin_user = self.session.query(User).filter_by(name='u1').one()
|
admin_user = self.session.query(db.User).filter_by(name='u1').one()
|
||||||
self.assertEqual(admin_user.name, 'u1')
|
self.assertEqual(admin_user.name, 'u1')
|
||||||
self.assertEqual(admin_user.email, 'dummy')
|
self.assertEqual(admin_user.email, 'dummy')
|
||||||
self.assertEqual(admin_user.access_rank, 'admin')
|
self.assertEqual(admin_user.access_rank, 'admin')
|
||||||
|
@ -69,16 +66,16 @@ class TestUserDetailApi(DatabaseTestCase):
|
||||||
self.context.request = {
|
self.context.request = {
|
||||||
'name': 'chewie',
|
'name': 'chewie',
|
||||||
'email': 'asd@asd.asd',
|
'email': 'asd@asd.asd',
|
||||||
'password': 'valid',
|
'password': 'oks',
|
||||||
'accessRank': 'mod',
|
'accessRank': 'mod',
|
||||||
}
|
}
|
||||||
self.api.put(self.context, 'u1')
|
self.api.put(self.context, 'u1')
|
||||||
admin_user = self.session.query(User).filter_by(name='chewie').one()
|
admin_user = self.session.query(db.User).filter_by(name='chewie').one()
|
||||||
self.assertEqual(admin_user.name, 'chewie')
|
self.assertEqual(admin_user.name, 'chewie')
|
||||||
self.assertEqual(admin_user.email, 'asd@asd.asd')
|
self.assertEqual(admin_user.email, 'asd@asd.asd')
|
||||||
self.assertEqual(admin_user.access_rank, 'mod')
|
self.assertEqual(admin_user.access_rank, 'mod')
|
||||||
self.assertTrue(self.auth_service.is_valid_password(admin_user, 'valid'))
|
self.assertTrue(auth.is_valid_password(admin_user, 'oks'))
|
||||||
self.assertFalse(self.auth_service.is_valid_password(admin_user, 'invalid'))
|
self.assertFalse(auth.is_valid_password(admin_user, 'invalid'))
|
||||||
|
|
||||||
def test_removing_email(self):
|
def test_removing_email(self):
|
||||||
admin_user = self._create_user('u1', 'admin')
|
admin_user = self._create_user('u1', 'admin')
|
||||||
|
@ -86,7 +83,7 @@ class TestUserDetailApi(DatabaseTestCase):
|
||||||
self.context.user = admin_user
|
self.context.user = admin_user
|
||||||
self.context.request = {'email': ''}
|
self.context.request = {'email': ''}
|
||||||
self.api.put(self.context, 'u1')
|
self.api.put(self.context, 'u1')
|
||||||
admin_user = self.session.query(User).filter_by(name='u1').one()
|
admin_user = self.session.query(db.User).filter_by(name='u1').one()
|
||||||
self.assertEqual(admin_user.email, None)
|
self.assertEqual(admin_user.email, None)
|
||||||
|
|
||||||
def test_invalid_inputs(self):
|
def test_invalid_inputs(self):
|
||||||
|
@ -95,16 +92,16 @@ class TestUserDetailApi(DatabaseTestCase):
|
||||||
self.context.user = admin_user
|
self.context.user = admin_user
|
||||||
self.context.request = {'name': '.'}
|
self.context.request = {'name': '.'}
|
||||||
self.assertRaises(
|
self.assertRaises(
|
||||||
ValidationError, self.api.put, self.context, 'u1')
|
errors.ValidationError, self.api.put, self.context, 'u1')
|
||||||
self.context.request = {'password': '.'}
|
self.context.request = {'password': '.'}
|
||||||
self.assertRaises(
|
self.assertRaises(
|
||||||
ValidationError, self.api.put, self.context, 'u1')
|
errors.ValidationError, self.api.put, self.context, 'u1')
|
||||||
self.context.request = {'accessRank': '.'}
|
self.context.request = {'accessRank': '.'}
|
||||||
self.assertRaises(
|
self.assertRaises(
|
||||||
ValidationError, self.api.put, self.context, 'u1')
|
errors.ValidationError, self.api.put, self.context, 'u1')
|
||||||
self.context.request = {'email': '.'}
|
self.context.request = {'email': '.'}
|
||||||
self.assertRaises(
|
self.assertRaises(
|
||||||
ValidationError, self.api.put, self.context, 'u1')
|
errors.ValidationError, self.api.put, self.context, 'u1')
|
||||||
|
|
||||||
def test_user_trying_to_update_someone_else(self):
|
def test_user_trying_to_update_someone_else(self):
|
||||||
user1 = self._create_user('u1', 'regular_user')
|
user1 = self._create_user('u1', 'regular_user')
|
||||||
|
@ -118,7 +115,7 @@ class TestUserDetailApi(DatabaseTestCase):
|
||||||
{'password': 'whatever'}]:
|
{'password': 'whatever'}]:
|
||||||
self.context.request = request
|
self.context.request = request
|
||||||
self.assertRaises(
|
self.assertRaises(
|
||||||
AuthError, self.api.put, self.context, user2.name)
|
errors.AuthError, self.api.put, self.context, user2.name)
|
||||||
|
|
||||||
def test_user_trying_to_become_someone_else(self):
|
def test_user_trying_to_become_someone_else(self):
|
||||||
user1 = self._create_user('u1', 'regular_user')
|
user1 = self._create_user('u1', 'regular_user')
|
||||||
|
@ -127,7 +124,7 @@ class TestUserDetailApi(DatabaseTestCase):
|
||||||
self.context.user = user1
|
self.context.user = user1
|
||||||
self.context.request = {'name': 'u2'}
|
self.context.request = {'name': 'u2'}
|
||||||
self.assertRaises(
|
self.assertRaises(
|
||||||
ValidationError, self.api.put, self.context, 'u1')
|
errors.ValidationError, self.api.put, self.context, 'u1')
|
||||||
|
|
||||||
def test_mods_trying_to_become_admin(self):
|
def test_mods_trying_to_become_admin(self):
|
||||||
user1 = self._create_user('u1', 'mod')
|
user1 = self._create_user('u1', 'mod')
|
||||||
|
@ -136,6 +133,6 @@ class TestUserDetailApi(DatabaseTestCase):
|
||||||
self.context.user = user1
|
self.context.user = user1
|
||||||
self.context.request = {'accessRank': 'admin'}
|
self.context.request = {'accessRank': 'admin'}
|
||||||
self.assertRaises(
|
self.assertRaises(
|
||||||
AuthError, self.api.put, self.context, user1.name)
|
errors.AuthError, self.api.put, self.context, user1.name)
|
||||||
self.assertRaises(
|
self.assertRaises(
|
||||||
AuthError, self.api.put, self.context, user2.name)
|
errors.AuthError, self.api.put, self.context, user2.name)
|
||||||
|
|
|
@ -1,11 +1,11 @@
|
||||||
import unittest
|
import unittest
|
||||||
import sqlalchemy
|
import sqlalchemy
|
||||||
from szurubooru.model import Base
|
from szurubooru import db
|
||||||
|
|
||||||
class DatabaseTestCase(unittest.TestCase):
|
class DatabaseTestCase(unittest.TestCase):
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
engine = sqlalchemy.create_engine('sqlite:///:memory:')
|
engine = sqlalchemy.create_engine('sqlite:///:memory:')
|
||||||
session_maker = sqlalchemy.orm.sessionmaker(bind=engine)
|
session_maker = sqlalchemy.orm.sessionmaker(bind=engine)
|
||||||
self.session = sqlalchemy.orm.scoped_session(session_maker)
|
self.session = sqlalchemy.orm.scoped_session(session_maker)
|
||||||
Base.query = self.session.query_property()
|
db.Base.query = self.session.query_property()
|
||||||
Base.metadata.create_all(bind=engine)
|
db.Base.metadata.create_all(bind=engine)
|
||||||
|
|
|
@ -1,7 +1,5 @@
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from szurubooru import errors
|
from szurubooru import db, errors, search
|
||||||
from szurubooru import model
|
|
||||||
from szurubooru.services import search
|
|
||||||
from szurubooru.tests.database_test_case import DatabaseTestCase
|
from szurubooru.tests.database_test_case import DatabaseTestCase
|
||||||
|
|
||||||
class TestUserSearchExecutor(DatabaseTestCase):
|
class TestUserSearchExecutor(DatabaseTestCase):
|
||||||
|
@ -11,7 +9,7 @@ class TestUserSearchExecutor(DatabaseTestCase):
|
||||||
self.executor = search.SearchExecutor(self.search_config)
|
self.executor = search.SearchExecutor(self.search_config)
|
||||||
|
|
||||||
def _create_user(self, name):
|
def _create_user(self, name):
|
||||||
user = model.User()
|
user = db.User()
|
||||||
user.name = name
|
user.name = name
|
||||||
user.password = 'dummy'
|
user.password = 'dummy'
|
||||||
user.password_salt = 'dummy'
|
user.password_salt = 'dummy'
|
||||||
|
@ -19,7 +17,7 @@ class TestUserSearchExecutor(DatabaseTestCase):
|
||||||
user.email = 'dummy'
|
user.email = 'dummy'
|
||||||
user.access_rank = 'dummy'
|
user.access_rank = 'dummy'
|
||||||
user.creation_time = datetime.now()
|
user.creation_time = datetime.now()
|
||||||
user.avatar_style = model.User.AVATAR_GRAVATAR
|
user.avatar_style = db.User.AVATAR_GRAVATAR
|
||||||
return user
|
return user
|
||||||
|
|
||||||
def _test(self, query, page, expected_count, expected_user_names):
|
def _test(self, query, page, expected_count, expected_user_names):
|
|
@ -1,8 +1,7 @@
|
||||||
import unittest
|
import unittest
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
import szurubooru.util
|
from szurubooru import errors
|
||||||
from szurubooru.util import parse_time_range
|
from szurubooru.util import misc
|
||||||
from szurubooru.errors import ValidationError
|
|
||||||
|
|
||||||
class FakeDatetime(datetime):
|
class FakeDatetime(datetime):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@ -11,33 +10,33 @@ class FakeDatetime(datetime):
|
||||||
|
|
||||||
class TestParseTime(unittest.TestCase):
|
class TestParseTime(unittest.TestCase):
|
||||||
def test_empty(self):
|
def test_empty(self):
|
||||||
self.assertRaises(ValidationError, parse_time_range, '')
|
self.assertRaises(errors.ValidationError, misc.parse_time_range, '')
|
||||||
|
|
||||||
def test_today(self):
|
def test_today(self):
|
||||||
szurubooru.util.datetime.datetime = FakeDatetime
|
misc.datetime.datetime = FakeDatetime
|
||||||
date_min, date_max = parse_time_range('today')
|
date_min, date_max = misc.parse_time_range('today')
|
||||||
self.assertEqual(date_min, datetime(1997, 1, 2, 0, 0, 0))
|
self.assertEqual(date_min, datetime(1997, 1, 2, 0, 0, 0))
|
||||||
self.assertEqual(date_max, datetime(1997, 1, 2, 23, 59, 59))
|
self.assertEqual(date_max, datetime(1997, 1, 2, 23, 59, 59))
|
||||||
|
|
||||||
def test_yesterday(self):
|
def test_yesterday(self):
|
||||||
szurubooru.util.datetime.datetime = FakeDatetime
|
misc.datetime.datetime = FakeDatetime
|
||||||
date_min, date_max = parse_time_range('yesterday')
|
date_min, date_max = misc.parse_time_range('yesterday')
|
||||||
self.assertEqual(date_min, datetime(1997, 1, 1, 0, 0, 0))
|
self.assertEqual(date_min, datetime(1997, 1, 1, 0, 0, 0))
|
||||||
self.assertEqual(date_max, datetime(1997, 1, 1, 23, 59, 59))
|
self.assertEqual(date_max, datetime(1997, 1, 1, 23, 59, 59))
|
||||||
|
|
||||||
def test_year(self):
|
def test_year(self):
|
||||||
date_min, date_max = parse_time_range('1999')
|
date_min, date_max = misc.parse_time_range('1999')
|
||||||
self.assertEqual(date_min, datetime(1999, 1, 1, 0, 0, 0))
|
self.assertEqual(date_min, datetime(1999, 1, 1, 0, 0, 0))
|
||||||
self.assertEqual(date_max, datetime(1999, 12, 31, 23, 59, 59))
|
self.assertEqual(date_max, datetime(1999, 12, 31, 23, 59, 59))
|
||||||
|
|
||||||
def test_month(self):
|
def test_month(self):
|
||||||
for text in ['1999-2', '1999-02']:
|
for text in ['1999-2', '1999-02']:
|
||||||
date_min, date_max = parse_time_range(text)
|
date_min, date_max = misc.parse_time_range(text)
|
||||||
self.assertEqual(date_min, datetime(1999, 2, 1, 0, 0, 0))
|
self.assertEqual(date_min, datetime(1999, 2, 1, 0, 0, 0))
|
||||||
self.assertEqual(date_max, datetime(1999, 2, 28, 23, 59, 59))
|
self.assertEqual(date_max, datetime(1999, 2, 28, 23, 59, 59))
|
||||||
|
|
||||||
def test_day(self):
|
def test_day(self):
|
||||||
for text in ['1999-2-6', '1999-02-6', '1999-2-06', '1999-02-06']:
|
for text in ['1999-2-6', '1999-02-6', '1999-2-06', '1999-02-06']:
|
||||||
date_min, date_max = parse_time_range(text)
|
date_min, date_max = misc.parse_time_range(text)
|
||||||
self.assertEqual(date_min, datetime(1999, 2, 6, 0, 0, 0))
|
self.assertEqual(date_min, datetime(1999, 2, 6, 0, 0, 0))
|
||||||
self.assertEqual(date_max, datetime(1999, 2, 6, 23, 59, 59))
|
self.assertEqual(date_max, datetime(1999, 2, 6, 23, 59, 59))
|
1
server/szurubooru/util/__init__.py
Normal file
1
server/szurubooru/util/__init__.py
Normal file
|
@ -0,0 +1 @@
|
||||||
|
''' Cool functions. '''
|
59
server/szurubooru/util/auth.py
Normal file
59
server/szurubooru/util/auth.py
Normal file
|
@ -0,0 +1,59 @@
|
||||||
|
import hashlib
|
||||||
|
import random
|
||||||
|
from szurubooru import config
|
||||||
|
from szurubooru import errors
|
||||||
|
|
||||||
|
def get_password_hash(salt, password):
|
||||||
|
''' Retrieve new-style password hash. '''
|
||||||
|
digest = hashlib.sha256()
|
||||||
|
digest.update(config.config['basic']['secret'].encode('utf8'))
|
||||||
|
digest.update(salt.encode('utf8'))
|
||||||
|
digest.update(password.encode('utf8'))
|
||||||
|
return digest.hexdigest()
|
||||||
|
|
||||||
|
def get_legacy_password_hash(salt, password):
|
||||||
|
''' Retrieve old-style password hash. '''
|
||||||
|
digest = hashlib.sha1()
|
||||||
|
digest.update(b'1A2/$_4xVa')
|
||||||
|
digest.update(salt.encode('utf8'))
|
||||||
|
digest.update(password.encode('utf8'))
|
||||||
|
return digest.hexdigest()
|
||||||
|
|
||||||
|
def create_password():
|
||||||
|
''' Create an easy-to-remember password. '''
|
||||||
|
alphabet = {
|
||||||
|
'c': list('bcdfghijklmnpqrstvwxyz'),
|
||||||
|
'v': list('aeiou'),
|
||||||
|
'n': list('0123456789'),
|
||||||
|
}
|
||||||
|
pattern = 'cvcvnncvcv'
|
||||||
|
return ''.join(random.choice(alphabet[l]) for l in list(pattern))
|
||||||
|
|
||||||
|
def is_valid_password(user, password):
|
||||||
|
''' Return whether the given password for a given user is valid. '''
|
||||||
|
salt, valid_hash = user.password_salt, user.password_hash
|
||||||
|
possible_hashes = [
|
||||||
|
get_password_hash(salt, password),
|
||||||
|
get_legacy_password_hash(salt, password)
|
||||||
|
]
|
||||||
|
return valid_hash in possible_hashes
|
||||||
|
|
||||||
|
def verify_privilege(user, privilege_name):
|
||||||
|
'''
|
||||||
|
Throw an AuthError if the given user doesn't have given privilege.
|
||||||
|
'''
|
||||||
|
all_ranks = config.config['service']['user_ranks']
|
||||||
|
|
||||||
|
assert privilege_name in config.config['privileges']
|
||||||
|
assert user.access_rank in all_ranks
|
||||||
|
minimal_rank = config.config['privileges'][privilege_name]
|
||||||
|
good_ranks = all_ranks[all_ranks.index(minimal_rank):]
|
||||||
|
if user.access_rank not in good_ranks:
|
||||||
|
raise errors.AuthError('Insufficient privileges to do this.')
|
||||||
|
|
||||||
|
def generate_authentication_token(user):
|
||||||
|
''' Generate nonguessable challenge (e.g. links in password reminder). '''
|
||||||
|
digest = hashlib.sha256()
|
||||||
|
digest.update(config.config['basic']['secret'].encode('utf8'))
|
||||||
|
digest.update(user.password_salt.encode('utf8'))
|
||||||
|
return digest.hexdigest()
|
15
server/szurubooru/util/mailer.py
Normal file
15
server/szurubooru/util/mailer.py
Normal file
|
@ -0,0 +1,15 @@
|
||||||
|
import smtplib
|
||||||
|
import email.mime.text
|
||||||
|
from szurubooru import config
|
||||||
|
|
||||||
|
def send_mail(sender, recipient, subject, body):
|
||||||
|
msg = email.mime.text.MIMEText(body)
|
||||||
|
msg['Subject'] = subject
|
||||||
|
msg['From'] = sender
|
||||||
|
msg['To'] = recipient
|
||||||
|
|
||||||
|
smtp = smtplib.SMTP(
|
||||||
|
config.config['smtp']['host'], int(config.config['smtp']['port']))
|
||||||
|
smtp.login(config.config['smtp']['user'], config.config['smtp']['pass'])
|
||||||
|
smtp.send_message(msg)
|
||||||
|
smtp.quit()
|
|
@ -3,7 +3,7 @@ import re
|
||||||
from szurubooru.errors import ValidationError
|
from szurubooru.errors import ValidationError
|
||||||
|
|
||||||
def is_valid_email(email):
|
def is_valid_email(email):
|
||||||
''' Validates given email address. '''
|
''' Return whether given email address is valid or empty. '''
|
||||||
return not email or re.match(r'^[^@]*@[^@]*\.[^@]*$', email)
|
return not email or re.match(r'^[^@]*@[^@]*\.[^@]*$', email)
|
||||||
|
|
||||||
class dotdict(dict): # pylint: disable=invalid-name
|
class dotdict(dict): # pylint: disable=invalid-name
|
||||||
|
@ -14,7 +14,7 @@ class dotdict(dict): # pylint: disable=invalid-name
|
||||||
__delattr__ = dict.__delitem__
|
__delattr__ = dict.__delitem__
|
||||||
|
|
||||||
def parse_time_range(value, timezone=datetime.timezone(datetime.timedelta())):
|
def parse_time_range(value, timezone=datetime.timezone(datetime.timedelta())):
|
||||||
''' Returns tuple containing min/max time for given text representation. '''
|
''' Return tuple containing min/max time for given text representation. '''
|
||||||
one_day = datetime.timedelta(days=1)
|
one_day = datetime.timedelta(days=1)
|
||||||
one_second = datetime.timedelta(seconds=1)
|
one_second = datetime.timedelta(seconds=1)
|
||||||
|
|
67
server/szurubooru/util/users.py
Normal file
67
server/szurubooru/util/users.py
Normal file
|
@ -0,0 +1,67 @@
|
||||||
|
import re
|
||||||
|
from datetime import datetime
|
||||||
|
from szurubooru import config, db, errors
|
||||||
|
from szurubooru.util import auth, misc
|
||||||
|
|
||||||
|
def create_user(name, password, email):
|
||||||
|
''' Create an user with given parameters and returns it. '''
|
||||||
|
user = db.User()
|
||||||
|
update_name(user, name)
|
||||||
|
update_password(user, password)
|
||||||
|
update_email(user, email)
|
||||||
|
user.access_rank = config.config['service']['default_user_rank']
|
||||||
|
user.creation_time = datetime.now()
|
||||||
|
user.avatar_style = db.User.AVATAR_GRAVATAR
|
||||||
|
return user
|
||||||
|
|
||||||
|
def update_name(user, name):
|
||||||
|
''' Validate and update user's name. '''
|
||||||
|
name = name.strip()
|
||||||
|
name_regex = config.config['service']['user_name_regex']
|
||||||
|
if not re.match(name_regex, name):
|
||||||
|
raise errors.ValidationError(
|
||||||
|
'Name must satisfy regex %r.' % name_regex)
|
||||||
|
user.name = name
|
||||||
|
|
||||||
|
def update_password(user, password):
|
||||||
|
''' Validate and update user's password. '''
|
||||||
|
password_regex = config.config['service']['password_regex']
|
||||||
|
if not re.match(password_regex, password):
|
||||||
|
raise errors.ValidationError(
|
||||||
|
'Password must satisfy regex %r.' % password_regex)
|
||||||
|
user.password_salt = auth.create_password()
|
||||||
|
user.password_hash = auth.get_password_hash(user.password_salt, password)
|
||||||
|
|
||||||
|
def update_email(user, email):
|
||||||
|
''' Validate and update user's email. '''
|
||||||
|
email = email.strip() or None
|
||||||
|
if not misc.is_valid_email(email):
|
||||||
|
raise errors.ValidationError(
|
||||||
|
'%r is not a vaild email address.' % email)
|
||||||
|
user.email = email
|
||||||
|
|
||||||
|
def update_rank(user, rank, authenticated_user):
|
||||||
|
rank = rank.strip()
|
||||||
|
available_access_ranks = config.config['service']['user_ranks']
|
||||||
|
if not rank in available_access_ranks:
|
||||||
|
raise errors.ValidationError(
|
||||||
|
'Bad access rank. Valid access ranks: %r' % available_access_ranks)
|
||||||
|
if available_access_ranks.index(authenticated_user.access_rank) \
|
||||||
|
< available_access_ranks.index(rank):
|
||||||
|
raise errors.AuthError('Trying to set higher access rank than one has')
|
||||||
|
user.access_rank = rank
|
||||||
|
|
||||||
|
def bump_login_time(user):
|
||||||
|
''' Update user's login time to current date. '''
|
||||||
|
user.last_login_time = datetime.now()
|
||||||
|
|
||||||
|
def reset_password(user):
|
||||||
|
''' Reset password for an user. '''
|
||||||
|
password = auth.create_password()
|
||||||
|
user.password_salt = auth.create_password()
|
||||||
|
user.password_hash = auth.get_password_hash(user.password_salt, password)
|
||||||
|
return password
|
||||||
|
|
||||||
|
def get_by_name(session, name):
|
||||||
|
''' Retrieve an user by its name. '''
|
||||||
|
return session.query(db.User).filter_by(name=name).first()
|
Loading…
Reference in a new issue