server/general: be more pythonic

This commit is contained in:
rr- 2016-04-03 22:03:58 +02:00
parent 7f4708c696
commit 219ab7c2c3
36 changed files with 335 additions and 428 deletions

View file

@ -1,5 +1,5 @@
import hashlib
from szurubooru import errors
from szurubooru import config, errors
from szurubooru.util import auth, mailer, users
from szurubooru.api.base_api import BaseApi
MAIL_SUBJECT = 'Password reset for {name}'
@ -9,45 +9,35 @@ MAIL_BODY = \
'Otherwise, please ignore this email.'
class PasswordReminderApi(BaseApi):
def __init__(self, config, mailer, user_service):
super().__init__()
self._config = config
self._mailer = mailer
self._user_service = user_service
def get(self, context, user_name):
user = self._user_service.get_by_name(context.session, user_name)
''' Send a mail with secure token to the correlated user. '''
user = users.get_by_name(context.session, user_name)
if not user:
raise errors.NotFoundError('User %r not found.' % user_name)
if not user.email:
raise errors.ValidationError(
'User %r hasn\'t supplied email. Cannot reset password.' % user_name)
token = self._generate_authentication_token(user)
token = auth.generate_authentication_token(user)
url = '%s/password-reset/%s' % (
self._config['basic']['base_url'].rstrip('/'), token)
self._mailer.send(
'noreply@%s' % self._config['basic']['name'],
config.config['basic']['base_url'].rstrip('/'), token)
mailer.send_mail(
'noreply@%s' % config.config['basic']['name'],
user.email,
MAIL_SUBJECT.format(name=self._config['basic']['name']),
MAIL_BODY.format(name=self._config['basic']['name'], url=url))
MAIL_SUBJECT.format(name=config.config['basic']['name']),
MAIL_BODY.format(name=config.config['basic']['name'], url=url))
return {}
def post(self, context, user_name):
user = self._user_service.get_by_name(context.session, user_name)
''' Verify token from mail, generate a new password and return it. '''
user = users.get_by_name(context.session, user_name)
if not user:
raise errors.NotFoundError('User %r not found.' % user_name)
good_token = self._generate_authentication_token(user)
good_token = auth.generate_authentication_token(user)
if not 'token' in context.request:
raise errors.ValidationError('Missing password reset token.')
token = context.request['token']
if token != good_token:
raise errors.ValidationError('Invalid password reset token.')
new_password = self._user_service.reset_password(user)
new_password = users.reset_password(user)
context.session.commit()
return {'password': new_password}
def _generate_authentication_token(self, user):
digest = hashlib.sha256()
digest.update(self._config['basic']['secret'].encode('utf8'))
digest.update(user.password_salt.encode('utf8'))
return digest.hexdigest()

View file

