Compare commits
7 Commits
a1c9f9bccb
...
e605527900
| Author | SHA1 | Date | |
|---|---|---|---|
|
e605527900
|
|||
|
9064b31fe9
|
|||
|
27e53c7acd
|
|||
|
ca1db103b5
|
|||
|
7f1035ff43
|
|||
|
5e0e39bfc3
|
|||
|
88861f4264
|
@ -13,7 +13,7 @@ steps:
|
||||
- name: submodules
|
||||
image: alpine/git
|
||||
commands:
|
||||
- git submodule update --init --recursive
|
||||
- git submodule update --init --recursive
|
||||
- name: 构建 Docker 镜像
|
||||
image: plugins/docker:latest
|
||||
privileged: true
|
||||
@ -76,7 +76,7 @@ steps:
|
||||
- name: submodules
|
||||
image: alpine/git
|
||||
commands:
|
||||
- git submodule update --init --recursive
|
||||
- git submodule update --init --recursive
|
||||
- name: 构建并推送 Release Docker 镜像
|
||||
image: plugins/docker:latest
|
||||
privileged: true
|
||||
|
||||
5
.gitignore
vendored
5
.gitignore
vendored
@ -9,3 +9,8 @@ __pycache__
|
||||
|
||||
# 可能会偶然生成的 diff 文件
|
||||
/*.diff
|
||||
|
||||
# 代码覆盖报告
|
||||
/.coverage
|
||||
/.coverage.db
|
||||
/htmlcov
|
||||
|
||||
6
.sqls.yml
Normal file
6
.sqls.yml
Normal file
@ -0,0 +1,6 @@
|
||||
lowercaseKeywords: false
|
||||
connections:
|
||||
- driver: sqlite
|
||||
dataSourceName: "./data/database.db"
|
||||
- driver: sqlite
|
||||
dataSourceName: "./data/perm.sqlite3"
|
||||
19
README.md
19
README.md
@ -96,6 +96,21 @@ poetry run python bot.py
|
||||
- [事件处理](https://nonebot.dev/docs/tutorial/handler)
|
||||
- [Alconna 插件](https://nonebot.dev/docs/best-practice/alconna/)
|
||||
|
||||
## 数据库模块
|
||||
## 代码测试
|
||||
|
||||
本项目的数据库模块已更新为异步实现,使用连接池来提高性能,并支持现代的`pathlib.Path`参数类型。详细使用方法请参考[数据库使用文档](/docs/database.md)。
|
||||
本项目使用 pytest 进行自动化测试,你可以把你的测试代码放在 `./tests` 目录下。
|
||||
|
||||
使用命令行执行测试:
|
||||
|
||||
```bash
|
||||
poetry run just test
|
||||
```
|
||||
|
||||
使用命令行,在浏览器查看测试覆盖率报告:
|
||||
|
||||
```bash
|
||||
poetry run just coverage
|
||||
# 此时会打开一个 :8000 端口的 Web 服务器
|
||||
# 你可以在 http://localhost:8000 查看覆盖率报告
|
||||
# 在控制台使用 Ctrl+C 关闭这个 Web 服务器
|
||||
```
|
||||
|
||||
24
bot.py
24
bot.py
@ -7,6 +7,7 @@ from nonebot.adapters.discord import Adapter as DiscordAdapter
|
||||
from nonebot.adapters.minecraft import Adapter as MinecraftAdapter
|
||||
from nonebot.adapters.onebot.v11 import Adapter as OnebotAdapter
|
||||
|
||||
from konabot.common import permsys
|
||||
from konabot.common.log import init_logger
|
||||
from konabot.common.nb.exc import BotExceptionMessage
|
||||
from konabot.common.path import LOG_PATH
|
||||
@ -22,19 +23,25 @@ env_enable_minecraft = os.environ.get("ENABLE_MINECRAFT", "none")
|
||||
|
||||
|
||||
def main():
|
||||
if env.upper() == 'DEBUG' or env.upper() == 'DEV':
|
||||
console_log_level = 'DEBUG'
|
||||
if env.upper() == "DEBUG" or env.upper() == "DEV":
|
||||
console_log_level = "DEBUG"
|
||||
else:
|
||||
console_log_level = 'INFO'
|
||||
init_logger(LOG_PATH, [
|
||||
BotExceptionMessage,
|
||||
], console_log_level=console_log_level)
|
||||
console_log_level = "INFO"
|
||||
init_logger(
|
||||
LOG_PATH,
|
||||
[
|
||||
BotExceptionMessage,
|
||||
],
|
||||
console_log_level=console_log_level,
|
||||
)
|
||||
|
||||
nonebot.init()
|
||||
|
||||
driver = nonebot.get_driver()
|
||||
|
||||
if (env != "prod" and env != "test" and env_enable_console.upper() != "FALSE") or (env_enable_console.upper() == "TRUE"):
|
||||
if (env != "prod" and env != "test" and env_enable_console.upper() != "FALSE") or (
|
||||
env_enable_console.upper() == "TRUE"
|
||||
):
|
||||
driver.register_adapter(ConsoleAdapter)
|
||||
|
||||
if env_enable_qq.upper() == "TRUE":
|
||||
@ -50,6 +57,8 @@ def main():
|
||||
nonebot.load_plugins("konabot/plugins")
|
||||
nonebot.load_plugin("nonebot_plugin_analysis_bilibili")
|
||||
|
||||
permsys.create_startup()
|
||||
|
||||
# 注册关闭钩子
|
||||
@driver.on_shutdown
|
||||
async def shutdown_handler():
|
||||
@ -59,5 +68,6 @@ def main():
|
||||
|
||||
nonebot.run()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
5
justfile
5
justfile
@ -1,4 +1,9 @@
|
||||
watch:
|
||||
poetry run watchfiles bot.main . --filter scripts.watch_filter.filter
|
||||
|
||||
test:
|
||||
poetry run pytest --cov-report term-missing:skip-covered
|
||||
|
||||
coverage:
|
||||
poetry run pytest --cov-report html
|
||||
python -m http.server -d htmlcov
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
from contextlib import asynccontextmanager
|
||||
import os
|
||||
import asyncio
|
||||
import sqlparse
|
||||
@ -10,16 +11,19 @@ if TYPE_CHECKING:
|
||||
from . import DatabaseManager
|
||||
|
||||
# 全局数据库管理器实例
|
||||
_global_db_manager: Optional['DatabaseManager'] = None
|
||||
_global_db_manager: Optional["DatabaseManager"] = None
|
||||
|
||||
def get_global_db_manager() -> 'DatabaseManager':
|
||||
|
||||
def get_global_db_manager() -> "DatabaseManager":
|
||||
"""获取全局数据库管理器实例"""
|
||||
global _global_db_manager
|
||||
if _global_db_manager is None:
|
||||
from . import DatabaseManager
|
||||
|
||||
_global_db_manager = DatabaseManager()
|
||||
return _global_db_manager
|
||||
|
||||
|
||||
def close_global_db_manager() -> None:
|
||||
"""关闭全局数据库管理器实例"""
|
||||
global _global_db_manager
|
||||
@ -87,6 +91,12 @@ class DatabaseManager:
|
||||
except:
|
||||
pass
|
||||
|
||||
@asynccontextmanager
|
||||
async def get_conn(self):
|
||||
conn = await self._get_connection()
|
||||
yield conn
|
||||
await self._return_connection(conn)
|
||||
|
||||
async def query(
|
||||
self, query: str, params: Optional[tuple] = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
@ -143,22 +153,24 @@ class DatabaseManager:
|
||||
# 使用sqlparse库更准确地分割SQL语句
|
||||
parsed = sqlparse.split(script)
|
||||
statements = []
|
||||
|
||||
|
||||
for statement in parsed:
|
||||
statement = statement.strip()
|
||||
if statement:
|
||||
statements.append(statement)
|
||||
|
||||
|
||||
return statements
|
||||
|
||||
async def execute_by_sql_file(
|
||||
self, file_path: Union[str, Path], params: Optional[Union[tuple, List[tuple]]] = None
|
||||
self,
|
||||
file_path: Union[str, Path],
|
||||
params: Optional[Union[tuple, List[tuple]]] = None,
|
||||
) -> None:
|
||||
"""从 SQL 文件中读取非查询语句并执行"""
|
||||
path = str(file_path) if isinstance(file_path, Path) else file_path
|
||||
with open(path, "r", encoding="utf-8") as f:
|
||||
script = f.read()
|
||||
|
||||
|
||||
# 如果有参数且是元组,使用execute执行整个脚本
|
||||
if params is not None and isinstance(params, tuple):
|
||||
await self.execute(script, params)
|
||||
@ -167,8 +179,10 @@ class DatabaseManager:
|
||||
# 使用sqlparse准确分割SQL语句
|
||||
statements = self._parse_sql_statements(script)
|
||||
if len(statements) != len(params):
|
||||
raise ValueError(f"语句数量({len(statements)})与参数组数量({len(params)})不匹配")
|
||||
|
||||
raise ValueError(
|
||||
f"语句数量({len(statements)})与参数组数量({len(params)})不匹配"
|
||||
)
|
||||
|
||||
for statement, stmt_params in zip(statements, params):
|
||||
if statement:
|
||||
await self.execute(statement, stmt_params)
|
||||
@ -215,4 +229,3 @@ class DatabaseManager:
|
||||
except:
|
||||
pass
|
||||
self._in_use.clear()
|
||||
|
||||
|
||||
70
konabot/common/permsys/__init__.py
Normal file
70
konabot/common/permsys/__init__.py
Normal file
@ -0,0 +1,70 @@
|
||||
import nonebot
|
||||
from nonebot.adapters import Event
|
||||
|
||||
from konabot.common.database import DatabaseManager
|
||||
from konabot.common.path import DATA_PATH
|
||||
from konabot.common.permsys.entity import PermEntity, get_entity_chain
|
||||
from konabot.common.permsys.migrates import execute_migration
|
||||
from konabot.common.permsys.repo import PermRepo
|
||||
|
||||
|
||||
db = DatabaseManager(DATA_PATH / "perm.sqlite3")
|
||||
|
||||
|
||||
class PermManager:
|
||||
def __init__(self, db: DatabaseManager) -> None:
|
||||
self.db = db
|
||||
|
||||
async def check_has_permission(
|
||||
self, entities: Event | PermEntity | list[PermEntity], key: str
|
||||
) -> bool:
|
||||
if isinstance(entities, Event):
|
||||
entities = await get_entity_chain(entities) # pragma: no cover
|
||||
if isinstance(entities, PermEntity):
|
||||
entities = [entities]
|
||||
|
||||
key = key.removesuffix("*").removesuffix(".")
|
||||
key_split = key.split(".")
|
||||
key_split = [s for s in key_split if len(s) > 0]
|
||||
keys = [".".join(key_split[: i + 1]) for i in range(len(key_split))][::-1] + [
|
||||
"*"
|
||||
]
|
||||
|
||||
async with self.db.get_conn() as conn:
|
||||
repo = PermRepo(conn)
|
||||
# for entity in entities:
|
||||
# for k in keys:
|
||||
# perm = await repo.get_perm_info(entity, k)
|
||||
# if perm is not None:
|
||||
# return perm
|
||||
data = await repo.get_perm_info_batch(entities, keys)
|
||||
for entity in entities:
|
||||
for k in keys:
|
||||
p = data.get((entity, k))
|
||||
if p is not None:
|
||||
return p
|
||||
return False
|
||||
|
||||
async def update_permission(self, entity: PermEntity, key: str, perm: bool | None):
|
||||
async with self.db.get_conn() as conn:
|
||||
repo = PermRepo(conn)
|
||||
await repo.update_perm_info(entity, key, perm)
|
||||
|
||||
|
||||
def perm_manager(_db: DatabaseManager | None = None) -> PermManager: # pragma: no cover
|
||||
if _db is None:
|
||||
_db = db
|
||||
return PermManager(_db)
|
||||
|
||||
|
||||
def create_startup(): # pragma: no cover
|
||||
driver = nonebot.get_driver()
|
||||
|
||||
@driver.on_startup
|
||||
async def _():
|
||||
async with db.get_conn() as conn:
|
||||
await execute_migration(conn)
|
||||
|
||||
@driver.on_shutdown
|
||||
async def _():
|
||||
await db.close_all_connections()
|
||||
61
konabot/common/permsys/entity.py
Normal file
61
konabot/common/permsys/entity.py
Normal file
@ -0,0 +1,61 @@
|
||||
from dataclasses import dataclass
|
||||
from nonebot.internal.adapter import Event
|
||||
|
||||
from nonebot.adapters.onebot.v11 import Event as OB11Event
|
||||
from nonebot.adapters.onebot.v11.event import GroupMessageEvent as OB11GroupEvent
|
||||
from nonebot.adapters.onebot.v11.event import PrivateMessageEvent as OB11PrivateEvent
|
||||
|
||||
from nonebot.adapters.discord.event import Event as DiscordEvent
|
||||
from nonebot.adapters.discord.event import GuildMessageCreateEvent as DiscordGMEvent
|
||||
from nonebot.adapters.discord.event import DirectMessageCreateEvent as DiscordDMEvent
|
||||
|
||||
from nonebot.adapters.minecraft.event import MessageEvent as MinecraftMessageEvent
|
||||
|
||||
from nonebot.adapters.console.event import MessageEvent as ConsoleEvent
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class PermEntity:
|
||||
platform: str
|
||||
entity_type: str
|
||||
external_id: str
|
||||
|
||||
|
||||
async def get_entity_chain(event: Event) -> list[PermEntity]: # pragma: no cover
|
||||
entities = [PermEntity("sys", "global", "global")]
|
||||
|
||||
if isinstance(event, OB11Event):
|
||||
entities.append(PermEntity("ob11", "global", "global"))
|
||||
|
||||
if isinstance(event, OB11GroupEvent):
|
||||
entities.append(PermEntity("ob11", "group", str(event.group_id)))
|
||||
entities.append(PermEntity("ob11", "user", str(event.user_id)))
|
||||
|
||||
if isinstance(event, OB11PrivateEvent):
|
||||
entities.append(PermEntity("ob11", "user", str(event.user_id)))
|
||||
|
||||
if isinstance(event, DiscordEvent):
|
||||
entities.append(PermEntity("discord", "global", "global"))
|
||||
|
||||
if isinstance(event, DiscordGMEvent):
|
||||
entities.append(PermEntity("discord", "guilt", str(event.guild_id)))
|
||||
entities.append(PermEntity("discord", "channel", str(event.channel_id)))
|
||||
entities.append(PermEntity("discord", "user", str(event.user_id)))
|
||||
|
||||
if isinstance(event, DiscordDMEvent):
|
||||
entities.append(PermEntity("discord", "channel", str(event.channel_id)))
|
||||
entities.append(PermEntity("discord", "user", str(event.user_id)))
|
||||
|
||||
if isinstance(event, MinecraftMessageEvent):
|
||||
entities.append(PermEntity("minecraft", "global", "global"))
|
||||
entities.append(PermEntity("minecraft", "server", event.server_name))
|
||||
player_uuid = event.player.uuid
|
||||
if player_uuid is not None:
|
||||
entities.append(PermEntity("minecraft", "player", player_uuid.hex))
|
||||
|
||||
if isinstance(event, ConsoleEvent):
|
||||
entities.append(PermEntity("console", "global", "global"))
|
||||
entities.append(PermEntity("console", "channel", event.channel.id))
|
||||
entities.append(PermEntity("console", "user", event.user.id))
|
||||
|
||||
return entities[::-1]
|
||||
81
konabot/common/permsys/migrates/__init__.py
Normal file
81
konabot/common/permsys/migrates/__init__.py
Normal file
@ -0,0 +1,81 @@
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
import aiosqlite
|
||||
from loguru import logger
|
||||
|
||||
from konabot.common.database import DatabaseManager
|
||||
from konabot.common.path import DATA_PATH
|
||||
|
||||
|
||||
PATH_THISFOLDER = Path(__file__).parent
|
||||
|
||||
SQL_CHECK_EXISTS = (PATH_THISFOLDER / "./check_migrate_version_exists.sql").read_text()
|
||||
SQL_CREATE_TABLE = (PATH_THISFOLDER / "./create_migrate_version_table.sql").read_text()
|
||||
SQL_GET_MIGRATE_VERSION = (PATH_THISFOLDER / "get_migrate_version.sql").read_text()
|
||||
SQL_UPDATE_VERSION = (PATH_THISFOLDER / "./update_migrate_version.sql").read_text()
|
||||
|
||||
db = DatabaseManager(DATA_PATH / "perm.sqlite3")
|
||||
|
||||
|
||||
@dataclass
|
||||
class Migration:
|
||||
upgrade: str | Path
|
||||
downgrade: str | Path
|
||||
|
||||
def get_upgrade_script(self) -> str:
|
||||
if isinstance(self.upgrade, Path):
|
||||
return self.upgrade.read_text()
|
||||
return self.upgrade
|
||||
|
||||
def get_downgrade_script(self) -> str:
|
||||
if isinstance(self.downgrade, Path):
|
||||
return self.downgrade.read_text()
|
||||
return self.downgrade
|
||||
|
||||
|
||||
migrations = [
|
||||
Migration(
|
||||
PATH_THISFOLDER / "./mu1_create_permsys_table.sql",
|
||||
PATH_THISFOLDER / "./md1_remove_permsys_table.sql",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
TARGET_VERSION = len(migrations)
|
||||
|
||||
|
||||
async def get_current_version(conn: aiosqlite.Connection) -> int:
|
||||
cursor = await conn.execute(SQL_CHECK_EXISTS)
|
||||
count = await cursor.fetchone()
|
||||
assert count is not None
|
||||
if count[0] < 1:
|
||||
logger.info("权限系统数据表不存在,现在创建表")
|
||||
await conn.executescript(SQL_CREATE_TABLE)
|
||||
await conn.commit()
|
||||
return 0
|
||||
cursor = await conn.execute(SQL_GET_MIGRATE_VERSION)
|
||||
row = await cursor.fetchone()
|
||||
if row is None:
|
||||
return 0
|
||||
return row[0]
|
||||
|
||||
|
||||
async def execute_migration(
|
||||
conn: aiosqlite.Connection,
|
||||
version: int = TARGET_VERSION,
|
||||
migrations: list[Migration] = migrations,
|
||||
):
|
||||
now_version = await get_current_version(conn)
|
||||
while now_version < version:
|
||||
migration = migrations[now_version]
|
||||
await conn.executescript(migration.get_upgrade_script())
|
||||
now_version += 1
|
||||
await conn.execute(SQL_UPDATE_VERSION, (now_version,))
|
||||
await conn.commit()
|
||||
while now_version > version:
|
||||
migration = migrations[now_version - 1]
|
||||
await conn.executescript(migration.get_downgrade_script())
|
||||
now_version -= 1
|
||||
await conn.execute(SQL_UPDATE_VERSION, (now_version,))
|
||||
await conn.commit()
|
||||
@ -0,0 +1,7 @@
|
||||
SELECT
|
||||
COUNT(*)
|
||||
FROM
|
||||
sqlite_master
|
||||
WHERE
|
||||
type = 'table'
|
||||
AND name = 'migrate_version'
|
||||
@ -0,0 +1,3 @@
|
||||
CREATE TABLE migrate_version(version INT PRIMARY KEY);
|
||||
INSERT INTO migrate_version(version)
|
||||
VALUES(0);
|
||||
4
konabot/common/permsys/migrates/get_migrate_version.sql
Normal file
4
konabot/common/permsys/migrates/get_migrate_version.sql
Normal file
@ -0,0 +1,4 @@
|
||||
SELECT
|
||||
version
|
||||
FROM
|
||||
migrate_version;
|
||||
@ -0,0 +1,2 @@
|
||||
DROP TABLE IF EXISTS perm_entity;
|
||||
DROP TABLE IF EXISTS perm_info;
|
||||
30
konabot/common/permsys/migrates/mu1_create_permsys_table.sql
Normal file
30
konabot/common/permsys/migrates/mu1_create_permsys_table.sql
Normal file
@ -0,0 +1,30 @@
|
||||
CREATE TABLE perm_entity(
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
platform TEXT NOT NULL,
|
||||
entity_type TEXT NOT NULL,
|
||||
external_id TEXT NOT NULL,
|
||||
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP
|
||||
);
|
||||
|
||||
CREATE UNIQUE INDEX idx_perm_entity_lookup
|
||||
ON perm_entity(platform, entity_type, external_id);
|
||||
|
||||
CREATE TABLE perm_info(
|
||||
entity_id INTEGER NOT NULL,
|
||||
config_key TEXT NOT NULL,
|
||||
value BOOLEAN,
|
||||
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||
-- 联合主键
|
||||
PRIMARY KEY (entity_id, config_key)
|
||||
);
|
||||
|
||||
CREATE TRIGGER perm_entity_update AFTER UPDATE
|
||||
ON perm_entity BEGIN
|
||||
UPDATE perm_entity SET updated_at=CURRENT_TIMESTAMP WHERE id=old.id;
|
||||
END;
|
||||
CREATE TRIGGER perm_info_update AFTER UPDATE
|
||||
ON perm_info BEGIN
|
||||
UPDATE perm_info SET updated_at=CURRENT_TIMESTAMP WHERE entity_id=old.entity_id;
|
||||
END;
|
||||
|
||||
@ -0,0 +1,2 @@
|
||||
UPDATE migrate_version
|
||||
SET version = ?;
|
||||
180
konabot/common/permsys/repo.py
Normal file
180
konabot/common/permsys/repo.py
Normal file
@ -0,0 +1,180 @@
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
import aiosqlite
|
||||
|
||||
from .entity import PermEntity
|
||||
|
||||
|
||||
def s(p: str):
|
||||
"""读取 SQL 文件内容。
|
||||
|
||||
Args:
|
||||
p: SQL 文件名(相对于当前文件所在目录的 sql/ 子目录)。
|
||||
|
||||
Returns:
|
||||
SQL 文件的内容字符串。
|
||||
"""
|
||||
return (Path(__file__).parent / "./sql/" / p).read_text()
|
||||
|
||||
|
||||
@dataclass
|
||||
class PermRepo:
|
||||
"""权限实体存储库,负责与数据库交互管理权限实体。
|
||||
|
||||
Attributes:
|
||||
conn: aiosqlite 数据库连接对象。
|
||||
"""
|
||||
|
||||
conn: aiosqlite.Connection
|
||||
|
||||
async def create_entity(self, entity: PermEntity) -> int:
|
||||
"""创建新的权限实体并返回其 ID。
|
||||
|
||||
Args:
|
||||
entity: 要创建的权限实体对象。
|
||||
|
||||
Returns:
|
||||
新创建实体的数据库 ID。
|
||||
|
||||
Raises:
|
||||
AssertionError: 如果创建后无法获取实体 ID。
|
||||
"""
|
||||
await self.conn.execute(
|
||||
s("create_entity.sql"),
|
||||
(entity.platform, entity.entity_type, entity.external_id),
|
||||
)
|
||||
await self.conn.commit()
|
||||
eid = await self._get_entity_id_or_none(entity)
|
||||
assert eid is not None
|
||||
return eid
|
||||
|
||||
async def _get_entity_id_or_none(self, entity: PermEntity) -> int | None:
|
||||
"""查询实体 ID,如果不存在则返回 None。
|
||||
|
||||
Args:
|
||||
entity: 要查询的权限实体对象。
|
||||
|
||||
Returns:
|
||||
实体 ID,如果不存在则返回 None。
|
||||
"""
|
||||
res = await self.conn.execute(
|
||||
s("get_entity_id.sql"),
|
||||
(entity.platform, entity.entity_type, entity.external_id),
|
||||
)
|
||||
row = await res.fetchone()
|
||||
if row is None:
|
||||
return None
|
||||
return row[0]
|
||||
|
||||
async def get_entity_id(self, entity: PermEntity) -> int:
|
||||
"""获取实体 ID,如果不存在则自动创建。
|
||||
|
||||
Args:
|
||||
entity: 权限实体对象。
|
||||
|
||||
Returns:
|
||||
实体的数据库 ID。
|
||||
"""
|
||||
eid = await self._get_entity_id_or_none(entity)
|
||||
if eid is None:
|
||||
return await self.create_entity(entity)
|
||||
return eid
|
||||
|
||||
async def get_perm_info(self, entity: PermEntity, config_key: str) -> bool | None:
|
||||
"""获取实体的权限配置信息。
|
||||
|
||||
Args:
|
||||
entity: 权限实体对象。
|
||||
config_key: 配置项的键名。
|
||||
|
||||
Returns:
|
||||
配置值(True/False),如果不存在则返回 None。
|
||||
"""
|
||||
eid = await self.get_entity_id(entity)
|
||||
res = await self.conn.execute(
|
||||
s("get_perm_info.sql"),
|
||||
(eid, config_key),
|
||||
)
|
||||
row = await res.fetchone()
|
||||
if row is None:
|
||||
return None
|
||||
return bool(row[0])
|
||||
|
||||
async def update_perm_info(
|
||||
self, entity: PermEntity, config_key: str, value: bool | None
|
||||
):
|
||||
"""更新实体的权限配置信息。
|
||||
|
||||
Args:
|
||||
entity: 权限实体对象。
|
||||
config_key: 配置项的键名。
|
||||
value: 要设置的配置值(True/False/None)。
|
||||
"""
|
||||
eid = await self.get_entity_id(entity)
|
||||
await self.conn.execute(s("update_perm_info.sql"), (eid, config_key, value))
|
||||
await self.conn.commit()
|
||||
|
||||
async def get_entity_id_batch(
|
||||
self, entities: list[PermEntity]
|
||||
) -> dict[PermEntity, int]:
|
||||
"""批量获取 Entity 的 eneity_id
|
||||
|
||||
Args:
|
||||
entities: PermEntity 列表
|
||||
|
||||
Returns:
|
||||
字典,键为 PermEntity,值为对应的 ID
|
||||
"""
|
||||
|
||||
for entity in entities:
|
||||
await self.conn.execute(
|
||||
s("create_entity.sql"),
|
||||
(entity.platform, entity.entity_type, entity.external_id),
|
||||
)
|
||||
await self.conn.commit()
|
||||
val_placeholders = ", ".join(["(?, ?, ?)"] * len(entities))
|
||||
params = []
|
||||
for e in entities:
|
||||
params.extend([e.platform, e.entity_type, e.external_id])
|
||||
cursor = await self.conn.execute(
|
||||
f"""
|
||||
SELECT id, platform, entity_type, external_id
|
||||
FROM perm_entity
|
||||
WHERE (platform, entity_type, external_id) IN (VALUES {val_placeholders});
|
||||
""",
|
||||
params,
|
||||
)
|
||||
rows = await cursor.fetchall()
|
||||
return {PermEntity(row[1], row[2], row[3]): row[0] for row in rows}
|
||||
|
||||
async def get_perm_info_batch(
|
||||
self, entities: list[PermEntity], config_keys: list[str]
|
||||
) -> dict[tuple[PermEntity, str], bool]:
|
||||
"""批量获取权限信息
|
||||
|
||||
Args:
|
||||
entities: PermEntity 列表
|
||||
config_keys: 查询的键列表
|
||||
|
||||
Returns:
|
||||
字典,键是 PermEntity 和 config_key 的元组,值是布尔,过滤掉所有空值
|
||||
"""
|
||||
entity_ids = {
|
||||
v: k for k, v in (await self.get_entity_id_batch(entities)).items()
|
||||
}
|
||||
placeholders1 = ", ".join("?" * len(entity_ids))
|
||||
placeholders2 = ", ".join("?" * len(config_keys))
|
||||
sql = f"""
|
||||
SELECT entity_id, config_key, value
|
||||
FROM perm_info
|
||||
WHERE entity_id IN ({placeholders1})
|
||||
AND config_key IN ({placeholders2})
|
||||
AND value IS NOT NULL;
|
||||
"""
|
||||
|
||||
params = tuple(entity_ids.keys()) + tuple(config_keys)
|
||||
cursor = await self.conn.execute(sql, params)
|
||||
rows = await cursor.fetchall()
|
||||
|
||||
return {(entity_ids[row[0]], row[1]): bool(row[2]) for row in rows}
|
||||
11
konabot/common/permsys/sql/create_entity.sql
Normal file
11
konabot/common/permsys/sql/create_entity.sql
Normal file
@ -0,0 +1,11 @@
|
||||
INSERT
|
||||
OR IGNORE INTO perm_entity(
|
||||
platform,
|
||||
entity_type,
|
||||
external_id
|
||||
)
|
||||
VALUES(
|
||||
?,
|
||||
?,
|
||||
?
|
||||
);
|
||||
8
konabot/common/permsys/sql/get_entity_id.sql
Normal file
8
konabot/common/permsys/sql/get_entity_id.sql
Normal file
@ -0,0 +1,8 @@
|
||||
SELECT
|
||||
id
|
||||
FROM
|
||||
perm_entity
|
||||
WHERE
|
||||
perm_entity.platform = ?
|
||||
AND perm_entity.entity_type = ?
|
||||
AND perm_entity.external_id = ?;
|
||||
7
konabot/common/permsys/sql/get_perm_info.sql
Normal file
7
konabot/common/permsys/sql/get_perm_info.sql
Normal file
@ -0,0 +1,7 @@
|
||||
SELECT
|
||||
VALUE
|
||||
FROM
|
||||
perm_info
|
||||
WHERE
|
||||
entity_id = ?
|
||||
AND config_key = ?;
|
||||
4
konabot/common/permsys/sql/update_perm_info.sql
Normal file
4
konabot/common/permsys/sql/update_perm_info.sql
Normal file
@ -0,0 +1,4 @@
|
||||
INSERT INTO perm_info (entity_id, config_key, value)
|
||||
VALUES (?, ?, ?)
|
||||
ON CONFLICT(entity_id, config_key)
|
||||
DO UPDATE SET value=excluded.value;
|
||||
2723
poetry.lock
generated
2723
poetry.lock
generated
File diff suppressed because it is too large
Load Diff
@ -34,6 +34,9 @@ dependencies = [
|
||||
"shapely (>=2.1.2,<3.0.0)",
|
||||
"mcstatus (>=12.2.1,<13.0.0)",
|
||||
"borax (>=4.1.3,<5.0.0)",
|
||||
"pytest (>=8.0.0,<9.0.0)",
|
||||
"nonebug (>=0.4.3,<0.5.0)",
|
||||
"pytest-cov (>=7.0.0,<8.0.0)",
|
||||
]
|
||||
|
||||
[tool.poetry]
|
||||
@ -52,8 +55,15 @@ priority = "primary"
|
||||
|
||||
|
||||
[dependency-groups]
|
||||
dev = [
|
||||
"rust-just (>=1.43.0,<2.0.0)",
|
||||
"pytest (>=9.0.1,<10.0.0)",
|
||||
"pytest-asyncio (>=1.3.0,<2.0.0)"
|
||||
]
|
||||
dev = ["rust-just (>=1.43.0,<2.0.0)", "pytest-asyncio (>=1.3.0,<2.0.0)"]
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
testpaths = "tests"
|
||||
python_files = "test_*.py"
|
||||
asyncio_mode = "auto"
|
||||
asyncio_default_fixture_loop_scope = "session"
|
||||
addopts = "--cov=./konabot/"
|
||||
|
||||
[tool.nonebot]
|
||||
# plugin_dirs = ["konabot/plugins/"]
|
||||
plugin_dirs = []
|
||||
|
||||
28
tests/conftest.py
Normal file
28
tests/conftest.py
Normal file
@ -0,0 +1,28 @@
|
||||
# 文件内容来源:
|
||||
# https://nonebot.dev/docs/best-practice/testing/
|
||||
# 保证 nonebug 测试框架正常运作
|
||||
|
||||
import pytest
|
||||
import nonebot
|
||||
from pytest_asyncio import is_async_test
|
||||
from nonebot.adapters.console import Adapter as ConsoleAdapter
|
||||
from nonebug import NONEBOT_START_LIFESPAN
|
||||
|
||||
|
||||
def pytest_collection_modifyitems(items: list[pytest.Item]):
|
||||
pytest_asyncio_tests = (item for item in items if is_async_test(item))
|
||||
session_scope_marker = pytest.mark.asyncio(loop_scope="session")
|
||||
for async_test in pytest_asyncio_tests:
|
||||
async_test.add_marker(session_scope_marker, append=False)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
async def after_nonebot_init(after_nonebot_init: None):
|
||||
driver = nonebot.get_driver()
|
||||
driver.register_adapter(ConsoleAdapter)
|
||||
|
||||
nonebot.load_from_toml("pyproject.toml")
|
||||
|
||||
|
||||
def pytest_configure(config: pytest.Config):
|
||||
config.stash[NONEBOT_START_LIFESPAN] = True
|
||||
@ -1,4 +1,3 @@
|
||||
import asyncio
|
||||
import os
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
@ -12,13 +11,13 @@ from konabot.common.database import DatabaseManager
|
||||
async def test_database_manager():
|
||||
"""测试数据库管理器的基本功能"""
|
||||
# 创建临时数据库文件
|
||||
with tempfile.NamedTemporaryFile(suffix='.db', delete=False) as tmp_file:
|
||||
with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as tmp_file:
|
||||
db_path = tmp_file.name
|
||||
|
||||
|
||||
try:
|
||||
# 初始化数据库管理器
|
||||
db_manager = DatabaseManager(db_path)
|
||||
|
||||
|
||||
# 创建测试表
|
||||
create_table_sql = """
|
||||
CREATE TABLE IF NOT EXISTS test_users (
|
||||
@ -28,26 +27,27 @@ async def test_database_manager():
|
||||
);
|
||||
"""
|
||||
await db_manager.execute(create_table_sql)
|
||||
|
||||
|
||||
# 插入测试数据
|
||||
insert_sql = "INSERT INTO test_users (name, email) VALUES (?, ?)"
|
||||
await db_manager.execute(insert_sql, ("张三", "zhangsan@example.com"))
|
||||
await db_manager.execute(insert_sql, ("李四", "lisi@example.com"))
|
||||
|
||||
|
||||
# 查询数据
|
||||
select_sql = "SELECT * FROM test_users WHERE name = ?"
|
||||
results = await db_manager.query(select_sql, ("张三",))
|
||||
assert len(results) == 1
|
||||
assert results[0]["name"] == "张三"
|
||||
assert results[0]["email"] == "zhangsan@example.com"
|
||||
|
||||
|
||||
# 测试使用Path对象
|
||||
results = await db_manager.query_by_sql_file(Path(__file__), ("李四",))
|
||||
# results = await db_manager.query_by_sql_file(Path(__file__), ("李四",))
|
||||
# 注意:这里只是测试参数传递,实际SQL文件内容不是有效的SQL
|
||||
|
||||
## ^^^ 卧了个槽的坏枪,你让 AI 写单元测试不检查一下吗
|
||||
|
||||
# 关闭所有连接
|
||||
await db_manager.close_all_connections()
|
||||
|
||||
|
||||
finally:
|
||||
# 清理临时文件
|
||||
if os.path.exists(db_path):
|
||||
@ -58,13 +58,13 @@ async def test_database_manager():
|
||||
async def test_execute_script():
|
||||
"""测试执行SQL脚本功能"""
|
||||
# 创建临时数据库文件
|
||||
with tempfile.NamedTemporaryFile(suffix='.db', delete=False) as tmp_file:
|
||||
with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as tmp_file:
|
||||
db_path = tmp_file.name
|
||||
|
||||
|
||||
try:
|
||||
# 初始化数据库管理器
|
||||
db_manager = DatabaseManager(db_path)
|
||||
|
||||
|
||||
# 创建测试表的脚本
|
||||
script = """
|
||||
CREATE TABLE IF NOT EXISTS test_products (
|
||||
@ -75,19 +75,19 @@ async def test_execute_script():
|
||||
INSERT INTO test_products (name, price) VALUES ('苹果', 5.0);
|
||||
INSERT INTO test_products (name, price) VALUES ('香蕉', 3.0);
|
||||
"""
|
||||
|
||||
|
||||
await db_manager.execute_script(script)
|
||||
|
||||
|
||||
# 查询数据
|
||||
results = await db_manager.query("SELECT * FROM test_products ORDER BY name")
|
||||
assert len(results) == 2
|
||||
assert results[0]["name"] == "苹果"
|
||||
assert results[1]["name"] == "香蕉"
|
||||
|
||||
|
||||
# 关闭所有连接
|
||||
await db_manager.close_all_connections()
|
||||
|
||||
|
||||
finally:
|
||||
# 清理临时文件
|
||||
if os.path.exists(db_path):
|
||||
os.unlink(db_path)
|
||||
os.unlink(db_path)
|
||||
|
||||
105
tests/test_permsys.py
Normal file
105
tests/test_permsys.py
Normal file
@ -0,0 +1,105 @@
|
||||
from contextlib import asynccontextmanager
|
||||
from pathlib import Path
|
||||
from tempfile import TemporaryDirectory
|
||||
import pytest
|
||||
|
||||
from konabot.common.database import DatabaseManager
|
||||
from konabot.common.permsys import PermManager
|
||||
from konabot.common.permsys.entity import PermEntity
|
||||
from konabot.common.permsys.migrates import execute_migration, get_current_version
|
||||
from konabot.common.permsys.repo import PermRepo
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def tempdb():
|
||||
with TemporaryDirectory() as _tempdir:
|
||||
tempdir = Path(_tempdir)
|
||||
db = DatabaseManager(tempdir / "perm.sqlite3")
|
||||
yield db
|
||||
await db.close_all_connections()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_db_version():
|
||||
async with tempdb() as db:
|
||||
async with db.get_conn() as conn:
|
||||
v = await get_current_version(conn)
|
||||
assert v == 0
|
||||
v = await get_current_version(conn)
|
||||
assert v == 0
|
||||
await execute_migration(conn, version=1)
|
||||
v = await get_current_version(conn)
|
||||
assert v == 1
|
||||
await execute_migration(conn, version=0)
|
||||
v = await get_current_version(conn)
|
||||
assert v == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_perm():
|
||||
async with tempdb() as db:
|
||||
async with db.get_conn() as conn:
|
||||
await execute_migration(conn)
|
||||
|
||||
service = PermManager(db)
|
||||
entity_global = PermEntity("sys", "global", "global")
|
||||
entity1 = PermEntity("nonexist-platform", "user", "passthem")
|
||||
chain1 = [entity1, entity_global]
|
||||
entity2 = PermEntity("nonexist-platform", "user", "jack")
|
||||
chain2 = [entity2, entity_global]
|
||||
|
||||
async with db.get_conn() as conn:
|
||||
repo = PermRepo(conn)
|
||||
|
||||
# 测试使用内置方法会创建 Entity 在数据库
|
||||
assert await repo._get_entity_id_or_none(entity1) is None
|
||||
assert await repo.get_entity_id(entity1) is not None
|
||||
assert await repo._get_entity_id_or_none(entity1) is not None
|
||||
|
||||
# 测试使用内置方法获得 perm_info
|
||||
assert await repo.get_perm_info(entity1, "module1") is None
|
||||
|
||||
assert not await service.check_has_permission(chain1, "*")
|
||||
|
||||
await service.update_permission(entity1, "*", True)
|
||||
assert await service.check_has_permission(chain1, "*")
|
||||
assert await service.check_has_permission(chain1, "module1")
|
||||
assert await service.check_has_permission(chain1, "module1.pack1")
|
||||
assert not await service.check_has_permission(chain2, "*")
|
||||
assert not await service.check_has_permission(chain2, "module1")
|
||||
assert not await service.check_has_permission(chain2, "module1.pack1")
|
||||
|
||||
await service.update_permission(entity2, "module1", True)
|
||||
assert not await service.check_has_permission(chain2, "*")
|
||||
assert await service.check_has_permission(chain2, "module1")
|
||||
assert await service.check_has_permission(chain2, "module1.pack1")
|
||||
assert await service.check_has_permission(chain2, "module1.pack2")
|
||||
assert not await service.check_has_permission(chain2, "module2")
|
||||
assert not await service.check_has_permission(chain2, "module2.pack1")
|
||||
assert not await service.check_has_permission(chain2, "module2.pack2")
|
||||
|
||||
await service.update_permission(entity2, "module1.pack2", False)
|
||||
assert not await service.check_has_permission(chain2, "*")
|
||||
assert await service.check_has_permission(chain2, "module1")
|
||||
assert await service.check_has_permission(chain2, "module1.pack1")
|
||||
assert not await service.check_has_permission(chain2, "module1.pack2")
|
||||
assert not await service.check_has_permission(chain2, "module2")
|
||||
assert not await service.check_has_permission(chain2, "module2.pack1")
|
||||
assert not await service.check_has_permission(chain2, "module2.pack2")
|
||||
|
||||
await service.update_permission(entity_global, "module2", True)
|
||||
assert not await service.check_has_permission(chain2, "*")
|
||||
assert await service.check_has_permission(chain2, "module1")
|
||||
assert await service.check_has_permission(chain2, "module1.pack1")
|
||||
assert not await service.check_has_permission(chain2, "module1.pack2")
|
||||
assert await service.check_has_permission(chain2, "module2")
|
||||
assert await service.check_has_permission(chain2, "module2.pack1")
|
||||
assert await service.check_has_permission(chain2, "module2.pack2")
|
||||
|
||||
assert not await service.check_has_permission(entity2, "module2.pack2")
|
||||
assert await service.check_has_permission(entity_global, "module2.pack2")
|
||||
|
||||
async with db.get_conn() as conn:
|
||||
repo = PermRepo(conn)
|
||||
assert await repo.get_perm_info(entity2, "module1") is True
|
||||
assert await repo.get_perm_info(entity2, "module1.pack2") is False
|
||||
Reference in New Issue
Block a user