from dataclasses import dataclass from pathlib import Path import aiosqlite from loguru import logger from konabot.common.database import DatabaseManager from konabot.common.path import DATA_PATH PATH_THISFOLDER = Path(__file__).parent SQL_CHECK_EXISTS = (PATH_THISFOLDER / "./check_migrate_version_exists.sql").read_text() SQL_CREATE_TABLE = (PATH_THISFOLDER / "./create_migrate_version_table.sql").read_text() SQL_GET_MIGRATE_VERSION = (PATH_THISFOLDER / "get_migrate_version.sql").read_text() SQL_UPDATE_VERSION = (PATH_THISFOLDER / "./update_migrate_version.sql").read_text() db = DatabaseManager(DATA_PATH / "perm.sqlite3") @dataclass class Migration: upgrade: str | Path downgrade: str | Path def get_upgrade_script(self) -> str: if isinstance(self.upgrade, Path): return self.upgrade.read_text() return self.upgrade def get_downgrade_script(self) -> str: if isinstance(self.downgrade, Path): return self.downgrade.read_text() return self.downgrade migrations = [ Migration( PATH_THISFOLDER / "./mu1_create_permsys_table.sql", PATH_THISFOLDER / "./md1_remove_permsys_table.sql", ), ] TARGET_VERSION = len(migrations) async def get_current_version(conn: aiosqlite.Connection) -> int: cursor = await conn.execute(SQL_CHECK_EXISTS) count = await cursor.fetchone() assert count is not None if count[0] < 1: logger.info("权限系统数据表不存在,现在创建表") await conn.executescript(SQL_CREATE_TABLE) await conn.commit() return 0 cursor = await conn.execute(SQL_GET_MIGRATE_VERSION) row = await cursor.fetchone() if row is None: return 0 return row[0] async def execute_migration( conn: aiosqlite.Connection, version: int = TARGET_VERSION, migrations: list[Migration] = migrations, ): now_version = await get_current_version(conn) while now_version < version: migration = migrations[now_version] await conn.executescript(migration.get_upgrade_script()) now_version += 1 await conn.execute(SQL_UPDATE_VERSION, (now_version,)) await conn.commit() while now_version > version: migration = migrations[now_version - 1] await conn.executescript(migration.get_downgrade_script()) now_version -= 1 await conn.execute(SQL_UPDATE_VERSION, (now_version,)) await conn.commit()