@ -1,9 +1,7 @@
import re
import sqlalchemy
from szurubooru import errors
from szurubooru import util
from szurubooru import errors, search
from szurubooru.util import auth, users
from szurubooru.api.base_api import BaseApi
from szurubooru.services import search
def _serialize_user(authenticated_user, user):
ret = {
@ -21,29 +19,27 @@ def _serialize_user(authenticated_user, user):
class UserListApi(BaseApi):
''' API for lists of users. '''
def __init__(self, auth_service, user_service):
def __init__(self):
super().__init__()
self._auth_service = auth_service
self._user_service = user_service
self._search_executor = search.SearchExecutor(search.UserSearchConfig())
def get(self, context):
''' Retrieves a list of users. '''
self._auth_service.verify_privilege(context.user, 'users:list')
''' Retrieve a list of users. '''
auth.verify_privilege(context.user, 'users:list')
query = context.get_param_as_string('query')
page = context.get_param_as_int('page', 1)
count, users = self._search_executor.execute(context.session, query, page)
count, user_list = self._search_executor.execute(context.session, query, page)
return {
'query': query,
'page': page,
'page_size': self._search_executor.page_size,
'total': count,
'users': [_serialize_user(context.user, user) for user in users],
'users': [_serialize_user(context.user, user) for user in user_list],
}
def post(self, context):
''' Creates a new user. '''
self._auth_service.verify_privilege(context.user, 'users:create')
''' Create a new user. '''
auth.verify_privilege(context.user, 'users:create')
try:
name = context.request['name'].strip()
@ -52,9 +48,9 @@ class UserListApi(BaseApi):
except KeyError as ex:
raise errors.ValidationError('Field %r not found.' % ex.args[0])
user = self._user_service.create_user(
context.session, name, password, email)
user = users.create_user(name, password, email)
try:
context.session.add(user)
context.session.commit()
except sqlalchemy.exc.IntegrityError:
raise errors.IntegrityError('User %r already exists.' % name)
@ -63,26 +59,17 @@ class UserListApi(BaseApi):
class UserDetailApi(BaseApi):
''' API for individual users. '''
def __init__(self, config, auth_service, password_service, user_service):
super().__init__()
self._available_access_ranks = config['service']['user_ranks']
self._name_regex = config['service']['user_name_regex']
self._password_regex = config['service']['password_regex']
self._password_service = password_service
self._auth_service = auth_service
self._user_service = user_service
def get(self, context, user_name):
''' Retrieves an user. '''
self._auth_service.verify_privilege(context.user, 'users:view')
user = self._user_service.get_by_name(context.session, user_name)
''' Retrieve an user. '''
auth.verify_privilege(context.user, 'users:view')
user = users.get_by_name(context.session, user_name)
if not user:
raise errors.NotFoundError('User %r not found.' % user_name)
return {'user': _serialize_user(context.user, user)}
def put(self, context, user_name):
''' Updates an existing user. '''
user = self._user_service.get_by_name(context.session, user_name)
''' Update an existing user. '''
user = users.get_by_name(context.session, user_name)
if not user:
raise errors.NotFoundError('User %r not found.' % user_name)
@ -92,53 +79,26 @@ class UserDetailApi(BaseApi):
infix = 'any'
if 'name' in context.request:
self._auth_service.verify_privilege(
context.user, 'users:edit:%s:name' % infix)
name = context.request['name'].strip()
if not re.match(self._name_regex, name):
raise errors.ValidationError(
'Name must satisfy regex %r.' % self._name_regex)
user.name = name
auth.verify_privilege(context.user, 'users:edit:%s:name' % infix)
users.update_name(user, context.request['name'])
if 'password' in context.request:
password = context.request['password']
self._auth_service.verify_privilege(
context.user, 'users:edit:%s:pass' % infix)
if not re.match(self._password_regex, password):
raise errors.ValidationError(
'Password must satisfy regex %r.' % self._password_regex)
user.password_salt = self._password_service.create_password()
user.password_hash = self._password_service.get_password_hash(
user.password_salt, password)
auth.verify_privilege(context.user, 'users:edit:%s:pass' % infix)
users.update_password(user, context.request['password'])
if 'email' in context.request:
self._auth_service.verify_privilege(
context.user, 'users:edit:%s:email' % infix)
email = context.request['email'].strip() or None
if not util.is_valid_email(email):
raise errors.ValidationError(
'%r is not a vaild email address.' % email)
user.email = email
auth.verify_privilege(context.user, 'users:edit:%s:email' % infix)
users.update_email(user, context.request['email'])
if 'accessRank' in context.request:
self._auth_service.verify_privilege(
context.user, 'users:edit:%s:rank' % infix)
rank = context.request['accessRank'].strip()
if not rank in self._available_access_ranks:
raise errors.ValidationError(
'Bad access rank. Valid access ranks: %r' \
% self._available_access_ranks)
if self._available_access_ranks.index(context.user.access_rank) \
< self._available_access_ranks.index(rank):
raise errors.AuthError(
'Trying to set higher access rank than one has')
user.access_rank = rank
auth.verify_privilege(context.user, 'users:edit:%s:rank' % infix)
users.update_rank(user, context.request['accessRank'], context.user)
# TODO: avatar
try:
context.session.commit()
except sqlalchemy.exc.IntegrityError:
raise errors.IntegrityError('User %r already exists.' % name)
raise errors.IntegrityError('User %r already exists.' % user.name)
return {'user': _serialize_user(context.user, user)}

View file

@ -3,16 +3,11 @@
import falcon
import sqlalchemy
import sqlalchemy.orm
import szurubooru.api
import szurubooru.config
import szurubooru.errors
import szurubooru.middleware
import szurubooru.services
import szurubooru.services.search
import szurubooru.util
from szurubooru import api, config, errors, middleware
from szurubooru.util import misc
class _CustomRequest(falcon.Request):
context_type = szurubooru.util.dotdict
context_type = misc.dotdict
def get_param_as_string(self, name, required=False, store=None, default=None):
params = self._params
@ -45,47 +40,37 @@ def _on_not_found_error(ex, _request, _response, _params):
raise falcon.HTTPNotFound(title='Not found', description=str(ex))
def create_app():
''' Creates a WSGI compatible App object. '''
config = szurubooru.config.Config()
''' Create a WSGI compatible App object. '''
engine = sqlalchemy.create_engine(
'{schema}://{user}:{password}@{host}:{port}/{name}'.format(
schema=config['database']['schema'],
user=config['database']['user'],
password=config['database']['pass'],
host=config['database']['host'],
port=config['database']['port'],
name=config['database']['name']))
schema=config.config['database']['schema'],
user=config.config['database']['user'],
password=config.config['database']['pass'],
host=config.config['database']['host'],
port=config.config['database']['port'],
name=config.config['database']['name']))
session_maker = sqlalchemy.orm.sessionmaker(bind=engine)
scoped_session = sqlalchemy.orm.scoped_session(session_maker)
# TODO: is there a better way?
mailer = szurubooru.services.Mailer(config)
password_service = szurubooru.services.PasswordService(config)
auth_service = szurubooru.services.AuthService(config, password_service)
user_service = szurubooru.services.UserService(config, password_service)
user_list_api = szurubooru.api.UserListApi(auth_service, user_service)
user_detail_api = szurubooru.api.UserDetailApi(
config, auth_service, password_service, user_service)
password_reminder_api = szurubooru.api.PasswordReminderApi(
config, mailer, user_service)
app = falcon.API(
request_type=_CustomRequest,
middleware=[
szurubooru.middleware.ImbueContext(),
szurubooru.middleware.RequireJson(),
szurubooru.middleware.JsonTranslator(),
szurubooru.middleware.DbSession(scoped_session),
szurubooru.middleware.Authenticator(auth_service, user_service),
middleware.ImbueContext(),
middleware.RequireJson(),
middleware.JsonTranslator(),
middleware.DbSession(scoped_session),
middleware.Authenticator(),
])
app.add_error_handler(szurubooru.errors.AuthError, _on_auth_error)
app.add_error_handler(szurubooru.errors.IntegrityError, _on_integrity_error)
app.add_error_handler(szurubooru.errors.ValidationError, _on_validation_error)
app.add_error_handler(szurubooru.errors.SearchError, _on_search_error)
app.add_error_handler(szurubooru.errors.NotFoundError, _on_not_found_error)
user_list_api = api.UserListApi()
user_detail_api = api.UserDetailApi()
password_reminder_api = api.PasswordReminderApi()
app.add_error_handler(errors.AuthError, _on_auth_error)
app.add_error_handler(errors.IntegrityError, _on_integrity_error)
app.add_error_handler(errors.ValidationError, _on_validation_error)
app.add_error_handler(errors.SearchError, _on_search_error)
app.add_error_handler(errors.NotFoundError, _on_not_found_error)
app.add_route('/users/', user_list_api)
app.add_route('/user/{user_name}', user_detail_api)

