forked from mttu-developers/konabot
通过了单元测试嗯
This commit is contained in:
@ -1,4 +1,3 @@
|
||||
from typing import Iterable
|
||||
import nonebot
|
||||
from nonebot.adapters import Event
|
||||
|
||||
@ -9,8 +8,6 @@ from konabot.common.permsys.migrates import execute_migration
|
||||
from konabot.common.permsys.repo import PermRepo
|
||||
|
||||
|
||||
driver = nonebot.get_driver()
|
||||
|
||||
db = DatabaseManager(DATA_PATH / "perm.sqlite3")
|
||||
|
||||
|
||||
@ -26,8 +23,12 @@ class PermManager:
|
||||
if isinstance(entities, PermEntity):
|
||||
entities = [entities]
|
||||
|
||||
key = key.removesuffix("*").removesuffix(".")
|
||||
key_split = key.split(".")
|
||||
keys = [".".join(key_split[: i + 1]) for i in range(len(key_split))][::-1]
|
||||
key_split = [s for s in key_split if len(s) > 0]
|
||||
keys = [".".join(key_split[: i + 1]) for i in range(len(key_split))][::-1] + [
|
||||
"*"
|
||||
]
|
||||
|
||||
async with self.db.get_conn() as conn:
|
||||
repo = PermRepo(conn)
|
||||
@ -44,6 +45,11 @@ class PermManager:
|
||||
return p
|
||||
return False
|
||||
|
||||
async def update_permission(self, entity: PermEntity, key: str, perm: bool | None):
|
||||
async with self.db.get_conn() as conn:
|
||||
repo = PermRepo(conn)
|
||||
await repo.update_perm_info(entity, key, perm)
|
||||
|
||||
|
||||
def perm_manager(_db: DatabaseManager | None = None) -> PermManager:
|
||||
if _db is None:
|
||||
@ -51,12 +57,14 @@ def perm_manager(_db: DatabaseManager | None = None) -> PermManager:
|
||||
return PermManager(_db)
|
||||
|
||||
|
||||
@driver.on_startup
|
||||
async def _():
|
||||
async with db.get_conn() as conn:
|
||||
await execute_migration(conn)
|
||||
def create_startup():
|
||||
driver = nonebot.get_driver()
|
||||
|
||||
@driver.on_startup
|
||||
async def _():
|
||||
async with db.get_conn() as conn:
|
||||
await execute_migration(conn)
|
||||
|
||||
@driver.on_shutdown
|
||||
async def _():
|
||||
await db.close_all_connections()
|
||||
@driver.on_shutdown
|
||||
async def _():
|
||||
await db.close_all_connections()
|
||||
|
||||
@ -14,7 +14,7 @@ from nonebot.adapters.minecraft.event import MessageEvent as MinecraftMessageEve
|
||||
from nonebot.adapters.console.event import MessageEvent as ConsoleEvent
|
||||
|
||||
|
||||
@dataclass
|
||||
@dataclass(frozen=True)
|
||||
class PermEntity:
|
||||
platform: str
|
||||
entity_type: str
|
||||
|
||||
@ -29,9 +29,9 @@ class Migration:
|
||||
return self.upgrade
|
||||
|
||||
def get_downgrade_script(self) -> str:
|
||||
if isinstance(self.upgrade, Path):
|
||||
return self.upgrade.read_text()
|
||||
return self.upgrade
|
||||
if isinstance(self.downgrade, Path):
|
||||
return self.downgrade.read_text()
|
||||
return self.downgrade
|
||||
|
||||
|
||||
migrations = [
|
||||
@ -53,11 +53,11 @@ async def get_current_version(conn: aiosqlite.Connection) -> int:
|
||||
logger.info("权限系统数据表不存在,现在创建表")
|
||||
await conn.executescript(SQL_CREATE_TABLE)
|
||||
await conn.commit()
|
||||
return -1
|
||||
return 0
|
||||
cursor = await conn.execute(SQL_GET_MIGRATE_VERSION)
|
||||
row = await cursor.fetchone()
|
||||
if row is None:
|
||||
return -1
|
||||
return 0
|
||||
return row[0]
|
||||
|
||||
|
||||
|
||||
@ -25,6 +25,6 @@ ON perm_entity BEGIN
|
||||
END;
|
||||
CREATE TRIGGER perm_info_update AFTER UPDATE
|
||||
ON perm_info BEGIN
|
||||
UPDATE perm_info SET updated_at=CURRENT_TIMESTAMP WHERE id=old.id;
|
||||
UPDATE perm_info SET updated_at=CURRENT_TIMESTAMP WHERE entity_id=old.entity_id;
|
||||
END;
|
||||
|
||||
|
||||
@ -99,7 +99,7 @@ class PermRepo:
|
||||
row = await res.fetchone()
|
||||
if row is None:
|
||||
return None
|
||||
return row[0]
|
||||
return bool(row[0])
|
||||
|
||||
async def update_perm_info(
|
||||
self, entity: PermEntity, config_key: str, value: bool | None
|
||||
@ -177,4 +177,4 @@ class PermRepo:
|
||||
cursor = await self.conn.execute(sql, params)
|
||||
rows = await cursor.fetchall()
|
||||
|
||||
return {(entity_ids[row[0]], row[1]): row[2] for row in rows}
|
||||
return {(entity_ids[row[0]], row[1]): bool(row[2]) for row in rows}
|
||||
|
||||
@ -25,4 +25,4 @@ async def after_nonebot_init(after_nonebot_init: None):
|
||||
|
||||
|
||||
def pytest_configure(config: pytest.Config):
|
||||
config.stash[NONEBOT_START_LIFESPAN] = False
|
||||
config.stash[NONEBOT_START_LIFESPAN] = True
|
||||
|
||||
@ -4,6 +4,8 @@ from tempfile import TemporaryDirectory
|
||||
import pytest
|
||||
|
||||
from konabot.common.database import DatabaseManager
|
||||
from konabot.common.permsys import PermManager
|
||||
from konabot.common.permsys.entity import PermEntity
|
||||
from konabot.common.permsys.migrates import execute_migration, get_current_version
|
||||
|
||||
|
||||
@ -21,7 +23,7 @@ async def test_get_db_version():
|
||||
async with tempdb() as db:
|
||||
async with db.get_conn() as conn:
|
||||
v = await get_current_version(conn)
|
||||
assert v == -1
|
||||
assert v == 0
|
||||
v = await get_current_version(conn)
|
||||
assert v == 0
|
||||
await execute_migration(conn, version=1)
|
||||
@ -30,3 +32,54 @@ async def test_get_db_version():
|
||||
await execute_migration(conn, version=0)
|
||||
v = await get_current_version(conn)
|
||||
assert v == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_perm():
|
||||
async with tempdb() as db:
|
||||
async with db.get_conn() as conn:
|
||||
await execute_migration(conn)
|
||||
|
||||
service = PermManager(db)
|
||||
entity_global = PermEntity("sys", "global", "global")
|
||||
entity1 = PermEntity("nonexist-platform", "user", "passthem")
|
||||
chain1 = [entity1, entity_global]
|
||||
entity2 = PermEntity("nonexist-platform", "user", "jack")
|
||||
chain2 = [entity2, entity_global]
|
||||
|
||||
assert not await service.check_has_permission(chain1, "*")
|
||||
|
||||
await service.update_permission(entity1, "*", True)
|
||||
assert await service.check_has_permission(chain1, "*")
|
||||
assert await service.check_has_permission(chain1, "module1")
|
||||
assert await service.check_has_permission(chain1, "module1.pack1")
|
||||
assert not await service.check_has_permission(chain2, "*")
|
||||
assert not await service.check_has_permission(chain2, "module1")
|
||||
assert not await service.check_has_permission(chain2, "module1.pack1")
|
||||
|
||||
await service.update_permission(entity2, "module1", True)
|
||||
assert not await service.check_has_permission(chain2, "*")
|
||||
assert await service.check_has_permission(chain2, "module1")
|
||||
assert await service.check_has_permission(chain2, "module1.pack1")
|
||||
assert await service.check_has_permission(chain2, "module1.pack2")
|
||||
assert not await service.check_has_permission(chain2, "module2")
|
||||
assert not await service.check_has_permission(chain2, "module2.pack1")
|
||||
assert not await service.check_has_permission(chain2, "module2.pack2")
|
||||
|
||||
await service.update_permission(entity2, "module1.pack2", False)
|
||||
assert not await service.check_has_permission(chain2, "*")
|
||||
assert await service.check_has_permission(chain2, "module1")
|
||||
assert await service.check_has_permission(chain2, "module1.pack1")
|
||||
assert not await service.check_has_permission(chain2, "module1.pack2")
|
||||
assert not await service.check_has_permission(chain2, "module2")
|
||||
assert not await service.check_has_permission(chain2, "module2.pack1")
|
||||
assert not await service.check_has_permission(chain2, "module2.pack2")
|
||||
|
||||
await service.update_permission(entity_global, "module2", True)
|
||||
assert not await service.check_has_permission(chain2, "*")
|
||||
assert await service.check_has_permission(chain2, "module1")
|
||||
assert await service.check_has_permission(chain2, "module1.pack1")
|
||||
assert not await service.check_has_permission(chain2, "module1.pack2")
|
||||
assert await service.check_has_permission(chain2, "module2")
|
||||
assert await service.check_has_permission(chain2, "module2.pack1")
|
||||
assert await service.check_has_permission(chain2, "module2.pack2")
|
||||
|
||||
Reference in New Issue
Block a user