szurubooru/server/migrate-v1
2016-05-11 17:02:41 +02:00

411 lines
15 KiB
Python
Executable file

#!/usr/bin/env python3
import os
import sys
import datetime
import argparse
import json
import zlib
import concurrent.futures
import logging
import coloredlogs
import sqlalchemy
from szurubooru import config, db
from szurubooru.func import files, images, posts, comments
coloredlogs.install(fmt='[%(asctime)-15s] %(message)s')
logger = logging.getLogger(__name__)
def read_file(path):
with open(path, 'rb') as handle:
return handle.read()
def get_v1_session(args):
dsn = '{schema}://{user}:{password}@{host}:{port}/{name}?charset=utf8'.format(
schema='mysql+pymysql',
user=args.user,
password=args.password,
host=args.host,
port=args.port,
name=args.name)
logger.info('Connecting to %r...', dsn)
engine = sqlalchemy.create_engine(dsn)
session_maker = sqlalchemy.orm.sessionmaker(bind=engine)
return session_maker()
def parse_args():
parser = argparse.ArgumentParser(
description='Migrate database from szurubooru v1.x to v2.x.\n\n')
parser.add_argument(
'--data-dir',
metavar='PATH', required=True,
help='root directory of v1.x deploy')
parser.add_argument(
'--host',
metavar='HOST', required=True,
help='name of v1.x database host')
parser.add_argument(
'--port',
metavar='NUM', type=int, default=3306,
help='port to v1.x database host')
parser.add_argument(
'--name',
metavar='HOST', required=True,
help='v1.x database name')
parser.add_argument(
'--user',
metavar='NAME', required=True,
help='v1.x database user')
parser.add_argument(
'--password',
metavar='PASSWORD', required=True,
help='v1.x database password')
parser.add_argument(
'--no-data',
action='store_true',
help='don\'t migrate post data')
return parser.parse_args()
def exec_query(session, query):
for row in list(session.execute(query)):
row = dict(zip(row.keys(), row))
yield row
def exec_scalar(session, query):
rows = list(exec_query(session, query))
first_row = rows[0]
return list(first_row.values())[0]
def import_users(v1_data_dir, v1_session, v2_session):
for row in exec_query(v1_session, 'SELECT * FROM users'):
logger.info('Importing user %s...', row['name'])
user = db.User()
user.user_id = row['id']
user.name = row['name']
user.password_salt = row['passwordSalt']
user.password_hash = row['passwordHash']
user.email = row['email']
user.rank = {
6: db.User.RANK_ADMINISTRATOR,
5: db.User.RANK_MODERATOR,
4: db.User.RANK_POWER,
3: db.User.RANK_REGULAR,
2: db.User.RANK_RESTRICTED,
1: db.User.RANK_ANONYMOUS,
}[row['accessRank']]
user.creation_time = row['creationTime']
user.last_login_time = row['lastLoginTime']
user.avatar_style = {
2: db.User.AVATAR_MANUAL,
1: db.User.AVATAR_GRAVATAR,
0: db.User.AVATAR_GRAVATAR,
}[row['avatarStyle']]
v2_session.add(user)
if user.avatar_style == db.User.AVATAR_MANUAL:
source_avatar_path = os.path.join(
v1_data_dir, 'public_html', 'data', 'avatars', str(user.user_id))
avatar_content = read_file(source_avatar_path)
image = images.Image(avatar_content)
image.resize_fill(
int(config.config['thumbnails']['avatar_width']),
int(config.config['thumbnails']['avatar_height']))
files.save('avatars/' + user.name.lower() + '.png', image.to_png())
counter = exec_scalar(v1_session, 'SELECT MAX(id) FROM users') + 1
v2_session.execute('ALTER SEQUENCE user_id_seq RESTART WITH %d' % counter)
v2_session.commit()
def import_tag_categories(v1_session, v2_session):
category_to_id_map = {}
for row in exec_query(v1_session, 'SELECT DISTINCT category FROM tags'):
logger.info('Importing tag category %s...', row['category'])
category = db.TagCategory()
category.tag_category_id = len(category_to_id_map) + 1
category.name = row['category']
category.color = 'default'
v2_session.add(category)
category_to_id_map[category.name] = category.tag_category_id
v2_session.execute(
'ALTER SEQUENCE tag_category_id_seq RESTART WITH %d' % (
len(category_to_id_map) + 1,))
return category_to_id_map
def import_tags(category_to_id_map, v1_session, v2_session):
unused_tag_ids = []
for row in exec_query(v1_session, 'SELECT * FROM tags'):
logger.info('Importing tag %s...', row['name'])
if row['banned']:
logger.info('Ignored banned tag %s', row['name'])
unused_tag_ids.append(row['id'])
continue
tag = db.Tag()
tag.tag_id = row['id']
tag.names = [db.TagName(name=row['name'])]
tag.category_id = category_to_id_map[row['category']]
tag.creation_time = row['creationTime']
tag.last_edit_time = row['lastEditTime']
v2_session.add(tag)
counter = exec_scalar(v1_session, 'SELECT MAX(id) FROM tags') + 1
v2_session.execute('ALTER SEQUENCE tag_id_seq RESTART WITH %d' % counter)
v2_session.commit()
return unused_tag_ids
def import_tag_relations(unused_tag_ids, v1_session, v2_session):
logger.info('Importing tag relations...')
for row in exec_query(v1_session, 'SELECT * FROM tagRelations'):
if row['tag1id'] in unused_tag_ids or row['tag2id'] in unused_tag_ids:
continue
if row['type'] == 1:
v2_session.add(
db.TagImplication(
parent_id=row['tag1id'], child_id=row['tag2id']))
else:
v2_session.add(
db.TagSuggestion(
parent_id=row['tag1id'], child_id=row['tag2id']))
v2_session.commit()
def import_posts(v1_session, v2_session):
unused_post_ids = []
for row in exec_query(v1_session, 'SELECT * FROM posts'):
logger.info('Importing post %d...', row['id'])
if row['contentType'] == 4:
logger.warn('Ignoring youtube post %d', row['id'])
unused_post_ids.append(row['id'])
continue
post = db.Post()
post.post_id = row['id']
post.user_id = row['userId']
post.type = {
1: db.Post.TYPE_IMAGE,
2: db.Post.TYPE_FLASH,
3: db.Post.TYPE_VIDEO,
5: db.Post.TYPE_ANIMATION,
}[row['contentType']]
post.source = row['source']
post.canvas_width = row['imageWidth']
post.canvas_height = row['imageHeight']
post.file_size = row['originalFileSize']
post.creation_time = row['creationTime']
post.last_edit_time = row['lastEditTime']
post.checksum = row['contentChecksum']
post.mime_type = row['contentMimeType']
post.safety = {
1: db.Post.SAFETY_SAFE,
2: db.Post.SAFETY_SKETCHY,
3: db.Post.SAFETY_UNSAFE,
}[row['safety']]
if row['flags'] & 1:
post.flags = [db.Post.FLAG_LOOP]
v2_session.add(post)
counter = exec_scalar(v1_session, 'SELECT MAX(id) FROM posts') + 1
v2_session.execute('ALTER SEQUENCE post_id_seq RESTART WITH %d' % counter)
v2_session.commit()
return unused_post_ids
def _import_post_content_for_post(
unused_post_ids, v1_data_dir, v1_session, v2_session, row, post):
logger.info('Importing post %d content...', row['id'])
if row['id'] in unused_post_ids:
logger.warn('Ignoring unimported post %d', row['id'])
return
assert post
source_content_path = os.path.join(
v1_data_dir,
'public_html',
'data',
'posts',
row['name'])
source_thumb_path = os.path.join(
v1_data_dir,
'public_html',
'data',
'posts',
row['name'] + '-custom-thumb')
post_content = read_file(source_content_path)
files.save(posts.get_post_content_path(post), post_content)
if os.path.exists(source_thumb_path):
thumb_content = read_file(source_thumb_path)
files.save(posts.get_post_thumbnail_backup_path(post), thumb_content)
posts.generate_post_thumbnail(post)
def import_post_content(unused_post_ids, v1_data_dir, v1_session, v2_session):
rows = list(exec_query(v1_session, 'SELECT * FROM posts'))
posts = {post.post_id: post for post in v2_session.query(db.Post).all()}
with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
for row in rows:
post = posts.get(row['id'])
executor.submit(
_import_post_content_for_post,
unused_post_ids, v1_data_dir, v1_session, v2_session, row, post)
def import_post_tags(unused_post_ids, v1_session, v2_session):
logger.info('Importing post tags...')
for row in exec_query(v1_session, 'SELECT * FROM postTags'):
if row['postId'] in unused_post_ids:
continue
v2_session.add(db.PostTag(post_id=row['postId'], tag_id=row['tagId']))
v2_session.commit()
def import_post_notes(unused_post_ids, v1_session, v2_session):
logger.info('Importing post notes...')
for row in exec_query(v1_session, 'SELECT * FROM postNotes'):
if row['postId'] in unused_post_ids:
continue
x, y, w, h = row['x'], row['y'], row['width'], row['height']
x /= 100
y /= 100
w /= 100
h /= 100
post_note = db.PostNote()
post_note.post_id = row['postId']
post_note.text = row['text']
post_note.polygon = [
(x, y ),
(x + w, y ),
(x + w, y + w),
(x, y + w),
]
v2_session.add(post_note)
v2_session.commit()
def import_post_relations(unused_post_ids, v1_session, v2_session):
logger.info('Importing post relations...')
for row in exec_query(v1_session, 'SELECT * FROM postRelations'):
if row['post1id'] in unused_post_ids or row['post2id'] in unused_post_ids:
continue
v2_session.add(
db.PostRelation(
parent_id=row['post1id'], child_id=row['post2id']))
v2_session.commit()
def import_post_favorites(unused_post_ids, v1_session, v2_session):
logger.info('Importing post favorites...')
for row in exec_query(v1_session, 'SELECT * FROM favorites'):
if row['postId'] in unused_post_ids:
continue
v2_session.add(
db.PostFavorite(
post_id=row['postId'],
user_id=row['userId'],
time=row['time'] or datetime.datetime.min))
v2_session.commit()
def import_comments(unused_post_ids, v1_session, v2_session):
for row in exec_query(v1_session, 'SELECT * FROM comments'):
logger.info('Importing comment %d...', row['id'])
if row['postId'] in unused_post_ids:
logger.warn('Ignoring comment for unimported post %d', row['postId'])
continue
if not posts.try_get_post_by_id(row['postId']):
logger.warn('Ignoring comment for non existing post %d', row['postId'])
continue
comment = db.Comment()
comment.comment_id = row['id']
comment.user_id = row['userId']
comment.post_id = row['postId']
comment.creation_time = row['creationTime']
comment.last_edit_time = row['lastEditTime']
comment.text = row['text']
v2_session.add(comment)
counter = exec_scalar(v1_session, 'SELECT MAX(id) FROM comments') + 1
v2_session.execute('ALTER SEQUENCE comment_id_seq RESTART WITH %d' % counter)
v2_session.commit()
def import_scores(v1_session, v2_session):
logger.info('Importing scores...')
for row in exec_query(v1_session, 'SELECT * FROM scores'):
if row['postId']:
post = posts.try_get_post_by_id(row['postId'])
if not post:
logger.warn('Ignoring score for unimported post %d', row['postId'])
continue
score = db.PostScore()
score.post = post
elif row['commentId']:
comment = comments.try_get_comment_by_id(row['commentId'])
if not comment:
logger.warn('Ignoring score for unimported comment %d', row['commentId'])
continue
score = db.CommentScore()
score.comment = comment
score.score = row['score']
score.time = row['time'] or datetime.datetime.min
score.user_id = row['userId']
v2_session.add(score)
v2_session.commit()
def import_snapshots(v1_session, v2_session):
logger.info('Importing snapshots...')
for row in exec_query(v1_session, 'SELECT * FROM snapshots ORDER BY time ASC'):
snapshot = db.Snapshot()
snapshot.creation_time = row['time']
snapshot.user_id = row['userId']
snapshot.operation = {
0: db.Snapshot.OPERATION_CREATED,
1: db.Snapshot.OPERATION_MODIFIED,
2: db.Snapshot.OPERATION_DELETED,
}[row['operation']]
snapshot.resource_type = {
0: 'post',
1: 'tag',
}[row['type']]
snapshot.resource_id = row['primaryKey']
data = json.loads(zlib.decompress(row['data'], -15).decode('utf-8'))
if snapshot.resource_type == 'post':
if 'contentChecksum' in data:
data['checksum'] = data['contentChecksum']
del data['contentChecksum']
if 'tags' in data and isinstance(data['tags'], dict):
data['tags'] = list(data['tags'].values())
snapshot.resource_repr = row['primaryKey']
elif snapshot.resource_type == 'tag':
if 'banned' in data:
del data['banned']
if 'name' in data:
data['names'] = [data['name']]
del data['name']
snapshot.resource_repr = data['names'][0]
snapshot.data = data
v2_session.add(snapshot)
v2_session.commit()
def main():
args = parse_args()
v1_data_dir = args.data_dir
v1_session = get_v1_session(args)
v2_session = db.session
if v2_session.query(db.Tag).count() \
or v2_session.query(db.Post).count() \
or v2_session.query(db.Comment).count() \
or v2_session.query(db.User).count():
logger.error('v2.x database is dirty! Aborting.')
sys.exit(1)
import_users(v1_data_dir, v1_session, v2_session)
category_to_id_map = import_tag_categories(v1_session, v2_session)
unused_tag_ids = import_tags(category_to_id_map, v1_session, v2_session)
import_tag_relations(unused_tag_ids, v1_session, v2_session)
unused_post_ids = import_posts(v1_session, v2_session)
if not args.no_data:
import_post_content(unused_post_ids, v1_data_dir, v1_session, v2_session)
import_post_tags(unused_post_ids, v1_session, v2_session)
import_post_notes(unused_post_ids, v1_session, v2_session)
import_post_relations(unused_post_ids, v1_session, v2_session)
import_post_favorites(unused_post_ids, v1_session, v2_session)
import_comments(unused_post_ids, v1_session, v2_session)
import_scores(v1_session, v2_session)
import_snapshots(v1_session, v2_session)
if __name__ == '__main__':
main()