View file

@ -1,6 +1,6 @@
import os
import configobj
import szurubooru.errors
from szurubooru import errors
class Config(object):
''' INI config parser and container. '''
@ -15,20 +15,22 @@ class Config(object):
def _validate(self):
'''
Checks whether config.ini doesn't contain errors that might prove
Check whether config.ini doesn't contain errors that might prove
lethal at runtime.
'''
all_ranks = self['service']['user_ranks']
for privilege, rank in self['privileges'].items():
if rank not in all_ranks:
raise szurubooru.errors.ConfigError(
raise errors.ConfigError(
'Rank %r for privilege %r is missing from user_ranks' % (
rank, privilege))
for rank in ['anonymous', 'admin', 'nobody']:
if rank not in all_ranks:
raise szurubooru.errors.ConfigError(
raise errors.ConfigError(
'Fixed rank %r is missing from user_ranks' % rank)
if self['service']['default_user_rank'] not in all_ranks:
raise szurubooru.errors.ConfigError(
raise errors.ConfigError(
'Default rank %r is missing from user_ranks' % (
self['service']['default_user_rank']))
config = Config() # pylint: disable=invalid-name

View file

@ -0,0 +1,4 @@
''' Database layer. '''
from szurubooru.db.base import Base
from szurubooru.db.user import User

View file

@ -1,7 +1,5 @@
# pylint: disable=too-many-instance-attributes,too-few-public-methods
import sqlalchemy as sa
from szurubooru.model.base import Base
from szurubooru.db.base import Base
class User(Base):
__tablename__ = 'user'

View file

@ -1,24 +1,20 @@
import base64
import falcon
from szurubooru import errors
from szurubooru import model
from szurubooru import db, errors
from szurubooru.util import auth, users
class Authenticator(object):
'''
Authenticates every request and puts information on active user in the
Authenticates every request and put information on active user in the
request context.
'''
def __init__(self, auth_service, user_service):
self._auth_service = auth_service
self._user_service = user_service
def process_request(self, request, _response):
''' Executed before passing the request to the API. '''
''' Bind the user to request. Update last login time if needed. '''
request.context.user = self._get_user(request)
if request.get_param_as_bool('bump-login') \
and request.context.user.user_id:
self._user_service.bump_login_time(request.context.user)
users.bump_login_time(request.context.user)
request.context.session.commit()
def _get_user(self, request):
@ -27,15 +23,12 @@ class Authenticator(object):
try:
auth_type, user_and_password = request.auth.split(' ', 1)
if auth_type.lower() != 'basic':
raise falcon.HTTPBadRequest(
'Invalid authentication type',
'Only basic authorization is supported.')
username, password = base64.decodebytes(
user_and_password.encode('ascii')).decode('utf8').split(':')
return self._authenticate(
request.context.session, username, password)
except ValueError as err:
@ -46,16 +39,16 @@ class Authenticator(object):
msg.format(request.auth, str(err)))
def _authenticate(self, session, username, password):
''' Tries to authenticate user. Throws AuthError for invalid users. '''
user = self._user_service.get_by_name(session, username)
''' Try to authenticate user. Throw AuthError for invalid users. '''
user = users.get_by_name(session, username)
if not user:
raise errors.AuthError('No such user.')
if not self._auth_service.is_valid_password(user, password):
if not auth.is_valid_password(user, password):
raise errors.AuthError('Invalid password.')
return user
def _create_anonymous_user(self):
user = model.User()
user = db.User()
user.name = None
user.access_rank = 'anonymous'
user.password = None

View file

@ -5,12 +5,8 @@ class DbSession(object):
self._session_factory = session_factory
def process_request(self, request, _response):
''' Executed before passing the request to the API. '''
request.context.session = self._session_factory()
def process_response(self, request, _response, _resource):
'''
Executed before passing the response to falcon.
Any commits to database need to happen explicitly in the API layer.
'''
# any commits need to happen explicitly in the API layer.
request.context.session.close()

View file

@ -16,7 +16,6 @@ class JsonTranslator(object):
'''
def process_request(self, request, _response):
''' Executed before passing the request to the API. '''
if request.content_length in (None, 0):
return
@ -36,7 +35,6 @@ class JsonTranslator(object):
'JSON was incorrect or not encoded as UTF-8.')
def process_response(self, request, response, _resource):
''' Executed before passing the response to falcon. '''
if 'result' not in request.context:
return
response.body = json.dumps(

View file

@ -4,7 +4,6 @@ class RequireJson(object):
''' Sanitizes requests so that only JSON is accepted. '''
def process_request(self, request, _response):
''' Executed before passing the request to the API. '''
if not request.client_accepts_json:
raise falcon.HTTPNotAcceptable(
'This API only supports responses encoded as JSON.')

View file

@ -9,13 +9,13 @@ import logging.config
dir_to_self = os.path.dirname(os.path.realpath(__file__))
sys.path.append(os.path.join(dir_to_self, *[os.pardir] * 2))
import szurubooru.model.base
import szurubooru.db.base
import szurubooru.config
alembic_config = alembic.context.config
logging.config.fileConfig(alembic_config.config_file_name)
szuru_config = szurubooru.config.Config()
szuru_config = szurubooru.config.config
alembic_config.set_main_option(
'sqlalchemy.url',
'{schema}://{user}:{password}@{host}:{port}/{name}'.format(
@ -26,7 +26,7 @@ alembic_config.set_main_option(
port=szuru_config['database']['port'],
name=szuru_config['database']['name']))
target_metadata = szurubooru.model.Base.metadata
target_metadata = szurubooru.db.Base.metadata
def run_migrations_offline():
'''

View file

@ -1,6 +0,0 @@
'''
Database models.
'''
from szurubooru.model.base import Base
from szurubooru.model.user import User

View file

@ -0,0 +1,4 @@
''' Search parsers and services. '''
from szurubooru.search.search_executor import SearchExecutor
from szurubooru.search.user_search_config import UserSearchConfig

View file

@ -1,11 +1,11 @@
import sqlalchemy
import szurubooru.errors
from szurubooru import util
from szurubooru.services.search import criteria
from szurubooru.util import misc
from szurubooru.search import criteria
def _apply_criterion_to_column(
column, query, criterion, allow_composite=True, allow_ranged=True):
''' Decorates SQLAlchemy filter on given column using supplied criterion. '''
''' Decorate SQLAlchemy filter on given column using supplied criterion. '''
if isinstance(criterion, criteria.StringSearchCriterion):
expr = column == criterion.value
if criterion.negated:
@ -32,11 +32,11 @@ def _apply_criterion_to_column(
def _apply_date_criterion_to_column(column, query, criterion):
'''
Decorates SQLAlchemy filter on given column using supplied criterion.
Parses the datetime inside the criterion.
Decorate SQLAlchemy filter on given column using supplied criterion.
Parse the datetime inside the criterion.
'''
if isinstance(criterion, criteria.StringSearchCriterion):
min_date, max_date = util.parse_time_range(criterion.value)
min_date, max_date = misc.parse_time_range(criterion.value)
expr = column.between(min_date, max_date)
if criterion.negated:
expr = ~expr
@ -44,7 +44,7 @@ def _apply_date_criterion_to_column(column, query, criterion):
elif isinstance(criterion, criteria.ArraySearchCriterion):
expr = sqlalchemy.sql.false()
for value in criterion.values:
min_date, max_date = util.parse_time_range(value)
min_date, max_date = misc.parse_time_range(value)
expr = expr | column.between(min_date, max_date)
if criterion.negated:
expr = ~expr
@ -52,14 +52,14 @@ def _apply_date_criterion_to_column(column, query, criterion):
elif isinstance(criterion, criteria.RangedSearchCriterion):
assert criterion.min_value or criterion.max_value
if criterion.min_value and criterion.max_value:
min_date = util.parse_time_range(criterion.min_value)[0]
max_date = util.parse_time_range(criterion.max_value)[1]
min_date = misc.parse_time_range(criterion.min_value)[0]
max_date = misc.parse_time_range(criterion.max_value)[1]
expr = column.between(min_date, max_date)
elif criterion.min_value:
min_date = util.parse_time_range(criterion.min_value)[0]
min_date = misc.parse_time_range(criterion.min_value)[0]
expr = column >= min_date
elif criterion.max_value:
max_date = util.parse_time_range(criterion.max_value)[1]
max_date = misc.parse_time_range(criterion.max_value)[1]
expr = column <= max_date
if criterion.negated:
expr = ~expr

View file

@ -1,19 +1,17 @@
''' Exports SearchExecutor. '''
import re
import sqlalchemy
from szurubooru import errors
from szurubooru.services.search import criteria
from szurubooru.search import criteria
class SearchExecutor(object):
ORDER_DESC = 1
ORDER_ASC = 2
'''
Class for search parsing and execution. Handles plaintext parsing and
delegates sqlalchemy filter decoration to SearchConfig instances.
'''
ORDER_DESC = 1
ORDER_ASC = 2
def __init__(self, search_config):
self.page_size = 100
self._search_config = search_config

View file

@ -0,0 +1,42 @@
from sqlalchemy.sql.expression import func
from szurubooru import db
from szurubooru.search.base_search_config import BaseSearchConfig
class UserSearchConfig(BaseSearchConfig):
''' Executes searches related to the users. '''
def create_query(self, session):
return session.query(db.User)
@property
def anonymous_filter(self):
return self._create_basic_filter(db.User.name, allow_ranged=False)
@property
def special_filters(self):
return {}
@property
def named_filters(self):
return {
'name': self._create_basic_filter(db.User.name, allow_ranged=False),
'creation_date': self._create_date_filter(db.User.creation_time),
'creation_time': self._create_date_filter(db.User.creation_time),
'last_login_date': self._create_date_filter(db.User.last_login_time),
'last_login_time': self._create_date_filter(db.User.last_login_time),
'login_date': self._create_date_filter(db.User.last_login_time),
'login_time': self._create_date_filter(db.User.last_login_time),
}
@property
def order_columns(self):
return {
'random': func.random(),
'name': db.User.name,
'creation_date': db.User.creation_time,
'creation_time': db.User.creation_time,
'last_login_date': db.User.last_login_time,
'last_login_time': db.User.last_login_time,
'login_date': db.User.last_login_time,
'login_time': db.User.last_login_time,
}

View file

@ -1,9 +0,0 @@
'''
Middle layer between REST API and database.
All the business logic goes here.
'''
from szurubooru.services.mailer import Mailer
from szurubooru.services.auth_service import AuthService
from szurubooru.services.user_service import UserService
from szurubooru.services.password_service import PasswordService

View file

@ -1,28 +0,0 @@
from szurubooru import errors
class AuthService(object):
def __init__(self, config, password_service):
self._config = config
self._password_service = password_service
def is_valid_password(self, user, password):
''' Returns whether the given password for a given user is valid. '''
salt, valid_hash = user.password_salt, user.password_hash
possible_hashes = [
self._password_service.get_password_hash(salt, password),
self._password_service.get_legacy_password_hash(salt, password)
]
return valid_hash in possible_hashes
def verify_privilege(self, user, privilege_name):
'''
Throws an AuthError if the given user doesn't have given privilege.
'''
all_ranks = self._config['service']['user_ranks']
assert privilege_name in self._config['privileges']
assert user.access_rank in all_ranks
minimal_rank = self._config['privileges'][privilege_name]
good_ranks = all_ranks[all_ranks.index(minimal_rank):]
if user.access_rank not in good_ranks:
raise errors.AuthError('Insufficient privileges to do this.')

View file

@ -1,19 +0,0 @@
import smtplib
import email.mime.text
class Mailer(object):
def __init__(self, config):
self._config = config
def send(self, sender, recipient, subject, body):
msg = email.mime.text.MIMEText(body)
msg['Subject'] = subject
msg['From'] = sender
msg['To'] = recipient
smtp = smtplib.SMTP(
self._config['smtp']['host'],
int(self._config['smtp']['port']))
smtp.login(self._config['smtp']['user'], self._config['smtp']['pass'])
smtp.send_message(msg)
smtp.quit()

View file

@ -1,34 +0,0 @@
import hashlib
import random
class PasswordService(object):
''' Stateless utilities for passwords '''
def __init__(self, config):
self._config = config
def get_password_hash(self, salt, password):
''' Retrieves new-style password hash. '''
digest = hashlib.sha256()
digest.update(self._config['basic']['secret'].encode('utf8'))
digest.update(salt.encode('utf8'))
digest.update(password.encode('utf8'))
return digest.hexdigest()
def get_legacy_password_hash(self, salt, password):
''' Retrieves old-style password hash. '''
digest = hashlib.sha1()
digest.update(b'1A2/$_4xVa')
digest.update(salt.encode('utf8'))
digest.update(password.encode('utf8'))
return digest.hexdigest()
def create_password(self):
''' Creates an easy-to-remember password. '''
alphabet = {
'c': list('bcdfghijklmnpqrstvwxyz'),
'v': list('aeiou'),
'n': list('0123456789'),
}
pattern = 'cvcvnncvcv'
return ''.join(random.choice(alphabet[l]) for l in list(pattern))

View file

@ -1,2 +0,0 @@
from szurubooru.services.search.search_executor import SearchExecutor
from szurubooru.services.search.user_search_config import UserSearchConfig

View file

@ -1,44 +0,0 @@
''' Exports UserSearchConfig. '''
from sqlalchemy.sql.expression import func
from szurubooru.model import User
from szurubooru.services.search.base_search_config import BaseSearchConfig
class UserSearchConfig(BaseSearchConfig):
''' Executes searches related to the users. '''
def create_query(self, session):
return session.query(User)
@property
def anonymous_filter(self):
return self._create_basic_filter(User.name, allow_ranged=False)
@property
def special_filters(self):
return {}
@property
def named_filters(self):
return {
'name': self._create_basic_filter(User.name, allow_ranged=False),
'creation_date': self._create_date_filter(User.creation_time),
'creation_time': self._create_date_filter(User.creation_time),
'last_login_date': self._create_date_filter(User.last_login_time),
'last_login_time': self._create_date_filter(User.last_login_time),
'login_date': self._create_date_filter(User.last_login_time),
'login_time': self._create_date_filter(User.last_login_time),
}
@property
def order_columns(self):
return {
'random': func.random(),
'name': User.name,
'creation_date': User.creation_time,
'creation_time': User.creation_time,
'last_login_date': User.last_login_time,
'last_login_time': User.last_login_time,
'login_date': User.last_login_time,
'login_time': User.last_login_time,
}

View file

@ -1,56 +0,0 @@
import re
from datetime import datetime
from szurubooru import errors
from szurubooru import model
from szurubooru import util
class UserService(object):
''' User management '''
def __init__(self, config, password_service):
self._config = config
self._password_service = password_service
self._name_regex = self._config['service']['user_name_regex']
self._password_regex = self._config['service']['password_regex']
def create_user(self, session, name, password, email):
''' Creates an user with given parameters and returns it. '''
if not re.match(self._name_regex, name):
raise errors.ValidationError(
'Name must satisfy regex %r.' % self._name_regex)
if not re.match(self._password_regex, password):
raise errors.ValidationError(
'Password must satisfy regex %r.' % self._password_regex)
if not util.is_valid_email(email):
raise errors.ValidationError(
'%r is not a vaild email address.' % email)
user = model.User()
user.name = name
user.password_salt = self._password_service.create_password()
user.password_hash = self._password_service.get_password_hash(
user.password_salt, password)
user.email = email or None
user.access_rank = self._config['service']['default_user_rank']
user.creation_time = datetime.now()
user.avatar_style = model.User.AVATAR_GRAVATAR
session.add(user)
return user
def bump_login_time(self, user):
user.last_login_time = datetime.now()
def reset_password(self, user):
password = self._password_service.create_password()
user.password_salt = self._password_service.create_password()
user.password_hash = self._password_service.get_password_hash(
user.password_salt, password)
return password
def get_by_name(self, session, name):
''' Retrieves an user by its name. '''
return session.query(model.User).filter_by(name=name).first()

View file

@ -1,15 +1,12 @@
from datetime import datetime
import szurubooru.services
from szurubooru.api.user_api import UserDetailApi
from szurubooru.errors import AuthError, ValidationError
from szurubooru.model.user import User
from szurubooru import api, db, errors, config
from szurubooru.util import auth, misc
from szurubooru.tests.database_test_case import DatabaseTestCase
from szurubooru.util import dotdict
class TestUserDetailApi(DatabaseTestCase):
def setUp(self):
super().setUp()
config = {
config_mock = {
'basic': {
'secret': '',
},
@ -30,18 +27,18 @@ class TestUserDetailApi(DatabaseTestCase):
'users:edit:any:rank': 'admin',
}
}
password_service = szurubooru.services.PasswordService(config)
auth_service = szurubooru.services.AuthService(config, password_service)
user_service = szurubooru.services.UserService(config, password_service)
self.auth_service = auth_service
self.api = UserDetailApi(
config, auth_service, password_service, user_service)
self.context = dotdict()
self.old_config = config.config
config.config = config_mock
self.api = api.UserDetailApi()
self.context = misc.dotdict()
self.context.session = self.session
self.context.request = {}
def tearDown(self):
config.config = self.old_config
def _create_user(self, name, rank='admin'):
user = User()
user = db.User()
user.name = name
user.password = 'dummy'
user.password_salt = 'dummy'
@ -49,7 +46,7 @@ class TestUserDetailApi(DatabaseTestCase):
user.email = 'dummy'
user.access_rank = rank
user.creation_time = datetime.now()
user.avatar_style = User.AVATAR_GRAVATAR
user.avatar_style = db.User.AVATAR_GRAVATAR
return user
def test_updating_nothing(self):
@ -57,7 +54,7 @@ class TestUserDetailApi(DatabaseTestCase):
self.session.add(admin_user)
self.context.user = admin_user
self.api.put(self.context, 'u1')
admin_user = self.session.query(User).filter_by(name='u1').one()
admin_user = self.session.query(db.User).filter_by(name='u1').one()
self.assertEqual(admin_user.name, 'u1')
self.assertEqual(admin_user.email, 'dummy')
self.assertEqual(admin_user.access_rank, 'admin')
@ -69,16 +66,16 @@ class TestUserDetailApi(DatabaseTestCase):
self.context.request = {
'name': 'chewie',
'email': 'asd@asd.asd',
'password': 'valid',
'password': 'oks',
'accessRank': 'mod',
}
self.api.put(self.context, 'u1')
admin_user = self.session.query(User).filter_by(name='chewie').one()
admin_user = self.session.query(db.User).filter_by(name='chewie').one()
self.assertEqual(admin_user.name, 'chewie')
self.assertEqual(admin_user.email, 'asd@asd.asd')
self.assertEqual(admin_user.access_rank, 'mod')
self.assertTrue(self.auth_service.is_valid_password(admin_user, 'valid'))
self.assertFalse(self.auth_service.is_valid_password(admin_user, 'invalid'))
self.assertTrue(auth.is_valid_password(admin_user, 'oks'))
self.assertFalse(auth.is_valid_password(admin_user, 'invalid'))
def test_removing_email(self):
admin_user = self._create_user('u1', 'admin')
@ -86,7 +83,7 @@ class TestUserDetailApi(DatabaseTestCase):
self.context.user = admin_user
self.context.request = {'email': ''}
self.api.put(self.context, 'u1')
admin_user = self.session.query(User).filter_by(name='u1').one()
admin_user = self.session.query(db.User).filter_by(name='u1').one()
self.assertEqual(admin_user.email, None)
def test_invalid_inputs(self):
@ -95,16 +92,16 @@ class TestUserDetailApi(DatabaseTestCase):
self.context.user = admin_user
self.context.request = {'name': '.'}
self.assertRaises(
ValidationError, self.api.put, self.context, 'u1')
errors.ValidationError, self.api.put, self.context, 'u1')
self.context.request = {'password': '.'}
self.assertRaises(
ValidationError, self.api.put, self.context, 'u1')
errors.ValidationError, self.api.put, self.context, 'u1')
self.context.request = {'accessRank': '.'}
self.assertRaises(
ValidationError, self.api.put, self.context, 'u1')
errors.ValidationError, self.api.put, self.context, 'u1')
self.context.request = {'email': '.'}
self.assertRaises(
ValidationError, self.api.put, self.context, 'u1')
errors.ValidationError, self.api.put, self.context, 'u1')
def test_user_trying_to_update_someone_else(self):
user1 = self._create_user('u1', 'regular_user')
@ -118,7 +115,7 @@ class TestUserDetailApi(DatabaseTestCase):
{'password': 'whatever'}]:
self.context.request = request
self.assertRaises(
AuthError, self.api.put, self.context, user2.name)
errors.AuthError, self.api.put, self.context, user2.name)
def test_user_trying_to_become_someone_else(self):
user1 = self._create_user('u1', 'regular_user')
@ -127,7 +124,7 @@ class TestUserDetailApi(DatabaseTestCase):
self.context.user = user1
self.context.request = {'name': 'u2'}
self.assertRaises(
ValidationError, self.api.put, self.context, 'u1')
errors.ValidationError, self.api.put, self.context, 'u1')
def test_mods_trying_to_become_admin(self):
user1 = self._create_user('u1', 'mod')
@ -136,6 +133,6 @@ class TestUserDetailApi(DatabaseTestCase):
self.context.user = user1
self.context.request = {'accessRank': 'admin'}
self.assertRaises(
AuthError, self.api.put, self.context, user1.name)
errors.AuthError, self.api.put, self.context, user1.name)
self.assertRaises(
AuthError, self.api.put, self.context, user2.name)
errors.AuthError, self.api.put, self.context, user2.name)

