Merge pull request '权限系统' (#55) from feature/permsystem into master
Some checks failed
continuous-integration/drone/push Build is failing
Some checks failed
continuous-integration/drone/push Build is failing
Reviewed-on: #55
This commit is contained in:
14
.drone.yml
14
.drone.yml
@ -13,7 +13,7 @@ steps:
|
|||||||
- name: submodules
|
- name: submodules
|
||||||
image: alpine/git
|
image: alpine/git
|
||||||
commands:
|
commands:
|
||||||
- git submodule update --init --recursive
|
- git submodule update --init --recursive
|
||||||
- name: 构建 Docker 镜像
|
- name: 构建 Docker 镜像
|
||||||
image: plugins/docker:latest
|
image: plugins/docker:latest
|
||||||
privileged: true
|
privileged: true
|
||||||
@ -30,7 +30,7 @@ steps:
|
|||||||
volumes:
|
volumes:
|
||||||
- name: docker-socket
|
- name: docker-socket
|
||||||
path: /var/run/docker.sock
|
path: /var/run/docker.sock
|
||||||
- name: 在容器中测试插件加载
|
- name: 在容器中进行若干测试
|
||||||
image: docker:dind
|
image: docker:dind
|
||||||
privileged: true
|
privileged: true
|
||||||
volumes:
|
volumes:
|
||||||
@ -38,14 +38,8 @@ steps:
|
|||||||
path: /var/run/docker.sock
|
path: /var/run/docker.sock
|
||||||
commands:
|
commands:
|
||||||
- docker run --rm gitea.service.jazzwhom.top/mttu-developers/konabot:nightly-${DRONE_COMMIT_SHA} python scripts/test_plugin_load.py
|
- docker run --rm gitea.service.jazzwhom.top/mttu-developers/konabot:nightly-${DRONE_COMMIT_SHA} python scripts/test_plugin_load.py
|
||||||
- name: 在容器中测试 Playwright 工作正常
|
|
||||||
image: docker:dind
|
|
||||||
privileged: true
|
|
||||||
volumes:
|
|
||||||
- name: docker-socket
|
|
||||||
path: /var/run/docker.sock
|
|
||||||
commands:
|
|
||||||
- docker run --rm gitea.service.jazzwhom.top/mttu-developers/konabot:nightly-${DRONE_COMMIT_SHA} python scripts/test_playwright.py
|
- docker run --rm gitea.service.jazzwhom.top/mttu-developers/konabot:nightly-${DRONE_COMMIT_SHA} python scripts/test_playwright.py
|
||||||
|
- docker run --rm gitea.service.jazzwhom.top/mttu-developers/konabot:nightly-${DRONE_COMMIT_SHA} python -m pytest --cov-report term-missing:skip-covered
|
||||||
- name: 发送构建结果到 ntfy
|
- name: 发送构建结果到 ntfy
|
||||||
image: parrazam/drone-ntfy
|
image: parrazam/drone-ntfy
|
||||||
when:
|
when:
|
||||||
@ -76,7 +70,7 @@ steps:
|
|||||||
- name: submodules
|
- name: submodules
|
||||||
image: alpine/git
|
image: alpine/git
|
||||||
commands:
|
commands:
|
||||||
- git submodule update --init --recursive
|
- git submodule update --init --recursive
|
||||||
- name: 构建并推送 Release Docker 镜像
|
- name: 构建并推送 Release Docker 镜像
|
||||||
image: plugins/docker:latest
|
image: plugins/docker:latest
|
||||||
privileged: true
|
privileged: true
|
||||||
|
|||||||
5
.gitignore
vendored
5
.gitignore
vendored
@ -9,3 +9,8 @@ __pycache__
|
|||||||
|
|
||||||
# 可能会偶然生成的 diff 文件
|
# 可能会偶然生成的 diff 文件
|
||||||
/*.diff
|
/*.diff
|
||||||
|
|
||||||
|
# 代码覆盖报告
|
||||||
|
/.coverage
|
||||||
|
/.coverage.db
|
||||||
|
/htmlcov
|
||||||
|
|||||||
6
.sqls.yml
Normal file
6
.sqls.yml
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
lowercaseKeywords: false
|
||||||
|
connections:
|
||||||
|
- driver: sqlite
|
||||||
|
dataSourceName: "./data/database.db"
|
||||||
|
- driver: sqlite
|
||||||
|
dataSourceName: "./data/perm.sqlite3"
|
||||||
19
README.md
19
README.md
@ -96,6 +96,21 @@ poetry run python bot.py
|
|||||||
- [事件处理](https://nonebot.dev/docs/tutorial/handler)
|
- [事件处理](https://nonebot.dev/docs/tutorial/handler)
|
||||||
- [Alconna 插件](https://nonebot.dev/docs/best-practice/alconna/)
|
- [Alconna 插件](https://nonebot.dev/docs/best-practice/alconna/)
|
||||||
|
|
||||||
## 数据库模块
|
## 代码测试
|
||||||
|
|
||||||
本项目的数据库模块已更新为异步实现,使用连接池来提高性能,并支持现代的`pathlib.Path`参数类型。详细使用方法请参考[数据库使用文档](/docs/database.md)。
|
本项目使用 pytest 进行自动化测试,你可以把你的测试代码放在 `./tests` 目录下。
|
||||||
|
|
||||||
|
使用命令行执行测试:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
poetry run just test
|
||||||
|
```
|
||||||
|
|
||||||
|
使用命令行,在浏览器查看测试覆盖率报告:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
poetry run just coverage
|
||||||
|
# 此时会打开一个 :8000 端口的 Web 服务器
|
||||||
|
# 你可以在 http://localhost:8000 查看覆盖率报告
|
||||||
|
# 在控制台使用 Ctrl+C 关闭这个 Web 服务器
|
||||||
|
```
|
||||||
|
|||||||
27
bot.py
27
bot.py
@ -22,19 +22,25 @@ env_enable_minecraft = os.environ.get("ENABLE_MINECRAFT", "none")
|
|||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
if env.upper() == 'DEBUG' or env.upper() == 'DEV':
|
if env.upper() == "DEBUG" or env.upper() == "DEV":
|
||||||
console_log_level = 'DEBUG'
|
console_log_level = "DEBUG"
|
||||||
else:
|
else:
|
||||||
console_log_level = 'INFO'
|
console_log_level = "INFO"
|
||||||
init_logger(LOG_PATH, [
|
init_logger(
|
||||||
BotExceptionMessage,
|
LOG_PATH,
|
||||||
], console_log_level=console_log_level)
|
[
|
||||||
|
BotExceptionMessage,
|
||||||
|
],
|
||||||
|
console_log_level=console_log_level,
|
||||||
|
)
|
||||||
|
|
||||||
nonebot.init()
|
nonebot.init()
|
||||||
|
|
||||||
driver = nonebot.get_driver()
|
driver = nonebot.get_driver()
|
||||||
|
|
||||||
if (env != "prod" and env != "test" and env_enable_console.upper() != "FALSE") or (env_enable_console.upper() == "TRUE"):
|
if (env != "prod" and env != "test" and env_enable_console.upper() != "FALSE") or (
|
||||||
|
env_enable_console.upper() == "TRUE"
|
||||||
|
):
|
||||||
driver.register_adapter(ConsoleAdapter)
|
driver.register_adapter(ConsoleAdapter)
|
||||||
|
|
||||||
if env_enable_qq.upper() == "TRUE":
|
if env_enable_qq.upper() == "TRUE":
|
||||||
@ -50,14 +56,19 @@ def main():
|
|||||||
nonebot.load_plugins("konabot/plugins")
|
nonebot.load_plugins("konabot/plugins")
|
||||||
nonebot.load_plugin("nonebot_plugin_analysis_bilibili")
|
nonebot.load_plugin("nonebot_plugin_analysis_bilibili")
|
||||||
|
|
||||||
|
from konabot.common import permsys
|
||||||
|
|
||||||
|
permsys.create_startup()
|
||||||
|
|
||||||
# 注册关闭钩子
|
# 注册关闭钩子
|
||||||
@driver.on_shutdown
|
@driver.on_shutdown
|
||||||
async def shutdown_handler():
|
async def _():
|
||||||
# 关闭全局数据库管理器
|
# 关闭全局数据库管理器
|
||||||
db_manager = get_global_db_manager()
|
db_manager = get_global_db_manager()
|
||||||
await db_manager.close_all_connections()
|
await db_manager.close_all_connections()
|
||||||
|
|
||||||
nonebot.run()
|
nonebot.run()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
|||||||
235
docs/permsys.md
Normal file
235
docs/permsys.md
Normal file
@ -0,0 +1,235 @@
|
|||||||
|
# 权限系统 `konabot.common.permsys`
|
||||||
|
|
||||||
|
本文档面向维护者,说明 `konabot/common/permsys` 模块的职责、数据模型、权限解析规则,以及在插件中接入的推荐方式。
|
||||||
|
|
||||||
|
## 模块目标
|
||||||
|
|
||||||
|
`permsys` 提供了一套简单的、可继承的权限系统,用于回答两个问题:
|
||||||
|
|
||||||
|
1. 某个事件对应的主体是谁。
|
||||||
|
2. 该主体是否拥有某项权限。
|
||||||
|
|
||||||
|
它适合处理 bot 内部的功能开关、管理权限、平台级授权等场景。
|
||||||
|
|
||||||
|
当前模块由以下几部分组成:
|
||||||
|
|
||||||
|
- `konabot/common/permsys/__init__.py`
|
||||||
|
- 暴露 `PermManager`、`DepPermManager`、`require_permission`
|
||||||
|
- 负责数据库初始化、启动迁移、超级管理员默认授权
|
||||||
|
- `konabot/common/permsys/entity.py`
|
||||||
|
- 定义 `PermEntity`
|
||||||
|
- 将事件转换为可查询的实体链
|
||||||
|
- `konabot/common/permsys/repo.py`
|
||||||
|
- 封装 SQLite 读写
|
||||||
|
- `konabot/common/permsys/migrates/`
|
||||||
|
- 存放迁移 SQL
|
||||||
|
- `konabot/common/permsys/sql/`
|
||||||
|
- 存放查询与更新 SQL
|
||||||
|
|
||||||
|
## 核心概念
|
||||||
|
|
||||||
|
### 1. `PermEntity`
|
||||||
|
|
||||||
|
`PermEntity` 是权限系统中的最小主体标识:
|
||||||
|
|
||||||
|
```python
|
||||||
|
PermEntity(platform: str, entity_type: str, external_id: str)
|
||||||
|
```
|
||||||
|
|
||||||
|
示例:
|
||||||
|
|
||||||
|
- `PermEntity("sys", "global", "global")`
|
||||||
|
- `PermEntity("ob11", "group", "123456")`
|
||||||
|
- `PermEntity("ob11", "user", "987654")`
|
||||||
|
|
||||||
|
其中:
|
||||||
|
|
||||||
|
- `platform` 表示来源平台,如 `sys`、`ob11`、`discord`
|
||||||
|
- `entity_type` 表示主体类型,如 `global`、`group`、`user`
|
||||||
|
- `external_id` 表示平台侧的外部标识
|
||||||
|
|
||||||
|
### 2. 实体链
|
||||||
|
|
||||||
|
权限判断不是只看单个实体,而是看一条“实体链”。
|
||||||
|
|
||||||
|
以 `get_entity_chain_of_entity()` 为例,传入一个具体实体时,返回的链为:
|
||||||
|
|
||||||
|
```python
|
||||||
|
[
|
||||||
|
PermEntity(platform, entity_type, external_id),
|
||||||
|
PermEntity(platform, "global", "global"),
|
||||||
|
PermEntity("sys", "global", "global"),
|
||||||
|
]
|
||||||
|
```
|
||||||
|
|
||||||
|
这意味着权限会优先读取更具体的主体,再回退到平台全局,最后回退到系统全局。
|
||||||
|
|
||||||
|
`get_entity_chain(event)` 则会根据事件类型自动构造链。例如:
|
||||||
|
|
||||||
|
- OneBot V11 群消息:用户 -> 群 -> 平台全局 -> 系统全局
|
||||||
|
- OneBot V11 私聊:用户 -> 平台全局 -> 系统全局
|
||||||
|
- Discord 频道消息:用户/频道/服务器 -> 平台全局 -> 系统全局
|
||||||
|
- Console:控制台用户/频道 -> 平台全局 -> 系统全局
|
||||||
|
|
||||||
|
注意:当前 `entity.py` 中的具体链顺序与字段命名应以实现为准;修改这里时要评估现有权限继承是否会被破坏。
|
||||||
|
|
||||||
|
### 3. 权限键
|
||||||
|
|
||||||
|
权限键使用点分结构,例如:
|
||||||
|
|
||||||
|
- `admin`
|
||||||
|
- `plugin.weather`
|
||||||
|
- `plugin.weather.use`
|
||||||
|
|
||||||
|
检查时会自动做前缀回退。以 `plugin.weather.use` 为例,查询顺序是:
|
||||||
|
|
||||||
|
1. `plugin.weather.use`
|
||||||
|
2. `plugin.weather`
|
||||||
|
3. `plugin`
|
||||||
|
4. `*`
|
||||||
|
|
||||||
|
因此,`*` 可以看作兜底总权限。
|
||||||
|
|
||||||
|
## 权限解析规则
|
||||||
|
|
||||||
|
`PermManager.check_has_permission_info()` 的逻辑可以概括为:
|
||||||
|
|
||||||
|
1. 先把输入转换成实体链。
|
||||||
|
2. 对权限键做逐级回退,同时追加 `*`。
|
||||||
|
3. 在数据库中批量查出链上所有实体、所有候选键的显式记录。
|
||||||
|
4. 按“实体越具体越优先、权限键越具体越优先”的顺序,返回第一条命中的记录。
|
||||||
|
|
||||||
|
若没有任何显式记录:
|
||||||
|
|
||||||
|
- `check_has_permission_info()` 返回 `None`
|
||||||
|
- `check_has_permission()` 返回 `False`
|
||||||
|
|
||||||
|
这表示本系统默认是“未授权即拒绝”。
|
||||||
|
|
||||||
|
## 数据存储
|
||||||
|
|
||||||
|
模块使用 SQLite,默认数据库文件位于:
|
||||||
|
|
||||||
|
- `data/perm.sqlite3`
|
||||||
|
|
||||||
|
启动时会执行迁移:
|
||||||
|
|
||||||
|
- `create_startup()` 在 NoneBot 启动事件中调用 `execute_migration()`
|
||||||
|
|
||||||
|
权限值支持三态:
|
||||||
|
|
||||||
|
- `True`:显式允许
|
||||||
|
- `False`:显式拒绝
|
||||||
|
- `None`:删除/清空该层的显式设置,让判断重新回退到继承链
|
||||||
|
|
||||||
|
`repo.py` 中的 `update_perm_info()` 会将这个三态直接写入数据库。
|
||||||
|
|
||||||
|
## 超级管理员注入
|
||||||
|
|
||||||
|
在启动阶段,`create_startup()` 会读取 `konabot.common.nb.is_admin.cfg.admin_qq_account`,并为这些 QQ 账号写入:
|
||||||
|
|
||||||
|
```python
|
||||||
|
PermEntity("ob11", "user", str(account)), "*", True
|
||||||
|
```
|
||||||
|
|
||||||
|
也就是说,配置中的超级管理员会直接拥有全部权限。
|
||||||
|
|
||||||
|
这属于启动时自动灌入的保底策略,不依赖手工授权命令。
|
||||||
|
|
||||||
|
## 在插件中使用
|
||||||
|
|
||||||
|
### 1. 直接做权限检查
|
||||||
|
|
||||||
|
```python
|
||||||
|
from konabot.common.permsys import DepPermManager
|
||||||
|
|
||||||
|
|
||||||
|
async def handler(pm: DepPermManager, event):
|
||||||
|
ok = await pm.check_has_permission(event, "plugin.example.use")
|
||||||
|
if not ok:
|
||||||
|
return
|
||||||
|
```
|
||||||
|
|
||||||
|
适合需要在处理流程中动态决定权限键的场景。
|
||||||
|
|
||||||
|
### 2. 挂到 Rule 上做准入控制
|
||||||
|
|
||||||
|
```python
|
||||||
|
from nonebot_plugin_alconna import Alconna, on_alconna
|
||||||
|
from konabot.common.permsys import require_permission
|
||||||
|
|
||||||
|
|
||||||
|
cmd = on_alconna(
|
||||||
|
Alconna("example"),
|
||||||
|
rule=require_permission("plugin.example.use"),
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
适合命令入口明确、未通过时直接拦截的场景。
|
||||||
|
|
||||||
|
### 3. 更新权限
|
||||||
|
|
||||||
|
```python
|
||||||
|
from konabot.common.permsys import DepPermManager
|
||||||
|
from konabot.common.permsys.entity import PermEntity
|
||||||
|
|
||||||
|
|
||||||
|
await pm.update_permission(
|
||||||
|
PermEntity("ob11", "group", "123456"),
|
||||||
|
"plugin.example.use",
|
||||||
|
True,
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
建议只在专门的管理插件中开放写权限,避免普通功能插件到处分散改表。
|
||||||
|
|
||||||
|
## `perm_manage` 插件与本模块的关系
|
||||||
|
|
||||||
|
`konabot/plugins/perm_manage/__init__.py` 是本模块当前的管理入口,提供:
|
||||||
|
|
||||||
|
- `konaperm list`:列出实体链上已有的显式权限记录
|
||||||
|
- `konaperm get`:查看某个权限最终命中的记录
|
||||||
|
- `konaperm set`:写入 allow/deny/null
|
||||||
|
|
||||||
|
这个插件本身使用 `require_permission("admin")` 保护,因此只有拥有 `admin` 权限的主体才能管理权限。
|
||||||
|
|
||||||
|
## 接入建议
|
||||||
|
|
||||||
|
### 权限键命名
|
||||||
|
|
||||||
|
建议使用稳定、可扩展的分层键名:
|
||||||
|
|
||||||
|
- 推荐:`plugin.xxx`、`plugin.xxx.action`
|
||||||
|
- 不推荐:含糊的单词或临时字符串
|
||||||
|
|
||||||
|
这样才能利用前缀回退机制做批量授权。
|
||||||
|
|
||||||
|
### 输入安全
|
||||||
|
|
||||||
|
虽然这个项目偏内部使用,但权限键、实体类型、外部 ID 仍然应视为不可信输入:
|
||||||
|
|
||||||
|
- 不要把聊天输入直接拼到 SQL 中
|
||||||
|
- 不要让任意用户可随意构造高权限写入
|
||||||
|
- 对可写命令至少做权限保护和必要校验
|
||||||
|
|
||||||
|
### 改动兼容性
|
||||||
|
|
||||||
|
以下改动都可能影响全局权限行为,修改前应充分评估:
|
||||||
|
|
||||||
|
- 更改实体链顺序
|
||||||
|
- 更改默认兜底键 `*` 的语义
|
||||||
|
- 更改 `None` 的处理方式
|
||||||
|
- 更改启动时超级管理员注入逻辑
|
||||||
|
|
||||||
|
## 调试建议
|
||||||
|
|
||||||
|
- 先用 `konaperm get ...` 确认某个权限最终命中了哪一层
|
||||||
|
- 再用 `konaperm list ...` 查看该实体链上有哪些显式记录
|
||||||
|
- 若表现异常,检查是否是更上层实体或更宽泛权限键提前命中
|
||||||
|
|
||||||
|
## 相关文件
|
||||||
|
|
||||||
|
- `konabot/common/permsys/__init__.py`
|
||||||
|
- `konabot/common/permsys/entity.py`
|
||||||
|
- `konabot/common/permsys/repo.py`
|
||||||
|
- `konabot/plugins/perm_manage/__init__.py`
|
||||||
5
justfile
5
justfile
@ -1,4 +1,9 @@
|
|||||||
watch:
|
watch:
|
||||||
poetry run watchfiles bot.main . --filter scripts.watch_filter.filter
|
poetry run watchfiles bot.main . --filter scripts.watch_filter.filter
|
||||||
|
|
||||||
|
test:
|
||||||
|
poetry run pytest --cov-report term-missing:skip-covered
|
||||||
|
|
||||||
|
coverage:
|
||||||
|
poetry run pytest --cov-report html
|
||||||
|
python -m http.server -d htmlcov
|
||||||
|
|||||||
@ -1,3 +1,4 @@
|
|||||||
|
from contextlib import asynccontextmanager
|
||||||
import os
|
import os
|
||||||
import asyncio
|
import asyncio
|
||||||
import sqlparse
|
import sqlparse
|
||||||
@ -10,16 +11,19 @@ if TYPE_CHECKING:
|
|||||||
from . import DatabaseManager
|
from . import DatabaseManager
|
||||||
|
|
||||||
# 全局数据库管理器实例
|
# 全局数据库管理器实例
|
||||||
_global_db_manager: Optional['DatabaseManager'] = None
|
_global_db_manager: Optional["DatabaseManager"] = None
|
||||||
|
|
||||||
def get_global_db_manager() -> 'DatabaseManager':
|
|
||||||
|
def get_global_db_manager() -> "DatabaseManager":
|
||||||
"""获取全局数据库管理器实例"""
|
"""获取全局数据库管理器实例"""
|
||||||
global _global_db_manager
|
global _global_db_manager
|
||||||
if _global_db_manager is None:
|
if _global_db_manager is None:
|
||||||
from . import DatabaseManager
|
from . import DatabaseManager
|
||||||
|
|
||||||
_global_db_manager = DatabaseManager()
|
_global_db_manager = DatabaseManager()
|
||||||
return _global_db_manager
|
return _global_db_manager
|
||||||
|
|
||||||
|
|
||||||
def close_global_db_manager() -> None:
|
def close_global_db_manager() -> None:
|
||||||
"""关闭全局数据库管理器实例"""
|
"""关闭全局数据库管理器实例"""
|
||||||
global _global_db_manager
|
global _global_db_manager
|
||||||
@ -87,6 +91,12 @@ class DatabaseManager:
|
|||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def get_conn(self):
|
||||||
|
conn = await self._get_connection()
|
||||||
|
yield conn
|
||||||
|
await self._return_connection(conn)
|
||||||
|
|
||||||
async def query(
|
async def query(
|
||||||
self, query: str, params: Optional[tuple] = None
|
self, query: str, params: Optional[tuple] = None
|
||||||
) -> List[Dict[str, Any]]:
|
) -> List[Dict[str, Any]]:
|
||||||
@ -143,22 +153,24 @@ class DatabaseManager:
|
|||||||
# 使用sqlparse库更准确地分割SQL语句
|
# 使用sqlparse库更准确地分割SQL语句
|
||||||
parsed = sqlparse.split(script)
|
parsed = sqlparse.split(script)
|
||||||
statements = []
|
statements = []
|
||||||
|
|
||||||
for statement in parsed:
|
for statement in parsed:
|
||||||
statement = statement.strip()
|
statement = statement.strip()
|
||||||
if statement:
|
if statement:
|
||||||
statements.append(statement)
|
statements.append(statement)
|
||||||
|
|
||||||
return statements
|
return statements
|
||||||
|
|
||||||
async def execute_by_sql_file(
|
async def execute_by_sql_file(
|
||||||
self, file_path: Union[str, Path], params: Optional[Union[tuple, List[tuple]]] = None
|
self,
|
||||||
|
file_path: Union[str, Path],
|
||||||
|
params: Optional[Union[tuple, List[tuple]]] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""从 SQL 文件中读取非查询语句并执行"""
|
"""从 SQL 文件中读取非查询语句并执行"""
|
||||||
path = str(file_path) if isinstance(file_path, Path) else file_path
|
path = str(file_path) if isinstance(file_path, Path) else file_path
|
||||||
with open(path, "r", encoding="utf-8") as f:
|
with open(path, "r", encoding="utf-8") as f:
|
||||||
script = f.read()
|
script = f.read()
|
||||||
|
|
||||||
# 如果有参数且是元组,使用execute执行整个脚本
|
# 如果有参数且是元组,使用execute执行整个脚本
|
||||||
if params is not None and isinstance(params, tuple):
|
if params is not None and isinstance(params, tuple):
|
||||||
await self.execute(script, params)
|
await self.execute(script, params)
|
||||||
@ -167,8 +179,10 @@ class DatabaseManager:
|
|||||||
# 使用sqlparse准确分割SQL语句
|
# 使用sqlparse准确分割SQL语句
|
||||||
statements = self._parse_sql_statements(script)
|
statements = self._parse_sql_statements(script)
|
||||||
if len(statements) != len(params):
|
if len(statements) != len(params):
|
||||||
raise ValueError(f"语句数量({len(statements)})与参数组数量({len(params)})不匹配")
|
raise ValueError(
|
||||||
|
f"语句数量({len(statements)})与参数组数量({len(params)})不匹配"
|
||||||
|
)
|
||||||
|
|
||||||
for statement, stmt_params in zip(statements, params):
|
for statement, stmt_params in zip(statements, params):
|
||||||
if statement:
|
if statement:
|
||||||
await self.execute(statement, stmt_params)
|
await self.execute(statement, stmt_params)
|
||||||
@ -215,4 +229,3 @@ class DatabaseManager:
|
|||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
self._in_use.clear()
|
self._in_use.clear()
|
||||||
|
|
||||||
|
|||||||
108
konabot/common/permsys/__init__.py
Normal file
108
konabot/common/permsys/__init__.py
Normal file
@ -0,0 +1,108 @@
|
|||||||
|
from typing import Annotated
|
||||||
|
import nonebot
|
||||||
|
from nonebot.adapters import Event
|
||||||
|
from nonebot.params import Depends
|
||||||
|
from nonebot.rule import Rule
|
||||||
|
|
||||||
|
from konabot.common.database import DatabaseManager
|
||||||
|
from konabot.common.pager import PagerQuery
|
||||||
|
from konabot.common.path import DATA_PATH
|
||||||
|
from konabot.common.permsys.entity import PermEntity, get_entity_chain
|
||||||
|
from konabot.common.permsys.migrates import execute_migration
|
||||||
|
from konabot.common.permsys.repo import PermRepo
|
||||||
|
|
||||||
|
|
||||||
|
db = DatabaseManager(DATA_PATH / "perm.sqlite3")
|
||||||
|
|
||||||
|
|
||||||
|
_EntityLike = Event | PermEntity | list[PermEntity]
|
||||||
|
|
||||||
|
|
||||||
|
async def _to_entity_chain(el: _EntityLike):
|
||||||
|
if isinstance(el, Event):
|
||||||
|
return await get_entity_chain(el) # pragma: no cover
|
||||||
|
if isinstance(el, PermEntity):
|
||||||
|
return [el]
|
||||||
|
return el
|
||||||
|
|
||||||
|
|
||||||
|
class PermManager:
|
||||||
|
def __init__(self, db: DatabaseManager) -> None:
|
||||||
|
self.db = db
|
||||||
|
|
||||||
|
async def check_has_permission_info(self, entities: _EntityLike, key: str):
|
||||||
|
entities = await _to_entity_chain(entities)
|
||||||
|
key = key.removesuffix("*").removesuffix(".")
|
||||||
|
key_split = key.split(".")
|
||||||
|
key_split = [s for s in key_split if len(s) > 0]
|
||||||
|
keys = [".".join(key_split[: i + 1]) for i in range(len(key_split))][::-1] + [
|
||||||
|
"*"
|
||||||
|
]
|
||||||
|
|
||||||
|
async with self.db.get_conn() as conn:
|
||||||
|
repo = PermRepo(conn)
|
||||||
|
data = await repo.get_perm_info_batch(entities, keys)
|
||||||
|
for entity in entities:
|
||||||
|
for k in keys:
|
||||||
|
p = data.get((entity, k))
|
||||||
|
if p is not None:
|
||||||
|
return (entity, k, p)
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def check_has_permission(self, entities: _EntityLike, key: str) -> bool:
|
||||||
|
res = await self.check_has_permission_info(entities, key)
|
||||||
|
if res is None:
|
||||||
|
return False
|
||||||
|
return res[2]
|
||||||
|
|
||||||
|
async def update_permission(self, entity: PermEntity, key: str, perm: bool | None):
|
||||||
|
async with self.db.get_conn() as conn:
|
||||||
|
repo = PermRepo(conn)
|
||||||
|
await repo.update_perm_info(entity, key, perm)
|
||||||
|
|
||||||
|
async def list_permission(self, entities: _EntityLike, query: PagerQuery):
|
||||||
|
entities = await _to_entity_chain(entities)
|
||||||
|
async with self.db.get_conn() as conn:
|
||||||
|
repo = PermRepo(conn)
|
||||||
|
return await repo.list_perm_info_batch(entities, query)
|
||||||
|
|
||||||
|
|
||||||
|
def perm_manager(_db: DatabaseManager | None = None) -> PermManager: # pragma: no cover
|
||||||
|
if _db is None:
|
||||||
|
_db = db
|
||||||
|
return PermManager(_db)
|
||||||
|
|
||||||
|
|
||||||
|
def create_startup(): # pragma: no cover
|
||||||
|
from konabot.common.nb.is_admin import cfg
|
||||||
|
|
||||||
|
driver = nonebot.get_driver()
|
||||||
|
|
||||||
|
@driver.on_startup
|
||||||
|
async def _():
|
||||||
|
async with db.get_conn() as conn:
|
||||||
|
await execute_migration(conn)
|
||||||
|
pm = perm_manager(db)
|
||||||
|
for account in cfg.admin_qq_account:
|
||||||
|
# ^ 这里的是超级管理员!!用环境变量定义的。
|
||||||
|
# 咕嘿嘿嘿!!!夺取全部权限!!!
|
||||||
|
await pm.update_permission(
|
||||||
|
PermEntity("ob11", "user", str(account)), "*", True
|
||||||
|
)
|
||||||
|
|
||||||
|
@driver.on_shutdown
|
||||||
|
async def _():
|
||||||
|
try:
|
||||||
|
await db.close_all_connections()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
DepPermManager = Annotated[PermManager, Depends(perm_manager)]
|
||||||
|
|
||||||
|
|
||||||
|
def require_permission(perm: str) -> Rule: # pragma: no cover
|
||||||
|
async def check_permission(event: Event, pm: DepPermManager) -> bool:
|
||||||
|
return await pm.check_has_permission(event, perm)
|
||||||
|
|
||||||
|
return Rule(check_permission)
|
||||||
69
konabot/common/permsys/entity.py
Normal file
69
konabot/common/permsys/entity.py
Normal file
@ -0,0 +1,69 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
|
from nonebot.internal.adapter import Event
|
||||||
|
|
||||||
|
from nonebot.adapters.onebot.v11 import Event as OB11Event
|
||||||
|
from nonebot.adapters.onebot.v11.event import GroupMessageEvent as OB11GroupEvent
|
||||||
|
from nonebot.adapters.onebot.v11.event import PrivateMessageEvent as OB11PrivateEvent
|
||||||
|
|
||||||
|
from nonebot.adapters.discord.event import Event as DiscordEvent
|
||||||
|
from nonebot.adapters.discord.event import GuildMessageCreateEvent as DiscordGMEvent
|
||||||
|
from nonebot.adapters.discord.event import DirectMessageCreateEvent as DiscordDMEvent
|
||||||
|
|
||||||
|
from nonebot.adapters.minecraft.event import MessageEvent as MinecraftMessageEvent
|
||||||
|
|
||||||
|
from nonebot.adapters.console.event import MessageEvent as ConsoleEvent
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class PermEntity:
|
||||||
|
platform: str
|
||||||
|
entity_type: str
|
||||||
|
external_id: str
|
||||||
|
|
||||||
|
|
||||||
|
def get_entity_chain_of_entity(entity: PermEntity) -> list[PermEntity]:
|
||||||
|
return [
|
||||||
|
PermEntity("sys", "global", "global"),
|
||||||
|
PermEntity(entity.platform, "global", "global"),
|
||||||
|
entity,
|
||||||
|
][::-1]
|
||||||
|
|
||||||
|
|
||||||
|
async def get_entity_chain(event: Event) -> list[PermEntity]: # pragma: no cover
|
||||||
|
entities = [PermEntity("sys", "global", "global")]
|
||||||
|
|
||||||
|
if isinstance(event, OB11Event):
|
||||||
|
entities.append(PermEntity("ob11", "global", "global"))
|
||||||
|
|
||||||
|
if isinstance(event, OB11GroupEvent):
|
||||||
|
entities.append(PermEntity("ob11", "group", str(event.group_id)))
|
||||||
|
entities.append(PermEntity("ob11", "user", str(event.user_id)))
|
||||||
|
|
||||||
|
if isinstance(event, OB11PrivateEvent):
|
||||||
|
entities.append(PermEntity("ob11", "user", str(event.user_id)))
|
||||||
|
|
||||||
|
if isinstance(event, DiscordEvent):
|
||||||
|
entities.append(PermEntity("discord", "global", "global"))
|
||||||
|
|
||||||
|
if isinstance(event, DiscordGMEvent):
|
||||||
|
entities.append(PermEntity("discord", "guild", str(event.guild_id)))
|
||||||
|
entities.append(PermEntity("discord", "channel", str(event.channel_id)))
|
||||||
|
entities.append(PermEntity("discord", "user", str(event.user_id)))
|
||||||
|
|
||||||
|
if isinstance(event, DiscordDMEvent):
|
||||||
|
entities.append(PermEntity("discord", "channel", str(event.channel_id)))
|
||||||
|
entities.append(PermEntity("discord", "user", str(event.user_id)))
|
||||||
|
|
||||||
|
if isinstance(event, MinecraftMessageEvent):
|
||||||
|
entities.append(PermEntity("minecraft", "global", "global"))
|
||||||
|
entities.append(PermEntity("minecraft", "server", event.server_name))
|
||||||
|
player_uuid = event.player.uuid
|
||||||
|
if player_uuid is not None:
|
||||||
|
entities.append(PermEntity("minecraft", "player", player_uuid.hex))
|
||||||
|
|
||||||
|
if isinstance(event, ConsoleEvent):
|
||||||
|
entities.append(PermEntity("console", "global", "global"))
|
||||||
|
entities.append(PermEntity("console", "channel", event.channel.id))
|
||||||
|
entities.append(PermEntity("console", "user", event.user.id))
|
||||||
|
|
||||||
|
return entities[::-1]
|
||||||
81
konabot/common/permsys/migrates/__init__.py
Normal file
81
konabot/common/permsys/migrates/__init__.py
Normal file
@ -0,0 +1,81 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import aiosqlite
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
from konabot.common.database import DatabaseManager
|
||||||
|
from konabot.common.path import DATA_PATH
|
||||||
|
|
||||||
|
|
||||||
|
PATH_THISFOLDER = Path(__file__).parent
|
||||||
|
|
||||||
|
SQL_CHECK_EXISTS = (PATH_THISFOLDER / "./check_migrate_version_exists.sql").read_text()
|
||||||
|
SQL_CREATE_TABLE = (PATH_THISFOLDER / "./create_migrate_version_table.sql").read_text()
|
||||||
|
SQL_GET_MIGRATE_VERSION = (PATH_THISFOLDER / "get_migrate_version.sql").read_text()
|
||||||
|
SQL_UPDATE_VERSION = (PATH_THISFOLDER / "./update_migrate_version.sql").read_text()
|
||||||
|
|
||||||
|
db = DatabaseManager(DATA_PATH / "perm.sqlite3")
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Migration:
|
||||||
|
upgrade: str | Path
|
||||||
|
downgrade: str | Path
|
||||||
|
|
||||||
|
def get_upgrade_script(self) -> str:
|
||||||
|
if isinstance(self.upgrade, Path):
|
||||||
|
return self.upgrade.read_text()
|
||||||
|
return self.upgrade
|
||||||
|
|
||||||
|
def get_downgrade_script(self) -> str:
|
||||||
|
if isinstance(self.downgrade, Path):
|
||||||
|
return self.downgrade.read_text()
|
||||||
|
return self.downgrade
|
||||||
|
|
||||||
|
|
||||||
|
migrations = [
|
||||||
|
Migration(
|
||||||
|
PATH_THISFOLDER / "./mu1_create_permsys_table.sql",
|
||||||
|
PATH_THISFOLDER / "./md1_remove_permsys_table.sql",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
TARGET_VERSION = len(migrations)
|
||||||
|
|
||||||
|
|
||||||
|
async def get_current_version(conn: aiosqlite.Connection) -> int:
|
||||||
|
cursor = await conn.execute(SQL_CHECK_EXISTS)
|
||||||
|
count = await cursor.fetchone()
|
||||||
|
assert count is not None
|
||||||
|
if count[0] < 1:
|
||||||
|
logger.info("权限系统数据表不存在,现在创建表")
|
||||||
|
await conn.executescript(SQL_CREATE_TABLE)
|
||||||
|
await conn.commit()
|
||||||
|
return 0
|
||||||
|
cursor = await conn.execute(SQL_GET_MIGRATE_VERSION)
|
||||||
|
row = await cursor.fetchone()
|
||||||
|
if row is None:
|
||||||
|
return 0
|
||||||
|
return row[0]
|
||||||
|
|
||||||
|
|
||||||
|
async def execute_migration(
|
||||||
|
conn: aiosqlite.Connection,
|
||||||
|
version: int = TARGET_VERSION,
|
||||||
|
migrations: list[Migration] = migrations,
|
||||||
|
):
|
||||||
|
now_version = await get_current_version(conn)
|
||||||
|
while now_version < version:
|
||||||
|
migration = migrations[now_version]
|
||||||
|
await conn.executescript(migration.get_upgrade_script())
|
||||||
|
now_version += 1
|
||||||
|
await conn.execute(SQL_UPDATE_VERSION, (now_version,))
|
||||||
|
await conn.commit()
|
||||||
|
while now_version > version:
|
||||||
|
migration = migrations[now_version - 1]
|
||||||
|
await conn.executescript(migration.get_downgrade_script())
|
||||||
|
now_version -= 1
|
||||||
|
await conn.execute(SQL_UPDATE_VERSION, (now_version,))
|
||||||
|
await conn.commit()
|
||||||
@ -0,0 +1,7 @@
|
|||||||
|
SELECT
|
||||||
|
COUNT(*)
|
||||||
|
FROM
|
||||||
|
sqlite_master
|
||||||
|
WHERE
|
||||||
|
type = 'table'
|
||||||
|
AND name = 'migrate_version'
|
||||||
@ -0,0 +1,3 @@
|
|||||||
|
CREATE TABLE migrate_version(version INT PRIMARY KEY);
|
||||||
|
INSERT INTO migrate_version(version)
|
||||||
|
VALUES(0);
|
||||||
4
konabot/common/permsys/migrates/get_migrate_version.sql
Normal file
4
konabot/common/permsys/migrates/get_migrate_version.sql
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
SELECT
|
||||||
|
version
|
||||||
|
FROM
|
||||||
|
migrate_version;
|
||||||
@ -0,0 +1,2 @@
|
|||||||
|
DROP TABLE IF EXISTS perm_entity;
|
||||||
|
DROP TABLE IF EXISTS perm_info;
|
||||||
30
konabot/common/permsys/migrates/mu1_create_permsys_table.sql
Normal file
30
konabot/common/permsys/migrates/mu1_create_permsys_table.sql
Normal file
@ -0,0 +1,30 @@
|
|||||||
|
CREATE TABLE perm_entity(
|
||||||
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
|
platform TEXT NOT NULL,
|
||||||
|
entity_type TEXT NOT NULL,
|
||||||
|
external_id TEXT NOT NULL,
|
||||||
|
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||||
|
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE UNIQUE INDEX idx_perm_entity_lookup
|
||||||
|
ON perm_entity(platform, entity_type, external_id);
|
||||||
|
|
||||||
|
CREATE TABLE perm_info(
|
||||||
|
entity_id INTEGER NOT NULL,
|
||||||
|
config_key TEXT NOT NULL,
|
||||||
|
value BOOLEAN,
|
||||||
|
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||||
|
-- 联合主键
|
||||||
|
PRIMARY KEY (entity_id, config_key)
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE TRIGGER perm_entity_update AFTER UPDATE
|
||||||
|
ON perm_entity BEGIN
|
||||||
|
UPDATE perm_entity SET updated_at=CURRENT_TIMESTAMP WHERE id=old.id;
|
||||||
|
END;
|
||||||
|
CREATE TRIGGER perm_info_update AFTER UPDATE
|
||||||
|
ON perm_info BEGIN
|
||||||
|
UPDATE perm_info SET updated_at=CURRENT_TIMESTAMP WHERE entity_id=old.entity_id;
|
||||||
|
END;
|
||||||
|
|
||||||
@ -0,0 +1,2 @@
|
|||||||
|
UPDATE migrate_version
|
||||||
|
SET version = ?;
|
||||||
242
konabot/common/permsys/repo.py
Normal file
242
konabot/common/permsys/repo.py
Normal file
@ -0,0 +1,242 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
|
import math
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import aiosqlite
|
||||||
|
|
||||||
|
from konabot.common.pager import PagerQuery, PagerResult
|
||||||
|
|
||||||
|
from .entity import PermEntity
|
||||||
|
|
||||||
|
|
||||||
|
def s(p: str):
|
||||||
|
"""读取 SQL 文件内容。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
p: SQL 文件名(相对于当前文件所在目录的 sql/ 子目录)。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
SQL 文件的内容字符串。
|
||||||
|
"""
|
||||||
|
return (Path(__file__).parent / "./sql/" / p).read_text()
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class PermRepo:
|
||||||
|
"""权限实体存储库,负责与数据库交互管理权限实体。
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
conn: aiosqlite 数据库连接对象。
|
||||||
|
"""
|
||||||
|
|
||||||
|
conn: aiosqlite.Connection
|
||||||
|
|
||||||
|
async def create_entity(self, entity: PermEntity) -> int:
|
||||||
|
"""创建新的权限实体并返回其 ID。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
entity: 要创建的权限实体对象。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
新创建实体的数据库 ID。
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
AssertionError: 如果创建后无法获取实体 ID。
|
||||||
|
"""
|
||||||
|
await self.conn.execute(
|
||||||
|
s("create_entity.sql"),
|
||||||
|
(entity.platform, entity.entity_type, entity.external_id),
|
||||||
|
)
|
||||||
|
await self.conn.commit()
|
||||||
|
eid = await self._get_entity_id_or_none(entity)
|
||||||
|
assert eid is not None
|
||||||
|
return eid
|
||||||
|
|
||||||
|
async def _get_entity_id_or_none(self, entity: PermEntity) -> int | None:
|
||||||
|
"""查询实体 ID,如果不存在则返回 None。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
entity: 要查询的权限实体对象。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
实体 ID,如果不存在则返回 None。
|
||||||
|
"""
|
||||||
|
res = await self.conn.execute(
|
||||||
|
s("get_entity_id.sql"),
|
||||||
|
(entity.platform, entity.entity_type, entity.external_id),
|
||||||
|
)
|
||||||
|
row = await res.fetchone()
|
||||||
|
if row is None:
|
||||||
|
return None
|
||||||
|
return row[0]
|
||||||
|
|
||||||
|
async def get_entity_id(self, entity: PermEntity) -> int:
|
||||||
|
"""获取实体 ID,如果不存在则自动创建。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
entity: 权限实体对象。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
实体的数据库 ID。
|
||||||
|
"""
|
||||||
|
eid = await self._get_entity_id_or_none(entity)
|
||||||
|
if eid is None:
|
||||||
|
return await self.create_entity(entity)
|
||||||
|
return eid
|
||||||
|
|
||||||
|
async def get_perm_info(self, entity: PermEntity, config_key: str) -> bool | None:
|
||||||
|
"""获取实体的权限配置信息。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
entity: 权限实体对象。
|
||||||
|
config_key: 配置项的键名。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
配置值(True/False),如果不存在则返回 None。
|
||||||
|
"""
|
||||||
|
eid = await self.get_entity_id(entity)
|
||||||
|
res = await self.conn.execute(
|
||||||
|
s("get_perm_info.sql"),
|
||||||
|
(eid, config_key),
|
||||||
|
)
|
||||||
|
row = await res.fetchone()
|
||||||
|
if row is None:
|
||||||
|
return None
|
||||||
|
return bool(row[0])
|
||||||
|
|
||||||
|
async def update_perm_info(
|
||||||
|
self, entity: PermEntity, config_key: str, value: bool | None
|
||||||
|
):
|
||||||
|
"""更新实体的权限配置信息。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
entity: 权限实体对象。
|
||||||
|
config_key: 配置项的键名。
|
||||||
|
value: 要设置的配置值(True/False/None)。
|
||||||
|
"""
|
||||||
|
eid = await self.get_entity_id(entity)
|
||||||
|
await self.conn.execute(s("update_perm_info.sql"), (eid, config_key, value))
|
||||||
|
await self.conn.commit()
|
||||||
|
|
||||||
|
async def get_entity_id_batch(
|
||||||
|
self, entities: list[PermEntity]
|
||||||
|
) -> dict[PermEntity, int]:
|
||||||
|
"""批量获取 Entity 的 entity_id
|
||||||
|
|
||||||
|
Args:
|
||||||
|
entities: PermEntity 列表
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
字典,键为 PermEntity,值为对应的 ID
|
||||||
|
"""
|
||||||
|
|
||||||
|
# for entity in entities:
|
||||||
|
# await self.conn.execute(
|
||||||
|
# s("create_entity.sql"),
|
||||||
|
# (entity.platform, entity.entity_type, entity.external_id),
|
||||||
|
# )
|
||||||
|
await self.conn.executemany(
|
||||||
|
s("create_entity.sql"),
|
||||||
|
[(e.platform, e.entity_type, e.external_id) for e in entities],
|
||||||
|
)
|
||||||
|
await self.conn.commit()
|
||||||
|
val_placeholders = ", ".join(["(?, ?, ?)"] * len(entities))
|
||||||
|
params = []
|
||||||
|
for e in entities:
|
||||||
|
params.extend([e.platform, e.entity_type, e.external_id])
|
||||||
|
cursor = await self.conn.execute(
|
||||||
|
f"""
|
||||||
|
SELECT id, platform, entity_type, external_id
|
||||||
|
FROM perm_entity
|
||||||
|
WHERE (platform, entity_type, external_id) IN (VALUES {val_placeholders});
|
||||||
|
""",
|
||||||
|
params,
|
||||||
|
)
|
||||||
|
rows = await cursor.fetchall()
|
||||||
|
return {PermEntity(row[1], row[2], row[3]): row[0] for row in rows}
|
||||||
|
|
||||||
|
async def get_perm_info_batch(
|
||||||
|
self, entities: list[PermEntity], config_keys: list[str]
|
||||||
|
) -> dict[tuple[PermEntity, str], bool]:
|
||||||
|
"""批量获取权限信息
|
||||||
|
|
||||||
|
Args:
|
||||||
|
entities: PermEntity 列表
|
||||||
|
config_keys: 查询的键列表
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
字典,键是 PermEntity 和 config_key 的元组,值是布尔,过滤掉所有空值
|
||||||
|
"""
|
||||||
|
entity_ids = {
|
||||||
|
v: k for k, v in (await self.get_entity_id_batch(entities)).items()
|
||||||
|
}
|
||||||
|
placeholders1 = ", ".join("?" * len(entity_ids))
|
||||||
|
placeholders2 = ", ".join("?" * len(config_keys))
|
||||||
|
sql = f"""
|
||||||
|
SELECT entity_id, config_key, value
|
||||||
|
FROM perm_info
|
||||||
|
WHERE entity_id IN ({placeholders1})
|
||||||
|
AND config_key IN ({placeholders2})
|
||||||
|
AND value IS NOT NULL;
|
||||||
|
"""
|
||||||
|
|
||||||
|
params = tuple(entity_ids.keys()) + tuple(config_keys)
|
||||||
|
cursor = await self.conn.execute(sql, params)
|
||||||
|
rows = await cursor.fetchall()
|
||||||
|
|
||||||
|
return {(entity_ids[row[0]], row[1]): bool(row[2]) for row in rows}
|
||||||
|
|
||||||
|
async def list_perm_info_batch(
|
||||||
|
self, entities: list[PermEntity], pager: PagerQuery
|
||||||
|
) -> PagerResult[tuple[PermEntity, str, bool]]:
|
||||||
|
"""批量获取某个实体的权限信息
|
||||||
|
|
||||||
|
Args:
|
||||||
|
entities: PermEntity 列表
|
||||||
|
pager: PagerQuery 对象,即分页要求
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
字典,键是 PermEntity,值是权限条目和布尔的元组,过滤掉所有空值
|
||||||
|
"""
|
||||||
|
entity_to_id = await self.get_entity_id_batch(entities)
|
||||||
|
id_to_entity = {v: k for k, v in entity_to_id.items()}
|
||||||
|
ordered_ids = [entity_to_id[e] for e in entities if e in entity_to_id]
|
||||||
|
|
||||||
|
placeholders = ", ".join("?" * len(ordered_ids))
|
||||||
|
order_by_cases = " ".join([f"WHEN ? THEN {i}" for i in range(len(ordered_ids))])
|
||||||
|
|
||||||
|
pagecount_sql = f"SELECT COUNT(*) FROM perm_info WHERE entity_id IN ({placeholders}) AND value IS NOT NULL;"
|
||||||
|
count_cursor = await self.conn.execute(pagecount_sql, tuple(ordered_ids))
|
||||||
|
total_count = (await count_cursor.fetchone() or (0,))[0]
|
||||||
|
|
||||||
|
sql = f"""
|
||||||
|
SELECT entity_id, config_key, value
|
||||||
|
FROM perm_info
|
||||||
|
WHERE entity_id IN ({placeholders})
|
||||||
|
AND value IS NOT NULL
|
||||||
|
ORDER BY
|
||||||
|
(CASE entity_id {order_by_cases} END) ASC,
|
||||||
|
config_key ASC
|
||||||
|
LIMIT ?
|
||||||
|
OFFSET ?;
|
||||||
|
"""
|
||||||
|
|
||||||
|
params = (
|
||||||
|
tuple(ordered_ids)
|
||||||
|
+ tuple(ordered_ids)
|
||||||
|
+ (
|
||||||
|
pager.page_size,
|
||||||
|
(pager.page_index - 1) * pager.page_size,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
cursor = await self.conn.execute(sql, params)
|
||||||
|
rows = await cursor.fetchall()
|
||||||
|
|
||||||
|
# return {entity_ids[row[0]]: (row[1], bool(row[2])) for row in rows}
|
||||||
|
return PagerResult(
|
||||||
|
data=[(id_to_entity[row[0]], row[1], row[2]) for row in rows],
|
||||||
|
success=True,
|
||||||
|
message="",
|
||||||
|
page_count=math.ceil(total_count / pager.page_size),
|
||||||
|
query=pager,
|
||||||
|
)
|
||||||
11
konabot/common/permsys/sql/create_entity.sql
Normal file
11
konabot/common/permsys/sql/create_entity.sql
Normal file
@ -0,0 +1,11 @@
|
|||||||
|
INSERT
|
||||||
|
OR IGNORE INTO perm_entity(
|
||||||
|
platform,
|
||||||
|
entity_type,
|
||||||
|
external_id
|
||||||
|
)
|
||||||
|
VALUES(
|
||||||
|
?,
|
||||||
|
?,
|
||||||
|
?
|
||||||
|
);
|
||||||
8
konabot/common/permsys/sql/get_entity_id.sql
Normal file
8
konabot/common/permsys/sql/get_entity_id.sql
Normal file
@ -0,0 +1,8 @@
|
|||||||
|
SELECT
|
||||||
|
id
|
||||||
|
FROM
|
||||||
|
perm_entity
|
||||||
|
WHERE
|
||||||
|
perm_entity.platform = ?
|
||||||
|
AND perm_entity.entity_type = ?
|
||||||
|
AND perm_entity.external_id = ?;
|
||||||
7
konabot/common/permsys/sql/get_perm_info.sql
Normal file
7
konabot/common/permsys/sql/get_perm_info.sql
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
SELECT
|
||||||
|
VALUE
|
||||||
|
FROM
|
||||||
|
perm_info
|
||||||
|
WHERE
|
||||||
|
entity_id = ?
|
||||||
|
AND config_key = ?;
|
||||||
4
konabot/common/permsys/sql/update_perm_info.sql
Normal file
4
konabot/common/permsys/sql/update_perm_info.sql
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
INSERT INTO perm_info (entity_id, config_key, value)
|
||||||
|
VALUES (?, ?, ?)
|
||||||
|
ON CONFLICT(entity_id, config_key)
|
||||||
|
DO UPDATE SET value=excluded.value;
|
||||||
212
konabot/docs/sys/konaperm.txt
Normal file
212
konabot/docs/sys/konaperm.txt
Normal file
@ -0,0 +1,212 @@
|
|||||||
|
# 指令介绍
|
||||||
|
|
||||||
|
`konaperm` - 用于查看和修改 Bot 内部权限系统记录的管理员指令
|
||||||
|
|
||||||
|
## 权限要求
|
||||||
|
|
||||||
|
只有拥有 `admin` 权限的主体才能使用本指令。
|
||||||
|
|
||||||
|
## 格式
|
||||||
|
|
||||||
|
```text
|
||||||
|
konaperm list <platform> <entity_type> <external_id> [page]
|
||||||
|
konaperm get <platform> <entity_type> <external_id> <perm>
|
||||||
|
konaperm set <platform> <entity_type> <external_id> <perm> <val>
|
||||||
|
```
|
||||||
|
|
||||||
|
## 子命令说明
|
||||||
|
|
||||||
|
### `list`
|
||||||
|
|
||||||
|
列出指定对象及其继承链上的显式权限记录,按分页输出。
|
||||||
|
|
||||||
|
参数:
|
||||||
|
|
||||||
|
- `platform` 平台名,如 `ob11`、`discord`、`sys`
|
||||||
|
- `entity_type` 对象类型,如 `user`、`group`、`global`
|
||||||
|
- `external_id` 平台侧对象 ID;全局对象通常写 `global`
|
||||||
|
- `page` 页码,可省略,默认 `1`
|
||||||
|
|
||||||
|
### `get`
|
||||||
|
|
||||||
|
查询某个对象对指定权限的最终判断结果,并说明它是从哪一层继承来的。
|
||||||
|
|
||||||
|
参数:
|
||||||
|
|
||||||
|
- `platform`
|
||||||
|
- `entity_type`
|
||||||
|
- `external_id`
|
||||||
|
- `perm` 权限键,如 `admin`、`plugin.xxx.use`
|
||||||
|
|
||||||
|
### `set`
|
||||||
|
|
||||||
|
为指定对象写入显式权限。
|
||||||
|
|
||||||
|
参数:
|
||||||
|
|
||||||
|
- `platform`
|
||||||
|
- `entity_type`
|
||||||
|
- `external_id`
|
||||||
|
- `perm` 权限键
|
||||||
|
- `val` 设置值
|
||||||
|
|
||||||
|
`val` 支持以下写法:
|
||||||
|
|
||||||
|
- 允许:`y` `yes` `allow` `true` `t`
|
||||||
|
- 拒绝:`n` `no` `deny` `false` `f`
|
||||||
|
- 清除:`null` `none`
|
||||||
|
|
||||||
|
其中:
|
||||||
|
|
||||||
|
- 允许 表示显式授予该权限
|
||||||
|
- 拒绝 表示显式禁止该权限
|
||||||
|
- 清除 表示删除该层的显式设置,重新回退到继承链判断
|
||||||
|
|
||||||
|
## 对象格式
|
||||||
|
|
||||||
|
本指令操作的对象由三段组成:
|
||||||
|
|
||||||
|
```text
|
||||||
|
<platform>.<entity_type>.<external_id>
|
||||||
|
```
|
||||||
|
|
||||||
|
例如:
|
||||||
|
|
||||||
|
- `ob11.user.123456789`
|
||||||
|
- `ob11.group.987654321`
|
||||||
|
- `sys.global.global`
|
||||||
|
|
||||||
|
## 当前支持的 `PermEntity` 值
|
||||||
|
|
||||||
|
以下内容按当前实现整理,便于手工查询和设置权限。
|
||||||
|
|
||||||
|
### `sys`
|
||||||
|
|
||||||
|
- `sys.global.global`
|
||||||
|
|
||||||
|
这是系统总兜底对象。
|
||||||
|
|
||||||
|
### `ob11`
|
||||||
|
|
||||||
|
- `ob11.global.global`
|
||||||
|
- `ob11.group.<group_id>`
|
||||||
|
- `ob11.user.<user_id>`
|
||||||
|
|
||||||
|
常见场景:
|
||||||
|
|
||||||
|
- 给整个 OneBot V11 平台统一授权:`ob11.global.global`
|
||||||
|
- 给某个 QQ 群授权:`ob11.group.群号`
|
||||||
|
- 给某个 QQ 用户授权:`ob11.user.QQ号`
|
||||||
|
|
||||||
|
### `discord`
|
||||||
|
|
||||||
|
- `discord.global.global`
|
||||||
|
- `discord.guild.<guild_id>`
|
||||||
|
- `discord.channel.<channel_id>`
|
||||||
|
- `discord.user.<user_id>`
|
||||||
|
|
||||||
|
常见场景:
|
||||||
|
|
||||||
|
- 给整个 Discord 平台统一授权:`discord.global.global`
|
||||||
|
- 给某个服务器授权:`discord.guild.服务器ID`
|
||||||
|
- 给某个频道授权:`discord.channel.频道ID`
|
||||||
|
- 给某个用户授权:`discord.user.用户ID`
|
||||||
|
|
||||||
|
### `minecraft`
|
||||||
|
|
||||||
|
- `minecraft.global.global`
|
||||||
|
- `minecraft.server.<server_name>`
|
||||||
|
- `minecraft.player.<player_uuid_hex>`
|
||||||
|
|
||||||
|
常见场景:
|
||||||
|
|
||||||
|
- 给整个 Minecraft 平台统一授权:`minecraft.global.global`
|
||||||
|
- 给某个服务器授权:`minecraft.server.服务器名`
|
||||||
|
- 给某个玩家授权:`minecraft.player.玩家UUID的hex`
|
||||||
|
|
||||||
|
### `console`
|
||||||
|
|
||||||
|
- `console.global.global`
|
||||||
|
- `console.channel.<channel_id>`
|
||||||
|
- `console.user.<user_id>`
|
||||||
|
|
||||||
|
### 快速参考
|
||||||
|
|
||||||
|
```text
|
||||||
|
sys.global.global
|
||||||
|
|
||||||
|
ob11.global.global
|
||||||
|
ob11.group.<group_id>
|
||||||
|
ob11.user.<user_id>
|
||||||
|
|
||||||
|
discord.global.global
|
||||||
|
discord.guild.<guild_id>
|
||||||
|
discord.channel.<channel_id>
|
||||||
|
discord.user.<user_id>
|
||||||
|
|
||||||
|
minecraft.global.global
|
||||||
|
minecraft.server.<server_name>
|
||||||
|
minecraft.player.<player_uuid_hex>
|
||||||
|
|
||||||
|
console.global.global
|
||||||
|
console.channel.<channel_id>
|
||||||
|
console.user.<user_id>
|
||||||
|
```
|
||||||
|
|
||||||
|
## 权限继承
|
||||||
|
|
||||||
|
权限不是只看当前对象,还会按继承链回退。
|
||||||
|
|
||||||
|
例如对 `ob11.user.123456` 查询时,通常会从更具体的对象一路回退到:
|
||||||
|
|
||||||
|
1. 当前用户
|
||||||
|
2. 平台全局对象
|
||||||
|
3. 系统全局对象
|
||||||
|
|
||||||
|
权限键本身也支持逐级回退。比如查询 `plugin.demo.use` 时,可能依次命中:
|
||||||
|
|
||||||
|
1. `plugin.demo.use`
|
||||||
|
2. `plugin.demo`
|
||||||
|
3. `plugin`
|
||||||
|
4. `*`
|
||||||
|
|
||||||
|
所以 `get` 返回的结果可能来自更宽泛的权限键,或更上层的继承对象。
|
||||||
|
|
||||||
|
## 示例
|
||||||
|
|
||||||
|
```text
|
||||||
|
konaperm list ob11 user 123456
|
||||||
|
```
|
||||||
|
|
||||||
|
查看 `ob11.user.123456` 及其继承链上的权限记录第一页。
|
||||||
|
|
||||||
|
```text
|
||||||
|
konaperm get ob11 user 123456 admin
|
||||||
|
```
|
||||||
|
|
||||||
|
查看该用户最终是否拥有 `admin` 权限,以及命中来源。
|
||||||
|
|
||||||
|
```text
|
||||||
|
konaperm set ob11 user 123456 admin allow
|
||||||
|
```
|
||||||
|
|
||||||
|
显式授予该用户 `admin` 权限。
|
||||||
|
|
||||||
|
```text
|
||||||
|
konaperm set ob11 user 123456 admin deny
|
||||||
|
```
|
||||||
|
|
||||||
|
显式拒绝该用户 `admin` 权限。
|
||||||
|
|
||||||
|
```text
|
||||||
|
konaperm set ob11 user 123456 admin none
|
||||||
|
```
|
||||||
|
|
||||||
|
删除该用户这一层对 `admin` 的显式设置,恢复继承判断。
|
||||||
|
|
||||||
|
## 注意事项
|
||||||
|
|
||||||
|
- 这是系统级管理指令,误操作可能直接影响其他插件的权限控制。
|
||||||
|
- `list` 只列出显式记录;没有显示出来不代表最终一定无权限,可能是从上层继承。
|
||||||
|
- `get` 显示的是最终命中的结果,比 `list` 更适合排查“为什么有/没有某个权限”。
|
||||||
|
- 对 `admin` 或 `*` 这类高影响权限做修改前,建议先确认对象是否写对。
|
||||||
112
konabot/plugins/perm_manage/__init__.py
Normal file
112
konabot/plugins/perm_manage/__init__.py
Normal file
@ -0,0 +1,112 @@
|
|||||||
|
from typing import Annotated
|
||||||
|
from nonebot.adapters import Event
|
||||||
|
from nonebot.params import Depends
|
||||||
|
from nonebot_plugin_alconna import Alconna, Args, Subcommand, UniMessage, on_alconna
|
||||||
|
from konabot.common.pager import PagerQuery
|
||||||
|
from konabot.common.permsys import DepPermManager, require_permission
|
||||||
|
from konabot.common.permsys.entity import PermEntity, get_entity_chain_of_entity
|
||||||
|
|
||||||
|
|
||||||
|
cmd = on_alconna(
|
||||||
|
Alconna(
|
||||||
|
"konaperm",
|
||||||
|
Subcommand(
|
||||||
|
"list",
|
||||||
|
Args["platform", str],
|
||||||
|
Args["entity_type", str],
|
||||||
|
Args["external_id", str],
|
||||||
|
Args["page?", int],
|
||||||
|
),
|
||||||
|
Subcommand(
|
||||||
|
"get",
|
||||||
|
Args["platform", str],
|
||||||
|
Args["entity_type", str],
|
||||||
|
Args["external_id", str],
|
||||||
|
Args["perm", str],
|
||||||
|
),
|
||||||
|
Subcommand(
|
||||||
|
"set",
|
||||||
|
Args["platform", str],
|
||||||
|
Args["entity_type", str],
|
||||||
|
Args["external_id", str],
|
||||||
|
Args["perm", str],
|
||||||
|
Args["val", str],
|
||||||
|
),
|
||||||
|
),
|
||||||
|
rule=require_permission("admin"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def _get_perm_entity_chain(platform: str, entity_type: str, external_id: str):
|
||||||
|
return get_entity_chain_of_entity(PermEntity(platform, entity_type, external_id))
|
||||||
|
|
||||||
|
|
||||||
|
_DepEntityChain = Annotated[list[PermEntity], Depends(_get_perm_entity_chain)]
|
||||||
|
|
||||||
|
|
||||||
|
def make_formatter(parent: PermEntity):
|
||||||
|
def _formatter(d: tuple[PermEntity, str, bool]):
|
||||||
|
permmark = {True: "[✅ ALLOW] ", False: "[❌ DENY] "}[d[2]]
|
||||||
|
inheritmark = ""
|
||||||
|
if parent != d[0]:
|
||||||
|
inheritmark = (
|
||||||
|
f"[继承自 {d[0].platform}.{d[0].entity_type}.{d[0].external_id}] "
|
||||||
|
)
|
||||||
|
return f"{permmark}{inheritmark}{d[1]}"
|
||||||
|
|
||||||
|
return _formatter
|
||||||
|
|
||||||
|
|
||||||
|
@cmd.assign("list")
|
||||||
|
async def list_permission(
|
||||||
|
pm: DepPermManager,
|
||||||
|
ec: _DepEntityChain,
|
||||||
|
event: Event,
|
||||||
|
page: int = 1,
|
||||||
|
):
|
||||||
|
pq = PagerQuery(page, 10)
|
||||||
|
data = await pm.list_permission(ec, pq)
|
||||||
|
msg = data.to_unimessage(make_formatter(ec[0]))
|
||||||
|
await msg.send(event)
|
||||||
|
|
||||||
|
|
||||||
|
@cmd.assign("get")
|
||||||
|
async def get_permission(
|
||||||
|
pm: DepPermManager,
|
||||||
|
ec: _DepEntityChain,
|
||||||
|
perm: str,
|
||||||
|
event: Event,
|
||||||
|
):
|
||||||
|
data = await pm.check_has_permission_info(ec, perm)
|
||||||
|
|
||||||
|
obj_s = f"{ec[0].platform}.{ec[0].entity_type}.{ec[0].external_id}"
|
||||||
|
|
||||||
|
if data is None:
|
||||||
|
await UniMessage.text(f"对象 {obj_s} 无 {perm} 权限记录").send(event)
|
||||||
|
return
|
||||||
|
pe, k, p = data
|
||||||
|
inheritmark = ""
|
||||||
|
if ec[0] != pe or k != perm:
|
||||||
|
inheritmark = (
|
||||||
|
f"继承自 {pe.platform}.{pe.entity_type}.{pe.external_id} 对 {k} 的设置,"
|
||||||
|
)
|
||||||
|
await UniMessage.text(f"{inheritmark}对象 {obj_s} 对 {perm} 的权限为 {p}").send(
|
||||||
|
event
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@cmd.assign("set")
|
||||||
|
async def set_permission(
|
||||||
|
pm: DepPermManager,
|
||||||
|
ec: _DepEntityChain,
|
||||||
|
perm: str,
|
||||||
|
val: str,
|
||||||
|
event: Event,
|
||||||
|
):
|
||||||
|
if any(i == val.lower() for i in ("y", "yes", "allow", "true", "t")):
|
||||||
|
await pm.update_permission(ec[0], perm, True)
|
||||||
|
elif any(i == val.lower() for i in ("n", "no", "deny", "false", "f")):
|
||||||
|
await pm.update_permission(ec[0], perm, False)
|
||||||
|
elif any(i == val.lower() for i in ("null", "none")):
|
||||||
|
await pm.update_permission(ec[0], perm, None)
|
||||||
|
await get_permission(pm, ec, perm, event)
|
||||||
2723
poetry.lock
generated
2723
poetry.lock
generated
File diff suppressed because it is too large
Load Diff
@ -34,6 +34,9 @@ dependencies = [
|
|||||||
"shapely (>=2.1.2,<3.0.0)",
|
"shapely (>=2.1.2,<3.0.0)",
|
||||||
"mcstatus (>=12.2.1,<13.0.0)",
|
"mcstatus (>=12.2.1,<13.0.0)",
|
||||||
"borax (>=4.1.3,<5.0.0)",
|
"borax (>=4.1.3,<5.0.0)",
|
||||||
|
"pytest (>=8.0.0,<9.0.0)",
|
||||||
|
"nonebug (>=0.4.3,<0.5.0)",
|
||||||
|
"pytest-cov (>=7.0.0,<8.0.0)",
|
||||||
]
|
]
|
||||||
|
|
||||||
[tool.poetry]
|
[tool.poetry]
|
||||||
@ -52,8 +55,15 @@ priority = "primary"
|
|||||||
|
|
||||||
|
|
||||||
[dependency-groups]
|
[dependency-groups]
|
||||||
dev = [
|
dev = ["rust-just (>=1.43.0,<2.0.0)", "pytest-asyncio (>=1.3.0,<2.0.0)"]
|
||||||
"rust-just (>=1.43.0,<2.0.0)",
|
|
||||||
"pytest (>=9.0.1,<10.0.0)",
|
[tool.pytest.ini_options]
|
||||||
"pytest-asyncio (>=1.3.0,<2.0.0)"
|
testpaths = "tests"
|
||||||
]
|
python_files = "test_*.py"
|
||||||
|
asyncio_mode = "auto"
|
||||||
|
asyncio_default_fixture_loop_scope = "session"
|
||||||
|
addopts = "--cov=./konabot/"
|
||||||
|
|
||||||
|
[tool.nonebot]
|
||||||
|
# plugin_dirs = ["konabot/plugins/"]
|
||||||
|
plugin_dirs = []
|
||||||
|
|||||||
28
tests/conftest.py
Normal file
28
tests/conftest.py
Normal file
@ -0,0 +1,28 @@
|
|||||||
|
# 文件内容来源:
|
||||||
|
# https://nonebot.dev/docs/best-practice/testing/
|
||||||
|
# 保证 nonebug 测试框架正常运作
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import nonebot
|
||||||
|
from pytest_asyncio import is_async_test
|
||||||
|
from nonebot.adapters.console import Adapter as ConsoleAdapter
|
||||||
|
from nonebug import NONEBOT_START_LIFESPAN
|
||||||
|
|
||||||
|
|
||||||
|
def pytest_collection_modifyitems(items: list[pytest.Item]):
|
||||||
|
pytest_asyncio_tests = (item for item in items if is_async_test(item))
|
||||||
|
session_scope_marker = pytest.mark.asyncio(loop_scope="session")
|
||||||
|
for async_test in pytest_asyncio_tests:
|
||||||
|
async_test.add_marker(session_scope_marker, append=False)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session", autouse=True)
|
||||||
|
async def after_nonebot_init(after_nonebot_init: None):
|
||||||
|
driver = nonebot.get_driver()
|
||||||
|
driver.register_adapter(ConsoleAdapter)
|
||||||
|
|
||||||
|
nonebot.load_from_toml("pyproject.toml")
|
||||||
|
|
||||||
|
|
||||||
|
def pytest_configure(config: pytest.Config):
|
||||||
|
config.stash[NONEBOT_START_LIFESPAN] = True
|
||||||
@ -1,4 +1,3 @@
|
|||||||
import asyncio
|
|
||||||
import os
|
import os
|
||||||
import tempfile
|
import tempfile
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@ -12,13 +11,13 @@ from konabot.common.database import DatabaseManager
|
|||||||
async def test_database_manager():
|
async def test_database_manager():
|
||||||
"""测试数据库管理器的基本功能"""
|
"""测试数据库管理器的基本功能"""
|
||||||
# 创建临时数据库文件
|
# 创建临时数据库文件
|
||||||
with tempfile.NamedTemporaryFile(suffix='.db', delete=False) as tmp_file:
|
with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as tmp_file:
|
||||||
db_path = tmp_file.name
|
db_path = tmp_file.name
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 初始化数据库管理器
|
# 初始化数据库管理器
|
||||||
db_manager = DatabaseManager(db_path)
|
db_manager = DatabaseManager(db_path)
|
||||||
|
|
||||||
# 创建测试表
|
# 创建测试表
|
||||||
create_table_sql = """
|
create_table_sql = """
|
||||||
CREATE TABLE IF NOT EXISTS test_users (
|
CREATE TABLE IF NOT EXISTS test_users (
|
||||||
@ -28,26 +27,27 @@ async def test_database_manager():
|
|||||||
);
|
);
|
||||||
"""
|
"""
|
||||||
await db_manager.execute(create_table_sql)
|
await db_manager.execute(create_table_sql)
|
||||||
|
|
||||||
# 插入测试数据
|
# 插入测试数据
|
||||||
insert_sql = "INSERT INTO test_users (name, email) VALUES (?, ?)"
|
insert_sql = "INSERT INTO test_users (name, email) VALUES (?, ?)"
|
||||||
await db_manager.execute(insert_sql, ("张三", "zhangsan@example.com"))
|
await db_manager.execute(insert_sql, ("张三", "zhangsan@example.com"))
|
||||||
await db_manager.execute(insert_sql, ("李四", "lisi@example.com"))
|
await db_manager.execute(insert_sql, ("李四", "lisi@example.com"))
|
||||||
|
|
||||||
# 查询数据
|
# 查询数据
|
||||||
select_sql = "SELECT * FROM test_users WHERE name = ?"
|
select_sql = "SELECT * FROM test_users WHERE name = ?"
|
||||||
results = await db_manager.query(select_sql, ("张三",))
|
results = await db_manager.query(select_sql, ("张三",))
|
||||||
assert len(results) == 1
|
assert len(results) == 1
|
||||||
assert results[0]["name"] == "张三"
|
assert results[0]["name"] == "张三"
|
||||||
assert results[0]["email"] == "zhangsan@example.com"
|
assert results[0]["email"] == "zhangsan@example.com"
|
||||||
|
|
||||||
# 测试使用Path对象
|
# 测试使用Path对象
|
||||||
results = await db_manager.query_by_sql_file(Path(__file__), ("李四",))
|
# results = await db_manager.query_by_sql_file(Path(__file__), ("李四",))
|
||||||
# 注意:这里只是测试参数传递,实际SQL文件内容不是有效的SQL
|
# 注意:这里只是测试参数传递,实际SQL文件内容不是有效的SQL
|
||||||
|
## ^^^ 卧了个槽的坏枪,你让 AI 写单元测试不检查一下吗
|
||||||
|
|
||||||
# 关闭所有连接
|
# 关闭所有连接
|
||||||
await db_manager.close_all_connections()
|
await db_manager.close_all_connections()
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
# 清理临时文件
|
# 清理临时文件
|
||||||
if os.path.exists(db_path):
|
if os.path.exists(db_path):
|
||||||
@ -58,13 +58,13 @@ async def test_database_manager():
|
|||||||
async def test_execute_script():
|
async def test_execute_script():
|
||||||
"""测试执行SQL脚本功能"""
|
"""测试执行SQL脚本功能"""
|
||||||
# 创建临时数据库文件
|
# 创建临时数据库文件
|
||||||
with tempfile.NamedTemporaryFile(suffix='.db', delete=False) as tmp_file:
|
with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as tmp_file:
|
||||||
db_path = tmp_file.name
|
db_path = tmp_file.name
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 初始化数据库管理器
|
# 初始化数据库管理器
|
||||||
db_manager = DatabaseManager(db_path)
|
db_manager = DatabaseManager(db_path)
|
||||||
|
|
||||||
# 创建测试表的脚本
|
# 创建测试表的脚本
|
||||||
script = """
|
script = """
|
||||||
CREATE TABLE IF NOT EXISTS test_products (
|
CREATE TABLE IF NOT EXISTS test_products (
|
||||||
@ -75,19 +75,19 @@ async def test_execute_script():
|
|||||||
INSERT INTO test_products (name, price) VALUES ('苹果', 5.0);
|
INSERT INTO test_products (name, price) VALUES ('苹果', 5.0);
|
||||||
INSERT INTO test_products (name, price) VALUES ('香蕉', 3.0);
|
INSERT INTO test_products (name, price) VALUES ('香蕉', 3.0);
|
||||||
"""
|
"""
|
||||||
|
|
||||||
await db_manager.execute_script(script)
|
await db_manager.execute_script(script)
|
||||||
|
|
||||||
# 查询数据
|
# 查询数据
|
||||||
results = await db_manager.query("SELECT * FROM test_products ORDER BY name")
|
results = await db_manager.query("SELECT * FROM test_products ORDER BY name")
|
||||||
assert len(results) == 2
|
assert len(results) == 2
|
||||||
assert results[0]["name"] == "苹果"
|
assert results[0]["name"] == "苹果"
|
||||||
assert results[1]["name"] == "香蕉"
|
assert results[1]["name"] == "香蕉"
|
||||||
|
|
||||||
# 关闭所有连接
|
# 关闭所有连接
|
||||||
await db_manager.close_all_connections()
|
await db_manager.close_all_connections()
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
# 清理临时文件
|
# 清理临时文件
|
||||||
if os.path.exists(db_path):
|
if os.path.exists(db_path):
|
||||||
os.unlink(db_path)
|
os.unlink(db_path)
|
||||||
|
|||||||
105
tests/test_permsys.py
Normal file
105
tests/test_permsys.py
Normal file
@ -0,0 +1,105 @@
|
|||||||
|
from contextlib import asynccontextmanager
|
||||||
|
from pathlib import Path
|
||||||
|
from tempfile import TemporaryDirectory
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from konabot.common.database import DatabaseManager
|
||||||
|
from konabot.common.permsys import PermManager
|
||||||
|
from konabot.common.permsys.entity import PermEntity
|
||||||
|
from konabot.common.permsys.migrates import execute_migration, get_current_version
|
||||||
|
from konabot.common.permsys.repo import PermRepo
|
||||||
|
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def tempdb():
|
||||||
|
with TemporaryDirectory() as _tempdir:
|
||||||
|
tempdir = Path(_tempdir)
|
||||||
|
db = DatabaseManager(tempdir / "perm.sqlite3")
|
||||||
|
yield db
|
||||||
|
await db.close_all_connections()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_db_version():
|
||||||
|
async with tempdb() as db:
|
||||||
|
async with db.get_conn() as conn:
|
||||||
|
v = await get_current_version(conn)
|
||||||
|
assert v == 0
|
||||||
|
v = await get_current_version(conn)
|
||||||
|
assert v == 0
|
||||||
|
await execute_migration(conn, version=1)
|
||||||
|
v = await get_current_version(conn)
|
||||||
|
assert v == 1
|
||||||
|
await execute_migration(conn, version=0)
|
||||||
|
v = await get_current_version(conn)
|
||||||
|
assert v == 0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_perm():
|
||||||
|
async with tempdb() as db:
|
||||||
|
async with db.get_conn() as conn:
|
||||||
|
await execute_migration(conn)
|
||||||
|
|
||||||
|
service = PermManager(db)
|
||||||
|
entity_global = PermEntity("sys", "global", "global")
|
||||||
|
entity1 = PermEntity("nonexist-platform", "user", "passthem")
|
||||||
|
chain1 = [entity1, entity_global]
|
||||||
|
entity2 = PermEntity("nonexist-platform", "user", "jack")
|
||||||
|
chain2 = [entity2, entity_global]
|
||||||
|
|
||||||
|
async with db.get_conn() as conn:
|
||||||
|
repo = PermRepo(conn)
|
||||||
|
|
||||||
|
# 测试使用内置方法会创建 Entity 在数据库
|
||||||
|
assert await repo._get_entity_id_or_none(entity1) is None
|
||||||
|
assert await repo.get_entity_id(entity1) is not None
|
||||||
|
assert await repo._get_entity_id_or_none(entity1) is not None
|
||||||
|
|
||||||
|
# 测试使用内置方法获得 perm_info
|
||||||
|
assert await repo.get_perm_info(entity1, "module1") is None
|
||||||
|
|
||||||
|
assert not await service.check_has_permission(chain1, "*")
|
||||||
|
|
||||||
|
await service.update_permission(entity1, "*", True)
|
||||||
|
assert await service.check_has_permission(chain1, "*")
|
||||||
|
assert await service.check_has_permission(chain1, "module1")
|
||||||
|
assert await service.check_has_permission(chain1, "module1.pack1")
|
||||||
|
assert not await service.check_has_permission(chain2, "*")
|
||||||
|
assert not await service.check_has_permission(chain2, "module1")
|
||||||
|
assert not await service.check_has_permission(chain2, "module1.pack1")
|
||||||
|
|
||||||
|
await service.update_permission(entity2, "module1", True)
|
||||||
|
assert not await service.check_has_permission(chain2, "*")
|
||||||
|
assert await service.check_has_permission(chain2, "module1")
|
||||||
|
assert await service.check_has_permission(chain2, "module1.pack1")
|
||||||
|
assert await service.check_has_permission(chain2, "module1.pack2")
|
||||||
|
assert not await service.check_has_permission(chain2, "module2")
|
||||||
|
assert not await service.check_has_permission(chain2, "module2.pack1")
|
||||||
|
assert not await service.check_has_permission(chain2, "module2.pack2")
|
||||||
|
|
||||||
|
await service.update_permission(entity2, "module1.pack2", False)
|
||||||
|
assert not await service.check_has_permission(chain2, "*")
|
||||||
|
assert await service.check_has_permission(chain2, "module1")
|
||||||
|
assert await service.check_has_permission(chain2, "module1.pack1")
|
||||||
|
assert not await service.check_has_permission(chain2, "module1.pack2")
|
||||||
|
assert not await service.check_has_permission(chain2, "module2")
|
||||||
|
assert not await service.check_has_permission(chain2, "module2.pack1")
|
||||||
|
assert not await service.check_has_permission(chain2, "module2.pack2")
|
||||||
|
|
||||||
|
await service.update_permission(entity_global, "module2", True)
|
||||||
|
assert not await service.check_has_permission(chain2, "*")
|
||||||
|
assert await service.check_has_permission(chain2, "module1")
|
||||||
|
assert await service.check_has_permission(chain2, "module1.pack1")
|
||||||
|
assert not await service.check_has_permission(chain2, "module1.pack2")
|
||||||
|
assert await service.check_has_permission(chain2, "module2")
|
||||||
|
assert await service.check_has_permission(chain2, "module2.pack1")
|
||||||
|
assert await service.check_has_permission(chain2, "module2.pack2")
|
||||||
|
|
||||||
|
assert not await service.check_has_permission(entity2, "module2.pack2")
|
||||||
|
assert await service.check_has_permission(entity_global, "module2.pack2")
|
||||||
|
|
||||||
|
async with db.get_conn() as conn:
|
||||||
|
repo = PermRepo(conn)
|
||||||
|
assert await repo.get_perm_info(entity2, "module1") is True
|
||||||
|
assert await repo.get_perm_info(entity2, "module1.pack2") is False
|
||||||
Reference in New Issue
Block a user