diff --git a/szurubooru/api/users.py b/szurubooru/api/users.py index ed13502..b313c63 100644 --- a/szurubooru/api/users.py +++ b/szurubooru/api/users.py @@ -2,6 +2,7 @@ import re import falcon +from szurubooru.services.errors import IntegrityError def _serialize_user(user): return { @@ -24,7 +25,7 @@ class UserListApi(object): def on_get(self, request, response): ''' Retrieves a list of users. ''' self._auth_service.verify_privilege(request.context['user'], 'users:list') - request.context['reuslt'] = {'message': 'Searching for users'} + request.context['result'] = {'message': 'Searching for users'} def on_post(self, request, response): ''' Creates a new user. ''' @@ -33,7 +34,7 @@ class UserListApi(object): password_regex = self._config['service']['password_regex'] try: - name = request.context['doc']['user'] + name = request.context['doc']['name'] password = request.context['doc']['password'] email = request.context['doc']['email'].strip() if not email: @@ -52,7 +53,12 @@ class UserListApi(object): 'Malformed data', 'Password must validate %r expression' % password_regex) - user = self._user_service.create_user(name, password, email) + session = request.context['session'] + try: + user = self._user_service.create_user(session, name, password, email) + session.commit() + except: + raise IntegrityError('User %r already exists.' % name) request.context['result'] = {'user': _serialize_user(user)} class UserDetailApi(object): @@ -65,7 +71,8 @@ class UserDetailApi(object): def on_get(self, request, response, user_name): ''' Retrieves an user. ''' self._auth_service.verify_privilege(request.context['user'], 'users:view') - user = self._user_service.get_by_name(user_name) + session = request.context['session'] + user = self._user_service.get_by_name(session, user_name) request.context['result'] = _serialize_user(user) def on_put(self, request, response, user_name): diff --git a/szurubooru/app.py b/szurubooru/app.py index c94aed1..75b0a2a 100644 --- a/szurubooru/app.py +++ b/szurubooru/app.py @@ -6,7 +6,6 @@ import sqlalchemy import sqlalchemy.orm import szurubooru.api import szurubooru.config -import szurubooru.db import szurubooru.middleware import szurubooru.services @@ -30,15 +29,13 @@ def create_app(): host=config['database']['host'], port=config['database']['port'], name=config['database']['name'])) - session_factory = sqlalchemy.orm.sessionmaker(bind=engine) - transaction_manager = szurubooru.db.TransactionManager(session_factory) + session_maker = sqlalchemy.orm.sessionmaker(bind=engine) + scoped_session = sqlalchemy.orm.scoped_session(session_maker) # TODO: is there a better way? password_service = szurubooru.services.PasswordService(config) - user_service = szurubooru.services.UserService( - config, transaction_manager, password_service) - auth_service = szurubooru.services.AuthService( - config, user_service, password_service) + auth_service = szurubooru.services.AuthService(config, password_service) + user_service = szurubooru.services.UserService(config, password_service) user_list = szurubooru.api.UserListApi(config, auth_service, user_service) user = szurubooru.api.UserDetailApi(config, auth_service, user_service) @@ -46,7 +43,8 @@ def create_app(): app = falcon.API(middleware=[ szurubooru.middleware.RequireJson(), szurubooru.middleware.JsonTranslator(), - szurubooru.middleware.Authenticator(auth_service), + szurubooru.middleware.DbSession(session_maker), + szurubooru.middleware.Authenticator(auth_service, user_service), ]) app.add_error_handler(szurubooru.services.AuthError, _on_auth_error) diff --git a/szurubooru/db.py b/szurubooru/db.py deleted file mode 100644 index 9527485..0000000 --- a/szurubooru/db.py +++ /dev/null @@ -1,36 +0,0 @@ -''' Exports TransactionManager. ''' - -from contextlib import contextmanager - -class TransactionManager(object): - ''' Helper class for managing database transactions. ''' - - def __init__(self, session_factory): - self._session_factory = session_factory - - @contextmanager - def transaction(self): - ''' - Provides a transactional scope around a series of DB operations that - might change the database. - ''' - return self._open_transaction(lambda session: session.commit) - - @contextmanager - def read_only_transaction(self): - ''' - Provides a transactional scope around a series of read-only DB - operations. - ''' - return self._open_transaction(lambda session: session.rollback) - - def _open_transaction(self, session_finalizer): - session = self._session_factory() - try: - yield session - session_finalizer(session) - except: - session.rollback() - raise - finally: - session.close() diff --git a/szurubooru/middleware/__init__.py b/szurubooru/middleware/__init__.py index 8d7dccb..a2ba328 100644 --- a/szurubooru/middleware/__init__.py +++ b/szurubooru/middleware/__init__.py @@ -3,3 +3,4 @@ from szurubooru.middleware.authenticator import Authenticator from szurubooru.middleware.json_translator import JsonTranslator from szurubooru.middleware.require_json import RequireJson +from szurubooru.middleware.db_session import DbSession diff --git a/szurubooru/middleware/authenticator.py b/szurubooru/middleware/authenticator.py index d34c52c..99a1fe9 100644 --- a/szurubooru/middleware/authenticator.py +++ b/szurubooru/middleware/authenticator.py @@ -2,6 +2,8 @@ import base64 import falcon +from szurubooru.model.user import User +from szurubooru.services.errors import AuthError class Authenticator(object): ''' @@ -9,8 +11,9 @@ class Authenticator(object): request context. ''' - def __init__(self, auth_service): + def __init__(self, auth_service, user_service): self._auth_service = auth_service + self._user_service = user_service def process_request(self, request, response): ''' Executed before passing the request to the API. ''' @@ -18,7 +21,7 @@ class Authenticator(object): def _get_user(self, request): if not request.auth: - return self._auth_service.authenticate(None, None) + return self._create_anonymous_user() try: auth_type, user_and_password = request.auth.split(' ', 1) @@ -31,10 +34,27 @@ class Authenticator(object): username, password = base64.decodebytes( user_and_password.encode('ascii')).decode('utf8').split(':') - return self._auth_service.authenticate(username, password) + session = request.context['session'] + return self._authenticate(session, username, password) except ValueError as err: msg = 'Basic authentication header value not properly formed. ' \ + 'Supplied header {0}. Got error: {1}' raise falcon.HTTPBadRequest( 'Malformed authentication request', msg.format(request.auth, str(err))) + + def _authenticate(self, session, username, password): + ''' Tries to authenticate user. Throws AuthError for invalid users. ''' + user = self._user_service.get_by_name(session, username) + if not user: + raise AuthError('No such user.') + if not self._auth_service.is_valid_password(user, password): + raise AuthError('Invalid password.') + return user + + def _create_anonymous_user(self): + user = User() + user.name = None + user.access_rank = 'anonymous' + user.password = None + return user diff --git a/szurubooru/middleware/db_session.py b/szurubooru/middleware/db_session.py new file mode 100644 index 0000000..3b69059 --- /dev/null +++ b/szurubooru/middleware/db_session.py @@ -0,0 +1,14 @@ +''' Exports DbSession. ''' + +class DbSession(object): + ''' Attaches database session to the context of every request. ''' + + def __init__(self, session_factory): + self._session_factory = session_factory + + def process_request(self, request, response): + ''' Executed before passing the request to the API. ''' + request.context['session'] = self._session_factory() + + def process_response(self, request, response, resource): + request.context['session'].close() diff --git a/szurubooru/middleware/json_translator.py b/szurubooru/middleware/json_translator.py index 3dc02a1..502fac0 100644 --- a/szurubooru/middleware/json_translator.py +++ b/szurubooru/middleware/json_translator.py @@ -1,8 +1,8 @@ ''' Exports JsonTranslator. ''' import json -import falcon from datetime import datetime +import falcon def json_serial(obj): ''' JSON serializer for objects not serializable by default JSON code ''' diff --git a/szurubooru/services/auth_service.py b/szurubooru/services/auth_service.py index 6d83102..f11a356 100644 --- a/szurubooru/services/auth_service.py +++ b/szurubooru/services/auth_service.py @@ -1,27 +1,14 @@ ''' Exports AuthService. ''' -from szurubooru.model.user import User from szurubooru.services.errors import AuthError class AuthService(object): ''' Services related to user authentication ''' - def __init__(self, config, user_service, password_service): + def __init__(self, config, password_service): self._config = config - self._user_service = user_service self._password_service = password_service - def authenticate(self, username, password): - ''' Tries to authenticate user. Throws AuthError for invalid users. ''' - if not username: - return self._create_anonymous_user() - user = self._user_service.get_by_name(username) - if not user: - raise AuthError('No such user.') - if not self.is_valid_password(user, password): - raise AuthError('Invalid password.') - return user - def is_valid_password(self, user, password): ''' Returns whether the given password for a given user is valid. ''' salt, valid_hash = user.password_salt, user.password_hash @@ -45,10 +32,3 @@ class AuthService(object): good_ranks = all_ranks[all_ranks.index(minimal_rank):] if user.access_rank not in good_ranks: raise AuthError('Insufficient privileges to do this.') - - def _create_anonymous_user(self): - user = User() - user.name = None - user.access_rank = 'anonymous' - user.password = None - return user diff --git a/szurubooru/services/user_service.py b/szurubooru/services/user_service.py index 9849bdf..d8ce8b6 100644 --- a/szurubooru/services/user_service.py +++ b/szurubooru/services/user_service.py @@ -2,39 +2,30 @@ from datetime import datetime from szurubooru.model.user import User -from szurubooru.services.errors import IntegrityError class UserService(object): ''' User management ''' - def __init__(self, config, transaction_manager, password_service): + def __init__(self, config, password_service): self._config = config - self._transaction_manager = transaction_manager self._password_service = password_service - def create_user(self, name, password, email): + def create_user(self, session, name, password, email): ''' Creates an user with given parameters and returns it. ''' - with self._transaction_manager.transaction() as session: - user = User() - user.name = name - user.password = password - user.password_salt = self._password_service.create_password() - user.password_hash = self._password_service.get_password_hash( - user.password_salt, user.password) - user.email = email - user.access_rank = self._config['service']['default_user_rank'] - user.creation_time = datetime.now() - user.avatar_style = User.AVATAR_GRAVATAR + user = User() + user.name = name + user.password = password + user.password_salt = self._password_service.create_password() + user.password_hash = self._password_service.get_password_hash( + user.password_salt, user.password) + user.email = email + user.access_rank = self._config['service']['default_user_rank'] + user.creation_time = datetime.now() + user.avatar_style = User.AVATAR_GRAVATAR - try: - session.add(user) - session.commit() - except: - raise IntegrityError('User %r already exists.' % name) + session.add(user) + return user - return user - - def get_by_name(self, name): + def get_by_name(self, session, name): ''' Retrieves an user by its name. ''' - with self._transaction_manager.read_only_transaction() as session: - return session.query(User).filter_by(name=name).first() + return session.query(User).filter_by(name=name).first()