View file

@ -1,11 +1,11 @@
import unittest
import sqlalchemy
from szurubooru.model import Base
from szurubooru import db
class DatabaseTestCase(unittest.TestCase):
def setUp(self):
engine = sqlalchemy.create_engine('sqlite:///:memory:')
session_maker = sqlalchemy.orm.sessionmaker(bind=engine)
self.session = sqlalchemy.orm.scoped_session(session_maker)
Base.query = self.session.query_property()
Base.metadata.create_all(bind=engine)
db.Base.query = self.session.query_property()
db.Base.metadata.create_all(bind=engine)

View file

@ -1,7 +1,5 @@
from datetime import datetime
from szurubooru import errors
from szurubooru import model
from szurubooru.services import search
from szurubooru import db, errors, search
from szurubooru.tests.database_test_case import DatabaseTestCase
class TestUserSearchExecutor(DatabaseTestCase):
@ -11,7 +9,7 @@ class TestUserSearchExecutor(DatabaseTestCase):
self.executor = search.SearchExecutor(self.search_config)
def _create_user(self, name):
user = model.User()
user = db.User()
user.name = name
user.password = 'dummy'
user.password_salt = 'dummy'
@ -19,7 +17,7 @@ class TestUserSearchExecutor(DatabaseTestCase):
user.email = 'dummy'
user.access_rank = 'dummy'
user.creation_time = datetime.now()
user.avatar_style = model.User.AVATAR_GRAVATAR
user.avatar_style = db.User.AVATAR_GRAVATAR
return user
def _test(self, query, page, expected_count, expected_user_names):

