server/search: fix missing default search order

This commit is contained in:
rr- 2016-04-16 18:55:04 +02:00
parent fa6b808659
commit c71c082000
4 changed files with 81 additions and 65 deletions

View file

@ -69,6 +69,9 @@ def _apply_str_criterion_to_column(column, query, criterion):
return query.filter(expr) return query.filter(expr)
class BaseSearchConfig(object): class BaseSearchConfig(object):
ORDER_DESC = 1
ORDER_ASC = 2
def create_query(self, session): def create_query(self, session):
raise NotImplementedError() raise NotImplementedError()

View file

@ -9,9 +9,6 @@ class SearchExecutor(object):
delegates sqlalchemy filter decoration to SearchConfig instances. delegates sqlalchemy filter decoration to SearchConfig instances.
''' '''
ORDER_DESC = 1
ORDER_ASC = 2
def __init__(self, search_config): def __init__(self, search_config):
self._search_config = search_config self._search_config = search_config
@ -52,13 +49,13 @@ class SearchExecutor(object):
def _handle_key_value(self, query, key, value, negated): def _handle_key_value(self, query, key, value, negated):
if key == 'order': if key == 'order':
if value.count(',') == 0: if value.count(',') == 0:
order = self.ORDER_ASC order = None
elif value.count(',') == 1: elif value.count(',') == 1:
value, order_str = value.split(',') value, order_str = value.split(',')
if order_str == 'asc': if order_str == 'asc':
order = self.ORDER_ASC order = self._search_config.ORDER_ASC
elif order_str == 'desc': elif order_str == 'desc':
order = self.ORDER_DESC order = self._search_config.ORDER_DESC
else: else:
raise errors.SearchError( raise errors.SearchError(
'Unknown search direction: %r.' % order_str) 'Unknown search direction: %r.' % order_str)
@ -66,10 +63,12 @@ class SearchExecutor(object):
raise errors.SearchError( raise errors.SearchError(
'Too many commas in order search token.') 'Too many commas in order search token.')
if negated: if negated:
if order == self.ORDER_DESC: if order == self._search_config.ORDER_DESC:
order = self.ORDER_ASC order = self._search_config.ORDER_ASC
elif order == self._search_config.ORDER_ASC:
order = self._search_config.ORDER_DESC
else: else:
order = self.ORDER_DESC order = -1
return self._handle_order(query, value, order) return self._handle_order(query, value, order)
elif key == 'special': elif key == 'special':
return self._handle_special(query, value, negated) return self._handle_special(query, value, negated)
@ -100,8 +99,17 @@ class SearchExecutor(object):
def _handle_order(self, query, value, order): def _handle_order(self, query, value, order):
if value in self._search_config.order_columns: if value in self._search_config.order_columns:
column = self._search_config.order_columns[value] column, default_order = self._search_config.order_columns[value]
if order == self.ORDER_ASC: if order is None:
order = default_order
elif order == -1:
if default_order == self._search_config.ORDER_ASC:
order = self._search_config.ORDER_DESC
elif default_order == self._search_config.ORDER_DESC:
order = self._search_config.ORDER_ASC
else:
order = self._search_config.ORDER_ASC
if order == self._search_config.ORDER_ASC:
column = column.asc() column = column.asc()
else: else:
column = column.desc() column = column.desc()

View file

@ -35,11 +35,11 @@ class UserSearchConfig(BaseSearchConfig):
def order_columns(self): def order_columns(self):
return { return {
'random': func.random(), 'random': func.random(),
'name': db.User.name, 'name': (db.User.name, self.ORDER_ASC),
'creation-date': db.User.creation_time, 'creation-date': (db.User.creation_time, self.ORDER_DESC),
'creation-time': db.User.creation_time, 'creation-time': (db.User.creation_time, self.ORDER_DESC),
'last-login-date': db.User.last_login_time, 'last-login-date': (db.User.last_login_time, self.ORDER_DESC),
'last-login-time': db.User.last_login_time, 'last-login-time': (db.User.last_login_time, self.ORDER_DESC),
'login-date': db.User.last_login_time, 'login-date': (db.User.last_login_time, self.ORDER_DESC),
'login-time': db.User.last_login_time, 'login-time': (db.User.last_login_time, self.ORDER_DESC),
} }

View file

