Compare commits

...

7 Commits

26 changed files with 2156 additions and 1320 deletions

View File

@ -13,7 +13,7 @@ steps:
- name: submodules
image: alpine/git
commands:
- git submodule update --init --recursive
- git submodule update --init --recursive
- name: 构建 Docker 镜像
image: plugins/docker:latest
privileged: true
@ -76,7 +76,7 @@ steps:
- name: submodules
image: alpine/git
commands:
- git submodule update --init --recursive
- git submodule update --init --recursive
- name: 构建并推送 Release Docker 镜像
image: plugins/docker:latest
privileged: true

5
.gitignore vendored
View File

@ -9,3 +9,8 @@ __pycache__
# 可能会偶然生成的 diff 文件
/*.diff
# 代码覆盖报告
/.coverage
/.coverage.db
/htmlcov

6
.sqls.yml Normal file
View File

@ -0,0 +1,6 @@
lowercaseKeywords: false
connections:
- driver: sqlite
dataSourceName: "./data/database.db"
- driver: sqlite
dataSourceName: "./data/perm.sqlite3"

View File

@ -96,6 +96,21 @@ poetry run python bot.py
- [事件处理](https://nonebot.dev/docs/tutorial/handler)
- [Alconna 插件](https://nonebot.dev/docs/best-practice/alconna/)
## 数据库模块
## 代码测试
本项目的数据库模块已更新为异步实现,使用连接池来提高性能,并支持现代的`pathlib.Path`参数类型。详细使用方法请参考[数据库使用文档](/docs/database.md)
本项目使用 pytest 进行自动化测试,你可以把你的测试代码放在 `./tests` 目录下
使用命令行执行测试:
```bash
poetry run just test
```
使用命令行,在浏览器查看测试覆盖率报告:
```bash
poetry run just coverage
# 此时会打开一个 :8000 端口的 Web 服务器
# 你可以在 http://localhost:8000 查看覆盖率报告
# 在控制台使用 Ctrl+C 关闭这个 Web 服务器
```

24
bot.py
View File

@ -7,6 +7,7 @@ from nonebot.adapters.discord import Adapter as DiscordAdapter
from nonebot.adapters.minecraft import Adapter as MinecraftAdapter
from nonebot.adapters.onebot.v11 import Adapter as OnebotAdapter
from konabot.common import permsys
from konabot.common.log import init_logger
from konabot.common.nb.exc import BotExceptionMessage
from konabot.common.path import LOG_PATH
@ -22,19 +23,25 @@ env_enable_minecraft = os.environ.get("ENABLE_MINECRAFT", "none")
def main():
if env.upper() == 'DEBUG' or env.upper() == 'DEV':
console_log_level = 'DEBUG'
if env.upper() == "DEBUG" or env.upper() == "DEV":
console_log_level = "DEBUG"
else:
console_log_level = 'INFO'
init_logger(LOG_PATH, [
BotExceptionMessage,
], console_log_level=console_log_level)
console_log_level = "INFO"
init_logger(
LOG_PATH,
[
BotExceptionMessage,
],
console_log_level=console_log_level,
)
nonebot.init()
driver = nonebot.get_driver()
if (env != "prod" and env != "test" and env_enable_console.upper() != "FALSE") or (env_enable_console.upper() == "TRUE"):
if (env != "prod" and env != "test" and env_enable_console.upper() != "FALSE") or (
env_enable_console.upper() == "TRUE"
):
driver.register_adapter(ConsoleAdapter)
if env_enable_qq.upper() == "TRUE":
@ -50,6 +57,8 @@ def main():
nonebot.load_plugins("konabot/plugins")
nonebot.load_plugin("nonebot_plugin_analysis_bilibili")
permsys.create_startup()
# 注册关闭钩子
@driver.on_shutdown
async def shutdown_handler():
@ -59,5 +68,6 @@ def main():
nonebot.run()
if __name__ == "__main__":
main()

View File

@ -1,4 +1,9 @@
watch:
poetry run watchfiles bot.main . --filter scripts.watch_filter.filter
test:
poetry run pytest --cov-report term-missing:skip-covered
coverage:
poetry run pytest --cov-report html
python -m http.server -d htmlcov

View File

@ -1,3 +1,4 @@
from contextlib import asynccontextmanager
import os
import asyncio
import sqlparse
@ -10,16 +11,19 @@ if TYPE_CHECKING:
from . import DatabaseManager
# 全局数据库管理器实例
_global_db_manager: Optional['DatabaseManager'] = None
_global_db_manager: Optional["DatabaseManager"] = None
def get_global_db_manager() -> 'DatabaseManager':
def get_global_db_manager() -> "DatabaseManager":
"""获取全局数据库管理器实例"""
global _global_db_manager
if _global_db_manager is None:
from . import DatabaseManager
_global_db_manager = DatabaseManager()
return _global_db_manager
def close_global_db_manager() -> None:
"""关闭全局数据库管理器实例"""
global _global_db_manager
@ -87,6 +91,12 @@ class DatabaseManager:
except:
pass
@asynccontextmanager
async def get_conn(self):
conn = await self._get_connection()
yield conn
await self._return_connection(conn)
async def query(
self, query: str, params: Optional[tuple] = None
) -> List[Dict[str, Any]]:
@ -143,22 +153,24 @@ class DatabaseManager:
# 使用sqlparse库更准确地分割SQL语句
parsed = sqlparse.split(script)
statements = []
for statement in parsed:
statement = statement.strip()
if statement:
statements.append(statement)
return statements
async def execute_by_sql_file(
self, file_path: Union[str, Path], params: Optional[Union[tuple, List[tuple]]] = None
self,
file_path: Union[str, Path],
params: Optional[Union[tuple, List[tuple]]] = None,
) -> None:
"""从 SQL 文件中读取非查询语句并执行"""
path = str(file_path) if isinstance(file_path, Path) else file_path
with open(path, "r", encoding="utf-8") as f:
script = f.read()
# 如果有参数且是元组使用execute执行整个脚本
if params is not None and isinstance(params, tuple):
await self.execute(script, params)
@ -167,8 +179,10 @@ class DatabaseManager:
# 使用sqlparse准确分割SQL语句
statements = self._parse_sql_statements(script)
if len(statements) != len(params):
raise ValueError(f"语句数量({len(statements)})与参数组数量({len(params)})不匹配")
raise ValueError(
f"语句数量({len(statements)})与参数组数量({len(params)})不匹配"
)
for statement, stmt_params in zip(statements, params):
if statement:
await self.execute(statement, stmt_params)
@ -215,4 +229,3 @@ class DatabaseManager:
except:
pass
self._in_use.clear()

View File

@ -0,0 +1,70 @@
import nonebot
from nonebot.adapters import Event
from konabot.common.database import DatabaseManager
from konabot.common.path import DATA_PATH
from konabot.common.permsys.entity import PermEntity, get_entity_chain
from konabot.common.permsys.migrates import execute_migration
from konabot.common.permsys.repo import PermRepo
db = DatabaseManager(DATA_PATH / "perm.sqlite3")
class PermManager:
def __init__(self, db: DatabaseManager) -> None:
self.db = db
async def check_has_permission(
self, entities: Event | PermEntity | list[PermEntity], key: str
) -> bool:
if isinstance(entities, Event):
entities = await get_entity_chain(entities) # pragma: no cover
if isinstance(entities, PermEntity):
entities = [entities]
key = key.removesuffix("*").removesuffix(".")
key_split = key.split(".")
key_split = [s for s in key_split if len(s) > 0]
keys = [".".join(key_split[: i + 1]) for i in range(len(key_split))][::-1] + [
"*"
]
async with self.db.get_conn() as conn:
repo = PermRepo(conn)
# for entity in entities:
# for k in keys:
# perm = await repo.get_perm_info(entity, k)
# if perm is not None:
# return perm
data = await repo.get_perm_info_batch(entities, keys)
for entity in entities:
for k in keys:
p = data.get((entity, k))
if p is not None:
return p
return False
async def update_permission(self, entity: PermEntity, key: str, perm: bool | None):
async with self.db.get_conn() as conn:
repo = PermRepo(conn)
await repo.update_perm_info(entity, key, perm)
def perm_manager(_db: DatabaseManager | None = None) -> PermManager: # pragma: no cover
if _db is None:
_db = db
return PermManager(_db)
def create_startup(): # pragma: no cover
driver = nonebot.get_driver()
@driver.on_startup
async def _():
async with db.get_conn() as conn:
await execute_migration(conn)
@driver.on_shutdown
async def _():
await db.close_all_connections()

View File

@ -0,0 +1,61 @@
from dataclasses import dataclass
from nonebot.internal.adapter import Event
from nonebot.adapters.onebot.v11 import Event as OB11Event
from nonebot.adapters.onebot.v11.event import GroupMessageEvent as OB11GroupEvent
from nonebot.adapters.onebot.v11.event import PrivateMessageEvent as OB11PrivateEvent
from nonebot.adapters.discord.event import Event as DiscordEvent
from nonebot.adapters.discord.event import GuildMessageCreateEvent as DiscordGMEvent
from nonebot.adapters.discord.event import DirectMessageCreateEvent as DiscordDMEvent
from nonebot.adapters.minecraft.event import MessageEvent as MinecraftMessageEvent
from nonebot.adapters.console.event import MessageEvent as ConsoleEvent
@dataclass(frozen=True)
class PermEntity:
platform: str
entity_type: str
external_id: str
async def get_entity_chain(event: Event) -> list[PermEntity]: # pragma: no cover
entities = [PermEntity("sys", "global", "global")]
if isinstance(event, OB11Event):
entities.append(PermEntity("ob11", "global", "global"))
if isinstance(event, OB11GroupEvent):
entities.append(PermEntity("ob11", "group", str(event.group_id)))
entities.append(PermEntity("ob11", "user", str(event.user_id)))
if isinstance(event, OB11PrivateEvent):
entities.append(PermEntity("ob11", "user", str(event.user_id)))
if isinstance(event, DiscordEvent):
entities.append(PermEntity("discord", "global", "global"))
if isinstance(event, DiscordGMEvent):
entities.append(PermEntity("discord", "guilt", str(event.guild_id)))
entities.append(PermEntity("discord", "channel", str(event.channel_id)))
entities.append(PermEntity("discord", "user", str(event.user_id)))
if isinstance(event, DiscordDMEvent):
entities.append(PermEntity("discord", "channel", str(event.channel_id)))
entities.append(PermEntity("discord", "user", str(event.user_id)))
if isinstance(event, MinecraftMessageEvent):
entities.append(PermEntity("minecraft", "global", "global"))
entities.append(PermEntity("minecraft", "server", event.server_name))
player_uuid = event.player.uuid
if player_uuid is not None:
entities.append(PermEntity("minecraft", "player", player_uuid.hex))
if isinstance(event, ConsoleEvent):
entities.append(PermEntity("console", "global", "global"))
entities.append(PermEntity("console", "channel", event.channel.id))
entities.append(PermEntity("console", "user", event.user.id))
return entities[::-1]

View File

@ -0,0 +1,81 @@
from dataclasses import dataclass
from pathlib import Path
import aiosqlite
from loguru import logger
from konabot.common.database import DatabaseManager
from konabot.common.path import DATA_PATH
PATH_THISFOLDER = Path(__file__).parent
SQL_CHECK_EXISTS = (PATH_THISFOLDER / "./check_migrate_version_exists.sql").read_text()
SQL_CREATE_TABLE = (PATH_THISFOLDER / "./create_migrate_version_table.sql").read_text()
SQL_GET_MIGRATE_VERSION = (PATH_THISFOLDER / "get_migrate_version.sql").read_text()
SQL_UPDATE_VERSION = (PATH_THISFOLDER / "./update_migrate_version.sql").read_text()
db = DatabaseManager(DATA_PATH / "perm.sqlite3")
@dataclass
class Migration:
upgrade: str | Path
downgrade: str | Path
def get_upgrade_script(self) -> str:
if isinstance(self.upgrade, Path):
return self.upgrade.read_text()
return self.upgrade
def get_downgrade_script(self) -> str:
if isinstance(self.downgrade, Path):
return self.downgrade.read_text()
return self.downgrade
migrations = [
Migration(
PATH_THISFOLDER / "./mu1_create_permsys_table.sql",
PATH_THISFOLDER / "./md1_remove_permsys_table.sql",
),
]
TARGET_VERSION = len(migrations)
async def get_current_version(conn: aiosqlite.Connection) -> int:
cursor = await conn.execute(SQL_CHECK_EXISTS)
count = await cursor.fetchone()
assert count is not None
if count[0] < 1:
logger.info("权限系统数据表不存在,现在创建表")
await conn.executescript(SQL_CREATE_TABLE)
await conn.commit()
return 0
cursor = await conn.execute(SQL_GET_MIGRATE_VERSION)
row = await cursor.fetchone()
if row is None:
return 0
return row[0]
async def execute_migration(
conn: aiosqlite.Connection,
version: int = TARGET_VERSION,
migrations: list[Migration] = migrations,
):
now_version = await get_current_version(conn)
while now_version < version:
migration = migrations[now_version]
await conn.executescript(migration.get_upgrade_script())
now_version += 1
await conn.execute(SQL_UPDATE_VERSION, (now_version,))
await conn.commit()
while now_version > version:
migration = migrations[now_version - 1]
await conn.executescript(migration.get_downgrade_script())
now_version -= 1
await conn.execute(SQL_UPDATE_VERSION, (now_version,))
await conn.commit()

View File

@ -0,0 +1,7 @@
SELECT
COUNT(*)
FROM
sqlite_master
WHERE
type = 'table'
AND name = 'migrate_version'

View File

@ -0,0 +1,3 @@
CREATE TABLE migrate_version(version INT PRIMARY KEY);
INSERT INTO migrate_version(version)
VALUES(0);

View File

@ -0,0 +1,4 @@
SELECT
version
FROM
migrate_version;

View File

@ -0,0 +1,2 @@
DROP TABLE IF EXISTS perm_entity;
DROP TABLE IF EXISTS perm_info;

View File

@ -0,0 +1,30 @@
CREATE TABLE perm_entity(
id INTEGER PRIMARY KEY AUTOINCREMENT,
platform TEXT NOT NULL,
entity_type TEXT NOT NULL,
external_id TEXT NOT NULL,
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP
);
CREATE UNIQUE INDEX idx_perm_entity_lookup
ON perm_entity(platform, entity_type, external_id);
CREATE TABLE perm_info(
entity_id INTEGER NOT NULL,
config_key TEXT NOT NULL,
value BOOLEAN,
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP,
-- 联合主键
PRIMARY KEY (entity_id, config_key)
);
CREATE TRIGGER perm_entity_update AFTER UPDATE
ON perm_entity BEGIN
UPDATE perm_entity SET updated_at=CURRENT_TIMESTAMP WHERE id=old.id;
END;
CREATE TRIGGER perm_info_update AFTER UPDATE
ON perm_info BEGIN
UPDATE perm_info SET updated_at=CURRENT_TIMESTAMP WHERE entity_id=old.entity_id;
END;

View File

@ -0,0 +1,2 @@
UPDATE migrate_version
SET version = ?;

View File

@ -0,0 +1,180 @@
from dataclasses import dataclass
from pathlib import Path
import aiosqlite
from .entity import PermEntity
def s(p: str):
"""读取 SQL 文件内容。
Args:
p: SQL 文件名(相对于当前文件所在目录的 sql/ 子目录)。
Returns:
SQL 文件的内容字符串。
"""
return (Path(__file__).parent / "./sql/" / p).read_text()
@dataclass
class PermRepo:
"""权限实体存储库,负责与数据库交互管理权限实体。
Attributes:
conn: aiosqlite 数据库连接对象。
"""
conn: aiosqlite.Connection
async def create_entity(self, entity: PermEntity) -> int:
"""创建新的权限实体并返回其 ID。
Args:
entity: 要创建的权限实体对象。
Returns:
新创建实体的数据库 ID。
Raises:
AssertionError: 如果创建后无法获取实体 ID。
"""
await self.conn.execute(
s("create_entity.sql"),
(entity.platform, entity.entity_type, entity.external_id),
)
await self.conn.commit()
eid = await self._get_entity_id_or_none(entity)
assert eid is not None
return eid
async def _get_entity_id_or_none(self, entity: PermEntity) -> int | None:
"""查询实体 ID如果不存在则返回 None。
Args:
entity: 要查询的权限实体对象。
Returns:
实体 ID如果不存在则返回 None。
"""
res = await self.conn.execute(
s("get_entity_id.sql"),
(entity.platform, entity.entity_type, entity.external_id),
)
row = await res.fetchone()
if row is None:
return None
return row[0]
async def get_entity_id(self, entity: PermEntity) -> int:
"""获取实体 ID如果不存在则自动创建。
Args:
entity: 权限实体对象。
Returns:
实体的数据库 ID。
"""
eid = await self._get_entity_id_or_none(entity)
if eid is None:
return await self.create_entity(entity)
return eid
async def get_perm_info(self, entity: PermEntity, config_key: str) -> bool | None:
"""获取实体的权限配置信息。
Args:
entity: 权限实体对象。
config_key: 配置项的键名。
Returns:
配置值True/False如果不存在则返回 None。
"""
eid = await self.get_entity_id(entity)
res = await self.conn.execute(
s("get_perm_info.sql"),
(eid, config_key),
)
row = await res.fetchone()
if row is None:
return None
return bool(row[0])
async def update_perm_info(
self, entity: PermEntity, config_key: str, value: bool | None
):
"""更新实体的权限配置信息。
Args:
entity: 权限实体对象。
config_key: 配置项的键名。
value: 要设置的配置值True/False/None
"""
eid = await self.get_entity_id(entity)
await self.conn.execute(s("update_perm_info.sql"), (eid, config_key, value))
await self.conn.commit()
async def get_entity_id_batch(
self, entities: list[PermEntity]
) -> dict[PermEntity, int]:
"""批量获取 Entity 的 eneity_id
Args:
entities: PermEntity 列表
Returns:
字典,键为 PermEntity值为对应的 ID
"""
for entity in entities:
await self.conn.execute(
s("create_entity.sql"),
(entity.platform, entity.entity_type, entity.external_id),
)
await self.conn.commit()
val_placeholders = ", ".join(["(?, ?, ?)"] * len(entities))
params = []
for e in entities:
params.extend([e.platform, e.entity_type, e.external_id])
cursor = await self.conn.execute(
f"""
SELECT id, platform, entity_type, external_id
FROM perm_entity
WHERE (platform, entity_type, external_id) IN (VALUES {val_placeholders});
""",
params,
)
rows = await cursor.fetchall()
return {PermEntity(row[1], row[2], row[3]): row[0] for row in rows}
async def get_perm_info_batch(
self, entities: list[PermEntity], config_keys: list[str]
) -> dict[tuple[PermEntity, str], bool]:
"""批量获取权限信息
Args:
entities: PermEntity 列表
config_keys: 查询的键列表
Returns:
字典,键是 PermEntity 和 config_key 的元组,值是布尔,过滤掉所有空值
"""
entity_ids = {
v: k for k, v in (await self.get_entity_id_batch(entities)).items()
}
placeholders1 = ", ".join("?" * len(entity_ids))
placeholders2 = ", ".join("?" * len(config_keys))
sql = f"""
SELECT entity_id, config_key, value
FROM perm_info
WHERE entity_id IN ({placeholders1})
AND config_key IN ({placeholders2})
AND value IS NOT NULL;
"""
params = tuple(entity_ids.keys()) + tuple(config_keys)
cursor = await self.conn.execute(sql, params)
rows = await cursor.fetchall()
return {(entity_ids[row[0]], row[1]): bool(row[2]) for row in rows}

View File

@ -0,0 +1,11 @@
INSERT
OR IGNORE INTO perm_entity(
platform,
entity_type,
external_id
)
VALUES(
?,
?,
?
);

View File

@ -0,0 +1,8 @@
SELECT
id
FROM
perm_entity
WHERE
perm_entity.platform = ?
AND perm_entity.entity_type = ?
AND perm_entity.external_id = ?;

View File

@ -0,0 +1,7 @@
SELECT
VALUE
FROM
perm_info
WHERE
entity_id = ?
AND config_key = ?;

View File

@ -0,0 +1,4 @@
INSERT INTO perm_info (entity_id, config_key, value)
VALUES (?, ?, ?)
ON CONFLICT(entity_id, config_key)
DO UPDATE SET value=excluded.value;

2723
poetry.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -34,6 +34,9 @@ dependencies = [
"shapely (>=2.1.2,<3.0.0)",
"mcstatus (>=12.2.1,<13.0.0)",
"borax (>=4.1.3,<5.0.0)",
"pytest (>=8.0.0,<9.0.0)",
"nonebug (>=0.4.3,<0.5.0)",
"pytest-cov (>=7.0.0,<8.0.0)",
]
[tool.poetry]
@ -52,8 +55,15 @@ priority = "primary"
[dependency-groups]
dev = [
"rust-just (>=1.43.0,<2.0.0)",
"pytest (>=9.0.1,<10.0.0)",
"pytest-asyncio (>=1.3.0,<2.0.0)"
]
dev = ["rust-just (>=1.43.0,<2.0.0)", "pytest-asyncio (>=1.3.0,<2.0.0)"]
[tool.pytest.ini_options]
testpaths = "tests"
python_files = "test_*.py"
asyncio_mode = "auto"
asyncio_default_fixture_loop_scope = "session"
addopts = "--cov=./konabot/"
[tool.nonebot]
# plugin_dirs = ["konabot/plugins/"]
plugin_dirs = []

28
tests/conftest.py Normal file
View File

@ -0,0 +1,28 @@
# 文件内容来源:
# https://nonebot.dev/docs/best-practice/testing/
# 保证 nonebug 测试框架正常运作
import pytest
import nonebot
from pytest_asyncio import is_async_test
from nonebot.adapters.console import Adapter as ConsoleAdapter
from nonebug import NONEBOT_START_LIFESPAN
def pytest_collection_modifyitems(items: list[pytest.Item]):
pytest_asyncio_tests = (item for item in items if is_async_test(item))
session_scope_marker = pytest.mark.asyncio(loop_scope="session")
for async_test in pytest_asyncio_tests:
async_test.add_marker(session_scope_marker, append=False)
@pytest.fixture(scope="session", autouse=True)
async def after_nonebot_init(after_nonebot_init: None):
driver = nonebot.get_driver()
driver.register_adapter(ConsoleAdapter)
nonebot.load_from_toml("pyproject.toml")
def pytest_configure(config: pytest.Config):
config.stash[NONEBOT_START_LIFESPAN] = True

View File

@ -1,4 +1,3 @@
import asyncio
import os
import tempfile
from pathlib import Path
@ -12,13 +11,13 @@ from konabot.common.database import DatabaseManager
async def test_database_manager():
"""测试数据库管理器的基本功能"""
# 创建临时数据库文件
with tempfile.NamedTemporaryFile(suffix='.db', delete=False) as tmp_file:
with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as tmp_file:
db_path = tmp_file.name
try:
# 初始化数据库管理器
db_manager = DatabaseManager(db_path)
# 创建测试表
create_table_sql = """
CREATE TABLE IF NOT EXISTS test_users (
@ -28,26 +27,27 @@ async def test_database_manager():
);
"""
await db_manager.execute(create_table_sql)
# 插入测试数据
insert_sql = "INSERT INTO test_users (name, email) VALUES (?, ?)"
await db_manager.execute(insert_sql, ("张三", "zhangsan@example.com"))
await db_manager.execute(insert_sql, ("李四", "lisi@example.com"))
# 查询数据
select_sql = "SELECT * FROM test_users WHERE name = ?"
results = await db_manager.query(select_sql, ("张三",))
assert len(results) == 1
assert results[0]["name"] == "张三"
assert results[0]["email"] == "zhangsan@example.com"
# 测试使用Path对象
results = await db_manager.query_by_sql_file(Path(__file__), ("李四",))
# results = await db_manager.query_by_sql_file(Path(__file__), ("李四",))
# 注意这里只是测试参数传递实际SQL文件内容不是有效的SQL
## ^^^ 卧了个槽的坏枪,你让 AI 写单元测试不检查一下吗
# 关闭所有连接
await db_manager.close_all_connections()
finally:
# 清理临时文件
if os.path.exists(db_path):
@ -58,13 +58,13 @@ async def test_database_manager():
async def test_execute_script():
"""测试执行SQL脚本功能"""
# 创建临时数据库文件
with tempfile.NamedTemporaryFile(suffix='.db', delete=False) as tmp_file:
with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as tmp_file:
db_path = tmp_file.name
try:
# 初始化数据库管理器
db_manager = DatabaseManager(db_path)
# 创建测试表的脚本
script = """
CREATE TABLE IF NOT EXISTS test_products (
@ -75,19 +75,19 @@ async def test_execute_script():
INSERT INTO test_products (name, price) VALUES ('苹果', 5.0);
INSERT INTO test_products (name, price) VALUES ('香蕉', 3.0);
"""
await db_manager.execute_script(script)
# 查询数据
results = await db_manager.query("SELECT * FROM test_products ORDER BY name")
assert len(results) == 2
assert results[0]["name"] == "苹果"
assert results[1]["name"] == "香蕉"
# 关闭所有连接
await db_manager.close_all_connections()
finally:
# 清理临时文件
if os.path.exists(db_path):
os.unlink(db_path)
os.unlink(db_path)

105
tests/test_permsys.py Normal file
View File

@ -0,0 +1,105 @@
from contextlib import asynccontextmanager
from pathlib import Path
from tempfile import TemporaryDirectory
import pytest
from konabot.common.database import DatabaseManager
from konabot.common.permsys import PermManager
from konabot.common.permsys.entity import PermEntity
from konabot.common.permsys.migrates import execute_migration, get_current_version
from konabot.common.permsys.repo import PermRepo
@asynccontextmanager
async def tempdb():
with TemporaryDirectory() as _tempdir:
tempdir = Path(_tempdir)
db = DatabaseManager(tempdir / "perm.sqlite3")
yield db
await db.close_all_connections()
@pytest.mark.asyncio
async def test_get_db_version():
async with tempdb() as db:
async with db.get_conn() as conn:
v = await get_current_version(conn)
assert v == 0
v = await get_current_version(conn)
assert v == 0
await execute_migration(conn, version=1)
v = await get_current_version(conn)
assert v == 1
await execute_migration(conn, version=0)
v = await get_current_version(conn)
assert v == 0
@pytest.mark.asyncio
async def test_perm():
async with tempdb() as db:
async with db.get_conn() as conn:
await execute_migration(conn)
service = PermManager(db)
entity_global = PermEntity("sys", "global", "global")
entity1 = PermEntity("nonexist-platform", "user", "passthem")
chain1 = [entity1, entity_global]
entity2 = PermEntity("nonexist-platform", "user", "jack")
chain2 = [entity2, entity_global]
async with db.get_conn() as conn:
repo = PermRepo(conn)
# 测试使用内置方法会创建 Entity 在数据库
assert await repo._get_entity_id_or_none(entity1) is None
assert await repo.get_entity_id(entity1) is not None
assert await repo._get_entity_id_or_none(entity1) is not None
# 测试使用内置方法获得 perm_info
assert await repo.get_perm_info(entity1, "module1") is None
assert not await service.check_has_permission(chain1, "*")
await service.update_permission(entity1, "*", True)
assert await service.check_has_permission(chain1, "*")
assert await service.check_has_permission(chain1, "module1")
assert await service.check_has_permission(chain1, "module1.pack1")
assert not await service.check_has_permission(chain2, "*")
assert not await service.check_has_permission(chain2, "module1")
assert not await service.check_has_permission(chain2, "module1.pack1")
await service.update_permission(entity2, "module1", True)
assert not await service.check_has_permission(chain2, "*")
assert await service.check_has_permission(chain2, "module1")
assert await service.check_has_permission(chain2, "module1.pack1")
assert await service.check_has_permission(chain2, "module1.pack2")
assert not await service.check_has_permission(chain2, "module2")
assert not await service.check_has_permission(chain2, "module2.pack1")
assert not await service.check_has_permission(chain2, "module2.pack2")
await service.update_permission(entity2, "module1.pack2", False)
assert not await service.check_has_permission(chain2, "*")
assert await service.check_has_permission(chain2, "module1")
assert await service.check_has_permission(chain2, "module1.pack1")
assert not await service.check_has_permission(chain2, "module1.pack2")
assert not await service.check_has_permission(chain2, "module2")
assert not await service.check_has_permission(chain2, "module2.pack1")
assert not await service.check_has_permission(chain2, "module2.pack2")
await service.update_permission(entity_global, "module2", True)
assert not await service.check_has_permission(chain2, "*")
assert await service.check_has_permission(chain2, "module1")
assert await service.check_has_permission(chain2, "module1.pack1")
assert not await service.check_has_permission(chain2, "module1.pack2")
assert await service.check_has_permission(chain2, "module2")
assert await service.check_has_permission(chain2, "module2.pack1")
assert await service.check_has_permission(chain2, "module2.pack2")
assert not await service.check_has_permission(entity2, "module2.pack2")
assert await service.check_has_permission(entity_global, "module2.pack2")
async with db.get_conn() as conn:
repo = PermRepo(conn)
assert await repo.get_perm_info(entity2, "module1") is True
assert await repo.get_perm_info(entity2, "module1.pack2") is False