View file

@ -1,8 +1,7 @@
import unittest
from datetime import datetime
import szurubooru.util
from szurubooru.util import parse_time_range
from szurubooru.errors import ValidationError
from szurubooru import errors
from szurubooru.util import misc
class FakeDatetime(datetime):
@staticmethod
@ -11,33 +10,33 @@ class FakeDatetime(datetime):
class TestParseTime(unittest.TestCase):
def test_empty(self):
self.assertRaises(ValidationError, parse_time_range, '')
self.assertRaises(errors.ValidationError, misc.parse_time_range, '')
def test_today(self):
szurubooru.util.datetime.datetime = FakeDatetime
date_min, date_max = parse_time_range('today')
misc.datetime.datetime = FakeDatetime
date_min, date_max = misc.parse_time_range('today')
self.assertEqual(date_min, datetime(1997, 1, 2, 0, 0, 0))
self.assertEqual(date_max, datetime(1997, 1, 2, 23, 59, 59))
def test_yesterday(self):
szurubooru.util.datetime.datetime = FakeDatetime
date_min, date_max = parse_time_range('yesterday')
misc.datetime.datetime = FakeDatetime
date_min, date_max = misc.parse_time_range('yesterday')
self.assertEqual(date_min, datetime(1997, 1, 1, 0, 0, 0))
self.assertEqual(date_max, datetime(1997, 1, 1, 23, 59, 59))
def test_year(self):
date_min, date_max = parse_time_range('1999')
date_min, date_max = misc.parse_time_range('1999')
self.assertEqual(date_min, datetime(1999, 1, 1, 0, 0, 0))
self.assertEqual(date_max, datetime(1999, 12, 31, 23, 59, 59))
def test_month(self):
for text in ['1999-2', '1999-02']:
date_min, date_max = parse_time_range(text)
date_min, date_max = misc.parse_time_range(text)
self.assertEqual(date_min, datetime(1999, 2, 1, 0, 0, 0))
self.assertEqual(date_max, datetime(1999, 2, 28, 23, 59, 59))
def test_day(self):
for text in ['1999-2-6', '1999-02-6', '1999-2-06', '1999-02-06']:
date_min, date_max = parse_time_range(text)
date_min, date_max = misc.parse_time_range(text)
self.assertEqual(date_min, datetime(1999, 2, 6, 0, 0, 0))
self.assertEqual(date_max, datetime(1999, 2, 6, 23, 59, 59))

