#!/usr/bin/env python3 import os import sys import datetime import argparse import json import zlib import concurrent.futures import logging import coloredlogs import sqlalchemy from szurubooru import config, db from szurubooru.func import files, images, posts, comments coloredlogs.install(fmt='[%(asctime)-15s] %(message)s') logger = logging.getLogger(__name__) def read_file(path): with open(path, 'rb') as handle: return handle.read() def translate_note_polygon(row): x, y = row['x'], row['y'] w, h = row.get('width', row.get('w')), row.get('height', row.get('h')) x /= 100.0 y /= 100.0 w /= 100.0 h /= 100.0 return [ (x, y ), (x + w, y ), (x + w, y + h), (x, y + h), ] def get_v1_session(args): dsn = '{schema}://{user}:{password}@{host}:{port}/{name}?charset=utf8'.format( schema='mysql+pymysql', user=args.user, password=args.password, host=args.host, port=args.port, name=args.name) logger.info('Connecting to %r...', dsn) engine = sqlalchemy.create_engine(dsn) session_maker = sqlalchemy.orm.sessionmaker(bind=engine) return session_maker() def parse_args(): parser = argparse.ArgumentParser( description='Migrate database from szurubooru v1.x to v2.x.\n\n') parser.add_argument( '--data-dir', metavar='PATH', required=True, help='root directory of v1.x deploy') parser.add_argument( '--host', metavar='HOST', required=True, help='name of v1.x database host') parser.add_argument( '--port', metavar='NUM', type=int, default=3306, help='port to v1.x database host') parser.add_argument( '--name', metavar='HOST', required=True, help='v1.x database name') parser.add_argument( '--user', metavar='NAME', required=True, help='v1.x database user') parser.add_argument( '--password', metavar='PASSWORD', required=True, help='v1.x database password') parser.add_argument( '--no-data', action='store_true', help='don\'t migrate post data') return parser.parse_args() def exec_query(session, query): for row in list(session.execute(query)): row = dict(zip(row.keys(), row)) yield row def exec_scalar(session, query): rows = list(exec_query(session, query)) first_row = rows[0] return list(first_row.values())[0] def import_users(v1_data_dir, v1_session, v2_session, no_data=False): for row in exec_query(v1_session, 'SELECT * FROM users'): logger.info('Importing user %s...', row['name']) user = db.User() user.user_id = row['id'] user.name = row['name'] user.password_salt = row['passwordSalt'] user.password_hash = row['passwordHash'] user.email = row['email'] user.rank = { 6: db.User.RANK_ADMINISTRATOR, 5: db.User.RANK_MODERATOR, 4: db.User.RANK_POWER, 3: db.User.RANK_REGULAR, 2: db.User.RANK_RESTRICTED, 1: db.User.RANK_ANONYMOUS, }[row['accessRank']] user.creation_time = row['creationTime'] user.last_login_time = row['lastLoginTime'] user.avatar_style = { 2: db.User.AVATAR_MANUAL, 1: db.User.AVATAR_GRAVATAR, 0: db.User.AVATAR_GRAVATAR, }[row['avatarStyle']] v2_session.add(user) if user.avatar_style == db.User.AVATAR_MANUAL and not no_data: source_avatar_path = os.path.join( v1_data_dir, 'public_html', 'data', 'avatars', str(user.user_id)) avatar_content = read_file(source_avatar_path) image = images.Image(avatar_content) image.resize_fill( int(config.config['thumbnails']['avatar_width']), int(config.config['thumbnails']['avatar_height'])) files.save('avatars/' + user.name.lower() + '.png', image.to_png()) counter = exec_scalar(v1_session, 'SELECT MAX(id) FROM users') + 1 v2_session.execute('ALTER SEQUENCE user_id_seq RESTART WITH %d' % counter) v2_session.commit() def import_tag_categories(v1_session, v2_session): category_to_id_map = {} for row in exec_query(v1_session, 'SELECT DISTINCT category FROM tags'): logger.info('Importing tag category %s...', row['category']) category = db.TagCategory() category.tag_category_id = len(category_to_id_map) + 1 category.name = row['category'] category.color = 'default' v2_session.add(category) category_to_id_map[category.name] = category.tag_category_id v2_session.execute( 'ALTER SEQUENCE tag_category_id_seq RESTART WITH %d' % ( len(category_to_id_map) + 1,)) return category_to_id_map def import_tags(category_to_id_map, v1_session, v2_session): unused_tag_ids = [] for row in exec_query(v1_session, 'SELECT * FROM tags'): logger.info('Importing tag %s...', row['name']) if row['banned']: logger.info('Ignored banned tag %s', row['name']) unused_tag_ids.append(row['id']) continue tag = db.Tag() tag.tag_id = row['id'] tag.names = [db.TagName(name=row['name'])] tag.category_id = category_to_id_map[row['category']] tag.creation_time = row['creationTime'] tag.last_edit_time = row['lastEditTime'] v2_session.add(tag) counter = exec_scalar(v1_session, 'SELECT MAX(id) FROM tags') + 1 v2_session.execute('ALTER SEQUENCE tag_id_seq RESTART WITH %d' % counter) v2_session.commit() return unused_tag_ids def import_tag_relations(unused_tag_ids, v1_session, v2_session): logger.info('Importing tag relations...') for row in exec_query(v1_session, 'SELECT * FROM tagRelations'): if row['tag1id'] in unused_tag_ids or row['tag2id'] in unused_tag_ids: continue if row['type'] == 1: v2_session.add( db.TagImplication( parent_id=row['tag1id'], child_id=row['tag2id'])) else: v2_session.add( db.TagSuggestion( parent_id=row['tag1id'], child_id=row['tag2id'])) v2_session.commit() def import_posts(v1_session, v2_session): unused_post_ids = [] for row in exec_query(v1_session, 'SELECT * FROM posts'): logger.info('Importing post %d...', row['id']) if row['contentType'] == 4: logger.warn('Ignoring youtube post %d', row['id']) unused_post_ids.append(row['id']) continue post = db.Post() post.post_id = row['id'] post.user_id = row['userId'] post.type = { 1: db.Post.TYPE_IMAGE, 2: db.Post.TYPE_FLASH, 3: db.Post.TYPE_VIDEO, 5: db.Post.TYPE_ANIMATION, }[row['contentType']] post.source = row['source'] post.canvas_width = row['imageWidth'] post.canvas_height = row['imageHeight'] post.file_size = row['originalFileSize'] post.creation_time = row['creationTime'] post.last_edit_time = row['lastEditTime'] post.checksum = row['contentChecksum'] post.mime_type = row['contentMimeType'] post.safety = { 1: db.Post.SAFETY_SAFE, 2: db.Post.SAFETY_SKETCHY, 3: db.Post.SAFETY_UNSAFE, }[row['safety']] if row['flags'] & 1: post.flags = [db.Post.FLAG_LOOP] v2_session.add(post) counter = exec_scalar(v1_session, 'SELECT MAX(id) FROM posts') + 1 v2_session.execute('ALTER SEQUENCE post_id_seq RESTART WITH %d' % counter) v2_session.commit() return unused_post_ids def _import_post_content_for_post( unused_post_ids, v1_data_dir, v1_session, v2_session, row, post): logger.info('Importing post %d content...', row['id']) if row['id'] in unused_post_ids: logger.warn('Ignoring unimported post %d', row['id']) return assert post source_content_path = os.path.join( v1_data_dir, 'public_html', 'data', 'posts', row['name']) source_thumb_path = os.path.join( v1_data_dir, 'public_html', 'data', 'posts', row['name'] + '-custom-thumb') post_content = read_file(source_content_path) files.save(posts.get_post_content_path(post), post_content) if os.path.exists(source_thumb_path): thumb_content = read_file(source_thumb_path) files.save(posts.get_post_thumbnail_backup_path(post), thumb_content) posts.generate_post_thumbnail(post) def import_post_content(unused_post_ids, v1_data_dir, v1_session, v2_session): rows = list(exec_query(v1_session, 'SELECT * FROM posts')) posts = {post.post_id: post for post in v2_session.query(db.Post).all()} with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor: for row in rows: post = posts.get(row['id']) executor.submit( _import_post_content_for_post, unused_post_ids, v1_data_dir, v1_session, v2_session, row, post) def import_post_tags(unused_post_ids, v1_session, v2_session): logger.info('Importing post tags...') for row in exec_query(v1_session, 'SELECT * FROM postTags'): if row['postId'] in unused_post_ids: continue v2_session.add(db.PostTag(post_id=row['postId'], tag_id=row['tagId'])) v2_session.commit() def import_post_notes(unused_post_ids, v1_session, v2_session): logger.info('Importing post notes...') for row in exec_query(v1_session, 'SELECT * FROM postNotes'): if row['postId'] in unused_post_ids: continue post_note = db.PostNote() post_note.post_id = row['postId'] post_note.text = row['text'] post_note.polygon = translate_note_polygon(row) v2_session.add(post_note) v2_session.commit() def import_post_relations(unused_post_ids, v1_session, v2_session): logger.info('Importing post relations...') for row in exec_query(v1_session, 'SELECT * FROM postRelations'): if row['post1id'] in unused_post_ids or row['post2id'] in unused_post_ids: continue v2_session.add( db.PostRelation( parent_id=row['post1id'], child_id=row['post2id'])) v2_session.commit() def import_post_favorites(unused_post_ids, v1_session, v2_session): logger.info('Importing post favorites...') for row in exec_query(v1_session, 'SELECT * FROM favorites'): if row['postId'] in unused_post_ids: continue v2_session.add( db.PostFavorite( post_id=row['postId'], user_id=row['userId'], time=row['time'] or datetime.datetime.min)) v2_session.commit() def import_comments(unused_post_ids, v1_session, v2_session): for row in exec_query(v1_session, 'SELECT * FROM comments'): logger.info('Importing comment %d...', row['id']) if row['postId'] in unused_post_ids: logger.warn('Ignoring comment for unimported post %d', row['postId']) continue if not posts.try_get_post_by_id(row['postId']): logger.warn('Ignoring comment for non existing post %d', row['postId']) continue comment = db.Comment() comment.comment_id = row['id'] comment.user_id = row['userId'] comment.post_id = row['postId'] comment.creation_time = row['creationTime'] comment.last_edit_time = row['lastEditTime'] comment.text = row['text'] v2_session.add(comment) counter = exec_scalar(v1_session, 'SELECT MAX(id) FROM comments') + 1 v2_session.execute('ALTER SEQUENCE comment_id_seq RESTART WITH %d' % counter) v2_session.commit() def import_scores(v1_session, v2_session): logger.info('Importing scores...') for row in exec_query(v1_session, 'SELECT * FROM scores'): if row['postId']: post = posts.try_get_post_by_id(row['postId']) if not post: logger.warn('Ignoring score for unimported post %d', row['postId']) continue score = db.PostScore() score.post = post elif row['commentId']: comment = comments.try_get_comment_by_id(row['commentId']) if not comment: logger.warn('Ignoring score for unimported comment %d', row['commentId']) continue score = db.CommentScore() score.comment = comment score.score = row['score'] score.time = row['time'] or datetime.datetime.min score.user_id = row['userId'] v2_session.add(score) v2_session.commit() def import_snapshots(v1_session, v2_session): logger.info('Importing snapshots...') for row in exec_query(v1_session, 'SELECT * FROM snapshots ORDER BY time ASC'): snapshot = db.Snapshot() snapshot.creation_time = row['time'] snapshot.user_id = row['userId'] snapshot.operation = { 0: db.Snapshot.OPERATION_CREATED, 1: db.Snapshot.OPERATION_MODIFIED, 2: db.Snapshot.OPERATION_DELETED, }[row['operation']] snapshot.resource_type = { 0: 'post', 1: 'tag', }[row['type']] snapshot.resource_id = row['primaryKey'] data = json.loads(zlib.decompress(row['data'], -15).decode('utf-8')) if snapshot.resource_type == 'post': if 'contentChecksum' in data: data['checksum'] = data['contentChecksum'] del data['contentChecksum'] if 'tags' in data and isinstance(data['tags'], dict): data['tags'] = list(data['tags'].values()) if 'notes' in data: notes = [] for note in data['notes']: notes.append({ 'polygon': translate_note_polygon(note), 'text': note['text'], }) data['notes'] = notes snapshot.resource_repr = row['primaryKey'] elif snapshot.resource_type == 'tag': if 'banned' in data: del data['banned'] if 'name' in data: data['names'] = [data['name']] del data['name'] snapshot.resource_repr = data['names'][0] snapshot.data = data v2_session.add(snapshot) v2_session.commit() def main(): args = parse_args() v1_data_dir = args.data_dir v1_session = get_v1_session(args) v2_session = db.session if v2_session.query(db.Tag).count() \ or v2_session.query(db.Post).count() \ or v2_session.query(db.Comment).count() \ or v2_session.query(db.User).count(): logger.error('v2.x database is dirty! Aborting.') sys.exit(1) import_users(v1_data_dir, v1_session, v2_session, args.no_data) category_to_id_map = import_tag_categories(v1_session, v2_session) unused_tag_ids = import_tags(category_to_id_map, v1_session, v2_session) import_tag_relations(unused_tag_ids, v1_session, v2_session) unused_post_ids = import_posts(v1_session, v2_session) if not args.no_data: import_post_content(unused_post_ids, v1_data_dir, v1_session, v2_session) import_post_tags(unused_post_ids, v1_session, v2_session) import_post_notes(unused_post_ids, v1_session, v2_session) import_post_relations(unused_post_ids, v1_session, v2_session) import_post_favorites(unused_post_ids, v1_session, v2_session) import_comments(unused_post_ids, v1_session, v2_session) import_scores(v1_session, v2_session) import_snapshots(v1_session, v2_session) if __name__ == '__main__': main()