forked from mttu-developers/konabot
82 lines
2.4 KiB
Python
82 lines
2.4 KiB
Python
from dataclasses import dataclass
|
|
from pathlib import Path
|
|
|
|
import aiosqlite
|
|
from loguru import logger
|
|
|
|
from konabot.common.database import DatabaseManager
|
|
from konabot.common.path import DATA_PATH
|
|
|
|
|
|
PATH_THISFOLDER = Path(__file__).parent
|
|
|
|
SQL_CHECK_EXISTS = (PATH_THISFOLDER / "./check_migrate_version_exists.sql").read_text()
|
|
SQL_CREATE_TABLE = (PATH_THISFOLDER / "./create_migrate_version_table.sql").read_text()
|
|
SQL_GET_MIGRATE_VERSION = (PATH_THISFOLDER / "get_migrate_version.sql").read_text()
|
|
SQL_UPDATE_VERSION = (PATH_THISFOLDER / "./update_migrate_version.sql").read_text()
|
|
|
|
db = DatabaseManager(DATA_PATH / "perm.sqlite3")
|
|
|
|
|
|
@dataclass
|
|
class Migration:
|
|
upgrade: str | Path
|
|
downgrade: str | Path
|
|
|
|
def get_upgrade_script(self) -> str:
|
|
if isinstance(self.upgrade, Path):
|
|
return self.upgrade.read_text()
|
|
return self.upgrade
|
|
|
|
def get_downgrade_script(self) -> str:
|
|
if isinstance(self.downgrade, Path):
|
|
return self.downgrade.read_text()
|
|
return self.downgrade
|
|
|
|
|
|
migrations = [
|
|
Migration(
|
|
PATH_THISFOLDER / "./mu1_create_permsys_table.sql",
|
|
PATH_THISFOLDER / "./md1_remove_permsys_table.sql",
|
|
),
|
|
]
|
|
|
|
|
|
TARGET_VERSION = len(migrations)
|
|
|
|
|
|
async def get_current_version(conn: aiosqlite.Connection) -> int:
|
|
cursor = await conn.execute(SQL_CHECK_EXISTS)
|
|
count = await cursor.fetchone()
|
|
assert count is not None
|
|
if count[0] < 1:
|
|
logger.info("权限系统数据表不存在,现在创建表")
|
|
await conn.executescript(SQL_CREATE_TABLE)
|
|
await conn.commit()
|
|
return 0
|
|
cursor = await conn.execute(SQL_GET_MIGRATE_VERSION)
|
|
row = await cursor.fetchone()
|
|
if row is None:
|
|
return 0
|
|
return row[0]
|
|
|
|
|
|
async def execute_migration(
|
|
conn: aiosqlite.Connection,
|
|
version: int = TARGET_VERSION,
|
|
migrations: list[Migration] = migrations,
|
|
):
|
|
now_version = await get_current_version(conn)
|
|
while now_version < version:
|
|
migration = migrations[now_version]
|
|
await conn.executescript(migration.get_upgrade_script())
|
|
now_version += 1
|
|
await conn.execute(SQL_UPDATE_VERSION, (now_version,))
|
|
await conn.commit()
|
|
while now_version > version:
|
|
migration = migrations[now_version - 1]
|
|
await conn.executescript(migration.get_downgrade_script())
|
|
now_version -= 1
|
|
await conn.execute(SQL_UPDATE_VERSION, (now_version,))
|
|
await conn.commit()
|