View file

@ -0,0 +1 @@
''' Cool functions. '''

View file

@ -0,0 +1,59 @@
import hashlib
import random
from szurubooru import config
from szurubooru import errors
def get_password_hash(salt, password):
''' Retrieve new-style password hash. '''
digest = hashlib.sha256()
digest.update(config.config['basic']['secret'].encode('utf8'))
digest.update(salt.encode('utf8'))
digest.update(password.encode('utf8'))
return digest.hexdigest()
def get_legacy_password_hash(salt, password):
''' Retrieve old-style password hash. '''
digest = hashlib.sha1()
digest.update(b'1A2/$_4xVa')
digest.update(salt.encode('utf8'))
digest.update(password.encode('utf8'))
return digest.hexdigest()
def create_password():
''' Create an easy-to-remember password. '''
alphabet = {
'c': list('bcdfghijklmnpqrstvwxyz'),
'v': list('aeiou'),
'n': list('0123456789'),
}
pattern = 'cvcvnncvcv'
return ''.join(random.choice(alphabet[l]) for l in list(pattern))
def is_valid_password(user, password):
''' Return whether the given password for a given user is valid. '''
salt, valid_hash = user.password_salt, user.password_hash
possible_hashes = [
get_password_hash(salt, password),
get_legacy_password_hash(salt, password)
]
return valid_hash in possible_hashes
def verify_privilege(user, privilege_name):
'''
Throw an AuthError if the given user doesn't have given privilege.
'''
all_ranks = config.config['service']['user_ranks']
assert privilege_name in config.config['privileges']
assert user.access_rank in all_ranks
minimal_rank = config.config['privileges'][privilege_name]
good_ranks = all_ranks[all_ranks.index(minimal_rank):]
if user.access_rank not in good_ranks:
raise errors.AuthError('Insufficient privileges to do this.')
def generate_authentication_token(user):
''' Generate nonguessable challenge (e.g. links in password reminder). '''
digest = hashlib.sha256()
digest.update(config.config['basic']['secret'].encode('utf8'))
digest.update(user.password_salt.encode('utf8'))
return digest.hexdigest()

