通过了单元测试嗯

This commit is contained in:
2026-03-07 15:53:13 +08:00
parent 7f1035ff43
commit ca1db103b5
7 changed files with 83 additions and 22 deletions

View File

@ -1,4 +1,3 @@
from typing import Iterable
import nonebot import nonebot
from nonebot.adapters import Event from nonebot.adapters import Event
@ -9,8 +8,6 @@ from konabot.common.permsys.migrates import execute_migration
from konabot.common.permsys.repo import PermRepo from konabot.common.permsys.repo import PermRepo
driver = nonebot.get_driver()
db = DatabaseManager(DATA_PATH / "perm.sqlite3") db = DatabaseManager(DATA_PATH / "perm.sqlite3")
@ -26,8 +23,12 @@ class PermManager:
if isinstance(entities, PermEntity): if isinstance(entities, PermEntity):
entities = [entities] entities = [entities]
key = key.removesuffix("*").removesuffix(".")
key_split = key.split(".") key_split = key.split(".")
keys = [".".join(key_split[: i + 1]) for i in range(len(key_split))][::-1] key_split = [s for s in key_split if len(s) > 0]
keys = [".".join(key_split[: i + 1]) for i in range(len(key_split))][::-1] + [
"*"
]
async with self.db.get_conn() as conn: async with self.db.get_conn() as conn:
repo = PermRepo(conn) repo = PermRepo(conn)
@ -44,6 +45,11 @@ class PermManager:
return p return p
return False return False
async def update_permission(self, entity: PermEntity, key: str, perm: bool | None):
async with self.db.get_conn() as conn:
repo = PermRepo(conn)
await repo.update_perm_info(entity, key, perm)
def perm_manager(_db: DatabaseManager | None = None) -> PermManager: def perm_manager(_db: DatabaseManager | None = None) -> PermManager:
if _db is None: if _db is None:
@ -51,12 +57,14 @@ def perm_manager(_db: DatabaseManager | None = None) -> PermManager:
return PermManager(_db) return PermManager(_db)
@driver.on_startup def create_startup():
async def _(): driver = nonebot.get_driver()
async with db.get_conn() as conn:
await execute_migration(conn)
@driver.on_startup
async def _():
async with db.get_conn() as conn:
await execute_migration(conn)
@driver.on_shutdown @driver.on_shutdown
async def _(): async def _():
await db.close_all_connections() await db.close_all_connections()

View File

@ -14,7 +14,7 @@ from nonebot.adapters.minecraft.event import MessageEvent as MinecraftMessageEve
from nonebot.adapters.console.event import MessageEvent as ConsoleEvent from nonebot.adapters.console.event import MessageEvent as ConsoleEvent
@dataclass @dataclass(frozen=True)
class PermEntity: class PermEntity:
platform: str platform: str
entity_type: str entity_type: str

View File