@ -1,19 +1,7 @@
from datetime import datetime import datetime
import pytest import pytest
from szurubooru import db, errors, search from szurubooru import db, errors, search
def mock_user(name):
user = db.User()
user.name = name
user.password = 'dummy'
user.password_salt = 'dummy'
user.password_hash = 'dummy'
user.email = 'dummy'
user.rank = 'dummy'
user.creation_time = datetime(1997, 1, 1)
user.avatar_style = db.User.AVATAR_GRAVATAR
return user
@pytest.fixture @pytest.fixture
def executor(session): def executor(session):
search_config = search.UserSearchConfig() search_config = search.UserSearchConfig()
@ -29,7 +17,6 @@ def verify_unpaged(session, executor):
assert actual_user_names == expected_user_names assert actual_user_names == expected_user_names
return verify return verify
# -----------------------------------------------------------------------------
@pytest.mark.parametrize('input,expected_user_names', [ @pytest.mark.parametrize('input,expected_user_names', [
('creation-time:2014', ['u1', 'u2']), ('creation-time:2014', ['u1', 'u2']),
('creation-date:2014', ['u1', 'u2']), ('creation-date:2014', ['u1', 'u2']),
@ -53,17 +40,16 @@ def verify_unpaged(session, executor):
('-creation-date:2014-01,2015', ['u2']), ('-creation-date:2014-01,2015', ['u2']),
]) ])
def test_filter_by_creation_time( def test_filter_by_creation_time(
verify_unpaged, session, input, expected_user_names): verify_unpaged, session, input, expected_user_names, user_factory):
user1 = mock_user('u1') user1 = user_factory(name='u1')
user2 = mock_user('u2') user2 = user_factory(name='u2')
user3 = mock_user('u3') user3 = user_factory(name='u3')
user1.creation_time = datetime(2014, 1, 1) user1.creation_time = datetime.datetime(2014, 1, 1)
user2.creation_time = datetime(2014, 6, 1) user2.creation_time = datetime.datetime(2014, 6, 1)
user3.creation_time = datetime(2015, 1, 1) user3.creation_time = datetime.datetime(2015, 1, 1)
session.add_all([user1, user2, user3]) session.add_all([user1, user2, user3])
verify_unpaged(input, expected_user_names) verify_unpaged(input, expected_user_names)
# -----------------------------------------------------------------------------
@pytest.mark.parametrize('input,expected_user_names', [ @pytest.mark.parametrize('input,expected_user_names', [
('name:user1', ['user1']), ('name:user1', ['user1']),
('name:user2', ['user2']), ('name:user2', ['user2']),
@ -82,41 +68,41 @@ def test_filter_by_creation_time(
('name:user1,user2', ['user1', 'user2']), ('name:user1,user2', ['user1', 'user2']),
('-name:user1,user3', ['user2']), ('-name:user1,user3', ['user2']),
]) ])
def test_filter_by_name(session, verify_unpaged, input, expected_user_names): def test_filter_by_name(
session.add(mock_user('user1')) session, verify_unpaged, input, expected_user_names, user_factory):
session.add(mock_user('user2')) session.add(user_factory(name='user1'))
session.add(mock_user('user3')) session.add(user_factory(name='user2'))
session.add(user_factory(name='user3'))
verify_unpaged(input, expected_user_names) verify_unpaged(input, expected_user_names)
# -----------------------------------------------------------------------------
@pytest.mark.parametrize('input,expected_user_names', [ @pytest.mark.parametrize('input,expected_user_names', [
('', ['u1', 'u2']), ('', ['u1', 'u2']),
('u1', ['u1']), ('u1', ['u1']),
('u2', ['u2']), ('u2', ['u2']),
('u1,u2', ['u1', 'u2']), ('u1,u2', ['u1', 'u2']),
]) ])
def test_anonymous(session, verify_unpaged, input, expected_user_names): def test_anonymous(
session.add(mock_user('u1')) session, verify_unpaged, input, expected_user_names, user_factory):
session.add(mock_user('u2')) session.add(user_factory(name='u1'))
session.add(user_factory(name='u2'))
verify_unpaged(input, expected_user_names) verify_unpaged(input, expected_user_names)
# -----------------------------------------------------------------------------
@pytest.mark.parametrize('input,expected_user_names', [ @pytest.mark.parametrize('input,expected_user_names', [
('creation-time:2014 u1', ['u1']), ('creation-time:2014 u1', ['u1']),
('creation-time:2014 u2', ['u2']), ('creation-time:2014 u2', ['u2']),
('creation-time:2016 u2', []), ('creation-time:2016 u2', []),
]) ])
def test_combining_tokens(session, verify_unpaged, input, expected_user_names): def test_combining_tokens(
user1 = mock_user('u1') session, verify_unpaged, input, expected_user_names, user_factory):
user2 = mock_user('u2') user1 = user_factory(name='u1')
user3 = mock_user('u3') user2 = user_factory(name='u2')
user1.creation_time = datetime(2014, 1, 1) user3 = user_factory(name='u3')
user2.creation_time = datetime(2014, 6, 1) user1.creation_time = datetime.datetime(2014, 1, 1)
user3.creation_time = datetime(2015, 1, 1) user2.creation_time = datetime.datetime(2014, 6, 1)
user3.creation_time = datetime.datetime(2015, 1, 1)
session.add_all([user1, user2, user3]) session.add_all([user1, user2, user3])
verify_unpaged(input, expected_user_names) verify_unpaged(input, expected_user_names)
# -----------------------------------------------------------------------------
@pytest.mark.parametrize( @pytest.mark.parametrize(
'page,page_size,expected_total_count,expected_user_names', [ 'page,page_size,expected_total_count,expected_user_names', [
(1, 1, 2, ['u1']), (1, 1, 2, ['u1']),
@ -126,17 +112,16 @@ def test_combining_tokens(session, verify_unpaged, input, expected_user_names):
(0, 0, 2, []), (0, 0, 2, []),
]) ])
def test_paging( def test_paging(
session, executor, page, page_size, session, executor, user_factory, page, page_size,
expected_total_count, expected_user_names): expected_total_count, expected_user_names):
session.add(mock_user('u1')) session.add(user_factory(name='u1'))
session.add(mock_user('u2')) session.add(user_factory(name='u2'))
actual_count, actual_users = executor.execute( actual_count, actual_users = executor.execute(
session, '', page=page, page_size=page_size) session, '', page=page, page_size=page_size)
actual_user_names = [u.name for u in actual_users] actual_user_names = [u.name for u in actual_users]
assert actual_count == expected_total_count assert actual_count == expected_total_count
assert actual_user_names == expected_user_names assert actual_user_names == expected_user_names
# -----------------------------------------------------------------------------
@pytest.mark.parametrize('input,expected_user_names', [ @pytest.mark.parametrize('input,expected_user_names', [
('', ['u1', 'u2']), ('', ['u1', 'u2']),
('order:name', ['u1', 'u2']), ('order:name', ['u1', 'u2']),
@ -146,12 +131,32 @@ def test_paging(
('-order:name,asc', ['u2', 'u1']), ('-order:name,asc', ['u2', 'u1']),
('-order:name,desc', ['u1', 'u2']), ('-order:name,desc', ['u1', 'u2']),
]) ])
def test_order_by_name(session, verify_unpaged, input, expected_user_names): def test_order_by_name(
session.add(mock_user('u2')) session, verify_unpaged, input, expected_user_names, user_factory):
session.add(mock_user('u1')) session.add(user_factory(name='u2'))
session.add(user_factory(name='u1'))
verify_unpaged(input, expected_user_names)
@pytest.mark.parametrize('input,expected_user_names', [
('', ['u1', 'u2', 'u3']),
('order:creation-date', ['u3', 'u2', 'u1']),
('-order:creation-date', ['u1', 'u2', 'u3']),
('order:creation-date,asc', ['u1', 'u2', 'u3']),
('order:creation-date,desc', ['u3', 'u2', 'u1']),
('-order:creation-date,asc', ['u3', 'u2', 'u1']),
('-order:creation-date,desc', ['u1', 'u2', 'u3']),
])
def test_order_by_name(
session, verify_unpaged, input, expected_user_names, user_factory):
user1 = user_factory(name='u1')
user2 = user_factory(name='u2')
user3 = user_factory(name='u3')
user1.creation_time = datetime.datetime(1991, 1, 1)
user2.creation_time = datetime.datetime(1991, 1, 2)
user3.creation_time = datetime.datetime(1991, 1, 3)
session.add_all([user3, user1, user2])
verify_unpaged(input, expected_user_names) verify_unpaged(input, expected_user_names)
# -----------------------------------------------------------------------------
@pytest.mark.parametrize('input,expected_error', [ @pytest.mark.parametrize('input,expected_error', [
('creation-date:..', errors.SearchError), ('creation-date:..', errors.SearchError),
('creation-date:bad..', errors.ValidationError), ('creation-date:bad..', errors.ValidationError),