View file

@ -0,0 +1,15 @@
import smtplib
import email.mime.text
from szurubooru import config
def send_mail(sender, recipient, subject, body):
msg = email.mime.text.MIMEText(body)
msg['Subject'] = subject
msg['From'] = sender
msg['To'] = recipient
smtp = smtplib.SMTP(
config.config['smtp']['host'], int(config.config['smtp']['port']))
smtp.login(config.config['smtp']['user'], config.config['smtp']['pass'])
smtp.send_message(msg)
smtp.quit()

View file

@ -3,7 +3,7 @@ import re
from szurubooru.errors import ValidationError
def is_valid_email(email):
''' Validates given email address. '''
''' Return whether given email address is valid or empty. '''
return not email or re.match(r'^[^@]*@[^@]*\.[^@]*$', email)
class dotdict(dict): # pylint: disable=invalid-name
@ -14,7 +14,7 @@ class dotdict(dict): # pylint: disable=invalid-name
__delattr__ = dict.__delitem__
def parse_time_range(value, timezone=datetime.timezone(datetime.timedelta())):
''' Returns tuple containing min/max time for given text representation. '''
''' Return tuple containing min/max time for given text representation. '''
one_day = datetime.timedelta(days=1)
one_second = datetime.timedelta(seconds=1)