@ -29,9 +29,9 @@ class Migration:
return self.upgrade return self.upgrade
def get_downgrade_script(self) -> str: def get_downgrade_script(self) -> str:
if isinstance(self.upgrade, Path): if isinstance(self.downgrade, Path):
return self.upgrade.read_text() return self.downgrade.read_text()
return self.upgrade return self.downgrade
migrations = [ migrations = [
@ -53,11 +53,11 @@ async def get_current_version(conn: aiosqlite.Connection) -> int:
logger.info("权限系统数据表不存在,现在创建表") logger.info("权限系统数据表不存在,现在创建表")
await conn.executescript(SQL_CREATE_TABLE) await conn.executescript(SQL_CREATE_TABLE)
await conn.commit() await conn.commit()
return -1 return 0
cursor = await conn.execute(SQL_GET_MIGRATE_VERSION) cursor = await conn.execute(SQL_GET_MIGRATE_VERSION)
row = await cursor.fetchone() row = await cursor.fetchone()
if row is None: if row is None:
return -1 return 0
return row[0] return row[0]

View File

@ -25,6 +25,6 @@ ON perm_entity BEGIN
END; END;
CREATE TRIGGER perm_info_update AFTER UPDATE CREATE TRIGGER perm_info_update AFTER UPDATE
ON perm_info BEGIN ON perm_info BEGIN
UPDATE perm_info SET updated_at=CURRENT_TIMESTAMP WHERE id=old.id; UPDATE perm_info SET updated_at=CURRENT_TIMESTAMP WHERE entity_id=old.entity_id;
END; END;

View File

@ -99,7 +99,7 @@ class PermRepo:
row = await res.fetchone() row = await res.fetchone()
if row is None: if row is None:
return None return None
return row[0] return bool(row[0])
async def update_perm_info( async def update_perm_info(
self, entity: PermEntity, config_key: str, value: bool | None self, entity: PermEntity, config_key: str, value: bool | None
@ -177,4 +177,4 @@ class PermRepo:
cursor = await self.conn.execute(sql, params) cursor = await self.conn.execute(sql, params)
rows = await cursor.fetchall() rows = await cursor.fetchall()
return {(entity_ids[row[0]], row[1]): row[2] for row in rows} return {(entity_ids[row[0]], row[1]): bool(row[2]) for row in rows}

View File

@ -25,4 +25,4 @@ async def after_nonebot_init(after_nonebot_init: None):
def pytest_configure(config: pytest.Config): def pytest_configure(config: pytest.Config):
config.stash[NONEBOT_START_LIFESPAN] = False config.stash[NONEBOT_START_LIFESPAN] = True

View File

@ -4,6 +4,8 @@ from tempfile import TemporaryDirectory
import pytest import pytest
from konabot.common.database import DatabaseManager from konabot.common.database import DatabaseManager
from konabot.common.permsys import PermManager
from konabot.common.permsys.entity import PermEntity
from konabot.common.permsys.migrates import execute_migration, get_current_version from konabot.common.permsys.migrates import execute_migration, get_current_version
@ -21,7 +23,7 @@ async def test_get_db_version():
async with tempdb() as db: async with tempdb() as db:
async with db.get_conn() as conn: async with db.get_conn() as conn:
v = await get_current_version(conn) v = await get_current_version(conn)
assert v == -1 assert v == 0
v = await get_current_version(conn) v = await get_current_version(conn)
assert v == 0 assert v == 0
await execute_migration(conn, version=1) await execute_migration(conn, version=1)
@ -30,3 +32,54 @@ async def test_get_db_version():
await execute_migration(conn, version=0) await execute_migration(conn, version=0)
v = await get_current_version(conn) v = await get_current_version(conn)
assert v == 0 assert v == 0
@pytest.mark.asyncio
async def test_perm():
async with tempdb() as db:
async with db.get_conn() as conn:
await execute_migration(conn)
service = PermManager(db)
entity_global = PermEntity("sys", "global", "global")
entity1 = PermEntity("nonexist-platform", "user", "passthem")
chain1 = [entity1, entity_global]
entity2 = PermEntity("nonexist-platform", "user", "jack")
chain2 = [entity2, entity_global]
assert not await service.check_has_permission(chain1, "*")
await service.update_permission(entity1, "*", True)
assert await service.check_has_permission(chain1, "*")
assert await service.check_has_permission(chain1, "module1")
assert await service.check_has_permission(chain1, "module1.pack1")
assert not await service.check_has_permission(chain2, "*")
assert not await service.check_has_permission(chain2, "module1")
assert not await service.check_has_permission(chain2, "module1.pack1")
await service.update_permission(entity2, "module1", True)
assert not await service.check_has_permission(chain2, "*")
assert await service.check_has_permission(chain2, "module1")
assert await service.check_has_permission(chain2, "module1.pack1")
assert await service.check_has_permission(chain2, "module1.pack2")
assert not await service.check_has_permission(chain2, "module2")
assert not await service.check_has_permission(chain2, "module2.pack1")
assert not await service.check_has_permission(chain2, "module2.pack2")
await service.update_permission(entity2, "module1.pack2", False)
assert not await service.check_has_permission(chain2, "*")
assert await service.check_has_permission(chain2, "module1")
assert await service.check_has_permission(chain2, "module1.pack1")
assert not await service.check_has_permission(chain2, "module1.pack2")
assert not await service.check_has_permission(chain2, "module2")
assert not await service.check_has_permission(chain2, "module2.pack1")
assert not await service.check_has_permission(chain2, "module2.pack2")
await service.update_permission(entity_global, "module2", True)
assert not await service.check_has_permission(chain2, "*")
assert await service.check_has_permission(chain2, "module1")
assert await service.check_has_permission(chain2, "module1.pack1")
assert not await service.check_has_permission(chain2, "module1.pack2")
assert await service.check_has_permission(chain2, "module2")
assert await service.check_has_permission(chain2, "module2.pack1")
assert await service.check_has_permission(chain2, "module2.pack2")