View file

@ -0,0 +1,67 @@
import re
from datetime import datetime
from szurubooru import config, db, errors
from szurubooru.util import auth, misc
def create_user(name, password, email):
''' Create an user with given parameters and returns it. '''
user = db.User()
update_name(user, name)
update_password(user, password)
update_email(user, email)
user.access_rank = config.config['service']['default_user_rank']
user.creation_time = datetime.now()
user.avatar_style = db.User.AVATAR_GRAVATAR
return user
def update_name(user, name):
''' Validate and update user's name. '''
name = name.strip()
name_regex = config.config['service']['user_name_regex']
if not re.match(name_regex, name):
raise errors.ValidationError(
'Name must satisfy regex %r.' % name_regex)
user.name = name
def update_password(user, password):
''' Validate and update user's password. '''
password_regex = config.config['service']['password_regex']
if not re.match(password_regex, password):
raise errors.ValidationError(
'Password must satisfy regex %r.' % password_regex)
user.password_salt = auth.create_password()
user.password_hash = auth.get_password_hash(user.password_salt, password)
def update_email(user, email):
''' Validate and update user's email. '''
email = email.strip() or None
if not misc.is_valid_email(email):
raise errors.ValidationError(
'%r is not a vaild email address.' % email)
user.email = email
def update_rank(user, rank, authenticated_user):
rank = rank.strip()
available_access_ranks = config.config['service']['user_ranks']
if not rank in available_access_ranks:
raise errors.ValidationError(
'Bad access rank. Valid access ranks: %r' % available_access_ranks)
if available_access_ranks.index(authenticated_user.access_rank) \
< available_access_ranks.index(rank):
raise errors.AuthError('Trying to set higher access rank than one has')
user.access_rank = rank
def bump_login_time(user):
''' Update user's login time to current date. '''
user.last_login_time = datetime.now()
def reset_password(user):
''' Reset password for an user. '''
password = auth.create_password()
user.password_salt = auth.create_password()
user.password_hash = auth.get_password_hash(user.password_salt, password)
return password
def get_by_name(session, name):
''' Retrieve an user by its name. '''
return session.query(db.User).filter_by(name=name).first()