Compare commits
147 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| c2161635a8 | |||
| f21b7067df | |||
| f7212d6f67 | |||
| b87c58485c | |||
| 51c0bf4229 | |||
| 5b1c6d446c | |||
| 717b7a95e8 | |||
| 9bac2b8cdf | |||
| bfb8ebab29 | |||
| 805e60a9ff | |||
| 1331f8f893 | |||
| 00f42dbdf1 | |||
| d37c4870d8 | |||
| 23b9f101b3 | |||
|
8c1651ad3d
|
|||
| ff60642c62 | |||
| 69b5908445 | |||
| a542ed1fd9 | |||
| e86a385448 | |||
| d4bb36a074 | |||
| 1a2a3c0468 | |||
| 67502cb932 | |||
| f9a312b80a | |||
| 1980f8a895 | |||
|
d273ed4b1a
|
|||
|
265e9cc583
|
|||
|
8f5061ba41
|
|||
|
b3c3c77f3c
|
|||
|
6a84ce2cd8
|
|||
|
392c699b33
|
|||
|
72e21cd9aa
|
|||
|
f3389ff2b9
|
|||
|
e59d3c2e4b
|
|||
|
31d19b7ec0
|
|||
|
c2f677911d
|
|||
|
f5b81319f8
|
|||
|
870e2383d8
|
|||
| 7e8fa45f36 | |||
|
abb864ec70
|
|||
|
b38dde1b70
|
|||
|
8f40572a38
|
|||
|
230705f689
|
|||
|
e605527900
|
|||
|
9064b31fe9
|
|||
|
27e53c7acd
|
|||
|
ca1db103b5
|
|||
|
7f1035ff43
|
|||
|
5e0e39bfc3
|
|||
|
88861f4264
|
|||
|
a1c9f9bccb
|
|||
|
f6601f807a
|
|||
|
f7cea196ec
|
|||
|
d4826e9e8b
|
|||
|
33934ef7b5
|
|||
|
f9f8ae4e67
|
|||
| 94db34037b | |||
|
df409a13a9
|
|||
|
34175e8c17
|
|||
| c66576e12b | |||
|
91769f93ae
|
|||
|
27841b8422
|
|||
|
48282ceb6c
|
|||
| 00c0202720 | |||
|
3ddf81e7de
|
|||
|
ba15841836
|
|||
|
014e9c9a71
|
|||
| 32cabc9452 | |||
|
19e83dea01
|
|||
|
9210f85300
|
|||
|
74052594c3
|
|||
|
31ad8dac3e
|
|||
|
c46b88060b
|
|||
|
02018cd11d
|
|||
|
d4cde42bdc
|
|||
|
58ff8f02da
|
|||
|
b32ddcaf38
|
|||
|
1eb7e62cfe
|
|||
|
c44e29a907
|
|||
|
24457ff7cd
|
|||
|
0d36bea3ca
|
|||
|
bf8504d432
|
|||
|
16a55ae69a
|
|||
|
3adbd38d65
|
|||
|
420630e35c
|
|||
|
36a564547c
|
|||
|
eb8bf16346
|
|||
| 67884f7133 | |||
| f18d94670e | |||
| 6e86a6987f | |||
| 9c9496efbd | |||
| 770d7567fb | |||
| 7026337a43 | |||
|
ef617e1c85
|
|||
|
bd71a8d75f
|
|||
| 605407549b | |||
| 5e01e086f2 | |||
| 1f887aeaf6 | |||
| 5de4b72a6b | |||
| 1861cd4f1a | |||
| 9148073095 | |||
|
ef3404b096
|
|||
| 14feae943e | |||
| 1d763dfc3c | |||
| a829f035b3 | |||
| 9904653cc6 | |||
| de04fcbec1 | |||
| 70e3565e44 | |||
| 6b10c99c7a | |||
| 54fae88914 | |||
|
cdfb822f42
|
|||
|
73aad89f57
|
|||
|
e1b5f9cfc9
|
|||
| 35f411fb3a | |||
| eed21e6223 | |||
| bf5c10b7a7 | |||
| 274ca0fa9a | |||
| c72cdd6a6b | |||
|
16b0451133
|
|||
|
cb34813c4b
|
|||
| 2de3be271e | |||
| f7d2168dac | |||
| 40be5ce335 | |||
| 8e6131473d | |||
|
26e10be4ec
|
|||
|
78bda5fc0a
|
|||
|
97658a6c56
|
|||
|
3fedc685a9
|
|||
|
d1a3e44c45
|
|||
|
f637778173
|
|||
|
145bfedf67
|
|||
|
61b9d733a5
|
|||
| ae59c20e2f | |||
| 0b7d21aeb0 | |||
|
d6ede3e6cd
|
|||
|
07ace8e6e9
|
|||
|
6f08c22b5b
|
|||
|
3e5c1941c8
|
|||
| f6e7dfcd93 | |||
| 1233677eea | |||
| 00bdb90e3c | |||
| 988965451b | |||
| f6fadb7226 | |||
| 0d540eea4c | |||
| f21da657db | |||
| a8a7b62f76 | |||
| 789500842c | |||
| 2f22f11d57 |
@ -13,7 +13,7 @@ steps:
|
||||
- name: submodules
|
||||
image: alpine/git
|
||||
commands:
|
||||
- git submodule update --init --recursive
|
||||
- git submodule update --init --recursive
|
||||
- name: 构建 Docker 镜像
|
||||
image: plugins/docker:latest
|
||||
privileged: true
|
||||
@ -30,7 +30,7 @@ steps:
|
||||
volumes:
|
||||
- name: docker-socket
|
||||
path: /var/run/docker.sock
|
||||
- name: 在容器中测试插件加载
|
||||
- name: 在容器中进行若干测试
|
||||
image: docker:dind
|
||||
privileged: true
|
||||
volumes:
|
||||
@ -38,6 +38,8 @@ steps:
|
||||
path: /var/run/docker.sock
|
||||
commands:
|
||||
- docker run --rm gitea.service.jazzwhom.top/mttu-developers/konabot:nightly-${DRONE_COMMIT_SHA} python scripts/test_plugin_load.py
|
||||
- docker run --rm gitea.service.jazzwhom.top/mttu-developers/konabot:nightly-${DRONE_COMMIT_SHA} python scripts/test_playwright.py
|
||||
- docker run --rm gitea.service.jazzwhom.top/mttu-developers/konabot:nightly-${DRONE_COMMIT_SHA} python -m pytest --cov=./konabot/ --cov-report term-missing:skip-covered
|
||||
- name: 发送构建结果到 ntfy
|
||||
image: parrazam/drone-ntfy
|
||||
when:
|
||||
@ -68,7 +70,7 @@ steps:
|
||||
- name: submodules
|
||||
image: alpine/git
|
||||
commands:
|
||||
- git submodule update --init --recursive
|
||||
- git submodule update --init --recursive
|
||||
- name: 构建并推送 Release Docker 镜像
|
||||
image: plugins/docker:latest
|
||||
privileged: true
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
ENVIRONMENT=dev
|
||||
PORT=21333
|
||||
|
||||
DATABASE_PATH="./data/database.db"
|
||||
ENABLE_CONSOLE=true
|
||||
|
||||
24
.gitignore
vendored
24
.gitignore
vendored
@ -1,4 +1,26 @@
|
||||
# 基本的数据文件,以及环境用文件
|
||||
/.env
|
||||
/data
|
||||
/pyrightconfig.json
|
||||
/pyrightconfig.toml
|
||||
/uv.lock
|
||||
|
||||
__pycache__
|
||||
# 缓存文件
|
||||
__pycache__
|
||||
/.ruff_cache
|
||||
/.pytest_cache
|
||||
/.mypy_cache
|
||||
/.black_cache
|
||||
|
||||
# 可能会偶然生成的 diff 文件
|
||||
/*.diff
|
||||
|
||||
# 代码覆盖报告
|
||||
/.coverage
|
||||
/.coverage.db
|
||||
/htmlcov
|
||||
|
||||
# 对手动创建虚拟环境的人
|
||||
/.venv
|
||||
/venv
|
||||
*.egg-info
|
||||
|
||||
3
.gitmodules
vendored
3
.gitmodules
vendored
@ -1,3 +1,6 @@
|
||||
[submodule "assets/lexicon/THUOCL"]
|
||||
path = assets/lexicon/THUOCL
|
||||
url = https://github.com/thunlp/THUOCL.git
|
||||
[submodule "assets/oracle"]
|
||||
path = assets/oracle
|
||||
url = https://gitea.service.jazzwhom.top/mttu-developers/oracle-source.git
|
||||
|
||||
6
.sqls.yml
Normal file
6
.sqls.yml
Normal file
@ -0,0 +1,6 @@
|
||||
lowercaseKeywords: false
|
||||
connections:
|
||||
- driver: sqlite
|
||||
dataSourceName: "./data/database.db"
|
||||
- driver: sqlite
|
||||
dataSourceName: "./data/perm.sqlite3"
|
||||
4
.vscode/settings.json
vendored
4
.vscode/settings.json
vendored
@ -1,3 +1,5 @@
|
||||
{
|
||||
"python.REPL.enableREPLSmartSend": false
|
||||
"python.REPL.enableREPLSmartSend": false,
|
||||
"python-envs.defaultEnvManager": "ms-python.python:poetry",
|
||||
"python-envs.defaultPackageManager": "ms-python.python:poetry"
|
||||
}
|
||||
188
AGENTS.md
Normal file
188
AGENTS.md
Normal file
@ -0,0 +1,188 @@
|
||||
# AGENTS.md
|
||||
|
||||
本文件面向两类协作者:
|
||||
|
||||
- 手写代码的人类朋友
|
||||
- 会在此仓库中协助开发的 AI Agents
|
||||
|
||||
这个项目以手写为主,欢迎协作,但请先理解这里的约束和结构,再开始改动。
|
||||
|
||||
## 项目定位
|
||||
|
||||
- 这是一个娱乐性质的、私域使用的 QQ Bot 项目。
|
||||
- 虽然主要用于熟人环境,但依然要按“不信任输入”的标准写代码。
|
||||
- 不要因为使用场景偏内部,就默认消息内容、安全边界、调用参数一定可靠。
|
||||
|
||||
## 基本原则
|
||||
|
||||
### 1. 默认不信任用户输入
|
||||
|
||||
所有来自聊天消息、命令参数、平台事件等的输入,都应视为不可信。
|
||||
|
||||
开发时至少注意以下几点:
|
||||
|
||||
- 不假设输入类型正确,先校验再使用。
|
||||
- 不假设输入长度合理,注意超长文本、大量参数、异常嵌套结构。
|
||||
- 不假设输入内容安全,避免直接拼接到文件路径、SQL、shell 参数、HTML 或模板中。
|
||||
- 不假设用户一定按预期使用命令,错误输入要能优雅失败。
|
||||
- 对任何外部请求、文件读写、渲染、执行型逻辑,都要先考虑滥用风险。
|
||||
|
||||
### 2. 优先保持现有风格
|
||||
|
||||
- 这是一个以人工维护为主的项目,改动应尽量贴近现有写法。
|
||||
- 除非有明确收益,不要为了“看起来更现代”而大规模重构。
|
||||
- 新增能力时,优先复用已有通用模块,而不是重复造轮子。
|
||||
|
||||
### 3. 小步修改,影响清晰
|
||||
|
||||
- 尽量做局部、明确、可解释的改动。
|
||||
- 修改插件时,避免顺手改动无关插件。
|
||||
- 如果要调整公共模块,先确认是否会影响大量插件行为。
|
||||
|
||||
## 仓库结构
|
||||
|
||||
### `konabot/`
|
||||
|
||||
核心代码目录。
|
||||
|
||||
#### `konabot/common/`
|
||||
|
||||
通用模块目录。
|
||||
|
||||
- 放置可复用的基础能力、工具模块、公共逻辑。
|
||||
- 如果某段逻辑可能被多个插件共享,应优先考虑放到这里。
|
||||
- 修改这里的代码时,要额外关注兼容性,因为它可能被很多插件依赖。
|
||||
|
||||
#### `konabot/docs/`
|
||||
|
||||
Bot 内部文档系统使用的文档目录。
|
||||
|
||||
- 这是给用户看的文档来源。
|
||||
- 文档会通过 `man` 指令被触发和展示。
|
||||
- 虽然文档文件通常使用 `.txt` 后缀,但内容可以按 markdown 风格书写。
|
||||
- `.md` 后缀文件会被忽略,因此 `.md` 更适合只留给仓库维护者阅读的附加说明。
|
||||
- 文档文件名就是用户查询时使用的指令名,应保持简洁、稳定、易理解。
|
||||
|
||||
补充说明:
|
||||
|
||||
- `konabot/docs/user/` 是直接面向用户检索的文档。
|
||||
- `konabot/docs/lib/` 更偏向维护者参考。
|
||||
- `konabot/docs/concepts/` 用于记录概念。
|
||||
- `konabot/docs/sys/` 用于特定范围可见的系统文档。
|
||||
|
||||
#### `konabot/plugins/`
|
||||
|
||||
插件目录。
|
||||
|
||||
- 插件数量很多,是本项目最主要的功能承载位置。
|
||||
- 插件可以是单文件,也可以是文件夹形式。
|
||||
- 新增插件或修改插件时,请先观察相邻插件的组织方式,再决定采用单文件还是目录结构。
|
||||
- 如果逻辑已经明显超出单文件可维护范围,应拆成目录插件,不要把一个文件堆得过大。
|
||||
|
||||
## 根目录文档
|
||||
|
||||
### `docs/`
|
||||
|
||||
仓库根目录下的 `docs/` 主要用于记录一些可以通用的模块说明和开发文档。
|
||||
|
||||
- 这里的内容主要面向开发和维护。
|
||||
- 适合放公共模块说明、集成说明、配置说明、开发笔记。
|
||||
- 不要把面向 `man` 指令直接展示给用户的文档放到这里;那类内容应放在 `konabot/docs/` 下。
|
||||
|
||||
## 对 AI Agents 的具体要求
|
||||
|
||||
如果你是 AI Agent,请遵守以下约定:
|
||||
|
||||
### 修改前
|
||||
|
||||
- 先阅读将要修改的文件以及相关上下文,不要只凭文件名猜用途。
|
||||
- 先判断目标逻辑属于公共模块、用户文档,还是某个具体插件。
|
||||
- 如果需求可以在局部完成,就不要扩大改动范围。
|
||||
|
||||
### 修改时
|
||||
|
||||
- 优先延续现有命名、目录结构和编码风格。
|
||||
- 不要因为“顺手”而批量格式化整个项目。
|
||||
- 不要擅自重命名大量文件、移动目录、替换现有架构。
|
||||
- 涉及用户输入、路径、网络、数据库、渲染时,主动补上必要的校验与防御。
|
||||
- 如果要新增 `konabot/common/` 或其他会被多处依赖的模块,优先考虑 NoneBot2 框架下的依赖注入方式,而不是把全局状态或硬编码依赖散落到调用方。
|
||||
- 写文档时,区分清楚是给 `man` 系统看的,还是给仓库维护者看的。
|
||||
|
||||
### 修改后
|
||||
|
||||
- 检查改动是否误伤其他插件或公共模块。
|
||||
- 如果新增了用户可见功能,考虑是否需要补充 `konabot/docs/` 下对应文档。
|
||||
- 如果新增或调整了通用能力,考虑是否需要补充根目录 `docs/` 下的说明。
|
||||
|
||||
## 插件开发建议
|
||||
|
||||
- 单个插件内部优先保持自洽,不要把特定业务逻辑过早抽成公共模块。
|
||||
- 当多个插件开始重复同类逻辑时,再考虑上移到 `konabot/common/`。
|
||||
- 插件应尽量对异常输入有稳定反馈,而不是直接抛出难理解的错误。
|
||||
- 如果插件会访问外部服务,要考虑超时、失败降级和返回内容校验。
|
||||
|
||||
### 最基本的用户交互书写建议
|
||||
|
||||
- 先用清晰、可收敛的规则匹配消息,再进入处理逻辑,不要一上来就在 handler 里兜底解析所有输入。
|
||||
- 在 handler 里尽早提取纯文本、拆分命令和参数,并对缺失参数、非法参数、异常格式给出稳定反馈。
|
||||
- 如果用户输入只允许有限枚举值,先定义允许集合,再进行归一化和校验。
|
||||
- 输出优先保持简单直接;能一句话说明问题时,不要返回难懂的异常堆栈或过度技术化提示。
|
||||
- 涉及渲染、网络请求、图片生成等较重操作时,先确认输入合理,再执行昂贵逻辑。
|
||||
- 如果插件只是做单一交互,优先保持 handler 简短,把渲染、请求、转换等逻辑拆成独立函数。
|
||||
- 倾向于使用 `UniMessage` / `UniMsg` 这一套消息抽象来组织收发消息,而不是把平台细节和文本拼接散落在各处。
|
||||
- 倾向于显式构造返回消息并发送,而不是大量依赖 NoneBot2 原生的 `.finish()` 作为主要输出路径,除非该场景确实更简单清晰。
|
||||
|
||||
### 关于公共能力的依赖方式
|
||||
|
||||
- 新建通用能力时,优先设计成可注入、可替换、可测试的接口。
|
||||
- 如果一个模块未来可能被多个插件依赖,优先考虑 NoneBot2 的依赖注入,而不是让调用方手动维护重复的初始化流程。
|
||||
- 除非确有必要,不要让插件直接依赖隐藏的全局副作用。
|
||||
- 如果使用单例、缓存或全局管理器,要明确其生命周期、并发行为以及关闭时机。
|
||||
|
||||
## 运行环境与部署限制
|
||||
|
||||
这个项目默认会跑在 Docker 环境里,修改功能时请先意识到运行环境不是一台“什么都有”的开发机。
|
||||
|
||||
### 容器环境
|
||||
|
||||
- 运行时基础镜像是 `python:3.13-slim`,不是完整桌面 Linux;很多系统库默认不存在。
|
||||
- 项目运行依赖 Playwright Chromium、字体库、图形相关库,以及部分额外二进制工具。
|
||||
- 构建阶段和运行阶段是分离的;不要假设在 builder 里装过的系统包,runtime 里也一定可用。
|
||||
- 额外制品目前通过多阶段构建放进镜像,例如 `typst`。
|
||||
|
||||
### Docker 相关要求
|
||||
|
||||
- 如果你新增的 Python 依赖背后还需要 Linux 动态库、字体、图形库、编译工具或其他系统包,必须同步检查并在 `Dockerfile` 中补齐。
|
||||
- 不要只让本地虚拟环境能跑;要默认以容器可运行作为完成标准之一。
|
||||
- 如果新功能依赖系统命令、共享库、浏览器能力或字体,请在提交说明里明确写出原因。
|
||||
- `.dockerignore` 当前会排除 `/.env`、`/.git`、`/data` 等内容;不要依赖这些文件被复制进镜像。
|
||||
- 关于额外制品的管理,优先先阅读根目录文档 `docs/artifact.md`;适合统一管理的二进制或外部资源,倾向于复用 `konabot/common/artifact.py`,而不是在各插件里各自处理下载和校验。
|
||||
|
||||
### 本地运行
|
||||
|
||||
- 本地开发可参考 `justfile`,当前主要入口是 `just watch`。
|
||||
- 如果你的改动影响启动方式、依赖准备方式或运行命令,记得同步更新对应文档或脚本。
|
||||
|
||||
## 分支与协作流程
|
||||
|
||||
- 本项目托管在个人 Gitea 实例:`https://gitea.service.jazzwhom.top/mttu-developers/konabot`。
|
||||
- 如果需要创建 Pull Request,优先倾向使用 `tea` CLI:`https://gitea.com/gitea/tea`。
|
||||
- Pull Request 创建后,当前主要会有自动机器人做初步评审,项目维护者会手动查看;不要催促立即合并,也不要默认会马上进主分支。
|
||||
- 如果当前是在仓库本体上直接开发、而不是在 fork 上工作,尽量提醒用户不要直接在主分支持续改动,优先使用功能分支。
|
||||
- 除非用户明确要求,否则不要擅自把改动直接合并到主分支。
|
||||
|
||||
## 文档编写建议
|
||||
|
||||
### 面向 `man` 的文档
|
||||
|
||||
- 放在 `konabot/docs/` 对应子目录。
|
||||
- 文件名直接对应用户查询名称。
|
||||
- 建议内容简洁,优先说明“做什么、怎么用、示例、注意事项”。
|
||||
- 使用 `.txt` 后缀;内容可以写成接近 markdown 的可读格式。
|
||||
|
||||
### 面向开发者的文档
|
||||
|
||||
- 放在仓库根目录 `docs/`。
|
||||
- 主要描述公共模块、配置方法、设计说明、维护经验。
|
||||
- 可以使用 `.md`。
|
||||
|
||||
22
Dockerfile
22
Dockerfile
@ -1,3 +1,16 @@
|
||||
FROM alpine:latest AS artifacts
|
||||
|
||||
RUN apk add --no-cache curl xz
|
||||
WORKDIR /tmp
|
||||
|
||||
RUN mkdir -p /artifacts
|
||||
RUN curl -L -o typst.tar.xz "https://github.com/typst/typst/releases/download/v0.14.2/typst-x86_64-unknown-linux-musl.tar.xz" \
|
||||
&& tar -xJf typst.tar.xz \
|
||||
&& mv typst-x86_64-unknown-linux-musl/typst /artifacts
|
||||
|
||||
RUN chmod -R +x /artifacts/
|
||||
|
||||
|
||||
FROM python:3.13-slim AS base
|
||||
|
||||
ENV VIRTUAL_ENV=/app/.venv \
|
||||
@ -18,11 +31,6 @@ RUN apt-get update && \
|
||||
fonts-noto-color-emoji \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
RUN pip install --no-cache-dir playwright \
|
||||
&& python -m playwright install chromium \
|
||||
&& pip uninstall -y playwright
|
||||
|
||||
|
||||
|
||||
FROM base AS builder
|
||||
|
||||
@ -43,13 +51,17 @@ RUN uv sync --no-install-project
|
||||
FROM base AS runtime
|
||||
|
||||
COPY --from=builder ${VIRTUAL_ENV} ${VIRTUAL_ENV}
|
||||
COPY --from=artifacts /artifacts/ /usr/local/bin/
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
RUN python -m playwright install chromium
|
||||
|
||||
COPY bot.py pyproject.toml .env.prod .env.test ./
|
||||
COPY assets ./assets
|
||||
COPY scripts ./scripts
|
||||
COPY konabot ./konabot
|
||||
COPY tests ./tests
|
||||
|
||||
ENV PYTHONPATH=/app
|
||||
|
||||
|
||||
23
README.md
23
README.md
@ -71,6 +71,10 @@ code .
|
||||
|
||||
详见[konabot-web 配置文档](/docs/konabot-web.md)
|
||||
|
||||
#### 数据库配置
|
||||
|
||||
本项目使用SQLite作为数据库,默认数据库文件位于`./data/database.db`。可以通过设置`DATABASE_PATH`环境变量来指定其他位置。
|
||||
|
||||
### 运行
|
||||
|
||||
使用命令行手动启动 Bot:
|
||||
@ -91,3 +95,22 @@ poetry run python bot.py
|
||||
- [事件响应器](https://nonebot.dev/docs/tutorial/matcher)
|
||||
- [事件处理](https://nonebot.dev/docs/tutorial/handler)
|
||||
- [Alconna 插件](https://nonebot.dev/docs/best-practice/alconna/)
|
||||
|
||||
## 代码测试
|
||||
|
||||
本项目使用 pytest 进行自动化测试,你可以把你的测试代码放在 `./tests` 目录下。
|
||||
|
||||
使用命令行执行测试:
|
||||
|
||||
```bash
|
||||
poetry run just test
|
||||
```
|
||||
|
||||
使用命令行,在浏览器查看测试覆盖率报告:
|
||||
|
||||
```bash
|
||||
poetry run just coverage
|
||||
# 此时会打开一个 :8000 端口的 Web 服务器
|
||||
# 你可以在 http://localhost:8000 查看覆盖率报告
|
||||
# 在控制台使用 Ctrl+C 关闭这个 Web 服务器
|
||||
```
|
||||
|
||||
9856
assets/old_font/symtable.csv
Normal file
9856
assets/old_font/symtable.csv
Normal file
File diff suppressed because it is too large
Load Diff
1
assets/oracle
Submodule
1
assets/oracle
Submodule
Submodule assets/oracle added at 9f3c08c5d2
33
bot.py
33
bot.py
@ -7,9 +7,12 @@ from nonebot.adapters.discord import Adapter as DiscordAdapter
|
||||
from nonebot.adapters.minecraft import Adapter as MinecraftAdapter
|
||||
from nonebot.adapters.onebot.v11 import Adapter as OnebotAdapter
|
||||
|
||||
from konabot.common.appcontext import run_afterinit_functions
|
||||
from konabot.common.log import init_logger
|
||||
from konabot.common.nb.exc import BotExceptionMessage
|
||||
from konabot.common.path import LOG_PATH
|
||||
from konabot.common.database import get_global_db_manager
|
||||
|
||||
|
||||
dotenv.load_dotenv()
|
||||
env = os.environ.get("ENVIRONMENT", "prod")
|
||||
@ -20,19 +23,25 @@ env_enable_minecraft = os.environ.get("ENABLE_MINECRAFT", "none")
|
||||
|
||||
|
||||
def main():
|
||||
if env.upper() == 'DEBUG' or env.upper() == 'DEV':
|
||||
console_log_level = 'DEBUG'
|
||||
if env.upper() == "DEBUG" or env.upper() == "DEV":
|
||||
console_log_level = "DEBUG"
|
||||
else:
|
||||
console_log_level = 'INFO'
|
||||
init_logger(LOG_PATH, [
|
||||
BotExceptionMessage,
|
||||
], console_log_level=console_log_level)
|
||||
console_log_level = "INFO"
|
||||
init_logger(
|
||||
LOG_PATH,
|
||||
[
|
||||
BotExceptionMessage,
|
||||
],
|
||||
console_log_level=console_log_level,
|
||||
)
|
||||
|
||||
nonebot.init()
|
||||
|
||||
driver = nonebot.get_driver()
|
||||
|
||||
if (env != "prod" and env != "test" and env_enable_console.upper() != "FALSE") or (env_enable_console.upper() == "TRUE"):
|
||||
if (env != "prod" and env != "test" and env_enable_console.upper() != "FALSE") or (
|
||||
env_enable_console.upper() == "TRUE"
|
||||
):
|
||||
driver.register_adapter(ConsoleAdapter)
|
||||
|
||||
if env_enable_qq.upper() == "TRUE":
|
||||
@ -48,7 +57,17 @@ def main():
|
||||
nonebot.load_plugins("konabot/plugins")
|
||||
nonebot.load_plugin("nonebot_plugin_analysis_bilibili")
|
||||
|
||||
run_afterinit_functions()
|
||||
|
||||
# 注册关闭钩子
|
||||
@driver.on_shutdown
|
||||
async def _():
|
||||
# 关闭全局数据库管理器
|
||||
db_manager = get_global_db_manager()
|
||||
await db_manager.close_all_connections()
|
||||
|
||||
nonebot.run()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
26
docs/artifact.md
Normal file
26
docs/artifact.md
Normal file
@ -0,0 +1,26 @@
|
||||
# artifact 模块说明
|
||||
|
||||
`konabot/common/artifact.py` 用于管理项目运行过程中依赖的额外制品,尤其是二进制文件、外部工具和按平台区分的运行时资源。
|
||||
|
||||
## 适用场景
|
||||
|
||||
- 某个插件或公共模块依赖额外下载的可执行文件或二进制资源。
|
||||
- 依赖需要按操作系统或架构区分。
|
||||
- 希望在启动时统一检测、按需下载并校验哈希。
|
||||
|
||||
如果额外制品适合在镜像构建阶段直接打包进 Docker 镜像,也可以在 `Dockerfile` 中通过多阶段构建处理;但对于需要在运行环境按平台管理、懒下载或统一校验的资源,优先考虑复用 `artifact.py`。
|
||||
|
||||
## 推荐做法
|
||||
|
||||
- 新增额外制品时,先判断它更适合放进镜像构建阶段,还是更适合交给 `artifact.py` 管理。
|
||||
- 如果该资源会被多个插件或环境复用,倾向于统一通过 `ArtifactDepends` 和 `register_artifacts(...)` 管理。
|
||||
- 为下载资源提供稳定来源,并填写 `sha256` 校验值,不要只校验“能不能下载下来”。
|
||||
- 使用 `required_os` 和 `required_arch` 限制平台,避免无意义下载。
|
||||
- 需要代理时,确认其行为与当前 NoneBot2 配置兼容。
|
||||
|
||||
## 注意事项
|
||||
|
||||
- 不要把是否存在额外制品的判断散落到多个插件里各自实现。
|
||||
- 不要跳过哈希校验,除非该资源确实无法提供稳定校验值,并且有明确理由。
|
||||
- 如果一个新能力除了额外制品,还依赖 Linux 动态库、字体、浏览器或系统命令,仍然需要同步检查并更新 `Dockerfile`。
|
||||
- 如果镜像构建和运行阶段都依赖该制品,要分别确认 builder 和 runtime 的可用性。
|
||||
223
docs/database.md
Normal file
223
docs/database.md
Normal file
@ -0,0 +1,223 @@
|
||||
# 数据库系统使用文档
|
||||
|
||||
本文档详细介绍了本项目中使用的异步数据库系统,包括其架构设计、使用方法和最佳实践。
|
||||
|
||||
## 系统概述
|
||||
|
||||
本项目的数据库系统基于 `aiosqlite` 库构建,提供了异步的 SQLite 数据库访问接口。系统主要特性包括:
|
||||
|
||||
1. **异步操作**:完全支持异步/await模式,适配NoneBot2框架
|
||||
2. **连接池**:内置连接池机制,提高数据库访问性能
|
||||
3. **参数化查询**:支持安全的参数化查询,防止SQL注入
|
||||
4. **SQL文件支持**:可以直接执行SQL文件中的脚本
|
||||
5. **类型支持**:支持 `pathlib.Path` 和 `str` 类型的路径参数
|
||||
|
||||
## 核心类和方法
|
||||
|
||||
### DatabaseManager 类
|
||||
|
||||
`DatabaseManager` 是数据库操作的核心类,提供了以下主要方法:
|
||||
|
||||
#### 初始化
|
||||
```python
|
||||
from konabot.common.database import DatabaseManager
|
||||
from pathlib import Path
|
||||
|
||||
# 使用默认数据库路径
|
||||
db = DatabaseManager()
|
||||
|
||||
# 指定了义数据库路径
|
||||
db = DatabaseManager("./data/myapp.db")
|
||||
db = DatabaseManager(Path("./data/myapp.db"))
|
||||
```
|
||||
|
||||
#### 查询操作
|
||||
```python
|
||||
# 执行查询语句并返回结果
|
||||
results = await db.query("SELECT * FROM users WHERE age > ?", (18,))
|
||||
|
||||
# 从SQL文件执行查询
|
||||
results = await db.query_by_sql_file("./sql/get_users.sql", (18,))
|
||||
```
|
||||
|
||||
#### 执行操作
|
||||
```python
|
||||
# 执行非查询语句
|
||||
await db.execute("INSERT INTO users (name, email) VALUES (?, ?)", ("张三", "zhangsan@example.com"))
|
||||
|
||||
# 执行SQL脚本(不带参数)
|
||||
await db.execute_script("""
|
||||
CREATE TABLE IF NOT EXISTS users (
|
||||
id INTEGER PRIMARY KEY,
|
||||
name TEXT NOT NULL,
|
||||
email TEXT UNIQUE
|
||||
);
|
||||
INSERT INTO users (name, email) VALUES ('测试用户', 'test@example.com');
|
||||
""")
|
||||
|
||||
# 从SQL文件执行非查询语句
|
||||
await db.execute_by_sql_file("./sql/create_tables.sql")
|
||||
|
||||
# 带参数执行SQL文件
|
||||
await db.execute_by_sql_file("./sql/insert_user.sql", ("张三", "zhangsan@example.com"))
|
||||
|
||||
# 执行多条语句(每条语句使用相同参数)
|
||||
await db.execute_many("INSERT INTO users (name, email) VALUES (?, ?)", [
|
||||
("张三", "zhangsan@example.com"),
|
||||
("李四", "lisi@example.com"),
|
||||
("王五", "wangwu@example.com")
|
||||
])
|
||||
|
||||
# 从SQL文件执行多条语句(每条语句使用相同参数)
|
||||
await db.execute_many_values_by_sql_file("./sql/batch_insert.sql", [
|
||||
("张三", "zhangsan@example.com"),
|
||||
("李四", "lisi@example.com")
|
||||
])
|
||||
```
|
||||
|
||||
## SQL文件处理机制
|
||||
|
||||
### 单语句SQL文件
|
||||
```sql
|
||||
-- insert_user.sql
|
||||
INSERT INTO users (name, email) VALUES (?, ?);
|
||||
```
|
||||
|
||||
```python
|
||||
# 使用方式
|
||||
await db.execute_by_sql_file("./sql/insert_user.sql", ("张三", "zhangsan@example.com"))
|
||||
```
|
||||
|
||||
### 多语句SQL文件
|
||||
```sql
|
||||
-- setup.sql
|
||||
CREATE TABLE IF NOT EXISTS users (
|
||||
id INTEGER PRIMARY KEY,
|
||||
name TEXT NOT NULL,
|
||||
email TEXT UNIQUE
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS profiles (
|
||||
user_id INTEGER,
|
||||
age INTEGER,
|
||||
FOREIGN KEY (user_id) REFERENCES users(id)
|
||||
);
|
||||
```
|
||||
|
||||
```python
|
||||
# 使用方式
|
||||
await db.execute_by_sql_file("./sql/setup.sql")
|
||||
```
|
||||
|
||||
### 多语句带不同参数的SQL文件
|
||||
```sql
|
||||
-- batch_operations.sql
|
||||
INSERT INTO users (name, email) VALUES (?, ?);
|
||||
INSERT INTO profiles (user_id, age) VALUES (?, ?);
|
||||
```
|
||||
|
||||
```python
|
||||
# 使用方式
|
||||
await db.execute_by_sql_file("./sql/batch_operations.sql", [
|
||||
("张三", "zhangsan@example.com"), # 第一条语句的参数
|
||||
(1, 25) # 第二条语句的参数
|
||||
])
|
||||
```
|
||||
|
||||
## 最佳实践
|
||||
|
||||
### 1. 数据库表设计
|
||||
```sql
|
||||
-- 推荐的表设计实践
|
||||
CREATE TABLE IF NOT EXISTS example_table (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
name TEXT NOT NULL,
|
||||
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP
|
||||
);
|
||||
```
|
||||
|
||||
### 2. SQL文件组织
|
||||
建议按照功能模块组织SQL文件:
|
||||
```
|
||||
plugin/
|
||||
├── sql/
|
||||
│ ├── create_tables.sql
|
||||
│ ├── insert_data.sql
|
||||
│ ├── update_data.sql
|
||||
│ └── query_data.sql
|
||||
└── __init__.py
|
||||
```
|
||||
|
||||
### 3. 错误处理
|
||||
```python
|
||||
try:
|
||||
results = await db.query("SELECT * FROM users WHERE id = ?", (user_id,))
|
||||
except Exception as e:
|
||||
logger.error(f"数据库查询失败: {e}")
|
||||
# 处理错误情况
|
||||
```
|
||||
|
||||
### 4. 连接管理
|
||||
```python
|
||||
# 在应用启动时初始化
|
||||
db_manager = DatabaseManager()
|
||||
|
||||
# 在应用关闭时清理连接
|
||||
async def shutdown():
|
||||
await db_manager.close_all_connections()
|
||||
```
|
||||
|
||||
## 高级特性
|
||||
|
||||
### 连接池配置
|
||||
```python
|
||||
class DatabaseManager:
|
||||
def __init__(self, db_path: Optional[Union[str, Path]] = None):
|
||||
# 连接池大小配置
|
||||
self._pool_size = 5 # 可根据需要调整
|
||||
```
|
||||
|
||||
### 事务支持
|
||||
```python
|
||||
# 通过execute方法的自动提交机制支持事务
|
||||
await db.execute("BEGIN TRANSACTION")
|
||||
try:
|
||||
await db.execute("INSERT INTO users (name) VALUES (?)", ("张三",))
|
||||
await db.execute("INSERT INTO profiles (user_id, age) VALUES (?, ?)", (1, 25))
|
||||
await db.execute("COMMIT")
|
||||
except Exception:
|
||||
await db.execute("ROLLBACK")
|
||||
raise
|
||||
```
|
||||
|
||||
## 注意事项
|
||||
|
||||
1. **异步环境**:所有数据库操作都必须在异步环境中执行
|
||||
2. **参数安全**:始终使用参数化查询,避免SQL注入
|
||||
3. **资源管理**:确保在应用关闭时调用 `close_all_connections()`
|
||||
4. **SQL解析**:使用 `sqlparse` 库准确解析SQL语句,正确处理包含分号的字符串和注释
|
||||
5. **错误处理**:适当处理数据库操作可能抛出的异常
|
||||
|
||||
## 常见问题
|
||||
|
||||
### Q: 如何处理数据库约束错误?
|
||||
A: 确保SQL语句中的字段名正确引用,特别是保留字需要使用双引号包围:
|
||||
```sql
|
||||
CREATE TABLE air_conditioner (
|
||||
id VARCHAR(128) PRIMARY KEY,
|
||||
"on" BOOLEAN NOT NULL, -- 使用双引号包围保留字
|
||||
temperature REAL NOT NULL
|
||||
);
|
||||
```
|
||||
|
||||
### Q: 如何处理多个语句和参数的匹配?
|
||||
A: 当SQL文件包含多个语句时,参数应该是参数列表,每个语句对应一个参数元组:
|
||||
```python
|
||||
await db.execute_by_sql_file("./sql/batch.sql", [
|
||||
("参数1", "参数2"), # 第一个语句的参数
|
||||
("参数3", "参数4") # 第二个语句的参数
|
||||
])
|
||||
```
|
||||
|
||||
通过遵循这些指南和最佳实践,您可以充分利用本项目的异步数据库系统,构建高性能、安全的数据库应用。
|
||||
244
docs/permsys.md
Normal file
244
docs/permsys.md
Normal file
@ -0,0 +1,244 @@
|
||||
# 权限系统 `konabot.common.permsys`
|
||||
|
||||
本文档面向维护者,说明 `konabot/common/permsys` 模块的职责、数据模型、权限解析规则,以及在插件中接入的推荐方式。
|
||||
|
||||
## 模块目标
|
||||
|
||||
`permsys` 提供了一套简单的、可继承的权限系统,用于回答两个问题:
|
||||
|
||||
1. 某个事件对应的主体是谁。
|
||||
2. 该主体是否拥有某项权限。
|
||||
|
||||
它适合处理 bot 内部的功能开关、管理权限、平台级授权等场景。
|
||||
|
||||
当前模块由以下几部分组成:
|
||||
|
||||
- `konabot/common/permsys/__init__.py`
|
||||
- 暴露 `PermManager`、`DepPermManager`、`require_permission`
|
||||
- 负责数据库初始化、启动迁移、超级管理员默认授权
|
||||
- 提供 `register_default_allow_permission()` 用于注册“启动时默认放行”的权限键
|
||||
- `konabot/common/permsys/entity.py`
|
||||
- 定义 `PermEntity`
|
||||
- 将事件转换为可查询的实体链
|
||||
- `konabot/common/permsys/repo.py`
|
||||
- 封装 SQLite 读写
|
||||
- `konabot/common/permsys/migrates/`
|
||||
- 存放迁移 SQL
|
||||
- `konabot/common/permsys/sql/`
|
||||
- 存放查询与更新 SQL
|
||||
|
||||
## 核心概念
|
||||
|
||||
### 1. `PermEntity`
|
||||
|
||||
`PermEntity` 是权限系统中的最小主体标识:
|
||||
|
||||
```python
|
||||
PermEntity(platform: str, entity_type: str, external_id: str)
|
||||
```
|
||||
|
||||
示例:
|
||||
|
||||
- `PermEntity("sys", "global", "global")`
|
||||
- `PermEntity("ob11", "group", "123456")`
|
||||
- `PermEntity("ob11", "user", "987654")`
|
||||
|
||||
其中:
|
||||
|
||||
- `platform` 表示来源平台,如 `sys`、`ob11`、`discord`
|
||||
- `entity_type` 表示主体类型,如 `global`、`group`、`user`
|
||||
- `external_id` 表示平台侧的外部标识
|
||||
|
||||
### 2. 实体链
|
||||
|
||||
权限判断不是只看单个实体,而是看一条“实体链”。
|
||||
|
||||
以 `get_entity_chain_of_entity()` 为例,传入一个具体实体时,返回的链为:
|
||||
|
||||
```python
|
||||
[
|
||||
PermEntity(platform, entity_type, external_id),
|
||||
PermEntity(platform, "global", "global"),
|
||||
PermEntity("sys", "global", "global"),
|
||||
]
|
||||
```
|
||||
|
||||
这意味着权限会优先读取更具体的主体,再回退到平台全局,最后回退到系统全局。
|
||||
|
||||
`get_entity_chain(event)` 则会根据事件类型自动构造链。例如:
|
||||
|
||||
- OneBot V11 群消息:用户 -> 群 -> 平台全局 -> 系统全局
|
||||
- OneBot V11 私聊:用户 -> 平台全局 -> 系统全局
|
||||
- Discord 频道消息:用户/频道/服务器 -> 平台全局 -> 系统全局
|
||||
- Console:控制台用户/频道 -> 平台全局 -> 系统全局
|
||||
|
||||
注意:当前 `entity.py` 中的具体链顺序与字段命名应以实现为准;修改这里时要评估现有权限继承是否会被破坏。
|
||||
|
||||
### 3. 权限键
|
||||
|
||||
权限键使用点分结构,例如:
|
||||
|
||||
- `admin`
|
||||
- `plugin.weather`
|
||||
- `plugin.weather.use`
|
||||
|
||||
检查时会自动做前缀回退。以 `plugin.weather.use` 为例,查询顺序是:
|
||||
|
||||
1. `plugin.weather.use`
|
||||
2. `plugin.weather`
|
||||
3. `plugin`
|
||||
4. `*`
|
||||
|
||||
因此,`*` 可以看作兜底总权限。
|
||||
|
||||
## 权限解析规则
|
||||
|
||||
`PermManager.check_has_permission_info()` 的逻辑可以概括为:
|
||||
|
||||
1. 先把输入转换成实体链。
|
||||
2. 对权限键做逐级回退,同时追加 `*`。
|
||||
3. 在数据库中批量查出链上所有实体、所有候选键的显式记录。
|
||||
4. 按“实体越具体越优先、权限键越具体越优先”的顺序,返回第一条命中的记录。
|
||||
|
||||
若没有任何显式记录:
|
||||
|
||||
- `check_has_permission_info()` 返回 `None`
|
||||
- `check_has_permission()` 返回 `False`
|
||||
|
||||
这表示本系统默认是“未授权即拒绝”。
|
||||
|
||||
## 数据存储
|
||||
|
||||
模块使用 SQLite,默认数据库文件位于:
|
||||
|
||||
- `data/perm.sqlite3`
|
||||
|
||||
启动时会执行迁移:
|
||||
|
||||
- `create_startup()` 在 NoneBot 启动事件中调用 `execute_migration()`
|
||||
|
||||
权限值支持三态:
|
||||
|
||||
- `True`:显式允许
|
||||
- `False`:显式拒绝
|
||||
- `None`:删除/清空该层的显式设置,让判断重新回退到继承链
|
||||
|
||||
`repo.py` 中的 `update_perm_info()` 会将这个三态直接写入数据库。
|
||||
|
||||
## 超级管理员注入
|
||||
|
||||
在启动阶段,`create_startup()` 会读取 `konabot.common.nb.is_admin.cfg.admin_qq_account`,并为这些 QQ 账号写入:
|
||||
|
||||
```python
|
||||
PermEntity("ob11", "user", str(account)), "*", True
|
||||
```
|
||||
|
||||
也就是说,配置中的超级管理员会直接拥有全部权限。
|
||||
|
||||
此外,模块也支持插件在导入阶段通过 `register_default_allow_permission("some.key")` 注册默认放行的权限键;这些键会在启动时被写入到:
|
||||
|
||||
```python
|
||||
PermEntity("sys", "global", "global"), "some.key", True
|
||||
```
|
||||
|
||||
这适合“默认所有人可用,但仍希望后续能被权限系统单独关闭”的功能。
|
||||
|
||||
这属于启动时自动灌入的保底策略,不依赖手工授权命令。
|
||||
|
||||
## 在插件中使用
|
||||
|
||||
### 1. 直接做权限检查
|
||||
|
||||
```python
|
||||
from konabot.common.permsys import DepPermManager
|
||||
|
||||
|
||||
async def handler(pm: DepPermManager, event):
|
||||
ok = await pm.check_has_permission(event, "plugin.example.use")
|
||||
if not ok:
|
||||
return
|
||||
```
|
||||
|
||||
适合需要在处理流程中动态决定权限键的场景。
|
||||
|
||||
### 2. 挂到 Rule 上做准入控制
|
||||
|
||||
```python
|
||||
from nonebot_plugin_alconna import Alconna, on_alconna
|
||||
from konabot.common.permsys import require_permission
|
||||
|
||||
|
||||
cmd = on_alconna(
|
||||
Alconna("example"),
|
||||
rule=require_permission("plugin.example.use"),
|
||||
)
|
||||
```
|
||||
|
||||
适合命令入口明确、未通过时直接拦截的场景。
|
||||
|
||||
### 3. 更新权限
|
||||
|
||||
```python
|
||||
from konabot.common.permsys import DepPermManager
|
||||
from konabot.common.permsys.entity import PermEntity
|
||||
|
||||
|
||||
await pm.update_permission(
|
||||
PermEntity("ob11", "group", "123456"),
|
||||
"plugin.example.use",
|
||||
True,
|
||||
)
|
||||
```
|
||||
|
||||
建议只在专门的管理插件中开放写权限,避免普通功能插件到处分散改表。
|
||||
|
||||
## `perm_manage` 插件与本模块的关系
|
||||
|
||||
`konabot/plugins/perm_manage/__init__.py` 是本模块当前的管理入口,提供:
|
||||
|
||||
- `konaperm list`:列出实体链上已有的显式权限记录
|
||||
- `konaperm get`:查看某个权限最终命中的记录
|
||||
- `konaperm set`:写入 allow/deny/null
|
||||
|
||||
这个插件本身使用 `require_permission("admin")` 保护,因此只有拥有 `admin` 权限的主体才能管理权限。
|
||||
|
||||
## 接入建议
|
||||
|
||||
### 权限键命名
|
||||
|
||||
建议使用稳定、可扩展的分层键名:
|
||||
|
||||
- 推荐:`plugin.xxx`、`plugin.xxx.action`
|
||||
- 不推荐:含糊的单词或临时字符串
|
||||
|
||||
这样才能利用前缀回退机制做批量授权。
|
||||
|
||||
### 输入安全
|
||||
|
||||
虽然这个项目偏内部使用,但权限键、实体类型、外部 ID 仍然应视为不可信输入:
|
||||
|
||||
- 不要把聊天输入直接拼到 SQL 中
|
||||
- 不要让任意用户可随意构造高权限写入
|
||||
- 对可写命令至少做权限保护和必要校验
|
||||
|
||||
### 改动兼容性
|
||||
|
||||
以下改动都可能影响全局权限行为,修改前应充分评估:
|
||||
|
||||
- 更改实体链顺序
|
||||
- 更改默认兜底键 `*` 的语义
|
||||
- 更改 `None` 的处理方式
|
||||
- 更改启动时超级管理员注入逻辑
|
||||
|
||||
## 调试建议
|
||||
|
||||
- 先用 `konaperm get ...` 确认某个权限最终命中了哪一层
|
||||
- 再用 `konaperm list ...` 查看该实体链上有哪些显式记录
|
||||
- 若表现异常,检查是否是更上层实体或更宽泛权限键提前命中
|
||||
|
||||
## 相关文件
|
||||
|
||||
- `konabot/common/permsys/__init__.py`
|
||||
- `konabot/common/permsys/entity.py`
|
||||
- `konabot/common/permsys/repo.py`
|
||||
- `konabot/plugins/perm_manage/__init__.py`
|
||||
37
docs/subscribe.md
Normal file
37
docs/subscribe.md
Normal file
@ -0,0 +1,37 @@
|
||||
# subscribe 模块
|
||||
|
||||
一套统一的接口,让用户可以订阅一些延迟或者定时消息。
|
||||
|
||||
```python
|
||||
import asyncio
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from konabot.common.subscribe import register_poster_info, broadcast, PosterInfo
|
||||
from nonebot_plugin_alconna import UniMessage
|
||||
|
||||
# 注册了服务信息,用户可以用「查询可用订阅」指令了解可用的订阅清单。
|
||||
# 用户可以使用「订阅 某某服务通知」或者「订阅 某某服务」来订阅消息。
|
||||
# 如果用户在群聊发起订阅,则会在 QQ 群订阅,不然会在私聊订阅
|
||||
register_poster_info("某某服务通知", PosterInfo(
|
||||
aliases={"某某服务"},
|
||||
description="告诉你关于某某的最新资讯等信息",
|
||||
))
|
||||
|
||||
async def main():
|
||||
while True:
|
||||
# 这里的服务 channel 名字必须填写该服务的名字,不可以是 alias
|
||||
# 这会给所有订阅了该通道的用户发送「向大家发送纯文本通知」
|
||||
await broadcast("某某服务通知", "向大家发送纯文本通知")
|
||||
|
||||
# 也可以发送 UniMessage 对象,可以构造包含图片的通知等
|
||||
data = Path('image.png').read_bytes()
|
||||
await broadcast(
|
||||
"某某服务通知",
|
||||
UniMessage.text("很遗憾告诉大家,我们倒闭了:").image(raw=data),
|
||||
)
|
||||
|
||||
await asyncio.sleep(114.514)
|
||||
```
|
||||
|
||||
该模块的代码请查阅 `/konabot/common/subscribe/` 下的文件。
|
||||
5
justfile
5
justfile
@ -1,4 +1,9 @@
|
||||
watch:
|
||||
poetry run watchfiles bot.main . --filter scripts.watch_filter.filter
|
||||
|
||||
test:
|
||||
poetry run pytest --cov-report term-missing:skip-covered
|
||||
|
||||
coverage:
|
||||
poetry run pytest --cov-report html
|
||||
python -m http.server -d htmlcov
|
||||
|
||||
92
konabot/common/apis/ali_content_safety.py
Normal file
92
konabot/common/apis/ali_content_safety.py
Normal file
@ -0,0 +1,92 @@
|
||||
import asyncio
|
||||
import json
|
||||
|
||||
from alibabacloud_green20220302.client import Client as AlibabaGreenClient
|
||||
from alibabacloud_green20220302.models import TextModerationPlusRequest
|
||||
from alibabacloud_tea_openapi.models import Config as AlibabaTeaConfig
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel
|
||||
|
||||
import nonebot
|
||||
|
||||
|
||||
class AlibabaGreenPluginConfig(BaseModel):
|
||||
module_aligreen_enable: bool = False
|
||||
module_aligreen_access_key_id: str = ""
|
||||
module_aligreen_access_key_secret: str = ""
|
||||
module_aligreen_region_id: str = "cn-shenzhen"
|
||||
module_aligreen_endpoint: str = "green-cip.cn-shenzhen.aliyuncs.com"
|
||||
module_aligreen_service: str = "llm_query_moderation"
|
||||
|
||||
|
||||
class AlibabaGreen:
|
||||
_client: AlibabaGreenClient | None = None
|
||||
_config: AlibabaGreenPluginConfig | None = None
|
||||
|
||||
@staticmethod
|
||||
def get_client() -> AlibabaGreenClient:
|
||||
assert AlibabaGreen._client is not None
|
||||
return AlibabaGreen._client
|
||||
|
||||
@staticmethod
|
||||
def get_config() -> AlibabaGreenPluginConfig:
|
||||
assert AlibabaGreen._config is not None
|
||||
return AlibabaGreen._config
|
||||
|
||||
@staticmethod
|
||||
def init():
|
||||
config = nonebot.get_plugin_config(AlibabaGreenPluginConfig)
|
||||
AlibabaGreen._config = config
|
||||
if not config.module_aligreen_enable:
|
||||
logger.info("该环境未启用阿里内容审查,跳过初始化")
|
||||
return
|
||||
AlibabaGreen._client = AlibabaGreenClient(AlibabaTeaConfig(
|
||||
access_key_id=config.module_aligreen_access_key_id,
|
||||
access_key_secret=config.module_aligreen_access_key_secret,
|
||||
connect_timeout=10000,
|
||||
read_timeout=3000,
|
||||
region_id=config.module_aligreen_region_id,
|
||||
endpoint=config.module_aligreen_endpoint,
|
||||
))
|
||||
|
||||
@staticmethod
|
||||
def _detect_sync(content: str) -> bool:
|
||||
if len(content) == 0:
|
||||
return True
|
||||
if not AlibabaGreen.get_config().module_aligreen_enable:
|
||||
logger.debug("该环境未启用阿里内容审查,直接跳过")
|
||||
return True
|
||||
|
||||
client = AlibabaGreen.get_client()
|
||||
try:
|
||||
response = client.text_moderation_plus(TextModerationPlusRequest(
|
||||
service=AlibabaGreen.get_config().module_aligreen_service,
|
||||
service_parameters=json.dumps({
|
||||
"content": content,
|
||||
}),
|
||||
))
|
||||
if response.status_code == 200:
|
||||
result = response.body
|
||||
logger.info(f"检测违规内容 API 调用成功:{result}")
|
||||
risk_level: str = result.data.risk_level or "none"
|
||||
if risk_level == "high":
|
||||
return False
|
||||
return True
|
||||
logger.error(f"检测违规内容 API 调用失败:{response}")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error("检测违规内容 API 调用失败")
|
||||
logger.exception(e)
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
async def detect(content: str) -> bool:
|
||||
return await asyncio.to_thread(AlibabaGreen._detect_sync, content)
|
||||
|
||||
|
||||
driver = nonebot.get_driver()
|
||||
|
||||
@driver.on_startup
|
||||
async def _():
|
||||
AlibabaGreen.init()
|
||||
|
||||
281
konabot/common/apis/wolfx.py
Normal file
281
konabot/common/apis/wolfx.py
Normal file
@ -0,0 +1,281 @@
|
||||
"""
|
||||
Wolfx 防灾免费 API
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from typing import Literal, TypeVar, cast
|
||||
import aiohttp
|
||||
from aiosignal import Signal
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel, RootModel
|
||||
import pydantic
|
||||
|
||||
from konabot.common.appcontext import after_init
|
||||
|
||||
|
||||
class ScEewReport(BaseModel):
|
||||
"""
|
||||
四川地震局报文
|
||||
"""
|
||||
|
||||
ID: str
|
||||
"EEW 发报 ID"
|
||||
|
||||
EventID: str
|
||||
"EEW 发报事件 ID"
|
||||
|
||||
ReportTime: str
|
||||
"EEW 发报时间(UTC+8)"
|
||||
|
||||
ReportNum: int
|
||||
"EEW 发报数"
|
||||
|
||||
OriginTime: str
|
||||
"发震时间(UTC+8)"
|
||||
|
||||
HypoCenter: str
|
||||
"震源地"
|
||||
|
||||
Latitude: float
|
||||
"震源地纬度"
|
||||
|
||||
Longitude: float
|
||||
"震源地经度"
|
||||
|
||||
Magnitude: float
|
||||
"震级"
|
||||
|
||||
Depth: float | None
|
||||
"震源深度"
|
||||
|
||||
MaxIntensity: float
|
||||
"最大烈度"
|
||||
|
||||
|
||||
class CencEewReport(BaseModel):
|
||||
"""
|
||||
中国地震台网报文
|
||||
"""
|
||||
|
||||
ID: str
|
||||
"EEW 发报 ID"
|
||||
|
||||
EventID: str
|
||||
"EEW 发报事件 ID"
|
||||
|
||||
ReportTime: str
|
||||
"EEW 发报时间(UTC+8)"
|
||||
|
||||
ReportNum: int
|
||||
"EEW 发报数"
|
||||
|
||||
OriginTime: str
|
||||
"发震时间(UTC+8)"
|
||||
|
||||
HypoCenter: str
|
||||
"震源地"
|
||||
|
||||
Latitude: float
|
||||
"震源地纬度"
|
||||
|
||||
Longitude: float
|
||||
"震源地经度"
|
||||
|
||||
Magnitude: float
|
||||
"震级"
|
||||
|
||||
Depth: float | None
|
||||
"震源深度"
|
||||
|
||||
MaxIntensity: float
|
||||
"最大烈度"
|
||||
|
||||
|
||||
class CencEqReport(BaseModel):
|
||||
type: str
|
||||
"报告类型"
|
||||
|
||||
EventID: str
|
||||
"事件 ID"
|
||||
|
||||
time: str
|
||||
"UTC+8 格式的地震发生时间"
|
||||
|
||||
location: str
|
||||
"地震发生位置"
|
||||
|
||||
magnitude: str
|
||||
"震级"
|
||||
|
||||
depth: str
|
||||
"地震深度"
|
||||
|
||||
latitude: str
|
||||
"纬度"
|
||||
|
||||
longtitude: str
|
||||
"经度"
|
||||
|
||||
intensity: str
|
||||
"烈度"
|
||||
|
||||
|
||||
class CencEqlist(RootModel):
|
||||
root: dict[str, CencEqReport]
|
||||
|
||||
|
||||
class WolfxWebSocket:
|
||||
def __init__(self, url: str) -> None:
|
||||
self.url = url
|
||||
self.signal: Signal[bytes] = Signal(self)
|
||||
self._running = False
|
||||
self._task: asyncio.Task | None = None
|
||||
self._session: aiohttp.ClientSession | None = None
|
||||
self._ws: aiohttp.ClientWebSocketResponse | None = None
|
||||
|
||||
@property
|
||||
def session(self) -> aiohttp.ClientSession: # pragma: no cover
|
||||
assert self._session is not None
|
||||
return self._session
|
||||
|
||||
async def start(self): # pragma: no cover
|
||||
if self._running:
|
||||
return
|
||||
self._running = True
|
||||
self._session = aiohttp.ClientSession()
|
||||
self._task = asyncio.create_task(self._run())
|
||||
self.signal.freeze()
|
||||
|
||||
async def stop(self): # pragma: no cover
|
||||
self._running = False
|
||||
if self._task:
|
||||
self._task.cancel()
|
||||
try:
|
||||
await self._task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
if self._session:
|
||||
await self._session.close()
|
||||
|
||||
async def _run(self): # pragma: no cover
|
||||
retry_delay = 1
|
||||
|
||||
while self._running:
|
||||
try:
|
||||
async with self.session.ws_connect(self.url) as ws:
|
||||
self._ws = ws
|
||||
logger.info(f"Wolfx API 服务连接上了 {self.url} 的 WebSocket")
|
||||
async for msg in ws:
|
||||
if msg.type == aiohttp.WSMsgType.TEXT:
|
||||
await self.handle(cast(str, msg.data).encode())
|
||||
elif msg.type == aiohttp.WSMsgType.BINARY:
|
||||
await self.handle(cast(bytes, msg.data))
|
||||
elif msg.type == aiohttp.WSMsgType.CLOSED:
|
||||
break
|
||||
elif msg.type == aiohttp.WSMsgType.ERROR:
|
||||
break
|
||||
except (aiohttp.ClientError, asyncio.TimeoutError) as e:
|
||||
logger.warning("连接 WebSocket 时发生错误")
|
||||
logger.exception(e)
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error("Wolfx API 发生未知错误")
|
||||
logger.exception(e)
|
||||
self._ws = None
|
||||
|
||||
if self._running:
|
||||
logger.info(f"Wolfx API 准备断线重连 {self.url}")
|
||||
await asyncio.sleep(retry_delay)
|
||||
retry_delay = min(retry_delay * 2, 60)
|
||||
|
||||
async def handle(self, data: bytes):
|
||||
try:
|
||||
obj = json.loads(data)
|
||||
except json.JSONDecodeError as e:
|
||||
logger.warning("解析 Wolfs API 时出错")
|
||||
logger.exception(e)
|
||||
return
|
||||
|
||||
if obj.get("type") == "heartbeat" or obj.get("type") == "pong":
|
||||
logger.debug(f"Wolfx API 收到了来自 {self.url} 的心跳: {obj}")
|
||||
else:
|
||||
await self.signal.send(data)
|
||||
|
||||
|
||||
T = TypeVar("T", bound=BaseModel)
|
||||
|
||||
|
||||
class WolfxAPIService:
|
||||
sc_eew: Signal[ScEewReport]
|
||||
"四川地震局地震速报"
|
||||
|
||||
cenc_eew: Signal[CencEewReport]
|
||||
"中国地震台网地震速报"
|
||||
|
||||
cenc_eqlist: Signal[CencEqReport]
|
||||
"中国地震台网地震信息发布"
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.sc_eew = Signal(self)
|
||||
self._sc_eew_ws = WolfxWebSocket("wss://ws-api.wolfx.jp/sc_eew")
|
||||
WolfxAPIService.bind(self.sc_eew, self._sc_eew_ws, ScEewReport)
|
||||
|
||||
self.cenc_eew = Signal(self)
|
||||
self._cenc_eew_ws = WolfxWebSocket("wss://ws-api.wolfx.jp/cenc_eew")
|
||||
WolfxAPIService.bind(self.cenc_eew, self._cenc_eew_ws, CencEewReport)
|
||||
|
||||
self.cenc_eqlist = Signal(self)
|
||||
self._cenc_eqlist_ws = WolfxWebSocket("wss://ws-api.wolfx.jp/cenc_eqlist")
|
||||
WolfxAPIService.bind(self.cenc_eqlist, self._cenc_eqlist_ws, CencEqReport)
|
||||
|
||||
@staticmethod
|
||||
def bind(signal: Signal[T], ws: WolfxWebSocket, t: type[T]):
|
||||
@ws.signal.append
|
||||
async def _(data: bytes):
|
||||
try:
|
||||
obj = t.model_validate_json(data)
|
||||
logger.info(f"接收到来自 Wolfx API 的信息:{data}")
|
||||
await signal.send(obj)
|
||||
except pydantic.ValidationError as e:
|
||||
logger.warning(f"解析 Wolfx API 时出错 URL={ws.url}")
|
||||
logger.error(e)
|
||||
|
||||
async def start(self): # pragma: no cover
|
||||
self.cenc_eew.freeze()
|
||||
self.sc_eew.freeze()
|
||||
self.cenc_eqlist.freeze()
|
||||
async with asyncio.TaskGroup() as task_group:
|
||||
if len(self.cenc_eew) > 0:
|
||||
task_group.create_task(self._cenc_eew_ws.start())
|
||||
|
||||
if len(self.sc_eew) > 0:
|
||||
task_group.create_task(self._sc_eew_ws.start())
|
||||
|
||||
if len(self.cenc_eqlist) > 0:
|
||||
task_group.create_task(self._cenc_eqlist_ws.start())
|
||||
|
||||
async def stop(self): # pragma: no cover
|
||||
async with asyncio.TaskGroup() as task_group:
|
||||
task_group.create_task(self._cenc_eew_ws.stop())
|
||||
task_group.create_task(self._sc_eew_ws.stop())
|
||||
task_group.create_task(self._cenc_eqlist_ws.stop())
|
||||
|
||||
|
||||
wolfx_api = WolfxAPIService()
|
||||
|
||||
|
||||
@after_init
|
||||
def init(): # pragma: no cover
|
||||
import nonebot
|
||||
|
||||
driver = nonebot.get_driver()
|
||||
|
||||
@driver.on_startup
|
||||
async def _():
|
||||
await wolfx_api.start()
|
||||
|
||||
@driver.on_shutdown
|
||||
async def _():
|
||||
await wolfx_api.stop()
|
||||
15
konabot/common/appcontext.py
Normal file
15
konabot/common/appcontext.py
Normal file
@ -0,0 +1,15 @@
|
||||
from typing import Any, Callable
|
||||
|
||||
|
||||
AFTER_INIT_FUNCTION = Callable[[], Any]
|
||||
|
||||
_after_init_functions: list[AFTER_INIT_FUNCTION] = []
|
||||
|
||||
|
||||
def after_init(func: AFTER_INIT_FUNCTION):
|
||||
_after_init_functions.append(func)
|
||||
|
||||
|
||||
def run_afterinit_functions(): # pragma: no cover
|
||||
for f in _after_init_functions:
|
||||
f()
|
||||
112
konabot/common/artifact.py
Normal file
112
konabot/common/artifact.py
Normal file
@ -0,0 +1,112 @@
|
||||
import asyncio
|
||||
import aiohttp
|
||||
import hashlib
|
||||
import platform
|
||||
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
import nonebot
|
||||
from loguru import logger
|
||||
from nonebot.adapters.discord.config import Config as DiscordConfig
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
@dataclass
|
||||
class ArtifactDepends:
|
||||
url: str
|
||||
sha256: str
|
||||
target: Path
|
||||
|
||||
required_os: str | None = None
|
||||
"示例值:Windows, Linux, Darwin"
|
||||
|
||||
required_arch: str | None = None
|
||||
"示例值:AMD64, x86_64, arm64"
|
||||
|
||||
use_proxy: bool = True
|
||||
"网络问题,赫赫;使用的是 Discord 模块配置的 proxy"
|
||||
|
||||
def is_corresponding_platform(self) -> bool:
|
||||
if self.required_os is not None:
|
||||
if self.required_os.lower() != platform.system().lower():
|
||||
return False
|
||||
if self.required_arch is not None:
|
||||
if self.required_arch.lower() != platform.machine().lower():
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
class Config(BaseModel):
|
||||
prefetch_artifact: bool = False
|
||||
"是否提前下载好二进制依赖"
|
||||
|
||||
|
||||
artifact_list = []
|
||||
|
||||
|
||||
driver = nonebot.get_driver()
|
||||
config = nonebot.get_plugin_config(Config)
|
||||
|
||||
@driver.on_startup
|
||||
async def _():
|
||||
if config.prefetch_artifact:
|
||||
logger.info("启动检测中:正在检测需求的二进制是否下载")
|
||||
semaphore = asyncio.Semaphore(10)
|
||||
async def _task(artifact: ArtifactDepends):
|
||||
async with semaphore:
|
||||
await ensure_artifact(artifact)
|
||||
|
||||
tasks: set[asyncio.Task] = set()
|
||||
for a in artifact_list:
|
||||
tasks.add(asyncio.Task(_task(a)))
|
||||
await asyncio.gather(*tasks, return_exceptions=False)
|
||||
logger.info("检测好了")
|
||||
|
||||
|
||||
async def download_artifact(artifact: ArtifactDepends):
|
||||
proxy = None
|
||||
if artifact.use_proxy:
|
||||
discord_config = nonebot.get_plugin_config(DiscordConfig)
|
||||
proxy = discord_config.discord_proxy
|
||||
|
||||
if proxy is not None:
|
||||
logger.info(f"正在使用 Proxy 下载 TARGET={artifact.target} PROXY={proxy}")
|
||||
else:
|
||||
logger.info(f"正在下载 TARGET={artifact.target}")
|
||||
|
||||
async with aiohttp.ClientSession(proxy=proxy) as client:
|
||||
result = await client.get(artifact.url)
|
||||
if result.status != 200:
|
||||
logger.warning(f"已经下载了二进制,但是注意服务器没有返回 200! URL={artifact.url} TARGET={artifact.target} CODE={result.status}")
|
||||
data = await result.read()
|
||||
artifact.target.write_bytes(data)
|
||||
if not platform.system().lower() == 'windows':
|
||||
artifact.target.chmod(0o755)
|
||||
|
||||
logger.info(f"下载好了 TARGET={artifact.target} URL={artifact.url}")
|
||||
m = hashlib.sha256(artifact.target.read_bytes())
|
||||
if m.hexdigest().lower() != artifact.sha256.lower():
|
||||
logger.warning(f"下载到的二进制的 sha256 与需求不同 TARGET={artifact.target} REQUESTED={artifact.sha256} ACTUAL={m.hexdigest()}")
|
||||
|
||||
|
||||
async def ensure_artifact(artifact: ArtifactDepends):
|
||||
if not artifact.is_corresponding_platform():
|
||||
return
|
||||
|
||||
if not artifact.target.exists():
|
||||
logger.info(f"二进制依赖 {artifact.target} 不存在")
|
||||
if not artifact.target.parent.exists():
|
||||
artifact.target.parent.mkdir(parents=True, exist_ok=True)
|
||||
await download_artifact(artifact)
|
||||
else:
|
||||
m = hashlib.sha256(artifact.target.read_bytes())
|
||||
if m.hexdigest().lower() != artifact.sha256.lower():
|
||||
logger.info(f"二进制依赖 {artifact.target} 的哈希无法对应需求的哈希,准备重新下载")
|
||||
artifact.target.unlink()
|
||||
await download_artifact(artifact)
|
||||
|
||||
|
||||
def register_artifacts(*artifacts: ArtifactDepends):
|
||||
artifact_list.extend(artifacts)
|
||||
|
||||
231
konabot/common/database/__init__.py
Normal file
231
konabot/common/database/__init__.py
Normal file
@ -0,0 +1,231 @@
|
||||
from contextlib import asynccontextmanager
|
||||
import os
|
||||
import asyncio
|
||||
import sqlparse
|
||||
from pathlib import Path
|
||||
from typing import List, Dict, Any, Optional, Union, TYPE_CHECKING
|
||||
|
||||
import aiosqlite
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from . import DatabaseManager
|
||||
|
||||
# 全局数据库管理器实例
|
||||
_global_db_manager: Optional["DatabaseManager"] = None
|
||||
|
||||
|
||||
def get_global_db_manager() -> "DatabaseManager":
|
||||
"""获取全局数据库管理器实例"""
|
||||
global _global_db_manager
|
||||
if _global_db_manager is None:
|
||||
from . import DatabaseManager
|
||||
|
||||
_global_db_manager = DatabaseManager()
|
||||
return _global_db_manager
|
||||
|
||||
|
||||
def close_global_db_manager() -> None:
|
||||
"""关闭全局数据库管理器实例"""
|
||||
global _global_db_manager
|
||||
if _global_db_manager is not None:
|
||||
# 注意:这个函数应该在async环境中调用close_all_connections
|
||||
_global_db_manager = None
|
||||
|
||||
|
||||
class DatabaseManager:
|
||||
"""异步数据库管理器"""
|
||||
|
||||
def __init__(self, db_path: Optional[Union[str, Path]] = None, pool_size: int = 5):
|
||||
"""
|
||||
初始化数据库管理器
|
||||
|
||||
Args:
|
||||
db_path: 数据库文件路径,支持str和Path类型
|
||||
pool_size: 连接池大小
|
||||
"""
|
||||
if db_path is None:
|
||||
self.db_path = os.environ.get("DATABASE_PATH", "./data/database.db")
|
||||
else:
|
||||
self.db_path = str(db_path) if isinstance(db_path, Path) else db_path
|
||||
|
||||
# 连接池
|
||||
self._connection_pool = []
|
||||
self._pool_size = pool_size
|
||||
self._lock = asyncio.Lock()
|
||||
self._in_use = set() # 跟踪正在使用的连接
|
||||
|
||||
async def _get_connection(self) -> aiosqlite.Connection:
|
||||
"""从连接池获取连接"""
|
||||
async with self._lock:
|
||||
# 尝试从池中获取现有连接
|
||||
while self._connection_pool:
|
||||
conn = self._connection_pool.pop()
|
||||
# 检查连接是否仍然有效
|
||||
try:
|
||||
await conn.execute("SELECT 1")
|
||||
self._in_use.add(conn)
|
||||
return conn
|
||||
except:
|
||||
# 连接已失效,关闭它
|
||||
try:
|
||||
await conn.close()
|
||||
except:
|
||||
pass
|
||||
|
||||
# 如果连接池为空,创建新连接
|
||||
conn = await aiosqlite.connect(self.db_path)
|
||||
await conn.execute("PRAGMA foreign_keys = ON")
|
||||
self._in_use.add(conn)
|
||||
return conn
|
||||
|
||||
async def _return_connection(self, conn: aiosqlite.Connection) -> None:
|
||||
"""将连接返回到连接池"""
|
||||
async with self._lock:
|
||||
self._in_use.discard(conn)
|
||||
if len(self._connection_pool) < self._pool_size:
|
||||
self._connection_pool.append(conn)
|
||||
else:
|
||||
# 池已满,直接关闭连接
|
||||
try:
|
||||
await conn.close()
|
||||
except:
|
||||
pass
|
||||
|
||||
@asynccontextmanager
|
||||
async def get_conn(self):
|
||||
conn = await self._get_connection()
|
||||
yield conn
|
||||
await self._return_connection(conn)
|
||||
|
||||
async def query(
|
||||
self, query: str, params: Optional[tuple] = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""执行查询语句并返回结果"""
|
||||
conn = await self._get_connection()
|
||||
try:
|
||||
cursor = await conn.execute(query, params or ())
|
||||
columns = [description[0] for description in cursor.description]
|
||||
rows = await cursor.fetchall()
|
||||
results = [dict(zip(columns, row)) for row in rows]
|
||||
await cursor.close()
|
||||
return results
|
||||
except Exception as e:
|
||||
# 记录错误但重新抛出,让调用者处理
|
||||
raise Exception(f"数据库查询失败: {str(e)}") from e
|
||||
finally:
|
||||
await self._return_connection(conn)
|
||||
|
||||
async def query_by_sql_file(
|
||||
self, file_path: Union[str, Path], params: Optional[tuple] = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""从 SQL 文件中读取查询语句并执行"""
|
||||
path = str(file_path) if isinstance(file_path, Path) else file_path
|
||||
with open(path, "r", encoding="utf-8") as f:
|
||||
query = f.read()
|
||||
return await self.query(query, params)
|
||||
|
||||
async def execute(self, command: str, params: Optional[tuple] = None) -> None:
|
||||
"""执行非查询语句"""
|
||||
conn = await self._get_connection()
|
||||
try:
|
||||
await conn.execute(command, params or ())
|
||||
await conn.commit()
|
||||
except Exception as e:
|
||||
# 记录错误但重新抛出,让调用者处理
|
||||
raise Exception(f"数据库执行失败: {str(e)}") from e
|
||||
finally:
|
||||
await self._return_connection(conn)
|
||||
|
||||
async def execute_script(self, script: str) -> None:
|
||||
"""执行SQL脚本"""
|
||||
conn = await self._get_connection()
|
||||
try:
|
||||
await conn.executescript(script)
|
||||
await conn.commit()
|
||||
except Exception as e:
|
||||
# 记录错误但重新抛出,让调用者处理
|
||||
raise Exception(f"数据库脚本执行失败: {str(e)}") from e
|
||||
finally:
|
||||
await self._return_connection(conn)
|
||||
|
||||
def _parse_sql_statements(self, script: str) -> List[str]:
|
||||
"""解析SQL脚本,分割成独立的语句"""
|
||||
# 使用sqlparse库更准确地分割SQL语句
|
||||
parsed = sqlparse.split(script)
|
||||
statements = []
|
||||
|
||||
for statement in parsed:
|
||||
statement = statement.strip()
|
||||
if statement:
|
||||
statements.append(statement)
|
||||
|
||||
return statements
|
||||
|
||||
async def execute_by_sql_file(
|
||||
self,
|
||||
file_path: Union[str, Path],
|
||||
params: Optional[Union[tuple, List[tuple]]] = None,
|
||||
) -> None:
|
||||
"""从 SQL 文件中读取非查询语句并执行"""
|
||||
path = str(file_path) if isinstance(file_path, Path) else file_path
|
||||
with open(path, "r", encoding="utf-8") as f:
|
||||
script = f.read()
|
||||
|
||||
# 如果有参数且是元组,使用execute执行整个脚本
|
||||
if params is not None and isinstance(params, tuple):
|
||||
await self.execute(script, params)
|
||||
# 如果有参数且是列表,分别执行每个语句
|
||||
elif params is not None and isinstance(params, list):
|
||||
# 使用sqlparse准确分割SQL语句
|
||||
statements = self._parse_sql_statements(script)
|
||||
if len(statements) != len(params):
|
||||
raise ValueError(
|
||||
f"语句数量({len(statements)})与参数组数量({len(params)})不匹配"
|
||||
)
|
||||
|
||||
for statement, stmt_params in zip(statements, params):
|
||||
if statement:
|
||||
await self.execute(statement, stmt_params)
|
||||
# 如果无参数,使用executescript
|
||||
else:
|
||||
await self.execute_script(script)
|
||||
|
||||
async def execute_many(self, command: str, seq_of_params: List[tuple]) -> None:
|
||||
"""执行多条非查询语句"""
|
||||
conn = await self._get_connection()
|
||||
try:
|
||||
await conn.executemany(command, seq_of_params)
|
||||
await conn.commit()
|
||||
except Exception as e:
|
||||
# 记录错误但重新抛出,让调用者处理
|
||||
raise Exception(f"数据库批量执行失败: {str(e)}") from e
|
||||
finally:
|
||||
await self._return_connection(conn)
|
||||
|
||||
async def execute_many_values_by_sql_file(
|
||||
self, file_path: Union[str, Path], seq_of_params: List[tuple]
|
||||
) -> None:
|
||||
"""从 SQL 文件中读取一条语句,但是被不同值同时执行"""
|
||||
path = str(file_path) if isinstance(file_path, Path) else file_path
|
||||
with open(path, "r", encoding="utf-8") as f:
|
||||
command = f.read()
|
||||
await self.execute_many(command, seq_of_params)
|
||||
|
||||
async def close_all_connections(self) -> None:
|
||||
"""关闭所有连接"""
|
||||
async with self._lock:
|
||||
# 关闭池中的连接
|
||||
for conn in self._connection_pool:
|
||||
try:
|
||||
await conn.close()
|
||||
except:
|
||||
pass
|
||||
self._connection_pool.clear()
|
||||
|
||||
# 关闭正在使用的连接
|
||||
for conn in self._in_use.copy():
|
||||
try:
|
||||
await conn.close()
|
||||
except:
|
||||
pass
|
||||
self._in_use.clear()
|
||||
@ -1,4 +1,4 @@
|
||||
from typing import Any
|
||||
from typing import Any, cast
|
||||
import openai
|
||||
|
||||
from loguru import logger
|
||||
@ -26,14 +26,14 @@ class LLMInfo(BaseModel):
|
||||
|
||||
async def chat(
|
||||
self,
|
||||
messages: list[ChatCompletionMessageParam],
|
||||
messages: list[ChatCompletionMessageParam] | list[dict[str, Any]],
|
||||
timeout: float | None = 30.0,
|
||||
max_tokens: int | None = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatCompletionMessage:
|
||||
logger.info(f"调用 LLM: BASE_URL={self.base_url} MODEL_NAME={self.model_name}")
|
||||
completion: ChatCompletion = await self.get_openai_client().chat.completions.create(
|
||||
messages=messages,
|
||||
messages=cast(Any, messages),
|
||||
model=self.model_name,
|
||||
max_tokens=max_tokens,
|
||||
timeout=timeout,
|
||||
@ -59,6 +59,9 @@ def get_llm(llm_model: str | None = None):
|
||||
if llm_model is None:
|
||||
llm_model = llm_config.default_llm
|
||||
if llm_model not in llm_config.llms:
|
||||
raise NotImplementedError("LLM 未配置,该功能无法使用")
|
||||
if llm_config.default_llm in llm_config.llms:
|
||||
logger.warning(f"[LLM] 需求的 LLM 不存在,回退到默认模型 REQUIRED={llm_model}")
|
||||
return llm_config.llms[llm_config.default_llm]
|
||||
raise NotImplementedError("[LLM] LLM 未配置,该功能无法使用")
|
||||
return llm_config.llms[llm_model]
|
||||
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from typing import Annotated
|
||||
|
||||
import httpx
|
||||
@ -19,15 +20,21 @@ from PIL import UnidentifiedImageError
|
||||
from pydantic import BaseModel
|
||||
from returns.result import Failure, Result, Success
|
||||
|
||||
from konabot.common.path import ASSETS_PATH
|
||||
|
||||
|
||||
discordConfig = nonebot.get_plugin_config(DiscordConfig)
|
||||
|
||||
|
||||
class ExtractImageConfig(BaseModel):
|
||||
module_extract_image_no_download: bool = False
|
||||
"要不要算了,不下载了,直接爆炸算了,适用于一些比较奇怪的网络环境,无法从协议端下载文件"
|
||||
"""
|
||||
要不要算了,不下载了,直接爆炸算了,
|
||||
适用于一些比较奇怪的网络环境,无法从协议端下载文件
|
||||
"""
|
||||
|
||||
module_extract_image_target: str = './assets/img/other/boom.jpg'
|
||||
"""
|
||||
使用哪个图片呢
|
||||
"""
|
||||
|
||||
|
||||
module_config = nonebot.get_plugin_config(ExtractImageConfig)
|
||||
@ -37,7 +44,7 @@ async def download_image_bytes(url: str, proxy: str | None = None) -> Result[byt
|
||||
# if "/matcha/cache/" in url:
|
||||
# url = url.replace('127.0.0.1', '10.126.126.101')
|
||||
if module_config.module_extract_image_no_download:
|
||||
return Success((ASSETS_PATH / "img" / "other" / "boom.jpg").read_bytes())
|
||||
return Success(Path(module_config.module_extract_image_target).read_bytes())
|
||||
logger.debug(f"开始从 {url} 下载图片")
|
||||
async with httpx.AsyncClient(proxy=proxy) as c:
|
||||
try:
|
||||
@ -70,15 +77,22 @@ def bytes_to_pil(raw_data: bytes | BytesIO) -> Result[PIL.Image.Image, str]:
|
||||
return Failure("图像无法读取,可能是网络存在问题orz")
|
||||
|
||||
|
||||
async def unimsg_img_to_pil(image: Image) -> Result[PIL.Image.Image, str]:
|
||||
async def unimsg_img_to_bytes(image: Image) -> Result[bytes, str]:
|
||||
if image.url is not None:
|
||||
raw_result = await download_image_bytes(image.url)
|
||||
elif image.raw is not None:
|
||||
raw_result = Success(image.raw)
|
||||
if isinstance(image.raw, bytes):
|
||||
raw_result = Success(image.raw)
|
||||
else:
|
||||
raw_result = Success(image.raw.getvalue())
|
||||
else:
|
||||
return Failure("由于一些内部问题,下载图片失败了orz")
|
||||
|
||||
return raw_result.bind(bytes_to_pil)
|
||||
return raw_result
|
||||
|
||||
|
||||
async def unimsg_img_to_pil(image: Image) -> Result[PIL.Image.Image, str]:
|
||||
return (await unimsg_img_to_bytes(image)).bind(bytes_to_pil)
|
||||
|
||||
|
||||
async def extract_image_from_qq_message(
|
||||
@ -86,7 +100,7 @@ async def extract_image_from_qq_message(
|
||||
evt: OnebotV11MessageEvent,
|
||||
bot: OnebotV11Bot,
|
||||
allow_reply: bool = True,
|
||||
) -> Result[PIL.Image.Image, str]:
|
||||
) -> Result[bytes, str]:
|
||||
if allow_reply and (reply := evt.reply) is not None:
|
||||
return await extract_image_from_qq_message(
|
||||
reply.message,
|
||||
@ -118,18 +132,17 @@ async def extract_image_from_qq_message(
|
||||
url = seg.data.get("url")
|
||||
if url is None:
|
||||
return Failure("无法下载图片,可能有一些网络问题")
|
||||
data = await download_image_bytes(url)
|
||||
return data.bind(bytes_to_pil)
|
||||
return await download_image_bytes(url)
|
||||
|
||||
return Failure("请在消息中包含图片,或者引用一个含有图片的消息")
|
||||
|
||||
|
||||
async def extract_image_from_message(
|
||||
async def extract_image_data_from_message(
|
||||
msg: Message,
|
||||
evt: Event,
|
||||
bot: Bot,
|
||||
allow_reply: bool = True,
|
||||
) -> Result[PIL.Image.Image, str]:
|
||||
) -> Result[bytes, str]:
|
||||
if (
|
||||
isinstance(bot, OnebotV11Bot)
|
||||
and isinstance(msg, OnebotV11Message)
|
||||
@ -145,18 +158,18 @@ async def extract_image_from_message(
|
||||
if "image/" not in a.content_type:
|
||||
continue
|
||||
url = a.proxy_url
|
||||
return (await download_image_bytes(url, discordConfig.discord_proxy)).bind(bytes_to_pil)
|
||||
return await download_image_bytes(url, discordConfig.discord_proxy)
|
||||
|
||||
for seg in UniMessage.of(msg, bot):
|
||||
logger.info(seg)
|
||||
if isinstance(seg, Image):
|
||||
return await unimsg_img_to_pil(seg)
|
||||
return await unimsg_img_to_bytes(seg)
|
||||
elif isinstance(seg, Reply) and allow_reply:
|
||||
msg2 = seg.msg
|
||||
logger.debug(f"深入搜索引用的消息:{msg2}")
|
||||
if msg2 is None or isinstance(msg2, str):
|
||||
continue
|
||||
return await extract_image_from_message(msg2, evt, bot, False)
|
||||
return await extract_image_data_from_message(msg2, evt, bot, False)
|
||||
elif isinstance(seg, RefNode) and allow_reply:
|
||||
if isinstance(bot, DiscordBot):
|
||||
return Failure("暂时不支持在 Discord 中通过引用的方式获取图片")
|
||||
@ -165,12 +178,12 @@ async def extract_image_from_message(
|
||||
return Failure("请在消息中包含图片,或者引用一个含有图片的消息")
|
||||
|
||||
|
||||
async def _ext_img(
|
||||
async def _ext_img_data(
|
||||
evt: Event,
|
||||
bot: Bot,
|
||||
matcher: Matcher,
|
||||
) -> PIL.Image.Image | None:
|
||||
match await extract_image_from_message(evt.get_message(), evt, bot):
|
||||
) -> bytes | None:
|
||||
match await extract_image_data_from_message(evt.get_message(), evt, bot):
|
||||
case Success(img):
|
||||
return img
|
||||
case Failure(err):
|
||||
@ -180,4 +193,35 @@ async def _ext_img(
|
||||
assert False
|
||||
|
||||
|
||||
PIL_Image = Annotated[PIL.Image.Image, nonebot.params.Depends(_ext_img)]
|
||||
async def _ext_img(
|
||||
evt: Event,
|
||||
bot: Bot,
|
||||
matcher: Matcher,
|
||||
) -> PIL.Image.Image | None:
|
||||
r = await _ext_img_data(evt, bot, matcher)
|
||||
if r:
|
||||
match bytes_to_pil(r):
|
||||
case Success(img):
|
||||
return img
|
||||
case Failure(msg):
|
||||
await matcher.send(await UniMessage.text(msg).export())
|
||||
return None
|
||||
|
||||
async def _try_ext_img(
|
||||
evt: Event,
|
||||
bot: Bot,
|
||||
matcher: Matcher,
|
||||
) -> bytes | None:
|
||||
match await extract_image_data_from_message(evt.get_message(), evt, bot):
|
||||
case Success(img):
|
||||
return img
|
||||
case Failure(err):
|
||||
# raise BotExceptionMessage(err)
|
||||
# await matcher.send(await UniMessage().text(err).export())
|
||||
return None
|
||||
assert False
|
||||
|
||||
DepImageBytes = Annotated[bytes, nonebot.params.Depends(_ext_img_data)]
|
||||
DepPILImage = Annotated[PIL.Image.Image, nonebot.params.Depends(_ext_img)]
|
||||
|
||||
DepImageBytesOrNone = Annotated[bytes | None, nonebot.params.Depends(_try_ext_img)]
|
||||
|
||||
@ -5,8 +5,10 @@ FONTS_PATH = ASSETS_PATH / "fonts"
|
||||
|
||||
SRC_PATH = Path(__file__).resolve().parent.parent
|
||||
DATA_PATH = SRC_PATH.parent / "data"
|
||||
TMP_PATH = DATA_PATH / "tmp"
|
||||
LOG_PATH = DATA_PATH / "logs"
|
||||
CONFIG_PATH = DATA_PATH / "config"
|
||||
BINARY_PATH = DATA_PATH / "bin"
|
||||
|
||||
DOCS_PATH = SRC_PATH / "docs"
|
||||
DOCS_PATH_MAN1 = DOCS_PATH / "user"
|
||||
@ -21,4 +23,6 @@ if not LOG_PATH.exists():
|
||||
LOG_PATH.mkdir()
|
||||
|
||||
CONFIG_PATH.mkdir(exist_ok=True)
|
||||
TMP_PATH.mkdir(exist_ok=True)
|
||||
BINARY_PATH.mkdir(exist_ok=True)
|
||||
|
||||
|
||||
119
konabot/common/permsys/__init__.py
Normal file
119
konabot/common/permsys/__init__.py
Normal file
@ -0,0 +1,119 @@
|
||||
from typing import Annotated
|
||||
import nonebot
|
||||
from nonebot.adapters import Event
|
||||
from nonebot.params import Depends
|
||||
from nonebot.rule import Rule
|
||||
|
||||
from konabot.common.appcontext import after_init
|
||||
from konabot.common.database import DatabaseManager
|
||||
from konabot.common.pager import PagerQuery
|
||||
from konabot.common.path import DATA_PATH
|
||||
from konabot.common.permsys.entity import PermEntity, get_entity_chain
|
||||
from konabot.common.permsys.migrates import execute_migration
|
||||
from konabot.common.permsys.repo import PermRepo
|
||||
|
||||
|
||||
db = DatabaseManager(DATA_PATH / "perm.sqlite3")
|
||||
_default_allow_permissions: set[str] = set()
|
||||
|
||||
|
||||
_EntityLike = Event | PermEntity | list[PermEntity]
|
||||
|
||||
|
||||
async def _to_entity_chain(el: _EntityLike):
|
||||
if isinstance(el, Event):
|
||||
return await get_entity_chain(el) # pragma: no cover
|
||||
if isinstance(el, PermEntity):
|
||||
return [el]
|
||||
return el
|
||||
|
||||
|
||||
class PermManager:
|
||||
def __init__(self, db: DatabaseManager) -> None:
|
||||
self.db = db
|
||||
|
||||
async def check_has_permission_info(self, entities: _EntityLike, key: str):
|
||||
entities = await _to_entity_chain(entities)
|
||||
key = key.removesuffix("*").removesuffix(".")
|
||||
key_split = key.split(".")
|
||||
key_split = [s for s in key_split if len(s) > 0]
|
||||
keys = [".".join(key_split[: i + 1]) for i in range(len(key_split))][::-1] + [
|
||||
"*"
|
||||
]
|
||||
|
||||
async with self.db.get_conn() as conn:
|
||||
repo = PermRepo(conn)
|
||||
data = await repo.get_perm_info_batch(entities, keys)
|
||||
for entity in entities:
|
||||
for k in keys:
|
||||
p = data.get((entity, k))
|
||||
if p is not None:
|
||||
return (entity, k, p)
|
||||
return None
|
||||
|
||||
async def check_has_permission(self, entities: _EntityLike, key: str) -> bool:
|
||||
res = await self.check_has_permission_info(entities, key)
|
||||
if res is None:
|
||||
return False
|
||||
return res[2]
|
||||
|
||||
async def update_permission(self, entity: PermEntity, key: str, perm: bool | None):
|
||||
async with self.db.get_conn() as conn:
|
||||
repo = PermRepo(conn)
|
||||
await repo.update_perm_info(entity, key, perm)
|
||||
|
||||
async def list_permission(self, entities: _EntityLike, query: PagerQuery):
|
||||
entities = await _to_entity_chain(entities)
|
||||
async with self.db.get_conn() as conn:
|
||||
repo = PermRepo(conn)
|
||||
return await repo.list_perm_info_batch(entities, query)
|
||||
|
||||
|
||||
def perm_manager(_db: DatabaseManager | None = None) -> PermManager: # pragma: no cover
|
||||
if _db is None:
|
||||
_db = db
|
||||
return PermManager(_db)
|
||||
|
||||
|
||||
@after_init
|
||||
def create_startup(): # pragma: no cover
|
||||
from konabot.common.nb.is_admin import cfg
|
||||
|
||||
driver = nonebot.get_driver()
|
||||
|
||||
@driver.on_startup
|
||||
async def _():
|
||||
async with db.get_conn() as conn:
|
||||
await execute_migration(conn)
|
||||
pm = perm_manager(db)
|
||||
for account in cfg.admin_qq_account:
|
||||
# ^ 这里的是超级管理员!!用环境变量定义的。
|
||||
# 咕嘿嘿嘿!!!夺取全部权限!!!
|
||||
await pm.update_permission(
|
||||
PermEntity("ob11", "user", str(account)), "*", True
|
||||
)
|
||||
for key in _default_allow_permissions:
|
||||
await pm.update_permission(
|
||||
PermEntity("sys", "global", "global"), key, True
|
||||
)
|
||||
|
||||
@driver.on_shutdown
|
||||
async def _():
|
||||
try:
|
||||
await db.close_all_connections()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
DepPermManager = Annotated[PermManager, Depends(perm_manager)]
|
||||
|
||||
|
||||
def register_default_allow_permission(key: str):
|
||||
_default_allow_permissions.add(key)
|
||||
|
||||
|
||||
def require_permission(perm: str) -> Rule: # pragma: no cover
|
||||
async def check_permission(event: Event, pm: DepPermManager) -> bool:
|
||||
return await pm.check_has_permission(event, perm)
|
||||
|
||||
return Rule(check_permission)
|
||||
69
konabot/common/permsys/entity.py
Normal file
69
konabot/common/permsys/entity.py
Normal file
@ -0,0 +1,69 @@
|
||||
from dataclasses import dataclass
|
||||
from nonebot.internal.adapter import Event
|
||||
|
||||
from nonebot.adapters.onebot.v11 import Event as OB11Event
|
||||
from nonebot.adapters.onebot.v11.event import GroupMessageEvent as OB11GroupEvent
|
||||
from nonebot.adapters.onebot.v11.event import PrivateMessageEvent as OB11PrivateEvent
|
||||
|
||||
from nonebot.adapters.discord.event import Event as DiscordEvent
|
||||
from nonebot.adapters.discord.event import GuildMessageCreateEvent as DiscordGMEvent
|
||||
from nonebot.adapters.discord.event import DirectMessageCreateEvent as DiscordDMEvent
|
||||
|
||||
from nonebot.adapters.minecraft.event import MessageEvent as MinecraftMessageEvent
|
||||
|
||||
from nonebot.adapters.console.event import MessageEvent as ConsoleEvent
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class PermEntity:
|
||||
platform: str
|
||||
entity_type: str
|
||||
external_id: str
|
||||
|
||||
|
||||
def get_entity_chain_of_entity(entity: PermEntity) -> list[PermEntity]:
|
||||
return [
|
||||
PermEntity("sys", "global", "global"),
|
||||
PermEntity(entity.platform, "global", "global"),
|
||||
entity,
|
||||
][::-1]
|
||||
|
||||
|
||||
async def get_entity_chain(event: Event) -> list[PermEntity]: # pragma: no cover
|
||||
entities = [PermEntity("sys", "global", "global")]
|
||||
|
||||
if isinstance(event, OB11Event):
|
||||
entities.append(PermEntity("ob11", "global", "global"))
|
||||
|
||||
if isinstance(event, OB11GroupEvent):
|
||||
entities.append(PermEntity("ob11", "group", str(event.group_id)))
|
||||
entities.append(PermEntity("ob11", "user", str(event.user_id)))
|
||||
|
||||
if isinstance(event, OB11PrivateEvent):
|
||||
entities.append(PermEntity("ob11", "user", str(event.user_id)))
|
||||
|
||||
if isinstance(event, DiscordEvent):
|
||||
entities.append(PermEntity("discord", "global", "global"))
|
||||
|
||||
if isinstance(event, DiscordGMEvent):
|
||||
entities.append(PermEntity("discord", "guild", str(event.guild_id)))
|
||||
entities.append(PermEntity("discord", "channel", str(event.channel_id)))
|
||||
entities.append(PermEntity("discord", "user", str(event.user_id)))
|
||||
|
||||
if isinstance(event, DiscordDMEvent):
|
||||
entities.append(PermEntity("discord", "channel", str(event.channel_id)))
|
||||
entities.append(PermEntity("discord", "user", str(event.user_id)))
|
||||
|
||||
if isinstance(event, MinecraftMessageEvent):
|
||||
entities.append(PermEntity("minecraft", "global", "global"))
|
||||
entities.append(PermEntity("minecraft", "server", event.server_name))
|
||||
player_uuid = event.player.uuid
|
||||
if player_uuid is not None:
|
||||
entities.append(PermEntity("minecraft", "player", player_uuid.hex))
|
||||
|
||||
if isinstance(event, ConsoleEvent):
|
||||
entities.append(PermEntity("console", "global", "global"))
|
||||
entities.append(PermEntity("console", "channel", event.channel.id))
|
||||
entities.append(PermEntity("console", "user", event.user.id))
|
||||
|
||||
return entities[::-1]
|
||||
81
konabot/common/permsys/migrates/__init__.py
Normal file
81
konabot/common/permsys/migrates/__init__.py
Normal file
@ -0,0 +1,81 @@
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
import aiosqlite
|
||||
from loguru import logger
|
||||
|
||||
from konabot.common.database import DatabaseManager
|
||||
from konabot.common.path import DATA_PATH
|
||||
|
||||
|
||||
PATH_THISFOLDER = Path(__file__).parent
|
||||
|
||||
SQL_CHECK_EXISTS = (PATH_THISFOLDER / "./check_migrate_version_exists.sql").read_text()
|
||||
SQL_CREATE_TABLE = (PATH_THISFOLDER / "./create_migrate_version_table.sql").read_text()
|
||||
SQL_GET_MIGRATE_VERSION = (PATH_THISFOLDER / "get_migrate_version.sql").read_text()
|
||||
SQL_UPDATE_VERSION = (PATH_THISFOLDER / "./update_migrate_version.sql").read_text()
|
||||
|
||||
db = DatabaseManager(DATA_PATH / "perm.sqlite3")
|
||||
|
||||
|
||||
@dataclass
|
||||
class Migration:
|
||||
upgrade: str | Path
|
||||
downgrade: str | Path
|
||||
|
||||
def get_upgrade_script(self) -> str:
|
||||
if isinstance(self.upgrade, Path):
|
||||
return self.upgrade.read_text()
|
||||
return self.upgrade
|
||||
|
||||
def get_downgrade_script(self) -> str:
|
||||
if isinstance(self.downgrade, Path):
|
||||
return self.downgrade.read_text()
|
||||
return self.downgrade
|
||||
|
||||
|
||||
migrations = [
|
||||
Migration(
|
||||
PATH_THISFOLDER / "./mu1_create_permsys_table.sql",
|
||||
PATH_THISFOLDER / "./md1_remove_permsys_table.sql",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
TARGET_VERSION = len(migrations)
|
||||
|
||||
|
||||
async def get_current_version(conn: aiosqlite.Connection) -> int:
|
||||
cursor = await conn.execute(SQL_CHECK_EXISTS)
|
||||
count = await cursor.fetchone()
|
||||
assert count is not None
|
||||
if count[0] < 1:
|
||||
logger.info("权限系统数据表不存在,现在创建表")
|
||||
await conn.executescript(SQL_CREATE_TABLE)
|
||||
await conn.commit()
|
||||
return 0
|
||||
cursor = await conn.execute(SQL_GET_MIGRATE_VERSION)
|
||||
row = await cursor.fetchone()
|
||||
if row is None:
|
||||
return 0
|
||||
return row[0]
|
||||
|
||||
|
||||
async def execute_migration(
|
||||
conn: aiosqlite.Connection,
|
||||
version: int = TARGET_VERSION,
|
||||
migrations: list[Migration] = migrations,
|
||||
):
|
||||
now_version = await get_current_version(conn)
|
||||
while now_version < version:
|
||||
migration = migrations[now_version]
|
||||
await conn.executescript(migration.get_upgrade_script())
|
||||
now_version += 1
|
||||
await conn.execute(SQL_UPDATE_VERSION, (now_version,))
|
||||
await conn.commit()
|
||||
while now_version > version:
|
||||
migration = migrations[now_version - 1]
|
||||
await conn.executescript(migration.get_downgrade_script())
|
||||
now_version -= 1
|
||||
await conn.execute(SQL_UPDATE_VERSION, (now_version,))
|
||||
await conn.commit()
|
||||
@ -0,0 +1,7 @@
|
||||
SELECT
|
||||
COUNT(*)
|
||||
FROM
|
||||
sqlite_master
|
||||
WHERE
|
||||
type = 'table'
|
||||
AND name = 'migrate_version'
|
||||
@ -0,0 +1,3 @@
|
||||
CREATE TABLE migrate_version(version INT PRIMARY KEY);
|
||||
INSERT INTO migrate_version(version)
|
||||
VALUES(0);
|
||||
4
konabot/common/permsys/migrates/get_migrate_version.sql
Normal file
4
konabot/common/permsys/migrates/get_migrate_version.sql
Normal file
@ -0,0 +1,4 @@
|
||||
SELECT
|
||||
version
|
||||
FROM
|
||||
migrate_version;
|
||||
@ -0,0 +1,2 @@
|
||||
DROP TABLE IF EXISTS perm_entity;
|
||||
DROP TABLE IF EXISTS perm_info;
|
||||
30
konabot/common/permsys/migrates/mu1_create_permsys_table.sql
Normal file
30
konabot/common/permsys/migrates/mu1_create_permsys_table.sql
Normal file
@ -0,0 +1,30 @@
|
||||
CREATE TABLE perm_entity(
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
platform TEXT NOT NULL,
|
||||
entity_type TEXT NOT NULL,
|
||||
external_id TEXT NOT NULL,
|
||||
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP
|
||||
);
|
||||
|
||||
CREATE UNIQUE INDEX idx_perm_entity_lookup
|
||||
ON perm_entity(platform, entity_type, external_id);
|
||||
|
||||
CREATE TABLE perm_info(
|
||||
entity_id INTEGER NOT NULL,
|
||||
config_key TEXT NOT NULL,
|
||||
value BOOLEAN,
|
||||
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||
-- 联合主键
|
||||
PRIMARY KEY (entity_id, config_key)
|
||||
);
|
||||
|
||||
CREATE TRIGGER perm_entity_update AFTER UPDATE
|
||||
ON perm_entity BEGIN
|
||||
UPDATE perm_entity SET updated_at=CURRENT_TIMESTAMP WHERE id=old.id;
|
||||
END;
|
||||
CREATE TRIGGER perm_info_update AFTER UPDATE
|
||||
ON perm_info BEGIN
|
||||
UPDATE perm_info SET updated_at=CURRENT_TIMESTAMP WHERE entity_id=old.entity_id;
|
||||
END;
|
||||
|
||||
@ -0,0 +1,2 @@
|
||||
UPDATE migrate_version
|
||||
SET version = ?;
|
||||
242
konabot/common/permsys/repo.py
Normal file
242
konabot/common/permsys/repo.py
Normal file
@ -0,0 +1,242 @@
|
||||
from dataclasses import dataclass
|
||||
import math
|
||||
from pathlib import Path
|
||||
|
||||
import aiosqlite
|
||||
|
||||
from konabot.common.pager import PagerQuery, PagerResult
|
||||
|
||||
from .entity import PermEntity
|
||||
|
||||
|
||||
def s(p: str):
|
||||
"""读取 SQL 文件内容。
|
||||
|
||||
Args:
|
||||
p: SQL 文件名(相对于当前文件所在目录的 sql/ 子目录)。
|
||||
|
||||
Returns:
|
||||
SQL 文件的内容字符串。
|
||||
"""
|
||||
return (Path(__file__).parent / "./sql/" / p).read_text()
|
||||
|
||||
|
||||
@dataclass
|
||||
class PermRepo:
|
||||
"""权限实体存储库,负责与数据库交互管理权限实体。
|
||||
|
||||
Attributes:
|
||||
conn: aiosqlite 数据库连接对象。
|
||||
"""
|
||||
|
||||
conn: aiosqlite.Connection
|
||||
|
||||
async def create_entity(self, entity: PermEntity) -> int:
|
||||
"""创建新的权限实体并返回其 ID。
|
||||
|
||||
Args:
|
||||
entity: 要创建的权限实体对象。
|
||||
|
||||
Returns:
|
||||
新创建实体的数据库 ID。
|
||||
|
||||
Raises:
|
||||
AssertionError: 如果创建后无法获取实体 ID。
|
||||
"""
|
||||
await self.conn.execute(
|
||||
s("create_entity.sql"),
|
||||
(entity.platform, entity.entity_type, entity.external_id),
|
||||
)
|
||||
await self.conn.commit()
|
||||
eid = await self._get_entity_id_or_none(entity)
|
||||
assert eid is not None
|
||||
return eid
|
||||
|
||||
async def _get_entity_id_or_none(self, entity: PermEntity) -> int | None:
|
||||
"""查询实体 ID,如果不存在则返回 None。
|
||||
|
||||
Args:
|
||||
entity: 要查询的权限实体对象。
|
||||
|
||||
Returns:
|
||||
实体 ID,如果不存在则返回 None。
|
||||
"""
|
||||
res = await self.conn.execute(
|
||||
s("get_entity_id.sql"),
|
||||
(entity.platform, entity.entity_type, entity.external_id),
|
||||
)
|
||||
row = await res.fetchone()
|
||||
if row is None:
|
||||
return None
|
||||
return row[0]
|
||||
|
||||
async def get_entity_id(self, entity: PermEntity) -> int:
|
||||
"""获取实体 ID,如果不存在则自动创建。
|
||||
|
||||
Args:
|
||||
entity: 权限实体对象。
|
||||
|
||||
Returns:
|
||||
实体的数据库 ID。
|
||||
"""
|
||||
eid = await self._get_entity_id_or_none(entity)
|
||||
if eid is None:
|
||||
return await self.create_entity(entity)
|
||||
return eid
|
||||
|
||||
async def get_perm_info(self, entity: PermEntity, config_key: str) -> bool | None:
|
||||
"""获取实体的权限配置信息。
|
||||
|
||||
Args:
|
||||
entity: 权限实体对象。
|
||||
config_key: 配置项的键名。
|
||||
|
||||
Returns:
|
||||
配置值(True/False),如果不存在则返回 None。
|
||||
"""
|
||||
eid = await self.get_entity_id(entity)
|
||||
res = await self.conn.execute(
|
||||
s("get_perm_info.sql"),
|
||||
(eid, config_key),
|
||||
)
|
||||
row = await res.fetchone()
|
||||
if row is None:
|
||||
return None
|
||||
return bool(row[0])
|
||||
|
||||
async def update_perm_info(
|
||||
self, entity: PermEntity, config_key: str, value: bool | None
|
||||
):
|
||||
"""更新实体的权限配置信息。
|
||||
|
||||
Args:
|
||||
entity: 权限实体对象。
|
||||
config_key: 配置项的键名。
|
||||
value: 要设置的配置值(True/False/None)。
|
||||
"""
|
||||
eid = await self.get_entity_id(entity)
|
||||
await self.conn.execute(s("update_perm_info.sql"), (eid, config_key, value))
|
||||
await self.conn.commit()
|
||||
|
||||
async def get_entity_id_batch(
|
||||
self, entities: list[PermEntity]
|
||||
) -> dict[PermEntity, int]:
|
||||
"""批量获取 Entity 的 entity_id
|
||||
|
||||
Args:
|
||||
entities: PermEntity 列表
|
||||
|
||||
Returns:
|
||||
字典,键为 PermEntity,值为对应的 ID
|
||||
"""
|
||||
|
||||
# for entity in entities:
|
||||
# await self.conn.execute(
|
||||
# s("create_entity.sql"),
|
||||
# (entity.platform, entity.entity_type, entity.external_id),
|
||||
# )
|
||||
await self.conn.executemany(
|
||||
s("create_entity.sql"),
|
||||
[(e.platform, e.entity_type, e.external_id) for e in entities],
|
||||
)
|
||||
await self.conn.commit()
|
||||
val_placeholders = ", ".join(["(?, ?, ?)"] * len(entities))
|
||||
params = []
|
||||
for e in entities:
|
||||
params.extend([e.platform, e.entity_type, e.external_id])
|
||||
cursor = await self.conn.execute(
|
||||
f"""
|
||||
SELECT id, platform, entity_type, external_id
|
||||
FROM perm_entity
|
||||
WHERE (platform, entity_type, external_id) IN (VALUES {val_placeholders});
|
||||
""",
|
||||
params,
|
||||
)
|
||||
rows = await cursor.fetchall()
|
||||
return {PermEntity(row[1], row[2], row[3]): row[0] for row in rows}
|
||||
|
||||
async def get_perm_info_batch(
|
||||
self, entities: list[PermEntity], config_keys: list[str]
|
||||
) -> dict[tuple[PermEntity, str], bool]:
|
||||
"""批量获取权限信息
|
||||
|
||||
Args:
|
||||
entities: PermEntity 列表
|
||||
config_keys: 查询的键列表
|
||||
|
||||
Returns:
|
||||
字典,键是 PermEntity 和 config_key 的元组,值是布尔,过滤掉所有空值
|
||||
"""
|
||||
entity_ids = {
|
||||
v: k for k, v in (await self.get_entity_id_batch(entities)).items()
|
||||
}
|
||||
placeholders1 = ", ".join("?" * len(entity_ids))
|
||||
placeholders2 = ", ".join("?" * len(config_keys))
|
||||
sql = f"""
|
||||
SELECT entity_id, config_key, value
|
||||
FROM perm_info
|
||||
WHERE entity_id IN ({placeholders1})
|
||||
AND config_key IN ({placeholders2})
|
||||
AND value IS NOT NULL;
|
||||
"""
|
||||
|
||||
params = tuple(entity_ids.keys()) + tuple(config_keys)
|
||||
cursor = await self.conn.execute(sql, params)
|
||||
rows = await cursor.fetchall()
|
||||
|
||||
return {(entity_ids[row[0]], row[1]): bool(row[2]) for row in rows}
|
||||
|
||||
async def list_perm_info_batch(
|
||||
self, entities: list[PermEntity], pager: PagerQuery
|
||||
) -> PagerResult[tuple[PermEntity, str, bool]]:
|
||||
"""批量获取某个实体的权限信息
|
||||
|
||||
Args:
|
||||
entities: PermEntity 列表
|
||||
pager: PagerQuery 对象,即分页要求
|
||||
|
||||
Returns:
|
||||
字典,键是 PermEntity,值是权限条目和布尔的元组,过滤掉所有空值
|
||||
"""
|
||||
entity_to_id = await self.get_entity_id_batch(entities)
|
||||
id_to_entity = {v: k for k, v in entity_to_id.items()}
|
||||
ordered_ids = [entity_to_id[e] for e in entities if e in entity_to_id]
|
||||
|
||||
placeholders = ", ".join("?" * len(ordered_ids))
|
||||
order_by_cases = " ".join([f"WHEN ? THEN {i}" for i in range(len(ordered_ids))])
|
||||
|
||||
pagecount_sql = f"SELECT COUNT(*) FROM perm_info WHERE entity_id IN ({placeholders}) AND value IS NOT NULL;"
|
||||
count_cursor = await self.conn.execute(pagecount_sql, tuple(ordered_ids))
|
||||
total_count = (await count_cursor.fetchone() or (0,))[0]
|
||||
|
||||
sql = f"""
|
||||
SELECT entity_id, config_key, value
|
||||
FROM perm_info
|
||||
WHERE entity_id IN ({placeholders})
|
||||
AND value IS NOT NULL
|
||||
ORDER BY
|
||||
(CASE entity_id {order_by_cases} END) ASC,
|
||||
config_key ASC
|
||||
LIMIT ?
|
||||
OFFSET ?;
|
||||
"""
|
||||
|
||||
params = (
|
||||
tuple(ordered_ids)
|
||||
+ tuple(ordered_ids)
|
||||
+ (
|
||||
pager.page_size,
|
||||
(pager.page_index - 1) * pager.page_size,
|
||||
)
|
||||
)
|
||||
cursor = await self.conn.execute(sql, params)
|
||||
rows = await cursor.fetchall()
|
||||
|
||||
# return {entity_ids[row[0]]: (row[1], bool(row[2])) for row in rows}
|
||||
return PagerResult(
|
||||
data=[(id_to_entity[row[0]], row[1], row[2]) for row in rows],
|
||||
success=True,
|
||||
message="",
|
||||
page_count=math.ceil(total_count / pager.page_size),
|
||||
query=pager,
|
||||
)
|
||||
11
konabot/common/permsys/sql/create_entity.sql
Normal file
11
konabot/common/permsys/sql/create_entity.sql
Normal file
@ -0,0 +1,11 @@
|
||||
INSERT
|
||||
OR IGNORE INTO perm_entity(
|
||||
platform,
|
||||
entity_type,
|
||||
external_id
|
||||
)
|
||||
VALUES(
|
||||
?,
|
||||
?,
|
||||
?
|
||||
);
|
||||
8
konabot/common/permsys/sql/get_entity_id.sql
Normal file
8
konabot/common/permsys/sql/get_entity_id.sql
Normal file
@ -0,0 +1,8 @@
|
||||
SELECT
|
||||
id
|
||||
FROM
|
||||
perm_entity
|
||||
WHERE
|
||||
perm_entity.platform = ?
|
||||
AND perm_entity.entity_type = ?
|
||||
AND perm_entity.external_id = ?;
|
||||
7
konabot/common/permsys/sql/get_perm_info.sql
Normal file
7
konabot/common/permsys/sql/get_perm_info.sql
Normal file
@ -0,0 +1,7 @@
|
||||
SELECT
|
||||
VALUE
|
||||
FROM
|
||||
perm_info
|
||||
WHERE
|
||||
entity_id = ?
|
||||
AND config_key = ?;
|
||||
4
konabot/common/permsys/sql/update_perm_info.sql
Normal file
4
konabot/common/permsys/sql/update_perm_info.sql
Normal file
@ -0,0 +1,4 @@
|
||||
INSERT INTO perm_info (entity_id, config_key, value)
|
||||
VALUES (?, ?, ?)
|
||||
ON CONFLICT(entity_id, config_key)
|
||||
DO UPDATE SET value=excluded.value;
|
||||
3
konabot/common/ptimeparse/README.md
Normal file
3
konabot/common/ptimeparse/README.md
Normal file
@ -0,0 +1,3 @@
|
||||
# 已废弃
|
||||
|
||||
坏枪用简单的 LLM + 提示词工程,完成了这 200 块的 `qwen3-coder-plus` 都搞不定的 nb 功能
|
||||
@ -1,653 +1,58 @@
|
||||
import re
|
||||
import datetime
|
||||
from typing import Tuple, Optional, Dict, Any
|
||||
"""
|
||||
Professional time parsing module for Chinese and English time expressions.
|
||||
|
||||
from .err import MultipleSpecificationException, TokenUnhandledException
|
||||
This module provides a robust parser for natural language time expressions,
|
||||
supporting both Chinese and English formats with proper whitespace handling.
|
||||
"""
|
||||
|
||||
import datetime
|
||||
from typing import Optional
|
||||
|
||||
from .expression import TimeExpression
|
||||
|
||||
|
||||
def parse(text: str, now: Optional[datetime.datetime] = None) -> datetime.datetime:
|
||||
"""
|
||||
Parse a time expression and return a datetime object.
|
||||
|
||||
Args:
|
||||
text: The time expression to parse
|
||||
now: The reference time (defaults to current time)
|
||||
|
||||
Returns:
|
||||
A datetime object representing the parsed time
|
||||
|
||||
Raises:
|
||||
TokenUnhandledException: If the input cannot be parsed
|
||||
"""
|
||||
return TimeExpression.parse(text, now)
|
||||
|
||||
|
||||
class Parser:
|
||||
"""
|
||||
Parser for time expressions with backward compatibility.
|
||||
|
||||
Maintains the original interface:
|
||||
>>> parser = Parser()
|
||||
>>> result = parser.parse("10分钟后")
|
||||
"""
|
||||
|
||||
def __init__(self, now: Optional[datetime.datetime] = None):
|
||||
self.now = now or datetime.datetime.now()
|
||||
|
||||
def digest_chinese_number(self, text: str) -> Tuple[str, int]:
|
||||
if not text:
|
||||
return text, 0
|
||||
# Handle "两" at start
|
||||
if text.startswith("两"):
|
||||
next_char = text[1] if len(text) > 1 else ''
|
||||
if not next_char or next_char in "十百千万亿":
|
||||
return text[1:], 2
|
||||
s = "零一二三四五六七八九"
|
||||
digits = {c: i for i, c in enumerate(s)}
|
||||
i = 0
|
||||
while i < len(text) and text[i] in s + "十百千万亿":
|
||||
i += 1
|
||||
if i == 0:
|
||||
return text, 0
|
||||
num_str = text[:i]
|
||||
rest = text[i:]
|
||||
|
||||
def parse(s):
|
||||
if not s:
|
||||
return 0
|
||||
if s == "零":
|
||||
return 0
|
||||
if "亿" in s:
|
||||
a, b = s.split("亿", 1)
|
||||
return parse(a) * 100000000 + parse(b)
|
||||
if "万" in s:
|
||||
a, b = s.split("万", 1)
|
||||
return parse(a) * 10000 + parse(b)
|
||||
n = 0
|
||||
t = 0
|
||||
for c in s:
|
||||
if c == "零":
|
||||
continue
|
||||
if c in digits:
|
||||
t = digits[c]
|
||||
elif c == "十":
|
||||
if t == 0:
|
||||
t = 1
|
||||
n += t * 10
|
||||
t = 0
|
||||
elif c == "百":
|
||||
if t == 0:
|
||||
t = 1
|
||||
n += t * 100
|
||||
t = 0
|
||||
elif c == "千":
|
||||
if t == 0:
|
||||
t = 1
|
||||
n += t * 1000
|
||||
t = 0
|
||||
n += t
|
||||
return n
|
||||
|
||||
return rest, parse(num_str)
|
||||
|
||||
|
||||
def parse(self, text: str) -> datetime.datetime:
|
||||
text = text.strip()
|
||||
if not text:
|
||||
raise TokenUnhandledException("Empty input")
|
||||
|
||||
ctx = {
|
||||
"date": None,
|
||||
"time": None,
|
||||
"relative_delta": None,
|
||||
"am_pm": None,
|
||||
"period_word": None,
|
||||
"has_time": False,
|
||||
"has_date": False,
|
||||
"ambiguous_hour": False,
|
||||
"is_24hour": False,
|
||||
"has_relative_date": False,
|
||||
}
|
||||
|
||||
rest = self._parse_all(text, ctx)
|
||||
if rest.strip():
|
||||
raise TokenUnhandledException(f"Unparsed tokens: {rest.strip()}")
|
||||
|
||||
return self._apply_context(ctx)
|
||||
|
||||
def _parse_all(self, text: str, ctx: Dict[str, Any]) -> str:
|
||||
rest = text.lstrip()
|
||||
while True:
|
||||
for parser in [
|
||||
self._parse_absolute_date,
|
||||
self._parse_relative_date,
|
||||
self._parse_relative_time,
|
||||
self._parse_period,
|
||||
self._parse_time,
|
||||
]:
|
||||
new_rest = parser(rest, ctx)
|
||||
if new_rest != rest:
|
||||
rest = new_rest.lstrip()
|
||||
break
|
||||
else:
|
||||
break
|
||||
return rest
|
||||
|
||||
def _add_delta(self, ctx, delta):
|
||||
if ctx["relative_delta"] is None:
|
||||
ctx["relative_delta"] = delta
|
||||
else:
|
||||
ctx["relative_delta"] += delta
|
||||
|
||||
def _parse_absolute_date(self, text: str, ctx: Dict[str, Any]) -> str:
|
||||
text = text.lstrip()
|
||||
m = re.match(r"^(\d{4})-(\d{1,2})-(\d{1,2})T(\d{1,2}):(\d{2})", text)
|
||||
if m:
|
||||
y, mth, d, h, minute = map(int, m.groups())
|
||||
ctx["date"] = datetime.date(y, mth, d)
|
||||
ctx["time"] = datetime.time(h, minute)
|
||||
ctx["has_date"] = True
|
||||
ctx["has_time"] = True
|
||||
ctx["is_24hour"] = True
|
||||
return text[m.end():]
|
||||
m = re.match(r"^(\d{4})-(\d{1,2})-(\d{1,2})", text)
|
||||
if m:
|
||||
y, mth, d = map(int, m.groups())
|
||||
ctx["date"] = datetime.date(y, mth, d)
|
||||
ctx["has_date"] = True
|
||||
return text[m.end():]
|
||||
m = re.match(r"^(\d{4})/(\d{1,2})/(\d{1,2})", text)
|
||||
if m:
|
||||
y, mth, d = map(int, m.groups())
|
||||
ctx["date"] = datetime.date(y, mth, d)
|
||||
ctx["has_date"] = True
|
||||
return text[m.end():]
|
||||
m = re.match(r"^(\d{4})年(\d{1,2})月(\d{1,2})[日号]", text)
|
||||
if m:
|
||||
y, mth, d = map(int, m.groups())
|
||||
ctx["date"] = datetime.date(y, mth, d)
|
||||
ctx["has_date"] = True
|
||||
return text[m.end():]
|
||||
m = re.match(r"^(\d{1,2})月(\d{1,2})[日号]", text)
|
||||
if m:
|
||||
mth, d = map(int, m.groups())
|
||||
ctx["date"] = datetime.date(self.now.year, mth, d)
|
||||
ctx["has_date"] = True
|
||||
return text[m.end():]
|
||||
m = re.match(r"^(.{1,3})月(.{1,3})[日号]", text)
|
||||
if m:
|
||||
m_str, d_str = m.groups()
|
||||
_, mth = self.digest_chinese_number(m_str)
|
||||
_, d = self.digest_chinese_number(d_str)
|
||||
if mth == 0:
|
||||
mth = 1
|
||||
if d == 0:
|
||||
d = 1
|
||||
ctx["date"] = datetime.date(self.now.year, mth, d)
|
||||
ctx["has_date"] = True
|
||||
return text[m.end():]
|
||||
return text
|
||||
|
||||
def _parse_relative_date(self, text: str, ctx: Dict[str, Any]) -> str:
|
||||
text = text.lstrip()
|
||||
"""
|
||||
Parse a time expression and return a datetime object.
|
||||
This maintains backward compatibility with the original interface.
|
||||
|
||||
# Handle "今天", "今晚", "今早", etc.
|
||||
today_variants = [
|
||||
("今晚上", "PM"),
|
||||
("今晚", "PM"),
|
||||
("今早", "AM"),
|
||||
("今天早上", "AM"),
|
||||
("今天早晨", "AM"),
|
||||
("今天上午", "AM"),
|
||||
("今天下午", "PM"),
|
||||
("今天晚上", "PM"),
|
||||
("今天", None),
|
||||
]
|
||||
for variant, period in today_variants:
|
||||
if text.startswith(variant):
|
||||
self._add_delta(ctx, datetime.timedelta(days=0))
|
||||
ctx["has_relative_date"] = True
|
||||
rest = text[len(variant):]
|
||||
if period is not None and ctx["am_pm"] is None:
|
||||
ctx["am_pm"] = period
|
||||
ctx["period_word"] = variant
|
||||
return rest
|
||||
Args:
|
||||
text: The time expression to parse
|
||||
|
||||
Returns:
|
||||
A datetime object representing the parsed time
|
||||
|
||||
Raises:
|
||||
TokenUnhandledException: If the input cannot be parsed
|
||||
"""
|
||||
return TimeExpression.parse(text, self.now)
|
||||
|
||||
mapping = {
|
||||
"明天": 1,
|
||||
"后天": 2,
|
||||
"大后天": 3,
|
||||
"昨天": -1,
|
||||
"前天": -2,
|
||||
"大前天": -3,
|
||||
}
|
||||
for word, days in mapping.items():
|
||||
if text.startswith(word):
|
||||
self._add_delta(ctx, datetime.timedelta(days=days))
|
||||
ctx["has_relative_date"] = True
|
||||
return text[len(word):]
|
||||
m = re.match(r"^(\d+|[零一二三四五六七八九十两]+)天(后|前|以后|之后)", text)
|
||||
if m:
|
||||
num_str, direction = m.groups()
|
||||
if num_str.isdigit():
|
||||
n = int(num_str)
|
||||
else:
|
||||
_, n = self.digest_chinese_number(num_str)
|
||||
days = n if direction in ("后", "以后", "之后") else -n
|
||||
self._add_delta(ctx, datetime.timedelta(days=days))
|
||||
ctx["has_relative_date"] = True
|
||||
return text[m.end():]
|
||||
m = re.match(r"^(本|上|下)周([一二三四五六日])", text)
|
||||
if m:
|
||||
scope, day = m.groups()
|
||||
weekday_map = {"一": 0, "二": 1, "三": 2, "四": 3, "五": 4, "六": 5, "日": 6}
|
||||
target = weekday_map[day]
|
||||
current = self.now.weekday()
|
||||
if scope == "本":
|
||||
delta = target - current
|
||||
elif scope == "上":
|
||||
delta = target - current - 7
|
||||
else:
|
||||
delta = target - current + 7
|
||||
self._add_delta(ctx, datetime.timedelta(days=delta))
|
||||
ctx["has_relative_date"] = True
|
||||
return text[m.end():]
|
||||
return text
|
||||
|
||||
def _parse_period(self, text: str, ctx: Dict[str, Any]) -> str:
|
||||
text = text.lstrip()
|
||||
period_mapping = {
|
||||
"上午": "AM",
|
||||
"早晨": "AM",
|
||||
"早上": "AM",
|
||||
"早": "AM",
|
||||
"中午": "PM",
|
||||
"下午": "PM",
|
||||
"晚上": "PM",
|
||||
"晚": "PM",
|
||||
"凌晨": "AM",
|
||||
}
|
||||
for word, tag in period_mapping.items():
|
||||
if text.startswith(word):
|
||||
if ctx["am_pm"] is not None:
|
||||
raise MultipleSpecificationException("Multiple periods")
|
||||
ctx["am_pm"] = tag
|
||||
ctx["period_word"] = word
|
||||
return text[len(word):]
|
||||
return text
|
||||
|
||||
def _parse_time(self, text: str, ctx: Dict[str, Any]) -> str:
|
||||
if ctx["has_time"]:
|
||||
return text
|
||||
text = text.lstrip()
|
||||
|
||||
# 1. H:MM pattern
|
||||
m = re.match(r"^(\d{1,2}):(\d{2})", text)
|
||||
if m:
|
||||
h, minute = int(m.group(1)), int(m.group(2))
|
||||
if 0 <= h <= 23 and 0 <= minute <= 59:
|
||||
ctx["time"] = datetime.time(h, minute)
|
||||
ctx["has_time"] = True
|
||||
ctx["ambiguous_hour"] = 1 <= h <= 12
|
||||
ctx["is_24hour"] = h > 12 or h == 0
|
||||
return text[m.end():]
|
||||
|
||||
# 2. Parse hour part
|
||||
hour = None
|
||||
rest_after_hour = text
|
||||
is_24hour_format = False
|
||||
|
||||
# Try Chinese number + 点/时
|
||||
temp_rest, num = self.digest_chinese_number(text)
|
||||
if num >= 0:
|
||||
temp_rest_stripped = temp_rest.lstrip()
|
||||
if temp_rest_stripped.startswith("点"):
|
||||
hour = num
|
||||
is_24hour_format = False
|
||||
rest_after_hour = temp_rest_stripped[1:]
|
||||
elif temp_rest_stripped.startswith("时"):
|
||||
hour = num
|
||||
is_24hour_format = True
|
||||
rest_after_hour = temp_rest_stripped[1:]
|
||||
|
||||
if hour is None:
|
||||
m = re.match(r"^(\d{1,2})\s*([点时])", text)
|
||||
if m:
|
||||
hour = int(m.group(1))
|
||||
is_24hour_format = m.group(2) == "时"
|
||||
rest_after_hour = text[m.end():]
|
||||
|
||||
if hour is None:
|
||||
if ctx.get("am_pm") is not None:
|
||||
temp_rest, num = self.digest_chinese_number(text)
|
||||
if 0 <= num <= 23:
|
||||
hour = num
|
||||
is_24hour_format = False
|
||||
rest_after_hour = temp_rest.lstrip()
|
||||
else:
|
||||
m = re.match(r"^(\d{1,2})", text)
|
||||
if m:
|
||||
h_val = int(m.group(1))
|
||||
if 0 <= h_val <= 23:
|
||||
hour = h_val
|
||||
is_24hour_format = False
|
||||
rest_after_hour = text[m.end():].lstrip()
|
||||
|
||||
if hour is None:
|
||||
return text
|
||||
|
||||
if not (0 <= hour <= 23):
|
||||
return text
|
||||
|
||||
# Parse minutes
|
||||
rest = rest_after_hour.lstrip()
|
||||
minute = 0
|
||||
minute_spec_count = 0
|
||||
|
||||
if rest.startswith("钟"):
|
||||
rest = rest[1:].lstrip()
|
||||
|
||||
has_zheng = False
|
||||
if rest.startswith("整"):
|
||||
has_zheng = True
|
||||
rest = rest[1:].lstrip()
|
||||
|
||||
if rest.startswith("半"):
|
||||
minute = 30
|
||||
minute_spec_count += 1
|
||||
rest = rest[1:].lstrip()
|
||||
if rest.startswith("钟"):
|
||||
rest = rest[1:].lstrip()
|
||||
if rest.startswith("整"):
|
||||
rest = rest[1:].lstrip()
|
||||
|
||||
if rest.startswith("一刻"):
|
||||
minute = 15
|
||||
minute_spec_count += 1
|
||||
rest = rest[2:].lstrip()
|
||||
if rest.startswith("钟"):
|
||||
rest = rest[1:].lstrip()
|
||||
|
||||
if rest.startswith("过一刻"):
|
||||
minute = 15
|
||||
minute_spec_count += 1
|
||||
rest = rest[3:].lstrip()
|
||||
if rest.startswith("钟"):
|
||||
rest = rest[1:].lstrip()
|
||||
|
||||
m = re.match(r"^(\d+|[零一二三四五六七八九十]+)分", rest)
|
||||
if m:
|
||||
minute_spec_count += 1
|
||||
m_str = m.group(1)
|
||||
if m_str.isdigit():
|
||||
minute = int(m_str)
|
||||
else:
|
||||
_, minute = self.digest_chinese_number(m_str)
|
||||
rest = rest[m.end():].lstrip()
|
||||
|
||||
if minute_spec_count == 0:
|
||||
temp_rest, num = self.digest_chinese_number(rest)
|
||||
if num > 0 and num <= 59:
|
||||
minute = num
|
||||
minute_spec_count += 1
|
||||
rest = temp_rest.lstrip()
|
||||
else:
|
||||
m = re.match(r"^(\d{1,2})", rest)
|
||||
if m:
|
||||
m_val = int(m.group(1))
|
||||
if 0 <= m_val <= 59:
|
||||
minute = m_val
|
||||
minute_spec_count += 1
|
||||
rest = rest[m.end():].lstrip()
|
||||
|
||||
if has_zheng and minute_spec_count == 0:
|
||||
minute_spec_count = 1
|
||||
|
||||
if minute_spec_count > 1:
|
||||
raise MultipleSpecificationException("Multiple minute specifications")
|
||||
|
||||
if not (0 <= minute <= 59):
|
||||
return text
|
||||
|
||||
# Hours 13-23 are always 24-hour, even with "点"
|
||||
if hour >= 13:
|
||||
is_24hour_format = True
|
||||
|
||||
ctx["time"] = datetime.time(hour, minute)
|
||||
ctx["has_time"] = True
|
||||
ctx["ambiguous_hour"] = 1 <= hour <= 12 and not is_24hour_format
|
||||
ctx["is_24hour"] = is_24hour_format
|
||||
|
||||
return rest
|
||||
|
||||
def _parse_relative_time(self, text: str, ctx: Dict[str, Any]) -> str:
|
||||
text = text.lstrip()
|
||||
|
||||
# 半小时
|
||||
m = re.match(r"^(半)(?:个)?小时?(后|前|以后|之后)", text)
|
||||
if m:
|
||||
direction = m.group(2)
|
||||
hours = 0.5
|
||||
delta = datetime.timedelta(
|
||||
hours=hours if direction in ("后", "以后", "之后") else -hours
|
||||
)
|
||||
self._add_delta(ctx, delta)
|
||||
return text[m.end():]
|
||||
|
||||
# X个半
|
||||
m = re.match(r"^([0-9零一二三四五六七八九十两]+)个半(?:小时?)?(后|前|以后|之后)", text)
|
||||
if m:
|
||||
num_str, direction = m.groups()
|
||||
if num_str.isdigit():
|
||||
base_hours = int(num_str)
|
||||
else:
|
||||
_, base_hours = self.digest_chinese_number(num_str)
|
||||
if base_hours == 0 and num_str != "零":
|
||||
return text
|
||||
if base_hours <= 0:
|
||||
return text
|
||||
hours = base_hours + 0.5
|
||||
delta = datetime.timedelta(
|
||||
hours=hours if direction in ("后", "以后", "之后") else -hours
|
||||
)
|
||||
self._add_delta(ctx, delta)
|
||||
return text[m.end():]
|
||||
|
||||
# 一个半
|
||||
m = re.match(r"^(一个半)小时?(后|前|以后|之后)", text)
|
||||
if m:
|
||||
direction = m.group(2)
|
||||
hours = 1.5
|
||||
delta = datetime.timedelta(
|
||||
hours=hours if direction in ("后", "以后", "之后") else -hours
|
||||
)
|
||||
self._add_delta(ctx, delta)
|
||||
return text[m.end():]
|
||||
|
||||
# X小时
|
||||
m = re.match(r"^([0-9零一二三四五六七八九十两]+)(?:个)?小时?(后|前|以后|之后)", text)
|
||||
if m:
|
||||
num_str, direction = m.groups()
|
||||
if num_str.isdigit():
|
||||
hours = int(num_str)
|
||||
else:
|
||||
_, hours = self.digest_chinese_number(num_str)
|
||||
if hours == 0 and num_str != "零":
|
||||
return text
|
||||
if hours <= 0:
|
||||
return text
|
||||
delta = datetime.timedelta(
|
||||
hours=hours if direction in ("后", "以后", "之后") else -hours
|
||||
)
|
||||
self._add_delta(ctx, delta)
|
||||
return text[m.end():]
|
||||
|
||||
m = re.match(r"^([0-9零一二三四五六七八九十两]+)(?:个)?小时(后|前)", text)
|
||||
if m:
|
||||
num_str, direction = m.groups()
|
||||
if num_str.isdigit():
|
||||
hours = int(num_str)
|
||||
else:
|
||||
_, hours = self.digest_chinese_number(num_str)
|
||||
if hours == 0 and num_str != "零":
|
||||
return text
|
||||
if hours <= 0:
|
||||
return text
|
||||
delta = datetime.timedelta(
|
||||
hours=hours if direction == "后" else -hours
|
||||
)
|
||||
self._add_delta(ctx, delta)
|
||||
return text[m.end():]
|
||||
|
||||
# X分钟
|
||||
m = re.match(r"^([0-9零一二三四五六七八九十两]+)分钟?(后|前|以后|之后)", text)
|
||||
if m:
|
||||
num_str, direction = m.groups()
|
||||
if num_str.isdigit():
|
||||
minutes = int(num_str)
|
||||
else:
|
||||
_, minutes = self.digest_chinese_number(num_str)
|
||||
if minutes == 0 and num_str != "零":
|
||||
return text
|
||||
if minutes <= 0:
|
||||
return text
|
||||
delta = datetime.timedelta(
|
||||
minutes=minutes if direction in ("后", "以后", "之后") else -minutes
|
||||
)
|
||||
self._add_delta(ctx, delta)
|
||||
return text[m.end():]
|
||||
|
||||
m = re.match(r"^([0-9零一二三四五六七八九十两]+)分(后|前|以后|之后)", text)
|
||||
if m:
|
||||
num_str, direction = m.groups()
|
||||
if num_str.isdigit():
|
||||
minutes = int(num_str)
|
||||
else:
|
||||
_, minutes = self.digest_chinese_number(num_str)
|
||||
if minutes == 0 and num_str != "零":
|
||||
return text
|
||||
if minutes <= 0:
|
||||
return text
|
||||
delta = datetime.timedelta(
|
||||
minutes=minutes if direction in ("后", "以后", "之后") else -minutes
|
||||
)
|
||||
self._add_delta(ctx, delta)
|
||||
return text[m.end():]
|
||||
|
||||
m = re.match(r"^([0-9零一二三四五六七八九十两]+)分钟?(后|前)", text)
|
||||
if m:
|
||||
num_str, direction = m.groups()
|
||||
if num_str.isdigit():
|
||||
minutes = int(num_str)
|
||||
else:
|
||||
_, minutes = self.digest_chinese_number(num_str)
|
||||
if minutes == 0 and num_str != "零":
|
||||
return text
|
||||
if minutes <= 0:
|
||||
return text
|
||||
delta = datetime.timedelta(
|
||||
minutes=minutes if direction == "后" else -minutes
|
||||
)
|
||||
self._add_delta(ctx, delta)
|
||||
return text[m.end():]
|
||||
|
||||
m = re.match(r"^([0-9零一二三四五六七八九十两]+)分(后|前)", text)
|
||||
if m:
|
||||
num_str, direction = m.groups()
|
||||
if num_str.isdigit():
|
||||
minutes = int(num_str)
|
||||
else:
|
||||
_, minutes = self.digest_chinese_number(num_str)
|
||||
if minutes == 0 and num_str != "零":
|
||||
return text
|
||||
if minutes <= 0:
|
||||
return text
|
||||
delta = datetime.timedelta(
|
||||
minutes=minutes if direction == "后" else -minutes
|
||||
)
|
||||
self._add_delta(ctx, delta)
|
||||
return text[m.end():]
|
||||
|
||||
# === 秒级支持 ===
|
||||
m = re.match(r"^([0-9零一二三四五六七八九十两]+)秒(后|前|以后|之后)", text)
|
||||
if m:
|
||||
num_str, direction = m.groups()
|
||||
if num_str.isdigit():
|
||||
seconds = int(num_str)
|
||||
else:
|
||||
_, seconds = self.digest_chinese_number(num_str)
|
||||
if seconds == 0 and num_str != "零":
|
||||
return text
|
||||
if seconds <= 0:
|
||||
return text
|
||||
delta = datetime.timedelta(
|
||||
seconds=seconds if direction in ("后", "以后", "之后") else -seconds
|
||||
)
|
||||
self._add_delta(ctx, delta)
|
||||
return text[m.end():]
|
||||
|
||||
m = re.match(r"^([0-9零一二三四五六七八九十两]+)秒(后|前)", text)
|
||||
if m:
|
||||
num_str, direction = m.groups()
|
||||
if num_str.isdigit():
|
||||
seconds = int(num_str)
|
||||
else:
|
||||
_, seconds = self.digest_chinese_number(num_str)
|
||||
if seconds == 0 and num_str != "零":
|
||||
return text
|
||||
if seconds <= 0:
|
||||
return text
|
||||
delta = datetime.timedelta(
|
||||
seconds=seconds if direction == "后" else -seconds
|
||||
)
|
||||
self._add_delta(ctx, delta)
|
||||
return text[m.end():]
|
||||
|
||||
return text
|
||||
|
||||
def _apply_context(self, ctx: Dict[str, Any]) -> datetime.datetime:
|
||||
result = self.now
|
||||
has_date = ctx["has_date"]
|
||||
has_time = ctx["has_time"]
|
||||
has_delta = ctx["relative_delta"] is not None
|
||||
has_relative_date = ctx["has_relative_date"]
|
||||
|
||||
if has_delta:
|
||||
result = result + ctx["relative_delta"]
|
||||
|
||||
if has_date:
|
||||
result = result.replace(
|
||||
year=ctx["date"].year,
|
||||
month=ctx["date"].month,
|
||||
day=ctx["date"].day,
|
||||
)
|
||||
|
||||
if has_time:
|
||||
h = ctx["time"].hour
|
||||
m = ctx["time"].minute
|
||||
|
||||
if ctx["is_24hour"]:
|
||||
# "10 时" → 10:00, no conversion
|
||||
pass
|
||||
|
||||
elif ctx["am_pm"] == "AM":
|
||||
if h == 12:
|
||||
h = 0
|
||||
|
||||
elif ctx["am_pm"] == "PM":
|
||||
if h == 12:
|
||||
if ctx.get("period_word") in ("晚上", "晚"):
|
||||
h = 0
|
||||
result += datetime.timedelta(days=1)
|
||||
else:
|
||||
h = 12
|
||||
elif 1 <= h <= 11:
|
||||
h += 12
|
||||
|
||||
else:
|
||||
# No period and not 24-hour (i.e., "点" format)
|
||||
if ctx["has_relative_date"]:
|
||||
# "明天五点" → 05:00 AM
|
||||
if h == 12:
|
||||
h = 0
|
||||
# keep h as AM hour (1-11 unchanged)
|
||||
else:
|
||||
# Infer from current time
|
||||
am_hour = 0 if h == 12 else h
|
||||
candidate_am = result.replace(hour=am_hour, minute=m, second=0, microsecond=0)
|
||||
if candidate_am < self.now:
|
||||
# AM time is in the past, so use PM
|
||||
if h == 12:
|
||||
h = 12
|
||||
else:
|
||||
h += 12
|
||||
# else: keep as AM (h unchanged)
|
||||
|
||||
if h > 23:
|
||||
h = h % 24
|
||||
|
||||
result = result.replace(hour=h, minute=m, second=0, microsecond=0)
|
||||
|
||||
else:
|
||||
if has_date or (has_relative_date and not has_time):
|
||||
result = result.replace(hour=0, minute=0, second=0, microsecond=0)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def parse(text: str) -> datetime.datetime:
|
||||
return Parser().parse(text)
|
||||
|
||||
133
konabot/common/ptimeparse/chinese_number.py
Normal file
133
konabot/common/ptimeparse/chinese_number.py
Normal file
@ -0,0 +1,133 @@
|
||||
"""
|
||||
Chinese number parser for the time expression parser.
|
||||
"""
|
||||
|
||||
import re
|
||||
from typing import Tuple
|
||||
|
||||
|
||||
class ChineseNumberParser:
|
||||
"""Parser for Chinese numbers."""
|
||||
|
||||
def __init__(self):
|
||||
self.digits = {"零": 0, "一": 1, "二": 2, "三": 3, "四": 4,
|
||||
"五": 5, "六": 6, "七": 7, "八": 8, "九": 9}
|
||||
self.units = {"十": 10, "百": 100, "千": 1000, "万": 10000, "亿": 100000000}
|
||||
|
||||
def digest(self, text: str) -> Tuple[str, int]:
|
||||
"""
|
||||
Parse a Chinese number from the beginning of text and return the rest and the parsed number.
|
||||
|
||||
Args:
|
||||
text: Text that may start with a Chinese number
|
||||
|
||||
Returns:
|
||||
Tuple of (remaining_text, parsed_number)
|
||||
"""
|
||||
if not text:
|
||||
return text, 0
|
||||
|
||||
# Handle "两" at start
|
||||
if text.startswith("两"):
|
||||
# Check if "两" is followed by a time unit
|
||||
# Look ahead to see if we have a valid pattern like "两小时", "两分钟", etc.
|
||||
if len(text) >= 2:
|
||||
# Check for time units that start with the second character
|
||||
time_units = ["小时", "分钟", "秒"]
|
||||
for unit in time_units:
|
||||
if text[1:].startswith(unit):
|
||||
# Return the text starting from the time unit, not after it
|
||||
# The parser will handle the time unit in the next step
|
||||
return text[1:], 2
|
||||
# Check for single character time units
|
||||
next_char = text[1]
|
||||
if next_char in "时分秒":
|
||||
return text[1:], 2
|
||||
# Check for Chinese number units
|
||||
if next_char in "十百千万亿":
|
||||
# This will be handled by the normal parsing below
|
||||
pass
|
||||
# If "两" is at the end of string, treat it as standalone
|
||||
elif len(text) == 1:
|
||||
return "", 2
|
||||
# Also accept "两" followed by whitespace and then time units
|
||||
elif next_char.isspace():
|
||||
# Check if after whitespace we have time units
|
||||
rest_after_space = text[2:].lstrip()
|
||||
for unit in time_units:
|
||||
if rest_after_space.startswith(unit):
|
||||
# Return the text starting from the time unit
|
||||
space_len = len(text[2:]) - len(rest_after_space)
|
||||
return text[2+space_len:], 2
|
||||
# Check single character time units after whitespace
|
||||
if rest_after_space and rest_after_space[0] in "时分秒":
|
||||
return text[2:], 2
|
||||
else:
|
||||
# Just "两" by itself
|
||||
return "", 2
|
||||
|
||||
s = "零一二三四五六七八九"
|
||||
i = 0
|
||||
while i < len(text) and text[i] in s + "十百千万亿":
|
||||
i += 1
|
||||
if i == 0:
|
||||
return text, 0
|
||||
num_str = text[:i]
|
||||
rest = text[i:]
|
||||
|
||||
return rest, self.parse(num_str)
|
||||
|
||||
def parse(self, text: str) -> int:
|
||||
"""
|
||||
Parse a Chinese number string and return its integer value.
|
||||
|
||||
Args:
|
||||
text: Chinese number string
|
||||
|
||||
Returns:
|
||||
Integer value of the Chinese number
|
||||
"""
|
||||
if not text:
|
||||
return 0
|
||||
if text == "零":
|
||||
return 0
|
||||
if text == "两":
|
||||
return 2
|
||||
|
||||
# Handle special case for "十"
|
||||
if text == "十":
|
||||
return 10
|
||||
|
||||
# Handle numbers with "亿"
|
||||
if "亿" in text:
|
||||
parts = text.split("亿", 1)
|
||||
a, b = parts[0], parts[1]
|
||||
return self.parse(a) * 100000000 + self.parse(b)
|
||||
|
||||
# Handle numbers with "万"
|
||||
if "万" in text:
|
||||
parts = text.split("万", 1)
|
||||
a, b = parts[0], parts[1]
|
||||
return self.parse(a) * 10000 + self.parse(b)
|
||||
|
||||
# Handle remaining numbers
|
||||
result = 0
|
||||
temp = 0
|
||||
|
||||
for char in text:
|
||||
if char == "零":
|
||||
continue
|
||||
elif char == "两":
|
||||
temp = 2
|
||||
elif char in self.digits:
|
||||
temp = self.digits[char]
|
||||
elif char in self.units:
|
||||
unit = self.units[char]
|
||||
if unit == 10 and temp == 0:
|
||||
# Special case for numbers like "十三"
|
||||
temp = 1
|
||||
result += temp * unit
|
||||
temp = 0
|
||||
|
||||
result += temp
|
||||
return result
|
||||
63
konabot/common/ptimeparse/expression.py
Normal file
63
konabot/common/ptimeparse/expression.py
Normal file
@ -0,0 +1,63 @@
|
||||
"""
|
||||
Main time expression parser class that integrates all components.
|
||||
"""
|
||||
|
||||
import datetime
|
||||
from typing import Optional
|
||||
|
||||
from .lexer import Lexer
|
||||
from .parser import Parser
|
||||
from .semantic import SemanticAnalyzer
|
||||
from .ptime_ast import TimeExpressionNode
|
||||
from .err import TokenUnhandledException
|
||||
|
||||
|
||||
class TimeExpression:
|
||||
"""Main class for parsing time expressions."""
|
||||
|
||||
def __init__(self, text: str, now: Optional[datetime.datetime] = None):
|
||||
self.text = text.strip()
|
||||
self.now = now or datetime.datetime.now()
|
||||
|
||||
if not self.text:
|
||||
raise TokenUnhandledException("Empty input")
|
||||
|
||||
# Initialize components
|
||||
self.lexer = Lexer(self.text, self.now)
|
||||
self.parser = Parser(self.text, self.now)
|
||||
self.semantic_analyzer = SemanticAnalyzer(self.now)
|
||||
|
||||
# Parse the expression
|
||||
self.ast = self._parse()
|
||||
|
||||
def _parse(self) -> TimeExpressionNode:
|
||||
"""Parse the time expression and return the AST."""
|
||||
try:
|
||||
return self.parser.parse()
|
||||
except Exception as e:
|
||||
raise TokenUnhandledException(f"Failed to parse '{self.text}': {str(e)}")
|
||||
|
||||
def evaluate(self) -> datetime.datetime:
|
||||
"""Evaluate the time expression and return the datetime."""
|
||||
try:
|
||||
return self.semantic_analyzer.evaluate(self.ast)
|
||||
except Exception as e:
|
||||
raise TokenUnhandledException(f"Failed to evaluate '{self.text}': {str(e)}")
|
||||
|
||||
@classmethod
|
||||
def parse(cls, text: str, now: Optional[datetime.datetime] = None) -> datetime.datetime:
|
||||
"""
|
||||
Parse a time expression and return a datetime object.
|
||||
|
||||
Args:
|
||||
text: The time expression to parse
|
||||
now: The reference time (defaults to current time)
|
||||
|
||||
Returns:
|
||||
A datetime object representing the parsed time
|
||||
|
||||
Raises:
|
||||
TokenUnhandledException: If the input cannot be parsed
|
||||
"""
|
||||
expression = cls(text, now)
|
||||
return expression.evaluate()
|
||||
225
konabot/common/ptimeparse/lexer.py
Normal file
225
konabot/common/ptimeparse/lexer.py
Normal file
@ -0,0 +1,225 @@
|
||||
"""
|
||||
Lexical analyzer for time expressions.
|
||||
"""
|
||||
|
||||
import re
|
||||
from typing import Iterator, Optional
|
||||
import datetime
|
||||
|
||||
from .ptime_token import Token, TokenType
|
||||
from .chinese_number import ChineseNumberParser
|
||||
|
||||
|
||||
class Lexer:
|
||||
"""Lexical analyzer for time expressions."""
|
||||
|
||||
def __init__(self, text: str, now: Optional[datetime.datetime] = None):
|
||||
self.text = text
|
||||
self.pos = 0
|
||||
self.current_char = self.text[self.pos] if self.text else None
|
||||
self.now = now or datetime.datetime.now()
|
||||
self.chinese_parser = ChineseNumberParser()
|
||||
|
||||
# Define token patterns
|
||||
self.token_patterns = [
|
||||
# Whitespace
|
||||
(r'^\s+', TokenType.WHITESPACE),
|
||||
|
||||
# Time separators
|
||||
(r'^:', TokenType.TIME_SEPARATOR),
|
||||
(r'^点', TokenType.TIME_SEPARATOR),
|
||||
(r'^时', TokenType.TIME_SEPARATOR),
|
||||
(r'^分', TokenType.TIME_SEPARATOR),
|
||||
(r'^秒', TokenType.TIME_SEPARATOR),
|
||||
|
||||
# Special time markers
|
||||
(r'^半', TokenType.HALF),
|
||||
(r'^一刻', TokenType.QUARTER),
|
||||
(r'^整', TokenType.ZHENG),
|
||||
(r'^钟', TokenType.ZHONG),
|
||||
|
||||
# Period indicators (must come before relative time patterns to avoid conflicts)
|
||||
(r'^(上午|早晨|早上|清晨|早(?!\d))', TokenType.PERIOD_AM),
|
||||
(r'^(中午|下午|晚上|晚(?!\d)|凌晨|午夜)', TokenType.PERIOD_PM),
|
||||
|
||||
# Week scope (more specific patterns first)
|
||||
(r'^本周', TokenType.WEEK_SCOPE_CURRENT),
|
||||
(r'^上周', TokenType.WEEK_SCOPE_LAST),
|
||||
(r'^下周', TokenType.WEEK_SCOPE_NEXT),
|
||||
|
||||
# Relative directions
|
||||
(r'^(后|以后|之后)', TokenType.RELATIVE_DIRECTION_FORWARD),
|
||||
(r'^(前|以前|之前)', TokenType.RELATIVE_DIRECTION_BACKWARD),
|
||||
|
||||
# Extended relative time
|
||||
(r'^明年', TokenType.RELATIVE_NEXT),
|
||||
(r'^去年', TokenType.RELATIVE_LAST),
|
||||
(r'^今年', TokenType.RELATIVE_THIS),
|
||||
(r'^下(?![午年月周])', TokenType.RELATIVE_NEXT),
|
||||
(r'^(上|去)(?![午年月周])', TokenType.RELATIVE_LAST),
|
||||
(r'^这', TokenType.RELATIVE_THIS),
|
||||
(r'^本(?![周月年])', TokenType.RELATIVE_THIS), # Match "本" but not "本周", "本月", "本年"
|
||||
|
||||
# Week scope (fallback for standalone terms)
|
||||
(r'^本', TokenType.WEEK_SCOPE_CURRENT),
|
||||
(r'^上', TokenType.WEEK_SCOPE_LAST),
|
||||
(r'^下(?![午年月周])', TokenType.WEEK_SCOPE_NEXT),
|
||||
|
||||
# Week days (order matters - longer patterns first)
|
||||
(r'^周一', TokenType.WEEKDAY_MONDAY),
|
||||
(r'^周二', TokenType.WEEKDAY_TUESDAY),
|
||||
(r'^周三', TokenType.WEEKDAY_WEDNESDAY),
|
||||
(r'^周四', TokenType.WEEKDAY_THURSDAY),
|
||||
(r'^周五', TokenType.WEEKDAY_FRIDAY),
|
||||
(r'^周六', TokenType.WEEKDAY_SATURDAY),
|
||||
(r'^周日', TokenType.WEEKDAY_SUNDAY),
|
||||
# Single character weekdays should be matched after numbers
|
||||
# (r'^一', TokenType.WEEKDAY_MONDAY),
|
||||
# (r'^二', TokenType.WEEKDAY_TUESDAY),
|
||||
# (r'^三', TokenType.WEEKDAY_WEDNESDAY),
|
||||
# (r'^四', TokenType.WEEKDAY_THURSDAY),
|
||||
# (r'^五', TokenType.WEEKDAY_FRIDAY),
|
||||
# (r'^六', TokenType.WEEKDAY_SATURDAY),
|
||||
# (r'^日', TokenType.WEEKDAY_SUNDAY),
|
||||
|
||||
# Student-friendly time expressions
|
||||
(r'^早(?=\d)', TokenType.EARLY_MORNING),
|
||||
(r'^晚(?=\d)', TokenType.LATE_NIGHT),
|
||||
|
||||
# Relative today variants
|
||||
(r'^今晚上', TokenType.RELATIVE_TODAY),
|
||||
(r'^今晚', TokenType.RELATIVE_TODAY),
|
||||
(r'^今早', TokenType.RELATIVE_TODAY),
|
||||
(r'^今天早上', TokenType.RELATIVE_TODAY),
|
||||
(r'^今天早晨', TokenType.RELATIVE_TODAY),
|
||||
(r'^今天上午', TokenType.RELATIVE_TODAY),
|
||||
(r'^今天下午', TokenType.RELATIVE_TODAY),
|
||||
(r'^今天晚上', TokenType.RELATIVE_TODAY),
|
||||
(r'^今天', TokenType.RELATIVE_TODAY),
|
||||
|
||||
# Relative days
|
||||
(r'^明天', TokenType.RELATIVE_TOMORROW),
|
||||
(r'^后天', TokenType.RELATIVE_DAY_AFTER_TOMORROW),
|
||||
(r'^大后天', TokenType.RELATIVE_THREE_DAYS_AFTER_TOMORROW),
|
||||
(r'^昨天', TokenType.RELATIVE_YESTERDAY),
|
||||
(r'^前天', TokenType.RELATIVE_DAY_BEFORE_YESTERDAY),
|
||||
(r'^大前天', TokenType.RELATIVE_THREE_DAYS_BEFORE_YESTERDAY),
|
||||
|
||||
# Digits
|
||||
(r'^\d+', TokenType.INTEGER),
|
||||
|
||||
# Time units (must come after date separators to avoid conflicts)
|
||||
(r'^年(?![月日号])', TokenType.YEAR),
|
||||
(r'^月(?![日号])', TokenType.MONTH),
|
||||
(r'^[日号](?![月年])', TokenType.DAY),
|
||||
(r'^天', TokenType.DAY),
|
||||
(r'^周', TokenType.WEEK),
|
||||
(r'^小时', TokenType.HOUR),
|
||||
(r'^分钟', TokenType.MINUTE),
|
||||
(r'^秒', TokenType.SECOND),
|
||||
|
||||
# Date separators (fallback patterns)
|
||||
(r'^年', TokenType.DATE_SEPARATOR),
|
||||
(r'^月', TokenType.DATE_SEPARATOR),
|
||||
(r'^[日号]', TokenType.DATE_SEPARATOR),
|
||||
(r'^[-/]', TokenType.DATE_SEPARATOR),
|
||||
]
|
||||
|
||||
def advance(self):
|
||||
"""Advance the position pointer and set the current character."""
|
||||
self.pos += 1
|
||||
if self.pos >= len(self.text):
|
||||
self.current_char = None
|
||||
else:
|
||||
self.current_char = self.text[self.pos]
|
||||
|
||||
def skip_whitespace(self):
|
||||
"""Skip whitespace characters."""
|
||||
while self.current_char is not None and self.current_char.isspace():
|
||||
self.advance()
|
||||
|
||||
def integer(self) -> int:
|
||||
"""Parse an integer from the input."""
|
||||
result = ''
|
||||
while self.current_char is not None and self.current_char.isdigit():
|
||||
result += self.current_char
|
||||
self.advance()
|
||||
return int(result)
|
||||
|
||||
def chinese_number(self) -> int:
|
||||
"""Parse a Chinese number from the input."""
|
||||
# Find the longest prefix that can be parsed as a Chinese number
|
||||
for i in range(len(self.text) - self.pos, 0, -1):
|
||||
prefix = self.text[self.pos:self.pos + i]
|
||||
try:
|
||||
# Use digest to get both the remaining text and the parsed value
|
||||
remaining, value = self.chinese_parser.digest(prefix)
|
||||
# Check if we actually consumed part of the prefix
|
||||
consumed_length = len(prefix) - len(remaining)
|
||||
if consumed_length > 0:
|
||||
# Advance position by the length of the consumed text
|
||||
for _ in range(consumed_length):
|
||||
self.advance()
|
||||
return value
|
||||
except ValueError:
|
||||
continue
|
||||
# If no Chinese number found, just return 0
|
||||
return 0
|
||||
|
||||
def get_next_token(self) -> Token:
|
||||
"""Lexical analyzer that breaks the sentence into tokens."""
|
||||
while self.current_char is not None:
|
||||
# Skip whitespace
|
||||
if self.current_char.isspace():
|
||||
self.skip_whitespace()
|
||||
continue
|
||||
|
||||
# Try to match each pattern
|
||||
text_remaining = self.text[self.pos:]
|
||||
for pattern, token_type in self.token_patterns:
|
||||
match = re.match(pattern, text_remaining)
|
||||
if match:
|
||||
value = match.group(0)
|
||||
position = self.pos
|
||||
|
||||
# Advance position
|
||||
for _ in range(len(value)):
|
||||
self.advance()
|
||||
|
||||
# Special handling for some tokens
|
||||
if token_type == TokenType.INTEGER:
|
||||
value = int(value)
|
||||
elif token_type == TokenType.RELATIVE_TODAY and value in [
|
||||
"今早上", "今天早上", "今天早晨", "今天上午"
|
||||
]:
|
||||
token_type = TokenType.PERIOD_AM
|
||||
elif token_type == TokenType.RELATIVE_TODAY and value in [
|
||||
"今晚上", "今天下午", "今天晚上"
|
||||
]:
|
||||
token_type = TokenType.PERIOD_PM
|
||||
|
||||
return Token(token_type, value, position)
|
||||
|
||||
# Try to parse Chinese numbers
|
||||
chinese_start_pos = self.pos
|
||||
try:
|
||||
chinese_value = self.chinese_number()
|
||||
if chinese_value > 0:
|
||||
# We successfully parsed a Chinese number
|
||||
return Token(TokenType.CHINESE_NUMBER, chinese_value, chinese_start_pos)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
# If no pattern matches, skip the character and continue
|
||||
self.advance()
|
||||
|
||||
# End of file
|
||||
return Token(TokenType.EOF, None, self.pos)
|
||||
|
||||
def tokenize(self) -> Iterator[Token]:
|
||||
"""Generate all tokens from the input."""
|
||||
while True:
|
||||
token = self.get_next_token()
|
||||
yield token
|
||||
if token.type == TokenType.EOF:
|
||||
break
|
||||
846
konabot/common/ptimeparse/parser.py
Normal file
846
konabot/common/ptimeparse/parser.py
Normal file
@ -0,0 +1,846 @@
|
||||
"""
|
||||
Parser for time expressions that builds an Abstract Syntax Tree (AST).
|
||||
"""
|
||||
|
||||
from typing import Iterator, Optional, List
|
||||
import datetime
|
||||
|
||||
from .ptime_token import Token, TokenType
|
||||
from .ptime_ast import (
|
||||
ASTNode, NumberNode, DateNode, TimeNode,
|
||||
RelativeDateNode, RelativeTimeNode, WeekdayNode, TimeExpressionNode
|
||||
)
|
||||
from .lexer import Lexer
|
||||
|
||||
|
||||
class ParserError(Exception):
|
||||
"""Exception raised for parser errors."""
|
||||
pass
|
||||
|
||||
|
||||
class Parser:
|
||||
"""Parser for time expressions that builds an AST."""
|
||||
|
||||
def __init__(self, text: str, now: Optional[datetime.datetime] = None):
|
||||
self.lexer = Lexer(text, now)
|
||||
self.tokens: List[Token] = list(self.lexer.tokenize())
|
||||
self.pos = 0
|
||||
self.now = now or datetime.datetime.now()
|
||||
|
||||
@property
|
||||
def current_token(self) -> Token:
|
||||
"""Get the current token."""
|
||||
if self.pos < len(self.tokens):
|
||||
return self.tokens[self.pos]
|
||||
return Token(TokenType.EOF, None, len(self.tokens))
|
||||
|
||||
def eat(self, token_type: TokenType) -> Token:
|
||||
"""Consume a token of the expected type."""
|
||||
if self.current_token.type == token_type:
|
||||
token = self.current_token
|
||||
self.pos += 1
|
||||
return token
|
||||
else:
|
||||
raise ParserError(
|
||||
f"Expected token {token_type}, got {self.current_token.type} "
|
||||
f"at position {self.current_token.position}"
|
||||
)
|
||||
|
||||
def peek(self, offset: int = 1) -> Token:
|
||||
"""Look ahead at the next token without consuming it."""
|
||||
next_pos = self.pos + offset
|
||||
if next_pos < len(self.tokens):
|
||||
return self.tokens[next_pos]
|
||||
return Token(TokenType.EOF, None, len(self.tokens))
|
||||
|
||||
def parse_number(self) -> NumberNode:
|
||||
"""Parse a number (integer or Chinese number)."""
|
||||
token = self.current_token
|
||||
if token.type == TokenType.INTEGER:
|
||||
self.eat(TokenType.INTEGER)
|
||||
return NumberNode(value=token.value)
|
||||
elif token.type == TokenType.CHINESE_NUMBER:
|
||||
self.eat(TokenType.CHINESE_NUMBER)
|
||||
return NumberNode(value=token.value)
|
||||
else:
|
||||
raise ParserError(
|
||||
f"Expected number, got {token.type} at position {token.position}"
|
||||
)
|
||||
|
||||
def parse_date(self) -> DateNode:
|
||||
"""Parse a date specification."""
|
||||
year_node = None
|
||||
month_node = None
|
||||
day_node = None
|
||||
|
||||
# Try YYYY-MM-DD or YYYY/MM/DD format
|
||||
if (self.current_token.type == TokenType.INTEGER and
|
||||
self.peek().type == TokenType.DATE_SEPARATOR and
|
||||
self.peek().value in ['-', '/'] and
|
||||
self.peek(2).type == TokenType.INTEGER and
|
||||
self.peek(3).type == TokenType.DATE_SEPARATOR and
|
||||
self.peek(3).value in ['-', '/'] and
|
||||
self.peek(4).type == TokenType.INTEGER):
|
||||
|
||||
year_token = self.current_token
|
||||
self.eat(TokenType.INTEGER)
|
||||
separator1 = self.eat(TokenType.DATE_SEPARATOR).value
|
||||
|
||||
month_token = self.current_token
|
||||
self.eat(TokenType.INTEGER)
|
||||
|
||||
separator2 = self.eat(TokenType.DATE_SEPARATOR).value
|
||||
|
||||
day_token = self.current_token
|
||||
self.eat(TokenType.INTEGER)
|
||||
|
||||
year_node = NumberNode(value=year_token.value)
|
||||
month_node = NumberNode(value=month_token.value)
|
||||
day_node = NumberNode(value=day_token.value)
|
||||
|
||||
return DateNode(year=year_node, month=month_node, day=day_node)
|
||||
|
||||
# Try YYYY年MM月DD[日号] format
|
||||
if (self.current_token.type == TokenType.INTEGER and
|
||||
self.peek().type in [TokenType.DATE_SEPARATOR, TokenType.YEAR] and
|
||||
self.peek(2).type == TokenType.INTEGER and
|
||||
self.peek(3).type in [TokenType.DATE_SEPARATOR, TokenType.MONTH] and
|
||||
self.peek(4).type == TokenType.INTEGER):
|
||||
|
||||
year_token = self.current_token
|
||||
self.eat(TokenType.INTEGER)
|
||||
self.eat(self.current_token.type) # 年 (could be DATE_SEPARATOR or YEAR)
|
||||
|
||||
month_token = self.current_token
|
||||
self.eat(TokenType.INTEGER)
|
||||
self.eat(self.current_token.type) # 月 (could be DATE_SEPARATOR or MONTH)
|
||||
|
||||
day_token = self.current_token
|
||||
self.eat(TokenType.INTEGER)
|
||||
# Optional 日 or 号
|
||||
if self.current_token.type in [TokenType.DATE_SEPARATOR, TokenType.DAY]:
|
||||
self.eat(self.current_token.type)
|
||||
|
||||
year_node = NumberNode(value=year_token.value)
|
||||
month_node = NumberNode(value=month_token.value)
|
||||
day_node = NumberNode(value=day_token.value)
|
||||
|
||||
return DateNode(year=year_node, month=month_node, day=day_node)
|
||||
|
||||
# Try MM月DD[日号] format (without year)
|
||||
if (self.current_token.type in [TokenType.INTEGER, TokenType.CHINESE_NUMBER] and
|
||||
self.peek().type in [TokenType.DATE_SEPARATOR, TokenType.MONTH] and
|
||||
self.peek().value == '月' and
|
||||
self.peek(2).type in [TokenType.INTEGER, TokenType.CHINESE_NUMBER]):
|
||||
|
||||
month_token = self.current_token
|
||||
self.eat(month_token.type)
|
||||
self.eat(self.current_token.type) # 月 (could be DATE_SEPARATOR or MONTH)
|
||||
|
||||
day_token = self.current_token
|
||||
self.eat(day_token.type)
|
||||
# Optional 日 or 号
|
||||
if self.current_token.type in [TokenType.DATE_SEPARATOR, TokenType.DAY]:
|
||||
self.eat(self.current_token.type)
|
||||
|
||||
month_node = NumberNode(value=month_token.value)
|
||||
day_node = NumberNode(value=day_token.value)
|
||||
|
||||
return DateNode(year=None, month=month_node, day=day_node)
|
||||
|
||||
# Try Chinese MM月DD[日号] format
|
||||
if (self.current_token.type == TokenType.CHINESE_NUMBER and
|
||||
self.peek().type == TokenType.DATE_SEPARATOR and
|
||||
self.peek().value == '月' and
|
||||
self.peek(2).type in [TokenType.INTEGER, TokenType.CHINESE_NUMBER]):
|
||||
|
||||
month_token = self.current_token
|
||||
self.eat(TokenType.CHINESE_NUMBER)
|
||||
self.eat(TokenType.DATE_SEPARATOR) # 月
|
||||
|
||||
day_token = self.current_token
|
||||
self.eat(day_token.type)
|
||||
# Optional 日 or 号
|
||||
if self.current_token.type == TokenType.DATE_SEPARATOR:
|
||||
self.eat(TokenType.DATE_SEPARATOR)
|
||||
|
||||
month_node = NumberNode(value=month_token.value)
|
||||
day_node = NumberNode(value=day_token.value)
|
||||
|
||||
return DateNode(year=None, month=month_node, day=day_node)
|
||||
|
||||
raise ParserError(
|
||||
f"Unable to parse date at position {self.current_token.position}"
|
||||
)
|
||||
|
||||
def parse_time(self) -> TimeNode:
|
||||
"""Parse a time specification."""
|
||||
hour_node = None
|
||||
minute_node = None
|
||||
second_node = None
|
||||
is_24hour = False
|
||||
period = None
|
||||
|
||||
# Try HH:MM format
|
||||
if (self.current_token.type == TokenType.INTEGER and
|
||||
self.peek().type == TokenType.TIME_SEPARATOR and
|
||||
self.peek().value == ':'):
|
||||
|
||||
hour_token = self.current_token
|
||||
self.eat(TokenType.INTEGER)
|
||||
self.eat(TokenType.TIME_SEPARATOR) # :
|
||||
|
||||
minute_token = self.current_token
|
||||
self.eat(TokenType.INTEGER)
|
||||
|
||||
hour_node = NumberNode(value=hour_token.value)
|
||||
minute_node = NumberNode(value=minute_token.value)
|
||||
is_24hour = True # HH:MM is always interpreted as 24-hour
|
||||
|
||||
# Optional :SS
|
||||
if (self.current_token.type == TokenType.TIME_SEPARATOR and
|
||||
self.peek().type == TokenType.INTEGER):
|
||||
|
||||
self.eat(TokenType.TIME_SEPARATOR) # :
|
||||
second_token = self.current_token
|
||||
self.eat(TokenType.INTEGER)
|
||||
second_node = NumberNode(value=second_token.value)
|
||||
|
||||
return TimeNode(
|
||||
hour=hour_node,
|
||||
minute=minute_node,
|
||||
second=second_node,
|
||||
is_24hour=is_24hour,
|
||||
period=period
|
||||
)
|
||||
|
||||
# Try Chinese time format (X点X分)
|
||||
# First check for period indicators
|
||||
period = None
|
||||
if self.current_token.type in [TokenType.PERIOD_AM, TokenType.PERIOD_PM]:
|
||||
if self.current_token.type == TokenType.PERIOD_AM:
|
||||
period = "AM"
|
||||
else:
|
||||
period = "PM"
|
||||
self.eat(self.current_token.type)
|
||||
|
||||
if self.current_token.type in [TokenType.INTEGER, TokenType.CHINESE_NUMBER, TokenType.EARLY_MORNING, TokenType.LATE_NIGHT]:
|
||||
if self.current_token.type == TokenType.EARLY_MORNING:
|
||||
self.eat(TokenType.EARLY_MORNING)
|
||||
is_24hour = True
|
||||
period = "AM"
|
||||
|
||||
# Expect a number next
|
||||
if self.current_token.type in [TokenType.INTEGER, TokenType.CHINESE_NUMBER]:
|
||||
hour_token = self.current_token
|
||||
self.eat(hour_token.type)
|
||||
hour_node = NumberNode(value=hour_token.value)
|
||||
|
||||
# "早八" should be interpreted as 08:00
|
||||
# If hour is greater than 12, treat as 24-hour
|
||||
if hour_node.value > 12:
|
||||
is_24hour = True
|
||||
period = None
|
||||
else:
|
||||
raise ParserError(
|
||||
f"Expected number after '早', got {self.current_token.type} "
|
||||
f"at position {self.current_token.position}"
|
||||
)
|
||||
elif self.current_token.type == TokenType.LATE_NIGHT:
|
||||
self.eat(TokenType.LATE_NIGHT)
|
||||
is_24hour = True
|
||||
period = "PM"
|
||||
|
||||
# Expect a number next
|
||||
if self.current_token.type in [TokenType.INTEGER, TokenType.CHINESE_NUMBER]:
|
||||
hour_token = self.current_token
|
||||
self.eat(hour_token.type)
|
||||
hour_node = NumberNode(value=hour_token.value)
|
||||
|
||||
# "晚十" should be interpreted as 22:00
|
||||
# Adjust hour to 24-hour format
|
||||
if hour_node.value <= 12:
|
||||
hour_node.value += 12
|
||||
is_24hour = True
|
||||
period = None
|
||||
else:
|
||||
raise ParserError(
|
||||
f"Expected number after '晚', got {self.current_token.type} "
|
||||
f"at position {self.current_token.position}"
|
||||
)
|
||||
else:
|
||||
# Regular time parsing
|
||||
hour_token = self.current_token
|
||||
self.eat(hour_token.type)
|
||||
|
||||
# Check for 点 or 时
|
||||
if self.current_token.type == TokenType.TIME_SEPARATOR:
|
||||
separator = self.current_token.value
|
||||
self.eat(TokenType.TIME_SEPARATOR)
|
||||
|
||||
if separator == '点':
|
||||
is_24hour = False
|
||||
elif separator == '时':
|
||||
is_24hour = True
|
||||
|
||||
hour_node = NumberNode(value=hour_token.value)
|
||||
|
||||
# Optional minutes
|
||||
if self.current_token.type in [TokenType.INTEGER, TokenType.CHINESE_NUMBER]:
|
||||
minute_token = self.current_token
|
||||
self.eat(minute_token.type)
|
||||
|
||||
# Optional 分
|
||||
if self.current_token.type == TokenType.TIME_SEPARATOR and \
|
||||
self.current_token.value == '分':
|
||||
self.eat(TokenType.TIME_SEPARATOR)
|
||||
|
||||
minute_node = NumberNode(value=minute_token.value)
|
||||
|
||||
# Handle special markers
|
||||
if self.current_token.type == TokenType.HALF:
|
||||
self.eat(TokenType.HALF)
|
||||
minute_node = NumberNode(value=30)
|
||||
elif self.current_token.type == TokenType.QUARTER:
|
||||
self.eat(TokenType.QUARTER)
|
||||
minute_node = NumberNode(value=15)
|
||||
elif self.current_token.type == TokenType.ZHENG:
|
||||
self.eat(TokenType.ZHENG)
|
||||
if minute_node is None:
|
||||
minute_node = NumberNode(value=0)
|
||||
|
||||
# Optional 钟
|
||||
if self.current_token.type == TokenType.ZHONG:
|
||||
self.eat(TokenType.ZHONG)
|
||||
else:
|
||||
# If no separator, treat as hour-only time (like "三点")
|
||||
hour_node = NumberNode(value=hour_token.value)
|
||||
is_24hour = False
|
||||
|
||||
return TimeNode(
|
||||
hour=hour_node,
|
||||
minute=minute_node,
|
||||
second=second_node,
|
||||
is_24hour=is_24hour,
|
||||
period=period
|
||||
)
|
||||
|
||||
raise ParserError(
|
||||
f"Unable to parse time at position {self.current_token.position}"
|
||||
)
|
||||
|
||||
def parse_relative_date(self) -> RelativeDateNode:
|
||||
"""Parse a relative date specification."""
|
||||
years = 0
|
||||
months = 0
|
||||
weeks = 0
|
||||
days = 0
|
||||
|
||||
# Handle today variants
|
||||
if self.current_token.type == TokenType.RELATIVE_TODAY:
|
||||
self.eat(TokenType.RELATIVE_TODAY)
|
||||
days = 0
|
||||
elif self.current_token.type == TokenType.RELATIVE_TOMORROW:
|
||||
self.eat(TokenType.RELATIVE_TOMORROW)
|
||||
days = 1
|
||||
elif self.current_token.type == TokenType.RELATIVE_DAY_AFTER_TOMORROW:
|
||||
self.eat(TokenType.RELATIVE_DAY_AFTER_TOMORROW)
|
||||
days = 2
|
||||
elif self.current_token.type == TokenType.RELATIVE_THREE_DAYS_AFTER_TOMORROW:
|
||||
self.eat(TokenType.RELATIVE_THREE_DAYS_AFTER_TOMORROW)
|
||||
days = 3
|
||||
elif self.current_token.type == TokenType.RELATIVE_YESTERDAY:
|
||||
self.eat(TokenType.RELATIVE_YESTERDAY)
|
||||
days = -1
|
||||
elif self.current_token.type == TokenType.RELATIVE_DAY_BEFORE_YESTERDAY:
|
||||
self.eat(TokenType.RELATIVE_DAY_BEFORE_YESTERDAY)
|
||||
days = -2
|
||||
elif self.current_token.type == TokenType.RELATIVE_THREE_DAYS_BEFORE_YESTERDAY:
|
||||
self.eat(TokenType.RELATIVE_THREE_DAYS_BEFORE_YESTERDAY)
|
||||
days = -3
|
||||
else:
|
||||
# Check if this looks like an absolute date pattern before processing
|
||||
# Look ahead to see if this matches absolute date patterns
|
||||
is_likely_absolute_date = False
|
||||
|
||||
# Check for MM月DD[日号] patterns (like "6月20日")
|
||||
if (self.pos + 2 < len(self.tokens) and
|
||||
self.tokens[self.pos].type in [TokenType.INTEGER, TokenType.CHINESE_NUMBER] and
|
||||
self.tokens[self.pos + 1].type in [TokenType.DATE_SEPARATOR, TokenType.MONTH] and
|
||||
self.tokens[self.pos + 1].value == '月' and
|
||||
self.tokens[self.pos + 2].type in [TokenType.INTEGER, TokenType.CHINESE_NUMBER]):
|
||||
is_likely_absolute_date = True
|
||||
|
||||
if is_likely_absolute_date:
|
||||
# This looks like an absolute date, skip relative date parsing
|
||||
raise ParserError("Looks like absolute date format")
|
||||
|
||||
# Try to parse extended relative time expressions
|
||||
# Handle patterns like "明年", "去年", "下个月", "上个月", etc.
|
||||
original_pos = self.pos
|
||||
try:
|
||||
# Check for "今年", "明年", "去年"
|
||||
if self.current_token.type == TokenType.RELATIVE_THIS and self.peek().type == TokenType.YEAR:
|
||||
self.eat(TokenType.RELATIVE_THIS)
|
||||
self.eat(TokenType.YEAR)
|
||||
years = 0 # Current year
|
||||
elif self.current_token.type == TokenType.RELATIVE_NEXT and self.peek().type == TokenType.YEAR:
|
||||
self.eat(TokenType.RELATIVE_NEXT)
|
||||
self.eat(TokenType.YEAR)
|
||||
years = 1 # Next year
|
||||
elif self.current_token.type == TokenType.RELATIVE_LAST and self.peek().type == TokenType.YEAR:
|
||||
self.eat(TokenType.RELATIVE_LAST)
|
||||
self.eat(TokenType.YEAR)
|
||||
years = -1 # Last year
|
||||
elif self.current_token.type == TokenType.RELATIVE_NEXT and self.current_token.value == "明年":
|
||||
self.eat(TokenType.RELATIVE_NEXT)
|
||||
years = 1 # Next year
|
||||
# Check if there's a month after "明年"
|
||||
if (self.current_token.type in [TokenType.INTEGER, TokenType.CHINESE_NUMBER] and
|
||||
self.peek().type == TokenType.MONTH):
|
||||
# Parse the month
|
||||
month_node = self.parse_number()
|
||||
self.eat(TokenType.MONTH) # Eat the "月" token
|
||||
# Store the month in the months field as a special marker
|
||||
# We'll handle this in semantic analysis
|
||||
months = month_node.value - 100 # Use negative offset to indicate absolute month
|
||||
elif self.current_token.type == TokenType.RELATIVE_LAST and self.current_token.value == "去年":
|
||||
self.eat(TokenType.RELATIVE_LAST)
|
||||
years = -1 # Last year
|
||||
elif self.current_token.type == TokenType.RELATIVE_THIS and self.current_token.value == "今年":
|
||||
self.eat(TokenType.RELATIVE_THIS)
|
||||
years = 0 # Current year
|
||||
|
||||
# Check for "这个月", "下个月", "上个月"
|
||||
elif self.current_token.type == TokenType.RELATIVE_THIS and self.peek().type == TokenType.MONTH:
|
||||
self.eat(TokenType.RELATIVE_THIS)
|
||||
self.eat(TokenType.MONTH)
|
||||
months = 0 # Current month
|
||||
elif self.current_token.type == TokenType.RELATIVE_NEXT and self.peek().type == TokenType.MONTH:
|
||||
self.eat(TokenType.RELATIVE_NEXT)
|
||||
self.eat(TokenType.MONTH)
|
||||
months = 1 # Next month
|
||||
|
||||
# Handle patterns like "下个月五号"
|
||||
if (self.current_token.type in [TokenType.INTEGER, TokenType.CHINESE_NUMBER] and
|
||||
self.peek().type == TokenType.DAY):
|
||||
# Parse the day
|
||||
day_node = self.parse_number()
|
||||
self.eat(TokenType.DAY) # Eat the "号" token
|
||||
# Instead of adding days to the current date, we should set a specific day in the target month
|
||||
# We'll handle this in semantic analysis by setting a flag or special value
|
||||
days = 0 # Reset days - we'll handle the day differently
|
||||
# Use a special marker to indicate we want a specific day in the target month
|
||||
# For now, we'll just store the target day in the weeks field as a temporary solution
|
||||
weeks = day_node.value # This is a hack - we'll fix this in semantic analysis
|
||||
elif self.current_token.type == TokenType.RELATIVE_LAST and self.peek().type == TokenType.MONTH:
|
||||
self.eat(TokenType.RELATIVE_LAST)
|
||||
self.eat(TokenType.MONTH)
|
||||
months = -1 # Last month
|
||||
|
||||
# Check for "下周", "上周"
|
||||
elif self.current_token.type == TokenType.RELATIVE_NEXT and self.peek().type == TokenType.WEEK:
|
||||
self.eat(TokenType.RELATIVE_NEXT)
|
||||
self.eat(TokenType.WEEK)
|
||||
weeks = 1 # Next week
|
||||
elif self.current_token.type == TokenType.RELATIVE_LAST and self.peek().type == TokenType.WEEK:
|
||||
self.eat(TokenType.RELATIVE_LAST)
|
||||
self.eat(TokenType.WEEK)
|
||||
weeks = -1 # Last week
|
||||
|
||||
# Handle more complex patterns like "X年后", "X个月后", etc.
|
||||
elif self.current_token.type in [TokenType.INTEGER, TokenType.CHINESE_NUMBER]:
|
||||
# Check if this is likely an absolute date format (e.g., "2025年11月21日")
|
||||
# If the next token after the number is a date separator or date unit,
|
||||
# and the number looks like a year (4 digits) or the pattern continues,
|
||||
# it might be an absolute date. In that case, skip relative date parsing.
|
||||
|
||||
# Look ahead to see if this matches absolute date patterns
|
||||
lookahead_pos = self.pos
|
||||
is_likely_absolute_date = False
|
||||
|
||||
# Check for YYYY-MM-DD or YYYY/MM/DD patterns
|
||||
if (lookahead_pos + 4 < len(self.tokens) and
|
||||
self.tokens[lookahead_pos].type in [TokenType.INTEGER, TokenType.CHINESE_NUMBER] and
|
||||
self.tokens[lookahead_pos + 1].type in [TokenType.DATE_SEPARATOR, TokenType.YEAR] and
|
||||
self.tokens[lookahead_pos + 1].value in ['-', '/', '年'] and
|
||||
self.tokens[lookahead_pos + 2].type in [TokenType.INTEGER, TokenType.CHINESE_NUMBER] and
|
||||
self.tokens[lookahead_pos + 3].type in [TokenType.DATE_SEPARATOR, TokenType.MONTH] and
|
||||
self.tokens[lookahead_pos + 3].value in ['-', '/', '月']):
|
||||
is_likely_absolute_date = True
|
||||
|
||||
# Check for YYYY年MM月DD patterns
|
||||
if (lookahead_pos + 4 < len(self.tokens) and
|
||||
self.tokens[lookahead_pos].type in [TokenType.INTEGER, TokenType.CHINESE_NUMBER] and
|
||||
self.tokens[lookahead_pos + 1].type in [TokenType.DATE_SEPARATOR, TokenType.YEAR] and
|
||||
self.tokens[lookahead_pos + 1].value == '年' and
|
||||
self.tokens[lookahead_pos + 2].type in [TokenType.INTEGER, TokenType.CHINESE_NUMBER] and
|
||||
self.tokens[lookahead_pos + 3].type in [TokenType.DATE_SEPARATOR, TokenType.MONTH] and
|
||||
self.tokens[lookahead_pos + 3].value == '月'):
|
||||
is_likely_absolute_date = True
|
||||
|
||||
# Check for MM月DD[日号] patterns (like "6月20日")
|
||||
if (self.pos + 2 < len(self.tokens) and
|
||||
self.tokens[self.pos].type in [TokenType.INTEGER, TokenType.CHINESE_NUMBER] and
|
||||
self.tokens[self.pos + 1].type in [TokenType.DATE_SEPARATOR, TokenType.MONTH] and
|
||||
self.tokens[self.pos + 1].value == '月' and
|
||||
self.tokens[self.pos + 2].type in [TokenType.INTEGER, TokenType.CHINESE_NUMBER]):
|
||||
is_likely_absolute_date = True
|
||||
|
||||
if is_likely_absolute_date:
|
||||
# This looks like an absolute date, skip relative date parsing
|
||||
raise ParserError("Looks like absolute date format")
|
||||
|
||||
print(f"DEBUG: Parsing complex relative date pattern")
|
||||
# Parse the number
|
||||
number_node = self.parse_number()
|
||||
number_value = number_node.value
|
||||
print(f"DEBUG: Parsed number: {number_value}")
|
||||
|
||||
# Check the unit
|
||||
if self.current_token.type == TokenType.YEAR:
|
||||
self.eat(TokenType.YEAR)
|
||||
years = number_value
|
||||
print(f"DEBUG: Set years to {years}")
|
||||
elif self.current_token.type == TokenType.MONTH:
|
||||
self.eat(TokenType.MONTH)
|
||||
months = number_value
|
||||
print(f"DEBUG: Set months to {months}")
|
||||
elif self.current_token.type == TokenType.WEEK:
|
||||
self.eat(TokenType.WEEK)
|
||||
weeks = number_value
|
||||
print(f"DEBUG: Set weeks to {weeks}")
|
||||
elif self.current_token.type == TokenType.DAY:
|
||||
self.eat(TokenType.DAY)
|
||||
days = number_value
|
||||
print(f"DEBUG: Set days to {days}")
|
||||
else:
|
||||
print(f"DEBUG: Unexpected token type: {self.current_token.type}")
|
||||
raise ParserError(
|
||||
f"Expected time unit, got {self.current_token.type} "
|
||||
f"at position {self.current_token.position}"
|
||||
)
|
||||
|
||||
# Check direction (前/后)
|
||||
if self.current_token.type == TokenType.RELATIVE_DIRECTION_FORWARD:
|
||||
self.eat(TokenType.RELATIVE_DIRECTION_FORWARD)
|
||||
print(f"DEBUG: Forward direction, values are already positive")
|
||||
# Values are already positive
|
||||
elif self.current_token.type == TokenType.RELATIVE_DIRECTION_BACKWARD:
|
||||
self.eat(TokenType.RELATIVE_DIRECTION_BACKWARD)
|
||||
print(f"DEBUG: Backward direction, negating values")
|
||||
years = -years
|
||||
months = -months
|
||||
weeks = -weeks
|
||||
days = -days
|
||||
|
||||
except ParserError:
|
||||
# Reset position if parsing failed
|
||||
self.pos = original_pos
|
||||
raise ParserError(
|
||||
f"Expected relative date, got {self.current_token.type} "
|
||||
f"at position {self.current_token.position}"
|
||||
)
|
||||
|
||||
return RelativeDateNode(years=years, months=months, weeks=weeks, days=days)
|
||||
|
||||
def parse_weekday(self) -> WeekdayNode:
|
||||
"""Parse a weekday specification."""
|
||||
# Parse week scope (本, 上, 下)
|
||||
scope = "current"
|
||||
if self.current_token.type == TokenType.WEEK_SCOPE_CURRENT:
|
||||
self.eat(TokenType.WEEK_SCOPE_CURRENT)
|
||||
scope = "current"
|
||||
elif self.current_token.type == TokenType.WEEK_SCOPE_LAST:
|
||||
self.eat(TokenType.WEEK_SCOPE_LAST)
|
||||
scope = "last"
|
||||
elif self.current_token.type == TokenType.WEEK_SCOPE_NEXT:
|
||||
self.eat(TokenType.WEEK_SCOPE_NEXT)
|
||||
scope = "next"
|
||||
|
||||
# Parse weekday
|
||||
weekday_map = {
|
||||
TokenType.WEEKDAY_MONDAY: 0,
|
||||
TokenType.WEEKDAY_TUESDAY: 1,
|
||||
TokenType.WEEKDAY_WEDNESDAY: 2,
|
||||
TokenType.WEEKDAY_THURSDAY: 3,
|
||||
TokenType.WEEKDAY_FRIDAY: 4,
|
||||
TokenType.WEEKDAY_SATURDAY: 5,
|
||||
TokenType.WEEKDAY_SUNDAY: 6,
|
||||
# Handle Chinese numbers (1=Monday, 2=Tuesday, etc.)
|
||||
TokenType.CHINESE_NUMBER: lambda x: x - 1 if 1 <= x <= 7 else None,
|
||||
}
|
||||
|
||||
if self.current_token.type in weekday_map:
|
||||
if self.current_token.type == TokenType.CHINESE_NUMBER:
|
||||
# Handle numeric weekday (1=Monday, 2=Tuesday, etc.)
|
||||
weekday_num = self.current_token.value
|
||||
if 1 <= weekday_num <= 7:
|
||||
weekday = weekday_num - 1 # Convert to 0-based index
|
||||
self.eat(TokenType.CHINESE_NUMBER)
|
||||
return WeekdayNode(weekday=weekday, scope=scope)
|
||||
else:
|
||||
raise ParserError(
|
||||
f"Invalid weekday number: {weekday_num} "
|
||||
f"at position {self.current_token.position}"
|
||||
)
|
||||
else:
|
||||
weekday = weekday_map[self.current_token.type]
|
||||
self.eat(self.current_token.type)
|
||||
return WeekdayNode(weekday=weekday, scope=scope)
|
||||
|
||||
raise ParserError(
|
||||
f"Expected weekday, got {self.current_token.type} "
|
||||
f"at position {self.current_token.position}"
|
||||
)
|
||||
|
||||
def parse_relative_time(self) -> RelativeTimeNode:
|
||||
"""Parse a relative time specification."""
|
||||
hours = 0.0
|
||||
minutes = 0.0
|
||||
seconds = 0.0
|
||||
|
||||
def parse_relative_time(self) -> RelativeTimeNode:
|
||||
"""Parse a relative time specification."""
|
||||
hours = 0.0
|
||||
minutes = 0.0
|
||||
seconds = 0.0
|
||||
|
||||
# Parse sequences of relative time expressions
|
||||
while self.current_token.type in [
|
||||
TokenType.INTEGER, TokenType.CHINESE_NUMBER,
|
||||
TokenType.HALF, TokenType.QUARTER
|
||||
] or (self.current_token.type == TokenType.RELATIVE_DIRECTION_FORWARD or
|
||||
self.current_token.type == TokenType.RELATIVE_DIRECTION_BACKWARD):
|
||||
|
||||
# Handle 半小时
|
||||
if (self.current_token.type == TokenType.HALF):
|
||||
self.eat(TokenType.HALF)
|
||||
# Optional 个
|
||||
if (self.current_token.type == TokenType.INTEGER and
|
||||
self.current_token.value == "个"):
|
||||
self.eat(TokenType.INTEGER)
|
||||
# Optional 小时
|
||||
if self.current_token.type == TokenType.HOUR:
|
||||
self.eat(TokenType.HOUR)
|
||||
hours += 0.5
|
||||
# Check for direction
|
||||
if self.current_token.type == TokenType.RELATIVE_DIRECTION_FORWARD:
|
||||
self.eat(TokenType.RELATIVE_DIRECTION_FORWARD)
|
||||
elif self.current_token.type == TokenType.RELATIVE_DIRECTION_BACKWARD:
|
||||
self.eat(TokenType.RELATIVE_DIRECTION_BACKWARD)
|
||||
hours = -hours
|
||||
continue
|
||||
|
||||
# Handle 一刻钟 (15 minutes)
|
||||
if self.current_token.type == TokenType.QUARTER:
|
||||
self.eat(TokenType.QUARTER)
|
||||
# Optional 钟
|
||||
if self.current_token.type == TokenType.ZHONG:
|
||||
self.eat(TokenType.ZHONG)
|
||||
minutes += 15
|
||||
# Check for direction
|
||||
if self.current_token.type == TokenType.RELATIVE_DIRECTION_FORWARD:
|
||||
self.eat(TokenType.RELATIVE_DIRECTION_FORWARD)
|
||||
elif self.current_token.type == TokenType.RELATIVE_DIRECTION_BACKWARD:
|
||||
self.eat(TokenType.RELATIVE_DIRECTION_BACKWARD)
|
||||
minutes = -minutes
|
||||
continue
|
||||
|
||||
# Parse number if we have one
|
||||
if self.current_token.type in [TokenType.INTEGER, TokenType.CHINESE_NUMBER]:
|
||||
number_node = self.parse_number()
|
||||
number_value = number_node.value
|
||||
|
||||
# Determine unit and direction
|
||||
unit = None
|
||||
direction = 1 # Forward by default
|
||||
|
||||
# Check for unit
|
||||
if self.current_token.type == TokenType.HOUR:
|
||||
self.eat(TokenType.HOUR)
|
||||
# Optional 个
|
||||
if (self.current_token.type == TokenType.INTEGER and
|
||||
self.current_token.value == "个"):
|
||||
self.eat(TokenType.INTEGER)
|
||||
unit = "hour"
|
||||
elif self.current_token.type == TokenType.MINUTE:
|
||||
self.eat(TokenType.MINUTE)
|
||||
unit = "minute"
|
||||
elif self.current_token.type == TokenType.SECOND:
|
||||
self.eat(TokenType.SECOND)
|
||||
unit = "second"
|
||||
elif self.current_token.type == TokenType.TIME_SEPARATOR:
|
||||
# Handle "X点", "X分", "X秒" format
|
||||
sep_value = self.current_token.value
|
||||
self.eat(TokenType.TIME_SEPARATOR)
|
||||
if sep_value == "点":
|
||||
unit = "hour"
|
||||
# Optional 钟
|
||||
if self.current_token.type == TokenType.ZHONG:
|
||||
self.eat(TokenType.ZHONG)
|
||||
# If we have "X点" without a direction, this is likely an absolute time
|
||||
# Check if there's a direction after
|
||||
if not (self.current_token.type == TokenType.RELATIVE_DIRECTION_FORWARD or
|
||||
self.current_token.type == TokenType.RELATIVE_DIRECTION_BACKWARD):
|
||||
# This is probably an absolute time, not relative time
|
||||
# Push back the number and break
|
||||
break
|
||||
elif sep_value == "分":
|
||||
unit = "minute"
|
||||
# Optional 钟
|
||||
if self.current_token.type == TokenType.ZHONG:
|
||||
self.eat(TokenType.ZHONG)
|
||||
elif sep_value == "秒":
|
||||
unit = "second"
|
||||
else:
|
||||
# If no unit specified, but we have a number followed by a direction,
|
||||
# assume it's hours
|
||||
if (self.current_token.type == TokenType.RELATIVE_DIRECTION_FORWARD or
|
||||
self.current_token.type == TokenType.RELATIVE_DIRECTION_BACKWARD):
|
||||
unit = "hour"
|
||||
else:
|
||||
# If no unit and no direction, this might not be a relative time expression
|
||||
# Push the number back and break
|
||||
# We can't easily push back, so let's break
|
||||
break
|
||||
|
||||
# Check for direction (后/前)
|
||||
if self.current_token.type == TokenType.RELATIVE_DIRECTION_FORWARD:
|
||||
self.eat(TokenType.RELATIVE_DIRECTION_FORWARD)
|
||||
direction = 1
|
||||
elif self.current_token.type == TokenType.RELATIVE_DIRECTION_BACKWARD:
|
||||
self.eat(TokenType.RELATIVE_DIRECTION_BACKWARD)
|
||||
direction = -1
|
||||
|
||||
# Apply the value based on unit
|
||||
if unit == "hour":
|
||||
hours += number_value * direction
|
||||
elif unit == "minute":
|
||||
minutes += number_value * direction
|
||||
elif unit == "second":
|
||||
seconds += number_value * direction
|
||||
continue
|
||||
|
||||
# If we still haven't handled the current token, break
|
||||
break
|
||||
|
||||
return RelativeTimeNode(hours=hours, minutes=minutes, seconds=seconds)
|
||||
|
||||
def parse_time_expression(self) -> TimeExpressionNode:
|
||||
"""Parse a complete time expression."""
|
||||
date_node = None
|
||||
time_node = None
|
||||
relative_date_node = None
|
||||
relative_time_node = None
|
||||
weekday_node = None
|
||||
|
||||
# Parse different parts of the expression
|
||||
while self.current_token.type != TokenType.EOF:
|
||||
# Try to parse date first (absolute dates should take precedence)
|
||||
if self.current_token.type in [TokenType.INTEGER, TokenType.CHINESE_NUMBER]:
|
||||
if date_node is None:
|
||||
original_pos = self.pos
|
||||
try:
|
||||
date_node = self.parse_date()
|
||||
continue
|
||||
except ParserError:
|
||||
# Reset position if parsing failed
|
||||
self.pos = original_pos
|
||||
pass
|
||||
|
||||
# Try to parse relative date
|
||||
if self.current_token.type in [
|
||||
TokenType.RELATIVE_TODAY, TokenType.RELATIVE_TOMORROW,
|
||||
TokenType.RELATIVE_DAY_AFTER_TOMORROW, TokenType.RELATIVE_THREE_DAYS_AFTER_TOMORROW,
|
||||
TokenType.RELATIVE_YESTERDAY, TokenType.RELATIVE_DAY_BEFORE_YESTERDAY,
|
||||
TokenType.RELATIVE_THREE_DAYS_BEFORE_YESTERDAY,
|
||||
TokenType.INTEGER, TokenType.CHINESE_NUMBER, # For patterns like "X年后", "X个月后", etc.
|
||||
TokenType.RELATIVE_NEXT, TokenType.RELATIVE_LAST, TokenType.RELATIVE_THIS
|
||||
]:
|
||||
if relative_date_node is None:
|
||||
original_pos = self.pos
|
||||
try:
|
||||
relative_date_node = self.parse_relative_date()
|
||||
continue
|
||||
except ParserError:
|
||||
# Reset position if parsing failed
|
||||
self.pos = original_pos
|
||||
pass
|
||||
|
||||
# Try to parse relative time first (since it can have numbers)
|
||||
if self.current_token.type in [
|
||||
TokenType.INTEGER, TokenType.CHINESE_NUMBER,
|
||||
TokenType.HALF, TokenType.QUARTER,
|
||||
TokenType.RELATIVE_DIRECTION_FORWARD, TokenType.RELATIVE_DIRECTION_BACKWARD
|
||||
]:
|
||||
if relative_time_node is None:
|
||||
original_pos = self.pos
|
||||
try:
|
||||
relative_time_node = self.parse_relative_time()
|
||||
# Only continue if we actually parsed some relative time
|
||||
if relative_time_node.hours != 0 or relative_time_node.minutes != 0 or relative_time_node.seconds != 0:
|
||||
continue
|
||||
else:
|
||||
# If we didn't parse any relative time, reset position
|
||||
self.pos = original_pos
|
||||
except ParserError:
|
||||
# Reset position if parsing failed
|
||||
self.pos = original_pos
|
||||
pass
|
||||
|
||||
# Try to parse time
|
||||
if self.current_token.type in [TokenType.INTEGER, TokenType.CHINESE_NUMBER, TokenType.TIME_SEPARATOR, TokenType.PERIOD_AM, TokenType.PERIOD_PM]:
|
||||
if time_node is None:
|
||||
original_pos = self.pos
|
||||
try:
|
||||
time_node = self.parse_time()
|
||||
continue
|
||||
except ParserError:
|
||||
# Reset position if parsing failed
|
||||
self.pos = original_pos
|
||||
pass
|
||||
|
||||
# Try to parse time
|
||||
if self.current_token.type in [TokenType.INTEGER, TokenType.CHINESE_NUMBER, TokenType.TIME_SEPARATOR, TokenType.PERIOD_AM, TokenType.PERIOD_PM]:
|
||||
if time_node is None:
|
||||
original_pos = self.pos
|
||||
try:
|
||||
time_node = self.parse_time()
|
||||
continue
|
||||
except ParserError:
|
||||
# Reset position if parsing failed
|
||||
self.pos = original_pos
|
||||
pass
|
||||
|
||||
# Try to parse weekday
|
||||
if self.current_token.type in [
|
||||
TokenType.WEEK_SCOPE_CURRENT, TokenType.WEEK_SCOPE_LAST, TokenType.WEEK_SCOPE_NEXT,
|
||||
TokenType.WEEKDAY_MONDAY, TokenType.WEEKDAY_TUESDAY, TokenType.WEEKDAY_WEDNESDAY,
|
||||
TokenType.WEEKDAY_THURSDAY, TokenType.WEEKDAY_FRIDAY, TokenType.WEEKDAY_SATURDAY,
|
||||
TokenType.WEEKDAY_SUNDAY
|
||||
]:
|
||||
if weekday_node is None:
|
||||
original_pos = self.pos
|
||||
try:
|
||||
weekday_node = self.parse_weekday()
|
||||
continue
|
||||
except ParserError:
|
||||
# Reset position if parsing failed
|
||||
self.pos = original_pos
|
||||
pass
|
||||
|
||||
# If we get here and couldn't parse anything, skip the token
|
||||
self.pos += 1
|
||||
|
||||
return TimeExpressionNode(
|
||||
date=date_node,
|
||||
time=time_node,
|
||||
relative_date=relative_date_node,
|
||||
relative_time=relative_time_node,
|
||||
weekday=weekday_node
|
||||
)
|
||||
|
||||
def parse(self) -> TimeExpressionNode:
|
||||
"""Parse the complete time expression and return the AST."""
|
||||
return self.parse_time_expression()
|
||||
71
konabot/common/ptimeparse/ptime_ast.py
Normal file
71
konabot/common/ptimeparse/ptime_ast.py
Normal file
@ -0,0 +1,71 @@
|
||||
"""
|
||||
Abstract Syntax Tree (AST) nodes for the time expression parser.
|
||||
"""
|
||||
|
||||
from abc import ABC
|
||||
from typing import Optional
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class ASTNode(ABC):
|
||||
"""Base class for all AST nodes."""
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class NumberNode(ASTNode):
|
||||
"""Represents a numeric value."""
|
||||
value: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class DateNode(ASTNode):
|
||||
"""Represents a date specification."""
|
||||
year: Optional[ASTNode]
|
||||
month: Optional[ASTNode]
|
||||
day: Optional[ASTNode]
|
||||
|
||||
|
||||
@dataclass
|
||||
class TimeNode(ASTNode):
|
||||
"""Represents a time specification."""
|
||||
hour: Optional[ASTNode]
|
||||
minute: Optional[ASTNode]
|
||||
second: Optional[ASTNode]
|
||||
is_24hour: bool = False
|
||||
period: Optional[str] = None # AM or PM
|
||||
|
||||
|
||||
@dataclass
|
||||
class RelativeDateNode(ASTNode):
|
||||
"""Represents a relative date specification."""
|
||||
years: int = 0
|
||||
months: int = 0
|
||||
weeks: int = 0
|
||||
days: int = 0
|
||||
|
||||
|
||||
@dataclass
|
||||
class RelativeTimeNode(ASTNode):
|
||||
"""Represents a relative time specification."""
|
||||
hours: float = 0.0
|
||||
minutes: float = 0.0
|
||||
seconds: float = 0.0
|
||||
|
||||
|
||||
@dataclass
|
||||
class WeekdayNode(ASTNode):
|
||||
"""Represents a weekday specification."""
|
||||
weekday: int # 0=Monday, 6=Sunday
|
||||
scope: str # current, last, next
|
||||
|
||||
|
||||
@dataclass
|
||||
class TimeExpressionNode(ASTNode):
|
||||
"""Represents a complete time expression."""
|
||||
date: Optional[DateNode] = None
|
||||
time: Optional[TimeNode] = None
|
||||
relative_date: Optional[RelativeDateNode] = None
|
||||
relative_time: Optional[RelativeTimeNode] = None
|
||||
weekday: Optional[WeekdayNode] = None
|
||||
95
konabot/common/ptimeparse/ptime_token.py
Normal file
95
konabot/common/ptimeparse/ptime_token.py
Normal file
@ -0,0 +1,95 @@
|
||||
"""
|
||||
Token definitions for the time parser.
|
||||
"""
|
||||
|
||||
from enum import Enum
|
||||
from typing import Union
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
class TokenType(Enum):
|
||||
"""Types of tokens recognized by the lexer."""
|
||||
|
||||
# Numbers
|
||||
INTEGER = "INTEGER"
|
||||
CHINESE_NUMBER = "CHINESE_NUMBER"
|
||||
|
||||
# Time units
|
||||
YEAR = "YEAR"
|
||||
MONTH = "MONTH"
|
||||
DAY = "DAY"
|
||||
WEEK = "WEEK"
|
||||
HOUR = "HOUR"
|
||||
MINUTE = "MINUTE"
|
||||
SECOND = "SECOND"
|
||||
|
||||
# Date separators
|
||||
DATE_SEPARATOR = "DATE_SEPARATOR" # -, /, 年, 月, 日, 号
|
||||
|
||||
# Time separators
|
||||
TIME_SEPARATOR = "TIME_SEPARATOR" # :, 点, 时, 分, 秒
|
||||
|
||||
# Period indicators
|
||||
PERIOD_AM = "PERIOD_AM" # 上午, 早上, 早晨, etc.
|
||||
PERIOD_PM = "PERIOD_PM" # 下午, 晚上, 中午, etc.
|
||||
|
||||
# Relative time
|
||||
RELATIVE_TODAY = "RELATIVE_TODAY" # 今天, 今晚, 今早, etc.
|
||||
RELATIVE_TOMORROW = "RELATIVE_TOMORROW" # 明天
|
||||
RELATIVE_DAY_AFTER_TOMORROW = "RELATIVE_DAY_AFTER_TOMORROW" # 后天
|
||||
RELATIVE_THREE_DAYS_AFTER_TOMORROW = "RELATIVE_THREE_DAYS_AFTER_TOMORROW" # 大后天
|
||||
RELATIVE_YESTERDAY = "RELATIVE_YESTERDAY" # 昨天
|
||||
RELATIVE_DAY_BEFORE_YESTERDAY = "RELATIVE_DAY_BEFORE_YESTERDAY" # 前天
|
||||
RELATIVE_THREE_DAYS_BEFORE_YESTERDAY = "RELATIVE_THREE_DAYS_BEFORE_YESTERDAY" # 大前天
|
||||
RELATIVE_DIRECTION_FORWARD = "RELATIVE_DIRECTION_FORWARD" # 后, 以后, 之后
|
||||
RELATIVE_DIRECTION_BACKWARD = "RELATIVE_DIRECTION_BACKWARD" # 前, 以前, 之前
|
||||
|
||||
# Extended relative time
|
||||
RELATIVE_NEXT = "RELATIVE_NEXT" # 下
|
||||
RELATIVE_LAST = "RELATIVE_LAST" # 上, 去
|
||||
RELATIVE_THIS = "RELATIVE_THIS" # 这, 本
|
||||
|
||||
# Week days
|
||||
WEEKDAY_MONDAY = "WEEKDAY_MONDAY"
|
||||
WEEKDAY_TUESDAY = "WEEKDAY_TUESDAY"
|
||||
WEEKDAY_WEDNESDAY = "WEEKDAY_WEDNESDAY"
|
||||
WEEKDAY_THURSDAY = "WEEKDAY_THURSDAY"
|
||||
WEEKDAY_FRIDAY = "WEEKDAY_FRIDAY"
|
||||
WEEKDAY_SATURDAY = "WEEKDAY_SATURDAY"
|
||||
WEEKDAY_SUNDAY = "WEEKDAY_SUNDAY"
|
||||
|
||||
# Week scope
|
||||
WEEK_SCOPE_CURRENT = "WEEK_SCOPE_CURRENT" # 本
|
||||
WEEK_SCOPE_LAST = "WEEK_SCOPE_LAST" # 上
|
||||
WEEK_SCOPE_NEXT = "WEEK_SCOPE_NEXT" # 下
|
||||
|
||||
# Special time markers
|
||||
HALF = "HALF" # 半
|
||||
QUARTER = "QUARTER" # 一刻
|
||||
ZHENG = "ZHENG" # 整
|
||||
ZHONG = "ZHONG" # 钟
|
||||
|
||||
# Student-friendly time expressions
|
||||
EARLY_MORNING = "EARLY_MORNING" # 早X
|
||||
LATE_NIGHT = "LATE_NIGHT" # 晚X
|
||||
|
||||
# Whitespace
|
||||
WHITESPACE = "WHITESPACE"
|
||||
|
||||
# End of input
|
||||
EOF = "EOF"
|
||||
|
||||
|
||||
@dataclass
|
||||
class Token:
|
||||
"""Represents a single token from the lexer."""
|
||||
|
||||
type: TokenType
|
||||
value: Union[str, int]
|
||||
position: int
|
||||
|
||||
def __str__(self):
|
||||
return f"Token({self.type.value}, {repr(self.value)}, {self.position})"
|
||||
|
||||
def __repr__(self):
|
||||
return self.__str__()
|
||||
369
konabot/common/ptimeparse/semantic.py
Normal file
369
konabot/common/ptimeparse/semantic.py
Normal file
@ -0,0 +1,369 @@
|
||||
"""
|
||||
Semantic analyzer for time expressions that evaluates the AST and produces datetime objects.
|
||||
"""
|
||||
|
||||
import datetime
|
||||
import calendar
|
||||
from typing import Optional
|
||||
|
||||
from .ptime_ast import (
|
||||
TimeExpressionNode, DateNode, TimeNode,
|
||||
RelativeDateNode, RelativeTimeNode, WeekdayNode, NumberNode
|
||||
)
|
||||
from .err import TokenUnhandledException
|
||||
|
||||
|
||||
class SemanticAnalyzer:
|
||||
"""Semantic analyzer that evaluates time expression ASTs."""
|
||||
|
||||
def __init__(self, now: Optional[datetime.datetime] = None):
|
||||
self.now = now or datetime.datetime.now()
|
||||
|
||||
def evaluate_number(self, node: NumberNode) -> int:
|
||||
"""Evaluate a number node."""
|
||||
return node.value
|
||||
|
||||
def evaluate_date(self, node: DateNode) -> datetime.date:
|
||||
"""Evaluate a date node."""
|
||||
year = self.now.year
|
||||
month = 1
|
||||
day = 1
|
||||
|
||||
if node.year is not None:
|
||||
year = self.evaluate_number(node.year)
|
||||
if node.month is not None:
|
||||
month = self.evaluate_number(node.month)
|
||||
if node.day is not None:
|
||||
day = self.evaluate_number(node.day)
|
||||
|
||||
return datetime.date(year, month, day)
|
||||
|
||||
def evaluate_time(self, node: TimeNode) -> datetime.time:
|
||||
"""Evaluate a time node."""
|
||||
hour = 0
|
||||
minute = 0
|
||||
second = 0
|
||||
|
||||
if node.hour is not None:
|
||||
hour = self.evaluate_number(node.hour)
|
||||
if node.minute is not None:
|
||||
minute = self.evaluate_number(node.minute)
|
||||
if node.second is not None:
|
||||
second = self.evaluate_number(node.second)
|
||||
|
||||
# Handle 24-hour vs 12-hour format
|
||||
if not node.is_24hour and node.period is not None:
|
||||
if node.period == "AM":
|
||||
if hour == 12:
|
||||
hour = 0
|
||||
elif node.period == "PM":
|
||||
if hour != 12 and hour <= 12:
|
||||
hour += 12
|
||||
|
||||
# Validate time values
|
||||
if not (0 <= hour <= 23):
|
||||
raise TokenUnhandledException(f"Invalid hour: {hour}")
|
||||
if not (0 <= minute <= 59):
|
||||
raise TokenUnhandledException(f"Invalid minute: {minute}")
|
||||
if not (0 <= second <= 59):
|
||||
raise TokenUnhandledException(f"Invalid second: {second}")
|
||||
|
||||
return datetime.time(hour, minute, second)
|
||||
|
||||
def evaluate_relative_date(self, node: RelativeDateNode) -> datetime.timedelta:
|
||||
"""Evaluate a relative date node."""
|
||||
# Start with current time
|
||||
result = self.now
|
||||
|
||||
# Special case: If weeks contains a target day (hacky way to pass target day info)
|
||||
# This is for patterns like "下个月五号"
|
||||
if node.weeks > 0 and node.weeks <= 31: # Valid day range
|
||||
target_day = node.weeks
|
||||
|
||||
# Calculate the target month
|
||||
if node.months != 0:
|
||||
# Handle month arithmetic carefully
|
||||
total_months = result.month + node.months - 1
|
||||
new_year = result.year + total_months // 12
|
||||
new_month = total_months % 12 + 1
|
||||
|
||||
# Handle day overflow (e.g., Jan 31 + 1 month = Feb 28/29)
|
||||
max_day_in_target_month = calendar.monthrange(new_year, new_month)[1]
|
||||
target_day = min(target_day, max_day_in_target_month)
|
||||
|
||||
try:
|
||||
result = result.replace(year=new_year, month=new_month, day=target_day)
|
||||
except ValueError:
|
||||
# Handle edge cases
|
||||
result = result.replace(year=new_year, month=new_month, day=max_day_in_target_month)
|
||||
|
||||
# Return the difference between the new date and the original date
|
||||
return result - self.now
|
||||
|
||||
# Apply years
|
||||
if node.years != 0:
|
||||
# Handle year arithmetic carefully due to leap years
|
||||
new_year = result.year + node.years
|
||||
try:
|
||||
result = result.replace(year=new_year)
|
||||
except ValueError:
|
||||
# Handle leap year edge case (Feb 29 -> Feb 28)
|
||||
result = result.replace(year=new_year, month=2, day=28)
|
||||
|
||||
# Apply months
|
||||
if node.months != 0:
|
||||
# Check if this is a special marker for absolute month (negative offset)
|
||||
if node.months < 0:
|
||||
# This is an absolute month specification (e.g., from "明年五月")
|
||||
absolute_month = node.months + 100
|
||||
if 1 <= absolute_month <= 12:
|
||||
result = result.replace(year=result.year, month=absolute_month, day=result.day)
|
||||
else:
|
||||
# Handle month arithmetic carefully
|
||||
total_months = result.month + node.months - 1
|
||||
new_year = result.year + total_months // 12
|
||||
new_month = total_months % 12 + 1
|
||||
|
||||
# Handle day overflow (e.g., Jan 31 + 1 month = Feb 28/29)
|
||||
new_day = min(result.day, calendar.monthrange(new_year, new_month)[1])
|
||||
|
||||
result = result.replace(year=new_year, month=new_month, day=new_day)
|
||||
|
||||
# Apply weeks and days
|
||||
if node.weeks != 0 or node.days != 0:
|
||||
delta_days = node.weeks * 7 + node.days
|
||||
result = result + datetime.timedelta(days=delta_days)
|
||||
|
||||
return result - self.now
|
||||
|
||||
def evaluate_relative_time(self, node: RelativeTimeNode) -> datetime.timedelta:
|
||||
"""Evaluate a relative time node."""
|
||||
# Convert all values to seconds for precise calculation
|
||||
total_seconds = (
|
||||
node.hours * 3600 +
|
||||
node.minutes * 60 +
|
||||
node.seconds
|
||||
)
|
||||
|
||||
return datetime.timedelta(seconds=total_seconds)
|
||||
|
||||
def evaluate_weekday(self, node: WeekdayNode) -> datetime.timedelta:
|
||||
"""Evaluate a weekday node."""
|
||||
current_weekday = self.now.weekday() # 0=Monday, 6=Sunday
|
||||
target_weekday = node.weekday
|
||||
|
||||
if node.scope == "current":
|
||||
delta = target_weekday - current_weekday
|
||||
elif node.scope == "last":
|
||||
delta = target_weekday - current_weekday - 7
|
||||
elif node.scope == "next":
|
||||
delta = target_weekday - current_weekday + 7
|
||||
else:
|
||||
delta = target_weekday - current_weekday
|
||||
|
||||
return datetime.timedelta(days=delta)
|
||||
|
||||
def infer_smart_time(self, hour: int, minute: int = 0, second: int = 0, base_time: Optional[datetime.datetime] = None) -> datetime.datetime:
|
||||
"""
|
||||
Smart time inference based on current time.
|
||||
|
||||
For example:
|
||||
- If now is 14:30 and user says "3点", interpret as 15:00
|
||||
- If now is 14:30 and user says "1点", interpret as next day 01:00
|
||||
- If now is 8:00 and user says "3点", interpret as 15:00
|
||||
- If now is 8:00 and user says "9点", interpret as 09:00
|
||||
"""
|
||||
# Use base_time if provided, otherwise use self.now
|
||||
now = base_time if base_time is not None else self.now
|
||||
|
||||
# Handle 24-hour format directly (13-23)
|
||||
if 13 <= hour <= 23:
|
||||
candidate = now.replace(hour=hour, minute=minute, second=second, microsecond=0)
|
||||
if candidate <= now:
|
||||
candidate += datetime.timedelta(days=1)
|
||||
return candidate
|
||||
|
||||
# Handle 12 (noon/midnight)
|
||||
if hour == 12:
|
||||
# For 12 specifically, we need to be more careful
|
||||
# Try noon first
|
||||
noon_candidate = now.replace(hour=12, minute=minute, second=second, microsecond=0)
|
||||
midnight_candidate = now.replace(hour=0, minute=minute, second=second, microsecond=0)
|
||||
|
||||
# Special case: If it's afternoon or evening, "十二点" likely means next day midnight
|
||||
if now.hour >= 12:
|
||||
result = midnight_candidate + datetime.timedelta(days=1)
|
||||
return result
|
||||
|
||||
# If noon is in the future and closer than midnight, use it
|
||||
if noon_candidate > now and (midnight_candidate <= now or noon_candidate < midnight_candidate):
|
||||
return noon_candidate
|
||||
# If midnight is in the future, use it
|
||||
elif midnight_candidate > now:
|
||||
return midnight_candidate
|
||||
# Both are in the past, use the closer one
|
||||
elif noon_candidate > midnight_candidate:
|
||||
return noon_candidate
|
||||
# Otherwise use midnight next day
|
||||
else:
|
||||
result = midnight_candidate + datetime.timedelta(days=1)
|
||||
return result
|
||||
|
||||
# Handle 1-11 (12-hour format)
|
||||
if 1 <= hour <= 11:
|
||||
# Calculate 12-hour format candidates
|
||||
pm_hour = hour + 12
|
||||
pm_candidate = now.replace(hour=pm_hour, minute=minute, second=second, microsecond=0)
|
||||
am_candidate = now.replace(hour=hour, minute=minute, second=second, microsecond=0)
|
||||
|
||||
# Special case: If it's afternoon (12:00-18:00) and the hour is 1-6,
|
||||
# user might mean either PM today or AM tomorrow.
|
||||
# But if PM is in the future, that's more likely what they mean.
|
||||
if 12 <= now.hour <= 18 and 1 <= hour <= 6:
|
||||
if pm_candidate > now:
|
||||
return pm_candidate
|
||||
else:
|
||||
# PM is in the past, so use AM tomorrow
|
||||
result = am_candidate + datetime.timedelta(days=1)
|
||||
return result
|
||||
|
||||
# Special case: If it's late evening (after 22:00) and user specifies early morning hours (1-5),
|
||||
# user likely means next day early morning
|
||||
if now.hour >= 22 and 1 <= hour <= 5:
|
||||
result = am_candidate + datetime.timedelta(days=1)
|
||||
return result
|
||||
|
||||
# Special case: In the morning (0-12:00)
|
||||
if now.hour < 12:
|
||||
# In the morning, for hours 1-11, generally prefer AM interpretation
|
||||
# unless it's a very early hour that's much earlier than current time
|
||||
# Only push to next day for very early hours (1-2) that are significantly earlier
|
||||
if hour <= 2 and hour < now.hour and now.hour - hour >= 6:
|
||||
# Very early morning hour that's significantly earlier, use next day
|
||||
result = am_candidate + datetime.timedelta(days=1)
|
||||
return result
|
||||
else:
|
||||
# For morning, generally prefer AM if it's in the future
|
||||
if am_candidate > now:
|
||||
return am_candidate
|
||||
# If PM is in the future, use it
|
||||
elif pm_candidate > now:
|
||||
return pm_candidate
|
||||
# Both are in the past, prefer AM if it's closer
|
||||
elif am_candidate > pm_candidate:
|
||||
return am_candidate
|
||||
# Otherwise use PM next day
|
||||
else:
|
||||
result = pm_candidate + datetime.timedelta(days=1)
|
||||
return result
|
||||
else:
|
||||
# General case: choose the one that's in the future and closer
|
||||
if pm_candidate > now and (am_candidate <= now or pm_candidate < am_candidate):
|
||||
return pm_candidate
|
||||
elif am_candidate > now:
|
||||
return am_candidate
|
||||
# Both are in the past, use the closer one
|
||||
elif pm_candidate > am_candidate:
|
||||
return pm_candidate
|
||||
# Otherwise use AM next day
|
||||
else:
|
||||
result = am_candidate + datetime.timedelta(days=1)
|
||||
return result
|
||||
|
||||
# Handle 0 (midnight)
|
||||
if hour == 0:
|
||||
candidate = now.replace(hour=0, minute=minute, second=second, microsecond=0)
|
||||
if candidate <= now:
|
||||
candidate += datetime.timedelta(days=1)
|
||||
return candidate
|
||||
|
||||
# Default case (should not happen with valid input)
|
||||
candidate = now.replace(hour=hour, minute=minute, second=second, microsecond=0)
|
||||
if candidate <= now:
|
||||
candidate += datetime.timedelta(days=1)
|
||||
return candidate
|
||||
|
||||
def evaluate(self, node: TimeExpressionNode) -> datetime.datetime:
|
||||
"""Evaluate a complete time expression node."""
|
||||
result = self.now
|
||||
|
||||
# Apply relative date (should set time to 00:00:00 for dates)
|
||||
if node.relative_date is not None:
|
||||
delta = self.evaluate_relative_date(node.relative_date)
|
||||
result = result + delta
|
||||
# For relative dates like "今天", "明天", set time to 00:00:00
|
||||
# But only for cases where we're dealing with days, not years/months
|
||||
if (node.date is None and node.time is None and node.weekday is None and
|
||||
node.relative_date.years == 0 and node.relative_date.months == 0):
|
||||
result = result.replace(hour=0, minute=0, second=0, microsecond=0)
|
||||
|
||||
# Apply weekday
|
||||
if node.weekday is not None:
|
||||
delta = self.evaluate_weekday(node.weekday)
|
||||
result = result + delta
|
||||
# For weekdays, set time to 00:00:00
|
||||
if node.date is None and node.time is None:
|
||||
result = result.replace(hour=0, minute=0, second=0, microsecond=0)
|
||||
|
||||
# Apply relative time
|
||||
if node.relative_time is not None:
|
||||
delta = self.evaluate_relative_time(node.relative_time)
|
||||
result = result + delta
|
||||
|
||||
# Apply absolute date
|
||||
if node.date is not None:
|
||||
date = self.evaluate_date(node.date)
|
||||
result = result.replace(year=date.year, month=date.month, day=date.day)
|
||||
# For absolute dates without time, set time to 00:00:00
|
||||
if node.time is None:
|
||||
result = result.replace(hour=0, minute=0, second=0, microsecond=0)
|
||||
|
||||
# Apply time
|
||||
if node.time is not None:
|
||||
time = self.evaluate_time(node.time)
|
||||
|
||||
# Handle explicit period or student-friendly expressions
|
||||
if node.time.is_24hour or node.time.period is not None:
|
||||
# Handle explicit period
|
||||
if not node.time.is_24hour and node.time.period is not None:
|
||||
hour = time.hour
|
||||
minute = time.minute
|
||||
second = time.second
|
||||
|
||||
if node.time.period == "AM":
|
||||
if hour == 12:
|
||||
hour = 0
|
||||
elif node.time.period == "PM":
|
||||
# Special case: "晚上十二点" should be interpreted as next day 00:00
|
||||
if hour == 12 and minute == 0 and second == 0:
|
||||
# Move to next day at 00:00:00
|
||||
result = result.replace(hour=0, minute=0, second=0, microsecond=0) + datetime.timedelta(days=1)
|
||||
# Skip the general replacement since we've already handled it
|
||||
skip_general_replacement = True
|
||||
else:
|
||||
# For other PM times, convert to 24-hour format
|
||||
if hour != 12 and hour <= 12:
|
||||
hour += 12
|
||||
|
||||
# Validate hour
|
||||
if not (0 <= hour <= 23):
|
||||
raise TokenUnhandledException(f"Invalid hour: {hour}")
|
||||
|
||||
# Only do general replacement if we haven't handled it specially
|
||||
if not locals().get('skip_general_replacement', False):
|
||||
result = result.replace(hour=hour, minute=minute, second=second, microsecond=0)
|
||||
else:
|
||||
# Already in 24-hour format
|
||||
result = result.replace(hour=time.hour, minute=time.minute, second=time.second, microsecond=0)
|
||||
else:
|
||||
# Use smart time inference for regular times
|
||||
# But if we have an explicit date, treat the time as 24-hour format
|
||||
if node.date is not None or node.relative_date is not None:
|
||||
# For explicit dates, treat time as 24-hour format
|
||||
result = result.replace(hour=time.hour, minute=time.minute or 0, second=time.second or 0, microsecond=0)
|
||||
else:
|
||||
# Use smart time inference for regular times
|
||||
smart_time = self.infer_smart_time(time.hour, time.minute, time.second, base_time=result)
|
||||
result = smart_time
|
||||
|
||||
return result
|
||||
34
konabot/common/render_error_message.py
Normal file
34
konabot/common/render_error_message.py
Normal file
@ -0,0 +1,34 @@
|
||||
from typing import Any
|
||||
from loguru import logger
|
||||
from nonebot_plugin_alconna import UniMessage
|
||||
import playwright.async_api
|
||||
from playwright.async_api import Page
|
||||
|
||||
from konabot.common.web_render import WebRenderer, konaweb
|
||||
|
||||
|
||||
async def render_error_message(message: str) -> UniMessage[Any]:
|
||||
"""
|
||||
渲染文本消息为错误信息图片。
|
||||
|
||||
如果无法访达 Web 端则返回纯文本给用户。
|
||||
"""
|
||||
|
||||
async def page_function(page: Page):
|
||||
await page.wait_for_function("typeof setContent === 'function'", timeout=3000)
|
||||
await page.evaluate(
|
||||
"""(message) => {return setContent(message);}""",
|
||||
message,
|
||||
)
|
||||
|
||||
try:
|
||||
img_data = await WebRenderer.render(
|
||||
url=konaweb("error_report"),
|
||||
target="#main",
|
||||
other_function=page_function,
|
||||
)
|
||||
return UniMessage.image(raw=img_data)
|
||||
except (playwright.async_api.Error, ConnectionError) as e:
|
||||
logger.warning("渲染报错信息图片时出错了,回退到文本 ERR={}", e)
|
||||
return UniMessage.text(message)
|
||||
|
||||
11
konabot/common/subscribe/__init__.py
Normal file
11
konabot/common/subscribe/__init__.py
Normal file
@ -0,0 +1,11 @@
|
||||
"""
|
||||
Subscribe 模块,用于向一些订阅的频道广播消息
|
||||
"""
|
||||
|
||||
from .service import broadcast as broadcast
|
||||
from .service import dep_poster_service as dep_poster_service
|
||||
from .service import DepPosterService as DepPosterService
|
||||
from .service import PosterService as PosterService
|
||||
from .subscribe_info import PosterInfo as PosterInfo
|
||||
from .subscribe_info import POSTER_INFO_DATA as POSTER_INFO_DATA
|
||||
from .subscribe_info import register_poster_info as register_poster_info
|
||||
@ -6,7 +6,8 @@ from pydantic import BaseModel, ValidationError
|
||||
from konabot.common.longtask import LongTaskTarget
|
||||
from konabot.common.pager import PagerQuery, PagerResult
|
||||
from konabot.common.path import DATA_PATH
|
||||
from konabot.plugins.poster.repository import IPosterRepo
|
||||
|
||||
from .repository import IPosterRepo
|
||||
|
||||
|
||||
class ChannelData(BaseModel):
|
||||
@ -18,9 +19,9 @@ class PosterData(BaseModel):
|
||||
|
||||
|
||||
def is_the_same_target(target1: LongTaskTarget, target2: LongTaskTarget) -> bool:
|
||||
if (target1.is_private_chat and not target2.is_private_chat):
|
||||
if target1.is_private_chat and not target2.is_private_chat:
|
||||
return False
|
||||
if (target2.is_private_chat and not target1.is_private_chat):
|
||||
if target2.is_private_chat and not target1.is_private_chat:
|
||||
return False
|
||||
if target1.platform != target2.platform:
|
||||
return False
|
||||
@ -58,7 +59,9 @@ class LocalPosterRepo(IPosterRepo):
|
||||
len1 = len(self.data.channels[channel].targets)
|
||||
return len0 != len1
|
||||
|
||||
async def get_subscribed_channels(self, target: LongTaskTarget, pager: PagerQuery) -> PagerResult[str]:
|
||||
async def get_subscribed_channels(
|
||||
self, target: LongTaskTarget, pager: PagerQuery
|
||||
) -> PagerResult[str]:
|
||||
channels: list[str] = []
|
||||
for channel_id, channel in self.data.channels.items():
|
||||
for t in channel.targets:
|
||||
@ -95,7 +98,9 @@ async def local_poster_data():
|
||||
data = PosterData()
|
||||
else:
|
||||
try:
|
||||
data = PosterData.model_validate_json(LOCAL_POSTER_DATA_PATH.read_text())
|
||||
data = PosterData.model_validate_json(
|
||||
LOCAL_POSTER_DATA_PATH.read_text()
|
||||
)
|
||||
except ValidationError:
|
||||
data = PosterData()
|
||||
yield data
|
||||
@ -109,4 +114,3 @@ async def local_poster():
|
||||
|
||||
|
||||
DepLocalPosterRepo = Annotated[LocalPosterRepo, Depends(local_poster)]
|
||||
|
||||
@ -4,9 +4,10 @@ from nonebot.params import Depends
|
||||
from nonebot_plugin_alconna import UniMessage
|
||||
from konabot.common.longtask import LongTaskTarget
|
||||
from konabot.common.pager import PagerQuery, PagerResult
|
||||
from konabot.plugins.poster.poster_info import POSTER_INFO_DATA
|
||||
from konabot.plugins.poster.repo_local_data import local_poster
|
||||
from konabot.plugins.poster.repository import IPosterRepo
|
||||
|
||||
from .subscribe_info import POSTER_INFO_DATA
|
||||
from .repo_local_data import local_poster
|
||||
from .repository import IPosterRepo
|
||||
|
||||
|
||||
class PosterService:
|
||||
@ -27,7 +28,9 @@ class PosterService:
|
||||
channel = self.parse_channel_id(channel)
|
||||
return await self.repo.remove_channel_target(channel, target)
|
||||
|
||||
async def broadcast(self, channel: str, message: UniMessage[Any] | str) -> list[LongTaskTarget]:
|
||||
async def broadcast(
|
||||
self, channel: str, message: UniMessage[Any] | str
|
||||
) -> list[LongTaskTarget]:
|
||||
channel = self.parse_channel_id(channel)
|
||||
targets = await self.repo.get_channel_targets(channel)
|
||||
for target in targets:
|
||||
@ -35,7 +38,9 @@ class PosterService:
|
||||
await target.send_message(message, at=False)
|
||||
return targets
|
||||
|
||||
async def get_channels(self, target: LongTaskTarget, pager: PagerQuery) -> PagerResult[str]:
|
||||
async def get_channels(
|
||||
self, target: LongTaskTarget, pager: PagerQuery
|
||||
) -> PagerResult[str]:
|
||||
return await self.repo.get_subscribed_channels(target, pager)
|
||||
|
||||
async def fix_data(self):
|
||||
@ -56,4 +61,3 @@ async def broadcast(channel: str, message: UniMessage[Any] | str):
|
||||
|
||||
|
||||
DepPosterService = Annotated[PosterService, Depends(dep_poster_service)]
|
||||
|
||||
@ -4,7 +4,7 @@ from dataclasses import dataclass, field
|
||||
@dataclass
|
||||
class PosterInfo:
|
||||
aliases: set[str] = field(default_factory=set)
|
||||
description: str = field(default='')
|
||||
description: str = field(default="")
|
||||
|
||||
|
||||
POSTER_INFO_DATA: dict[str, PosterInfo] = {}
|
||||
@ -12,4 +12,3 @@ POSTER_INFO_DATA: dict[str, PosterInfo] = {}
|
||||
|
||||
def register_poster_info(channel: str, info: PosterInfo):
|
||||
POSTER_INFO_DATA[channel] = info
|
||||
|
||||
212
konabot/docs/sys/konaperm.txt
Normal file
212
konabot/docs/sys/konaperm.txt
Normal file
@ -0,0 +1,212 @@
|
||||
# 指令介绍
|
||||
|
||||
`konaperm` - 用于查看和修改 Bot 内部权限系统记录的管理员指令
|
||||
|
||||
## 权限要求
|
||||
|
||||
只有拥有 `admin` 权限的主体才能使用本指令。
|
||||
|
||||
## 格式
|
||||
|
||||
```text
|
||||
konaperm list <platform> <entity_type> <external_id> [page]
|
||||
konaperm get <platform> <entity_type> <external_id> <perm>
|
||||
konaperm set <platform> <entity_type> <external_id> <perm> <val>
|
||||
```
|
||||
|
||||
## 子命令说明
|
||||
|
||||
### `list`
|
||||
|
||||
列出指定对象及其继承链上的显式权限记录,按分页输出。
|
||||
|
||||
参数:
|
||||
|
||||
- `platform` 平台名,如 `ob11`、`discord`、`sys`
|
||||
- `entity_type` 对象类型,如 `user`、`group`、`global`
|
||||
- `external_id` 平台侧对象 ID;全局对象通常写 `global`
|
||||
- `page` 页码,可省略,默认 `1`
|
||||
|
||||
### `get`
|
||||
|
||||
查询某个对象对指定权限的最终判断结果,并说明它是从哪一层继承来的。
|
||||
|
||||
参数:
|
||||
|
||||
- `platform`
|
||||
- `entity_type`
|
||||
- `external_id`
|
||||
- `perm` 权限键,如 `admin`、`plugin.xxx.use`
|
||||
|
||||
### `set`
|
||||
|
||||
为指定对象写入显式权限。
|
||||
|
||||
参数:
|
||||
|
||||
- `platform`
|
||||
- `entity_type`
|
||||
- `external_id`
|
||||
- `perm` 权限键
|
||||
- `val` 设置值
|
||||
|
||||
`val` 支持以下写法:
|
||||
|
||||
- 允许:`y` `yes` `allow` `true` `t`
|
||||
- 拒绝:`n` `no` `deny` `false` `f`
|
||||
- 清除:`null` `none`
|
||||
|
||||
其中:
|
||||
|
||||
- 允许 表示显式授予该权限
|
||||
- 拒绝 表示显式禁止该权限
|
||||
- 清除 表示删除该层的显式设置,重新回退到继承链判断
|
||||
|
||||
## 对象格式
|
||||
|
||||
本指令操作的对象由三段组成:
|
||||
|
||||
```text
|
||||
<platform>.<entity_type>.<external_id>
|
||||
```
|
||||
|
||||
例如:
|
||||
|
||||
- `ob11.user.123456789`
|
||||
- `ob11.group.987654321`
|
||||
- `sys.global.global`
|
||||
|
||||
## 当前支持的 `PermEntity` 值
|
||||
|
||||
以下内容按当前实现整理,便于手工查询和设置权限。
|
||||
|
||||
### `sys`
|
||||
|
||||
- `sys.global.global`
|
||||
|
||||
这是系统总兜底对象。
|
||||
|
||||
### `ob11`
|
||||
|
||||
- `ob11.global.global`
|
||||
- `ob11.group.<group_id>`
|
||||
- `ob11.user.<user_id>`
|
||||
|
||||
常见场景:
|
||||
|
||||
- 给整个 OneBot V11 平台统一授权:`ob11.global.global`
|
||||
- 给某个 QQ 群授权:`ob11.group.群号`
|
||||
- 给某个 QQ 用户授权:`ob11.user.QQ号`
|
||||
|
||||
### `discord`
|
||||
|
||||
- `discord.global.global`
|
||||
- `discord.guild.<guild_id>`
|
||||
- `discord.channel.<channel_id>`
|
||||
- `discord.user.<user_id>`
|
||||
|
||||
常见场景:
|
||||
|
||||
- 给整个 Discord 平台统一授权:`discord.global.global`
|
||||
- 给某个服务器授权:`discord.guild.服务器ID`
|
||||
- 给某个频道授权:`discord.channel.频道ID`
|
||||
- 给某个用户授权:`discord.user.用户ID`
|
||||
|
||||
### `minecraft`
|
||||
|
||||
- `minecraft.global.global`
|
||||
- `minecraft.server.<server_name>`
|
||||
- `minecraft.player.<player_uuid_hex>`
|
||||
|
||||
常见场景:
|
||||
|
||||
- 给整个 Minecraft 平台统一授权:`minecraft.global.global`
|
||||
- 给某个服务器授权:`minecraft.server.服务器名`
|
||||
- 给某个玩家授权:`minecraft.player.玩家UUID的hex`
|
||||
|
||||
### `console`
|
||||
|
||||
- `console.global.global`
|
||||
- `console.channel.<channel_id>`
|
||||
- `console.user.<user_id>`
|
||||
|
||||
### 快速参考
|
||||
|
||||
```text
|
||||
sys.global.global
|
||||
|
||||
ob11.global.global
|
||||
ob11.group.<group_id>
|
||||
ob11.user.<user_id>
|
||||
|
||||
discord.global.global
|
||||
discord.guild.<guild_id>
|
||||
discord.channel.<channel_id>
|
||||
discord.user.<user_id>
|
||||
|
||||
minecraft.global.global
|
||||
minecraft.server.<server_name>
|
||||
minecraft.player.<player_uuid_hex>
|
||||
|
||||
console.global.global
|
||||
console.channel.<channel_id>
|
||||
console.user.<user_id>
|
||||
```
|
||||
|
||||
## 权限继承
|
||||
|
||||
权限不是只看当前对象,还会按继承链回退。
|
||||
|
||||
例如对 `ob11.user.123456` 查询时,通常会从更具体的对象一路回退到:
|
||||
|
||||
1. 当前用户
|
||||
2. 平台全局对象
|
||||
3. 系统全局对象
|
||||
|
||||
权限键本身也支持逐级回退。比如查询 `plugin.demo.use` 时,可能依次命中:
|
||||
|
||||
1. `plugin.demo.use`
|
||||
2. `plugin.demo`
|
||||
3. `plugin`
|
||||
4. `*`
|
||||
|
||||
所以 `get` 返回的结果可能来自更宽泛的权限键,或更上层的继承对象。
|
||||
|
||||
## 示例
|
||||
|
||||
```text
|
||||
konaperm list ob11 user 123456
|
||||
```
|
||||
|
||||
查看 `ob11.user.123456` 及其继承链上的权限记录第一页。
|
||||
|
||||
```text
|
||||
konaperm get ob11 user 123456 admin
|
||||
```
|
||||
|
||||
查看该用户最终是否拥有 `admin` 权限,以及命中来源。
|
||||
|
||||
```text
|
||||
konaperm set ob11 user 123456 admin allow
|
||||
```
|
||||
|
||||
显式授予该用户 `admin` 权限。
|
||||
|
||||
```text
|
||||
konaperm set ob11 user 123456 admin deny
|
||||
```
|
||||
|
||||
显式拒绝该用户 `admin` 权限。
|
||||
|
||||
```text
|
||||
konaperm set ob11 user 123456 admin none
|
||||
```
|
||||
|
||||
删除该用户这一层对 `admin` 的显式设置,恢复继承判断。
|
||||
|
||||
## 注意事项
|
||||
|
||||
- 这是系统级管理指令,误操作可能直接影响其他插件的权限控制。
|
||||
- `list` 只列出显式记录;没有显示出来不代表最终一定无权限,可能是从上层继承。
|
||||
- `get` 显示的是最终命中的结果,比 `list` 更适合排查“为什么有/没有某个权限”。
|
||||
- 对 `admin` 或 `*` 这类高影响权限做修改前,建议先确认对象是否写对。
|
||||
4
konabot/docs/sys/宾几人.txt
Normal file
4
konabot/docs/sys/宾几人.txt
Normal file
@ -0,0 +1,4 @@
|
||||
# 宾几人
|
||||
|
||||
查询 Bingo 有几个人。直接发送给 Bot 即可。
|
||||
|
||||
38
konabot/docs/user/celeste.txt
Normal file
38
konabot/docs/user/celeste.txt
Normal file
@ -0,0 +1,38 @@
|
||||
# Celeste
|
||||
|
||||
爬山小游戏,移植自 Ccleste,是 Celeste Classic(即 PICO-8)版。
|
||||
|
||||
使用 `wasdxc` 和数字进行操作。
|
||||
|
||||
## 操作说明
|
||||
|
||||
`wsad` 是上下左右摇杆方向,或者是方向键。`c` 是跳跃键,`x` 是冲刺键。
|
||||
|
||||
使用空格分隔每一个操作,每个操作持续一帧。如果后面跟着数字,则持续那么多帧。
|
||||
|
||||
### 例子 1
|
||||
|
||||
```
|
||||
xc 180
|
||||
```
|
||||
|
||||
按下 xc 一帧,然后空置 180 帧。
|
||||
|
||||
### 例子 2
|
||||
|
||||
```
|
||||
d10 cd d10 xdw d20
|
||||
```
|
||||
|
||||
向右走 10 帧,向右跳一帧,再继续按下右 10 帧,按下向右上冲刺一帧,再按下右 20 帧。
|
||||
|
||||
## 指令使用说明
|
||||
|
||||
直接说 `celeste` 会开启一个新的游戏。但是,你需要在后面跟有操作,才能够渲染 gif 图出来。
|
||||
|
||||
一个常见的开始操作是直接发送 `celeste xc 130`,即按下 xc 两个按键触发 PICO 版的开始游戏,然后等待 130 秒动画播放完毕。
|
||||
|
||||
对于一个已经存在而且时间不是非常久远的 gif 图,只要是由 bot 自己发送出来的,就可以在它的基础上继续游戏。回复这条消息,可以继续游戏。
|
||||
|
||||
一种很常见的技巧是回复一个已经存在的 gif 图 `celeste 1`,此时会空操作一帧并且渲染画面。你可以用这种方法查看一个 gif 图的游戏目前的状态。
|
||||
|
||||
109
konabot/docs/user/fx.txt
Normal file
109
konabot/docs/user/fx.txt
Normal file
@ -0,0 +1,109 @@
|
||||
## 指令介绍
|
||||
|
||||
`fx` - 用于对图片应用各种滤镜效果的指令
|
||||
|
||||
## 格式
|
||||
|
||||
```
|
||||
fx [滤镜名称] <参数1> <参数2> ...
|
||||
```
|
||||
|
||||
## 示例
|
||||
|
||||
- `fx 模糊`
|
||||
- `fx 阈值 150`
|
||||
- `fx 缩放 2.0`
|
||||
- `fx 色彩 1.8`
|
||||
- `fx 色键 rgb(0,255,0) 50`
|
||||
|
||||
## 可用滤镜列表
|
||||
|
||||
### 基础滤镜
|
||||
* ```fx 轮廓```
|
||||
* ```fx 锐化```
|
||||
* ```fx 边缘增强```
|
||||
* ```fx 浮雕```
|
||||
* ```fx 查找边缘```
|
||||
* ```fx 平滑```
|
||||
* ```fx 暗角 <半径=1.5>```
|
||||
* ```fx 发光 <强度=0.5> <模糊半径=15>```
|
||||
* ```fx 噪点 <数量=0.05>```
|
||||
* ```fx 素描```
|
||||
* ```fx 阴影 <偏移量X=10> <偏移量Y=10> <模糊量=10> <不透明度=0.5> <阴影颜色=black>```
|
||||
|
||||
### 模糊滤镜
|
||||
* ```fx 模糊 <半径=10>```
|
||||
* ```fx 马赛克 <像素大小=10>```
|
||||
* ```fx 径向模糊 <强度=3.0> <采样量=6>```
|
||||
* ```fx 旋转模糊 <强度=30.0> <采样量=6>```
|
||||
* ```fx 方向模糊 <角度=0.0> <距离=20> <采样量=6>```
|
||||
* ```fx 缩放模糊 <强度=0.1> <采样量=6>```
|
||||
* ```fx 边缘模糊 <半径=10.0>```
|
||||
|
||||
### 色彩处理滤镜
|
||||
* ```fx 反色```
|
||||
* ```fx 黑白```
|
||||
* ```fx 阈值 <阈值=128>```
|
||||
* ```fx 对比度 <因子=1.5>```
|
||||
* ```fx 亮度 <因子=1.5>```
|
||||
* ```fx 色彩 <因子=1.5>```
|
||||
* ```fx 色调 <颜色="rgb(255,0,0)">```
|
||||
* ```fx RGB分离 <偏移量=5>```
|
||||
* ```fx 叠加颜色 <颜色列表=[rgb(255,0,0)|(0,0)+rgb(0,255,0)|(0,100)+rgb(0,0,255)|(50,100)]> <叠加模式=overlay>```
|
||||
* ```fx 像素抖动 <最大偏移量=2>```
|
||||
* ```fx 半调 <半径=5>```
|
||||
* ```fx 描边 <半径=5> <颜色=black>```
|
||||
* ```fx 形状描边 <半径=5> <颜色=black> <粗糙度=None>```
|
||||
|
||||
### 几何变换滤镜
|
||||
* ```fx 平移 <x偏移量=10> <y偏移量=10>```
|
||||
* ```fx 缩放 <比例(X)=1.5> <比例Y=None>```
|
||||
* ```fx 旋转 <角度=45>```
|
||||
* ```fx 透视变换 <变换矩阵>```
|
||||
* ```fx 裁剪 <左=0> <上=0> <右=100> <下=100>(百分比)```
|
||||
* ```fx 拓展边缘 <拓展量=10>```
|
||||
* ```fx 波纹 <振幅=5> <波长=20>```
|
||||
* ```fx 光学补偿 <数量=100> <反转=false>```
|
||||
* ```fx 球面化 <强度=0.5>```
|
||||
* ```fx 镜像 <角度=90>```
|
||||
* ```fx 水平翻转```
|
||||
* ```fx 垂直翻转```
|
||||
* ```fx 复制 <目标位置=(100,100)> <缩放=1.0> <源区域=(0,0,100,100)>(百分比)```
|
||||
|
||||
### 特殊效果滤镜
|
||||
* ```fx 设置通道 <通道=A>```
|
||||
* 可用 R、G、B、A。
|
||||
* ```fx 设置遮罩```
|
||||
* ```fx 色键 <目标颜色="rgb(255,0,0)"> <容差=60>```
|
||||
* ```fx 晃动 <最大偏移量=5> <运动模糊=False>```
|
||||
* ```fx JPEG损坏 <质量=10>```
|
||||
* 质量范围建议为 1~95,数值越低,压缩痕迹越重、效果越搞笑。
|
||||
* ```fx 动图 <帧率=10>```
|
||||
|
||||
### 多图像处理器
|
||||
* ```fx 存入图像 <目标名称>```
|
||||
* 目标名称是图像的代名词,图像最长可存 12 小时,如果公用容量满了图像也会被删除。
|
||||
* 该项仅可于首项使用。
|
||||
* ```fx 读取图像 <目标名称>```
|
||||
* 该项仅可于首项使用。
|
||||
* ```fx 暂存图像```
|
||||
* 此项默认插入存储在暂存列表中第一张图像的后面。
|
||||
* ```fx 交换图像 <交换项=2> <交换项=1>```
|
||||
* ```fx 删除图像 <删除索引=1>```
|
||||
* ```fx 选择图像 <目标索引=2>```
|
||||
|
||||
### 多图像混合
|
||||
* ```fx 混合图像 <模式=normal> <alpha=0.5>```
|
||||
* ```fx 覆盖图像```
|
||||
|
||||
### 生成类
|
||||
* ```fx 覆加颜色 <颜色列表=[rgb(255,0,0)|(0,0)+rgb(0,255,0)|(0,100)+rgb(0,0,255)|(50,100)]>```
|
||||
* ```fx 生成图层 <宽度=512> <高度=512>```
|
||||
* ```fx 生成文本 <文本内容=请输入文本> <字体大小=32> <文字颜色=black> <字体文件=HarmonyOS_Sans_SC_Regular.ttf>```
|
||||
|
||||
## 颜色名称支持
|
||||
- **格式**:颜色列表采用 ```[颜色|位置+颜色|位置+颜色|位置]``` 的格式,位置是形如```(x百分比,y百分比)```的元组。
|
||||
- **基本颜色**:红、绿、蓝、黄、紫、黑、白、橙、粉、灰、青、靛、棕
|
||||
- **修饰词**:浅、深、亮、暗(可组合使用,如`浅红`、`深蓝`)
|
||||
- **RGB格式**:`rgb(255,0,0)`、`rgb(0,255,0)`、`(255,0,0)` 等
|
||||
- **HEX格式**:`#66ccff`等
|
||||
@ -71,6 +71,14 @@ giftool [图片] [选项]
|
||||
|
||||
- 调整 GIF 图的速度。若为负数,则代表倒放。
|
||||
|
||||
### `--pingpong`(可选)
|
||||
|
||||
- 开启乒乓模式,生成正放-倒放拼接的 GIF 图。
|
||||
- 即播放完正向后,会倒放回去,形成往复循环效果。
|
||||
- 可与 `--speed` 配合使用,调整播放速度。
|
||||
- 示例:`giftool [图片] --pingpong`
|
||||
- 示例:`giftool [图片] --pingpong --speed 2.0`
|
||||
|
||||
## 使用方式
|
||||
|
||||
1. 发送指令前,请确保:
|
||||
|
||||
10
konabot/docs/user/k8x12S.txt
Normal file
10
konabot/docs/user/k8x12S.txt
Normal file
@ -0,0 +1,10 @@
|
||||
# 指令介绍
|
||||
|
||||
根据文字生成 k8x12S
|
||||
|
||||
> 「现在还不知道k8x12S是什么的可以开除界隈籍了」—— Louis, 2025/12/31
|
||||
|
||||
## 使用指南
|
||||
|
||||
`k8x12S 安心をしてください`
|
||||
|
||||
@ -1,10 +1,10 @@
|
||||
## 指令介绍
|
||||
**`ntfy`** - 配置使用 [ntfy](https://ntfy.sh/) 来更好地为你通知此方 BOT 的代办事项。
|
||||
**`ntfy`** - 配置使用 [ntfy](https://ntfy.sh/) 来更好地为你通知此方 BOT 的待办事项。
|
||||
|
||||
## 指令示例
|
||||
|
||||
- **`ntfy 创建`**
|
||||
创建一个随机的 ntfy 订阅主题来提醒代办。此方 Bot 将会给你使用指引。你可以前往 [https://ntfy.sh/](https://ntfy.sh/) 官网下载 ntfy APP,或者使用网页版 ntfy。
|
||||
创建一个随机的 ntfy 订阅主题来提醒待办。此方 Bot 将会给你使用指引。你可以前往 [https://ntfy.sh/](https://ntfy.sh/) 官网下载 ntfy APP,或者使用网页版 ntfy。
|
||||
|
||||
- **`ntfy 创建 kagami-notice`**
|
||||
创建一个名称包含 `kagami-notice` 的 ntfy 订阅主题。
|
||||
|
||||
53
konabot/docs/user/roll.txt
Normal file
53
konabot/docs/user/roll.txt
Normal file
@ -0,0 +1,53 @@
|
||||
**roll** - 面向跑团的文本骰子指令
|
||||
|
||||
## 用法
|
||||
|
||||
`roll 表达式`
|
||||
|
||||
支持常见骰子写法:
|
||||
|
||||
- `roll 3d6`
|
||||
- `roll d20+5`
|
||||
- `roll 2d8+1d4+3`
|
||||
- `roll d%`
|
||||
- `roll 4dF`
|
||||
|
||||
## 说明
|
||||
|
||||
- `NdM` 表示掷 N 个 M 面骰,例如 `3d6`
|
||||
- `d20` 等价于 `1d20`
|
||||
- `d%` 表示百分骰,范围 1 到 100
|
||||
- `dF` 表示 Fate/Fudge 骰,单骰结果为 -1、0、+1
|
||||
- 支持用 `+`、`-` 连接多个项,也支持常数修正
|
||||
|
||||
## 返回格式
|
||||
|
||||
会返回总结果,以及每一项的明细。
|
||||
|
||||
例如:
|
||||
|
||||
- `roll 3d6`
|
||||
可能返回:
|
||||
- `3d6 = 11`
|
||||
- `+3d6=[2, 4, 5]`
|
||||
|
||||
- `roll d20+5`
|
||||
可能返回:
|
||||
- `d20+5 = 19`
|
||||
- `+1d20=[14] +5=5`
|
||||
|
||||
## 限制
|
||||
|
||||
为防止刷屏和滥用,当前实现会限制:
|
||||
|
||||
- 单项最多 100 个骰子
|
||||
- 单个骰子最多 1000 面
|
||||
- 一次表达式最多 20 项
|
||||
- 一次表达式最多实际掷 200 个骰子
|
||||
- 结果过长时会直接拒绝
|
||||
|
||||
## 权限
|
||||
|
||||
需要 `trpg.roll` 权限。
|
||||
|
||||
默认启动时会给系统全局授予允许,因此通常所有人都能用;如有需要可再用权限系统单独关闭。
|
||||
258
konabot/docs/user/textfx.txt
Normal file
258
konabot/docs/user/textfx.txt
Normal file
@ -0,0 +1,258 @@
|
||||
# 文字处理机器人使用手册(小白友好版)
|
||||
|
||||
欢迎使用文字处理机器人!你不需要懂编程,只要会打字,就能用它完成各种文字操作——比如加密、解密、打乱顺序、进制转换、排版整理等。
|
||||
|
||||
---
|
||||
|
||||
## 一、基础演示
|
||||
|
||||
在 QQ 群里这样使用:
|
||||
|
||||
1. **直接输入命令**(适合短文本):
|
||||
```
|
||||
/textfx reverse 你好世界
|
||||
```
|
||||
→ 机器人回复:`界世好你`
|
||||
|
||||
2. **先发一段文字,再用命令处理它**(适合长文本):
|
||||
- 先发送:`Hello, World!`
|
||||
- 回复这条消息,输入:
|
||||
```
|
||||
/textfx b64 encode
|
||||
```
|
||||
→ 机器人返回:`SGVsbG8sIFdvcmxkIQ==`
|
||||
|
||||
> 命令可写为 `/textfx`、`/处理文字` 或 `/处理文本`。
|
||||
> 若不回复消息,命令会处理当前行后面的文本。
|
||||
|
||||
---
|
||||
|
||||
## 二、流水线语法(超简单)
|
||||
|
||||
- 用 `|` 连接多个操作,前一个的输出自动作为后一个的输入。
|
||||
- 用 `;` 分隔多条独立指令,它们各自产生输出,最终合并显示。
|
||||
- 用 `&&` / `||` 做最小 shell 风格条件执行:
|
||||
- `cmd1 && cmd2`:仅当 `cmd1` 成功时执行 `cmd2`
|
||||
- `cmd1 || cmd2`:仅当 `cmd1` 失败时执行 `cmd2`
|
||||
- 用 `!` 对一条 pipeline 的成功/失败取反。
|
||||
- 支持最小 bash-like `if ... then ... else ... fi` 语句。
|
||||
- 支持最小 bash-like `while ... do ... done` 循环。
|
||||
- 可使用内建真假命令:`true` / `false`。
|
||||
- 为避免滥用与卡死:
|
||||
- 同一用户同时只能运行 **一个** textfx 脚本
|
||||
- 单个脚本最长执行时间为 **60 秒**
|
||||
|
||||
**例子**:把"HELLO"先反转,再转成摩斯电码:(转换为摩斯电码功能暂未实现)
|
||||
```
|
||||
textfx reverse HELLO | morse en
|
||||
```
|
||||
→ 输出:`--- .-.. .-.. . ....`
|
||||
|
||||
**例子**:失败后兜底执行:
|
||||
```
|
||||
textfx test a = b || echo 不相等
|
||||
```
|
||||
→ 输出:`不相等`
|
||||
|
||||
**例子**:成功后继续执行:
|
||||
```
|
||||
textfx [ 2 -gt 1 ] && echo 条件成立
|
||||
```
|
||||
→ 输出:`条件成立`
|
||||
|
||||
**例子**:真正的 if 语句:
|
||||
```
|
||||
textfx if test a = b; then echo yes; else echo no; fi
|
||||
```
|
||||
→ 输出:`no`
|
||||
|
||||
**例子**:对条件取反:
|
||||
```
|
||||
textfx ! test a = b && echo 条件不成立
|
||||
```
|
||||
→ 输出:`条件不成立`
|
||||
|
||||
**例子**:while 循环:
|
||||
```
|
||||
textfx while false; do echo 不会执行; done
|
||||
```
|
||||
→ 输出为空
|
||||
|
||||
**例子**:多条指令各自输出:
|
||||
```
|
||||
textfx echo 你好; echo 世界
|
||||
```
|
||||
→ 输出:
|
||||
```
|
||||
你好
|
||||
世界
|
||||
```
|
||||
|
||||
**例子**:重定向的指令不输出,其余正常输出:
|
||||
```
|
||||
textfx echo 1; echo 2 > a; echo 3
|
||||
```
|
||||
→ 输出:
|
||||
```
|
||||
1
|
||||
3
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 三、功能清单(含示例)
|
||||
|
||||
### reverse(或 rev、反转)
|
||||
反转文字。
|
||||
示例:`/textfx reverse 爱你一万年` → `年万一你爱`
|
||||
|
||||
### b64(或 base64)
|
||||
Base64 编码或解码。
|
||||
示例:`/textfx b64 encode 你好` → `5L2g5aW9`
|
||||
示例:`/textfx b64 decode 5L2g5aW9` → `你好`
|
||||
|
||||
### caesar(或 凯撒、rot)
|
||||
凯撒密码(仅对英文字母有效)。
|
||||
示例:`/textfx caesar 3 ABC` → `DEF`
|
||||
示例:`/textfx caesar -3 DEF` → `ABC`
|
||||
|
||||
### morse(或 摩斯)
|
||||
将摩斯电码解码为文字(支持英文和日文)。字符间用空格,单词间用 `/`。
|
||||
示例:`/textfx morse en .... . .-.. .-.. ---` → `HELLO`
|
||||
示例:`/textfx morse jp -... --.-- -.. --.. ..- ..` → `ハアホフウイ`
|
||||
|
||||
### baseconv(或 进制转换)
|
||||
在不同进制之间转换数字。
|
||||
示例:`/textfx baseconv 2 10 1101` → `13`
|
||||
示例:`/textfx baseconv 10 16 255` → `FF`
|
||||
|
||||
### shuffle(或 打乱)
|
||||
随机打乱文字顺序。
|
||||
示例:`/textfx shuffle abcdef` → `fcbade`(每次结果不同)
|
||||
|
||||
### sort(或 排序)
|
||||
将文字按字符顺序排列。
|
||||
示例:`/textfx sort dbca` → `abcd`
|
||||
|
||||
### b64hex
|
||||
在 Base64 和十六进制之间互转。
|
||||
示例:`/textfx b64hex dec SGVsbG8=` → `48656c6c6f`
|
||||
示例:`/textfx b64hex enc 48656c6c6f` → `SGVsbG8=`
|
||||
|
||||
### align(或 format、排版)
|
||||
按指定格式分组排版文字。
|
||||
示例:`/textfx align 2 4 0123456789abcdef` →
|
||||
```
|
||||
01 23 45 67
|
||||
89 ab cd ef
|
||||
```
|
||||
|
||||
### echo
|
||||
输出指定文字。
|
||||
示例:`/textfx echo 你好` → `你好`
|
||||
|
||||
### cat
|
||||
读取并拼接缓存内容,类似 Unix cat 命令。
|
||||
- 无参数时直接传递标准输入(管道输入或回复的消息)。
|
||||
- 使用 `-` 代表标准输入,可与缓存名混合使用。
|
||||
- 支持多个参数,按顺序拼接输出。
|
||||
|
||||
示例:
|
||||
- 传递输入:`/textfx echo 你好 | cat` → `你好`
|
||||
- 读取缓存:`/textfx cat mytext` → 输出 mytext 的内容
|
||||
- 拼接多个缓存:`/textfx cat a b c` → 依次拼接缓存 a、b、c
|
||||
- 混合标准输入和缓存:`/textfx echo 前缀 | cat - mytext` → 拼接标准输入与缓存 mytext
|
||||
|
||||
### 缓存操作(保存中间结果)
|
||||
- 保存:`/textfx reverse 你好 > mytext`(不输出,存入 mytext)
|
||||
- 读取:`/textfx cat mytext` → `好你`
|
||||
- 追加:`/textfx echo world >> mytext`
|
||||
- 删除:`/textfx rm mytext`
|
||||
|
||||
> 缓存仅在当前对话中有效,重启后清空。
|
||||
|
||||
### true / false / test / [
|
||||
最小 shell 风格条件命令。通常配合 `if`、`&&`、`||`、`!` 使用。
|
||||
|
||||
支持:
|
||||
- `true`:总是成功
|
||||
- `false`:总是失败
|
||||
- 字符串非空:`test foo`
|
||||
- `-n` / `-z`:`test -n foo`、`test -z ""`
|
||||
- 字符串比较:`test a = a`、`test a != b`
|
||||
- 整数比较:`test 2 -gt 1`、`test 3 -le 5`
|
||||
- 方括号别名:`[ 2 -gt 1 ]`
|
||||
|
||||
示例:
|
||||
- `/textfx true && echo 一定执行`
|
||||
- `/textfx false || echo 兜底执行`
|
||||
- `/textfx test hello && echo 有内容`
|
||||
- `/textfx test a = b || echo 不相等`
|
||||
- `/textfx [ 3 -ge 2 ] && echo yes`
|
||||
|
||||
### if / then / else / fi
|
||||
支持最小 bash-like 条件语句。
|
||||
|
||||
示例:
|
||||
- `/textfx if test a = a; then echo yes; else echo no; fi`
|
||||
- `/textfx if [ 2 -gt 1 ]; then echo 成立; fi`
|
||||
- `/textfx if test a = a; then if test b = c; then echo x; else echo y; fi; fi`
|
||||
|
||||
说明:
|
||||
- `if` 后面跟一个条件链,可配合 `test`、`[`、`!`、`&&`、`||`
|
||||
- `then` 和 `else` 后面都可以写多条以 `;` 分隔的 textfx 语句
|
||||
- `else` 可省略
|
||||
|
||||
### while / do / done
|
||||
支持最小 bash-like 循环语句。
|
||||
|
||||
示例:
|
||||
- `/textfx while false; do echo 不会执行; done`
|
||||
- `/textfx while ! false; do false; done`
|
||||
- `/textfx while ! false; do if true; then false; fi; done`
|
||||
|
||||
说明:
|
||||
- `while` 后面跟一个条件链,返回成功就继续循环
|
||||
- `do` 后面可写多条以 `;` 分隔的 textfx 语句
|
||||
- 为避免 bot 死循环,内置最大循环次数限制;超限会报错
|
||||
|
||||
### replace(或 替换、sed)
|
||||
替换文字(支持正则表达式)。
|
||||
示例(普通):`/textfx replace 世界 宇宙 你好世界` → `你好宇宙`
|
||||
示例(正则):`/textfx replace \d+ [数字] 我有123个苹果` → `我有[数字]个苹果`
|
||||
|
||||
### trim(或 strip、去空格)
|
||||
去除文本首尾空白字符。
|
||||
示例:`/textfx trim " 你好 "` → `你好`
|
||||
示例:`/textfx echo " hello " | trim` → `hello`
|
||||
|
||||
### ltrim(或 lstrip)
|
||||
去除文本左侧空白字符。
|
||||
示例:`/textfx ltrim " 你好 "` → `你好 `
|
||||
|
||||
### rtrim(或 rstrip)
|
||||
去除文本右侧空白字符。
|
||||
示例:`/textfx rtrim " 你好 "` → ` 你好`
|
||||
|
||||
### squeeze(或 压缩空白)
|
||||
将连续的空白字符(空格、制表符)压缩为单个空格。
|
||||
示例:`/textfx squeeze "你好 世界"` → `你好 世界`
|
||||
|
||||
### lines(或 行处理)
|
||||
按行处理文本,支持以下子命令:
|
||||
- `lines trim` — 去除每行首尾空白
|
||||
- `lines empty` — 去除所有空行
|
||||
- `lines squeeze` — 将连续空行压缩为一行
|
||||
|
||||
示例:`/textfx echo " hello\n\n\n world " | lines trim` → `hello\n\n\n world`
|
||||
示例:`/textfx echo "a\n\n\nb" | lines squeeze` → `a\n\nb`
|
||||
|
||||
---
|
||||
|
||||
## 常见问题
|
||||
|
||||
- **没反应?** 可能内容被安全系统拦截,机器人会提示“内容被拦截”。
|
||||
- **只支持纯文字**,暂不支持图片或文件。
|
||||
- 命令拼错时,机器人会提示“不存在名为 xxx 的函数”,请检查名称。
|
||||
|
||||
快去试试吧!用法核心:**`/textfx` + 你的操作**
|
||||
24
konabot/docs/user/tqszm.txt
Normal file
24
konabot/docs/user/tqszm.txt
Normal file
@ -0,0 +1,24 @@
|
||||
# tqszm
|
||||
|
||||
引用一条消息,让此方帮你提取首字母。
|
||||
|
||||
例子:
|
||||
|
||||
```
|
||||
John: 11-28 16:50:37
|
||||
谁来总结一下今天的工作?
|
||||
|
||||
Jack: 11-28 16:50:55
|
||||
[引用John的消息] @此方Bot tqszm
|
||||
|
||||
此方Bot: 11-28 16:50:56
|
||||
slzjyxjtdgz?
|
||||
```
|
||||
|
||||
或者,你也可以直接以正常指令的方式调用:
|
||||
|
||||
```
|
||||
@此方Bot 提取首字母 中山大学
|
||||
> zsdx
|
||||
```
|
||||
|
||||
4
konabot/docs/user/typst.txt
Normal file
4
konabot/docs/user/typst.txt
Normal file
@ -0,0 +1,4 @@
|
||||
# Typst 渲染
|
||||
|
||||
只需使用 `typst ...` 就可以渲染 Typst 了
|
||||
|
||||
@ -1,7 +0,0 @@
|
||||
## 指令介绍
|
||||
|
||||
**黑白** - 将图片经过一个黑白滤镜的处理
|
||||
|
||||
## 示例
|
||||
|
||||
引用一个带有图片的消息,或者消息本身携带图片,然后发送「黑白」即可
|
||||
52
konabot/plugins/ai_extract_text/__init__.py
Normal file
52
konabot/plugins/ai_extract_text/__init__.py
Normal file
@ -0,0 +1,52 @@
|
||||
from io import BytesIO
|
||||
import base64
|
||||
import re
|
||||
from loguru import logger
|
||||
from nonebot import on_message
|
||||
from nonebot.rule import Rule
|
||||
|
||||
from konabot.common.apis.ali_content_safety import AlibabaGreen
|
||||
from konabot.common.llm import get_llm
|
||||
from konabot.common.longtask import DepLongTaskTarget
|
||||
from konabot.common.nb.extract_image import DepPILImage
|
||||
from konabot.common.nb.match_keyword import match_keyword
|
||||
|
||||
|
||||
cmd = on_message(rule=Rule(match_keyword(re.compile(r"^千问识图\s*$"))))
|
||||
|
||||
|
||||
@cmd.handle()
|
||||
async def _(img: DepPILImage, target: DepLongTaskTarget):
|
||||
if 1:
|
||||
return #TODO:这里还没写完,还有 Bug 要修
|
||||
jpeg_data = BytesIO()
|
||||
if img.width > 2160:
|
||||
img = img.resize((2160, img.height * 2160 // img.width))
|
||||
if img.height > 2160:
|
||||
img = img.resize((img.width * 2160 // img.height, 2160))
|
||||
img = img.convert("RGB")
|
||||
img.save(jpeg_data, format="jpeg", optimize=True, quality=85)
|
||||
data_url = "data:image/jpeg;base64,"
|
||||
data_url += base64.b64encode(jpeg_data.getvalue()).decode('ascii')
|
||||
|
||||
llm = get_llm("qwen3-vl-plus")
|
||||
res = await llm.chat([
|
||||
{ "role": "user", "content": [
|
||||
{ "type": "image_url", "image_url": {
|
||||
"url": data_url
|
||||
} },
|
||||
{ "type": "text", "text": "请你提取这张图片中的所有文字,并尽量按照原图的排版输出,不需要其他内容" },
|
||||
] }
|
||||
])
|
||||
result = res.content
|
||||
logger.info(res)
|
||||
if result is None:
|
||||
await target.send_message("提取失败:可能存在网络异常")
|
||||
return
|
||||
|
||||
if not await AlibabaGreen.detect(result):
|
||||
await target.send_message("提取失败:图片中可能存在一些不合适的内容")
|
||||
return
|
||||
|
||||
await target.send_message(result, at=False)
|
||||
|
||||
@ -1,22 +1,29 @@
|
||||
from io import BytesIO
|
||||
from typing import Optional, Union
|
||||
import cv2
|
||||
import nonebot
|
||||
from nonebot.adapters import Event as BaseEvent
|
||||
from nonebot.adapters.console.event import MessageEvent as ConsoleMessageEvent
|
||||
from nonebot.adapters.discord.event import MessageEvent as DiscordMessageEvent
|
||||
from nonebot_plugin_alconna import Alconna, AlconnaMatcher, Args, UniMessage, on_alconna
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
from konabot.common.database import DatabaseManager
|
||||
from konabot.common.longtask import DepLongTaskTarget
|
||||
from konabot.common.path import ASSETS_PATH
|
||||
from konabot.common.web_render import WebRenderer
|
||||
from konabot.plugins.air_conditioner.ac import AirConditioner, CrashType, generate_ac_image, wiggle_transform
|
||||
|
||||
from pathlib import Path
|
||||
import random
|
||||
import math
|
||||
|
||||
def get_ac(id: str) -> AirConditioner:
|
||||
ac = AirConditioner.air_conditioners.get(id)
|
||||
ROOT_PATH = Path(__file__).resolve().parent
|
||||
|
||||
# 创建全局数据库管理器实例
|
||||
db_manager = DatabaseManager()
|
||||
|
||||
async def get_ac(id: str) -> AirConditioner:
|
||||
ac = await AirConditioner.get_ac(id)
|
||||
if ac is None:
|
||||
ac = AirConditioner(id)
|
||||
return ac
|
||||
@ -43,14 +50,32 @@ async def send_ac_image(event: type[AlconnaMatcher], ac: AirConditioner):
|
||||
ac_image = await generate_ac_image(ac)
|
||||
await event.send(await UniMessage().image(raw=ac_image).export())
|
||||
|
||||
|
||||
driver = nonebot.get_driver()
|
||||
|
||||
|
||||
@driver.on_startup
|
||||
async def register_startup_hook():
|
||||
"""注册启动时需要执行的函数"""
|
||||
# 初始化数据库表
|
||||
await db_manager.execute_by_sql_file(
|
||||
Path(__file__).resolve().parent / "sql" / "create_table.sql"
|
||||
)
|
||||
|
||||
@driver.on_shutdown
|
||||
async def register_shutdown_hook():
|
||||
"""注册关闭时需要执行的函数"""
|
||||
# 关闭所有数据库连接
|
||||
await db_manager.close_all_connections()
|
||||
|
||||
evt = on_alconna(Alconna(
|
||||
"群空调"
|
||||
), use_cmd_start=True, use_cmd_sep=False, skip_for_unmatch=True)
|
||||
|
||||
@evt.handle()
|
||||
async def _(event: BaseEvent, target: DepLongTaskTarget):
|
||||
async def _(target: DepLongTaskTarget):
|
||||
id = target.channel_id
|
||||
ac = get_ac(id)
|
||||
ac = await get_ac(id)
|
||||
await send_ac_image(evt, ac)
|
||||
|
||||
evt = on_alconna(Alconna(
|
||||
@ -58,10 +83,10 @@ evt = on_alconna(Alconna(
|
||||
), use_cmd_start=True, use_cmd_sep=False, skip_for_unmatch=True)
|
||||
|
||||
@evt.handle()
|
||||
async def _(event: BaseEvent, target: DepLongTaskTarget):
|
||||
async def _(target: DepLongTaskTarget):
|
||||
id = target.channel_id
|
||||
ac = get_ac(id)
|
||||
ac.on = True
|
||||
ac = await get_ac(id)
|
||||
await ac.update_ac(state=True)
|
||||
await send_ac_image(evt, ac)
|
||||
|
||||
evt = on_alconna(Alconna(
|
||||
@ -69,10 +94,10 @@ evt = on_alconna(Alconna(
|
||||
), use_cmd_start=True, use_cmd_sep=False, skip_for_unmatch=True)
|
||||
|
||||
@evt.handle()
|
||||
async def _(event: BaseEvent, target: DepLongTaskTarget):
|
||||
async def _(target: DepLongTaskTarget):
|
||||
id = target.channel_id
|
||||
ac = get_ac(id)
|
||||
ac.on = False
|
||||
ac = await get_ac(id)
|
||||
await ac.update_ac(state=False)
|
||||
await send_ac_image(evt, ac)
|
||||
|
||||
evt = on_alconna(Alconna(
|
||||
@ -81,31 +106,29 @@ evt = on_alconna(Alconna(
|
||||
), use_cmd_start=True, use_cmd_sep=False, skip_for_unmatch=True)
|
||||
|
||||
@evt.handle()
|
||||
async def _(event: BaseEvent, target: DepLongTaskTarget, temp: Optional[Union[int, float]] = 1):
|
||||
async def _(target: DepLongTaskTarget, temp: Optional[Union[int, float]] = 1):
|
||||
if temp is None:
|
||||
temp = 1
|
||||
if temp <= 0:
|
||||
return
|
||||
id = target.channel_id
|
||||
ac = get_ac(id)
|
||||
ac = await get_ac(id)
|
||||
if not ac.on or ac.burnt == True or ac.frozen == True:
|
||||
await send_ac_image(evt, ac)
|
||||
return
|
||||
ac.temperature += temp
|
||||
if ac.temperature > 40:
|
||||
# 根据温度随机出是否爆炸,40度开始,呈指数增长
|
||||
possibility = -math.e ** ((40-ac.temperature) / 50) + 1
|
||||
if random.random() < possibility:
|
||||
# 打开爆炸图片
|
||||
with open(ASSETS_PATH / "img" / "other" / "boom.jpg", "rb") as f:
|
||||
output = BytesIO()
|
||||
# 爆炸抖动
|
||||
frames = wiggle_transform(np.array(Image.open(f)), intensity=5)
|
||||
pil_frames = [Image.fromarray(frame) for frame in frames]
|
||||
pil_frames[0].save(output, format="GIF", save_all=True, append_images=pil_frames[1:], loop=0, duration=35, disposal=2)
|
||||
output.seek(0)
|
||||
await evt.send(await UniMessage().image(raw=output).export())
|
||||
ac.broke_ac(CrashType.BURNT)
|
||||
await evt.send("太热啦,空调炸了!")
|
||||
return
|
||||
await ac.update_ac(temperature_delta=temp)
|
||||
if ac.burnt:
|
||||
# 打开爆炸图片
|
||||
with open(ASSETS_PATH / "img" / "other" / "boom.jpg", "rb") as f:
|
||||
output = BytesIO()
|
||||
# 爆炸抖动
|
||||
frames = wiggle_transform(np.array(Image.open(f)), intensity=5)
|
||||
pil_frames = [Image.fromarray(frame) for frame in frames]
|
||||
pil_frames[0].save(output, format="GIF", save_all=True, append_images=pil_frames[1:], loop=0, duration=35, disposal=2)
|
||||
output.seek(0)
|
||||
await evt.send(await UniMessage().image(raw=output).export())
|
||||
await evt.send("太热啦,空调炸了!")
|
||||
return
|
||||
await send_ac_image(evt, ac)
|
||||
|
||||
evt = on_alconna(Alconna(
|
||||
@ -114,20 +137,17 @@ evt = on_alconna(Alconna(
|
||||
), use_cmd_start=True, use_cmd_sep=False, skip_for_unmatch=True)
|
||||
|
||||
@evt.handle()
|
||||
async def _(event: BaseEvent, target: DepLongTaskTarget, temp: Optional[Union[int, float]] = 1):
|
||||
async def _(target: DepLongTaskTarget, temp: Optional[Union[int, float]] = 1):
|
||||
if temp is None:
|
||||
temp = 1
|
||||
if temp <= 0:
|
||||
return
|
||||
id = target.channel_id
|
||||
ac = get_ac(id)
|
||||
ac = await get_ac(id)
|
||||
if not ac.on or ac.burnt == True or ac.frozen == True:
|
||||
await send_ac_image(evt, ac)
|
||||
return
|
||||
ac.temperature -= temp
|
||||
if ac.temperature < 0:
|
||||
# 根据温度随机出是否冻结,0度开始,呈指数增长
|
||||
possibility = -math.e ** (ac.temperature / 50) + 1
|
||||
if random.random() < possibility:
|
||||
ac.broke_ac(CrashType.FROZEN)
|
||||
await ac.update_ac(temperature_delta=-temp)
|
||||
await send_ac_image(evt, ac)
|
||||
|
||||
evt = on_alconna(Alconna(
|
||||
@ -135,21 +155,34 @@ evt = on_alconna(Alconna(
|
||||
), use_cmd_start=True, use_cmd_sep=False, skip_for_unmatch=True)
|
||||
|
||||
@evt.handle()
|
||||
async def _(event: BaseEvent, target: DepLongTaskTarget):
|
||||
async def _(target: DepLongTaskTarget):
|
||||
id = target.channel_id
|
||||
ac = get_ac(id)
|
||||
ac.change_ac()
|
||||
ac = await get_ac(id)
|
||||
await ac.change_ac()
|
||||
await send_ac_image(evt, ac)
|
||||
|
||||
async def query_number_ranking(id: str) -> tuple[int, int]:
|
||||
result = await db_manager.query_by_sql_file(
|
||||
ROOT_PATH / "sql" / "query_crash_and_rank.sql",
|
||||
(id,id)
|
||||
)
|
||||
if len(result) == 0:
|
||||
return 0, 0
|
||||
else:
|
||||
# 将字典转换为值的元组
|
||||
values = list(result[0].values())
|
||||
return values[0], values[1]
|
||||
|
||||
evt = on_alconna(Alconna(
|
||||
"空调炸炸排行榜",
|
||||
), use_cmd_start=True, use_cmd_sep=False, skip_for_unmatch=True)
|
||||
|
||||
@evt.handle()
|
||||
async def _(event: BaseEvent, target: DepLongTaskTarget):
|
||||
async def _(target: DepLongTaskTarget):
|
||||
id = target.channel_id
|
||||
ac = get_ac(id)
|
||||
number, ranking = ac.get_crashes_and_ranking()
|
||||
# ac = get_ac(id)
|
||||
# number, ranking = ac.get_crashes_and_ranking()
|
||||
number, ranking = await query_number_ranking(id)
|
||||
params = {
|
||||
"number": number,
|
||||
"ranking": ranking
|
||||
@ -159,4 +192,37 @@ async def _(event: BaseEvent, target: DepLongTaskTarget):
|
||||
target=".box",
|
||||
params=params
|
||||
)
|
||||
await evt.send(await UniMessage().image(raw=image).export())
|
||||
await evt.send(await UniMessage().image(raw=image).export())
|
||||
|
||||
evt = on_alconna(Alconna(
|
||||
"空调最高峰",
|
||||
), use_cmd_start=True, use_cmd_sep=False, skip_for_unmatch=True)
|
||||
|
||||
@evt.handle()
|
||||
async def _(target: DepLongTaskTarget):
|
||||
result = await db_manager.query_by_sql_file(
|
||||
ROOT_PATH / "sql" / "query_peak.sql"
|
||||
)
|
||||
if len(result) == 0:
|
||||
await evt.send("没有空调记录!")
|
||||
return
|
||||
max_temp = result[0].get("max")
|
||||
min_temp = result[0].get("min")
|
||||
his_max = result[0].get("his_max")
|
||||
his_min = result[0].get("his_min")
|
||||
# 再从内存里的空调池中获取最高温度和最低温度
|
||||
for ac in AirConditioner.InstancesPool.values():
|
||||
if ac.on and not ac.burnt and not ac.frozen:
|
||||
if max_temp is None or min_temp is None:
|
||||
max_temp = ac.temperature
|
||||
min_temp = ac.temperature
|
||||
max_temp = max(max_temp, ac.temperature)
|
||||
min_temp = min(min_temp, ac.temperature)
|
||||
if max_temp is None or min_temp is None:
|
||||
await evt.send(f"目前全部空调都被炸掉了!")
|
||||
else:
|
||||
await evt.send(f"全球在线空调最高温度为 {'%.1f' % max_temp}°C,最低温度为 {'%.1f' % min_temp}°C!")
|
||||
if his_max is None or his_min is None:
|
||||
pass
|
||||
else:
|
||||
await evt.send(f"历史最高温度为 {'%.1f' % his_max}°C,最低温度为 {'%.1f' % his_min}°C!\n(要进入历史记录,温度需至少保持 5 分钟)")
|
||||
@ -1,20 +1,193 @@
|
||||
import asyncio
|
||||
from enum import Enum
|
||||
from io import BytesIO
|
||||
import math
|
||||
from pathlib import Path
|
||||
import random
|
||||
import signal
|
||||
import time
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
|
||||
from nonebot import logger
|
||||
|
||||
from konabot.common.database import DatabaseManager
|
||||
from konabot.common.path import ASSETS_PATH, FONTS_PATH
|
||||
from konabot.common.path import DATA_PATH
|
||||
import nonebot
|
||||
import json
|
||||
|
||||
ROOT_PATH = Path(__file__).resolve().parent
|
||||
|
||||
# 创建全局数据库管理器实例
|
||||
db_manager = DatabaseManager()
|
||||
|
||||
class CrashType(Enum):
|
||||
BURNT = 0
|
||||
FROZEN = 1
|
||||
|
||||
driver = nonebot.get_driver()
|
||||
|
||||
@driver.on_startup
|
||||
async def register_startup_hook():
|
||||
await ac_manager.start_auto_save()
|
||||
|
||||
@driver.on_shutdown
|
||||
async def register_shutdown_hook():
|
||||
"""注册关闭时需要执行的函数"""
|
||||
# 停止自动保存任务
|
||||
if ac_manager:
|
||||
await ac_manager.stop_auto_save()
|
||||
|
||||
class AirConditionerManager:
|
||||
def __init__(self, save_interval: int = 300): # 默认5分钟保存一次
|
||||
self.save_interval = save_interval
|
||||
self._save_task = None
|
||||
self._running = False
|
||||
|
||||
async def start_auto_save(self):
|
||||
"""启动自动保存任务"""
|
||||
self._running = True
|
||||
self._save_task = asyncio.create_task(self._auto_save_loop())
|
||||
|
||||
logger.info(f"自动保存任务已启动,间隔: {self.save_interval}秒")
|
||||
|
||||
async def stop_auto_save(self):
|
||||
"""停止自动保存任务"""
|
||||
if self._save_task:
|
||||
self._running = False
|
||||
self._save_task.cancel()
|
||||
try:
|
||||
await self._save_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
logger.info("自动保存任务已停止")
|
||||
else:
|
||||
logger.warning("没有正在运行的自动保存任务")
|
||||
|
||||
async def _auto_save_loop(self):
|
||||
"""自动保存循环"""
|
||||
while self._running:
|
||||
try:
|
||||
await asyncio.sleep(self.save_interval)
|
||||
await self.save_all_instances()
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"定时保存失败: {e}")
|
||||
|
||||
async def save_all_instances(self):
|
||||
save_time = time.time()
|
||||
to_remove = []
|
||||
"""保存所有实例到数据库"""
|
||||
for ac_id, ac_instance in AirConditioner.InstancesPool.items():
|
||||
try:
|
||||
await db_manager.execute_by_sql_file(
|
||||
ROOT_PATH / "sql" / "update_ac.sql",
|
||||
[(ac_instance.on, ac_instance.temperature,
|
||||
ac_instance.burnt, ac_instance.frozen, ac_id),(ac_id,)]
|
||||
)
|
||||
if(save_time - ac_instance.instance_get_time >= 300): # 5 分钟
|
||||
to_remove.append(ac_id)
|
||||
except Exception as e:
|
||||
logger.error(f"保存空调 {ac_id} 失败: {e}")
|
||||
|
||||
logger.info(f"定时保存完成,共保存 {len(AirConditioner.InstancesPool)} 个空调实例")
|
||||
|
||||
# 删除时间过长实例
|
||||
for ac_id in to_remove:
|
||||
del AirConditioner.InstancesPool[ac_id]
|
||||
|
||||
logger.info(f"清理长期不活跃的空调实例完成,目前池内共有 {len(AirConditioner.InstancesPool)} 个实例")
|
||||
|
||||
ac_manager = AirConditionerManager(save_interval=300) # 5分钟
|
||||
|
||||
class AirConditioner:
|
||||
air_conditioners: dict[str, "AirConditioner"] = {}
|
||||
InstancesPool: dict[str, 'AirConditioner'] = {}
|
||||
|
||||
@classmethod
|
||||
async def refresh_ac(cls, id: str):
|
||||
cls.InstancesPool[id].instance_get_time = time.time()
|
||||
|
||||
@classmethod
|
||||
async def storage_ac(cls, id: str, ac: 'AirConditioner'):
|
||||
cls.InstancesPool[id] = ac
|
||||
|
||||
@classmethod
|
||||
async def get_ac(cls, id: str) -> 'AirConditioner':
|
||||
if(id in cls.InstancesPool):
|
||||
await cls.refresh_ac(id)
|
||||
return cls.InstancesPool[id]
|
||||
# 如果没有,那么从数据库重新实例化一个 AC 出来
|
||||
result = await db_manager.query_by_sql_file(ROOT_PATH / "sql" / "query_ac.sql", (id,))
|
||||
if len(result) == 0:
|
||||
ac = await cls.create_ac(id)
|
||||
return ac
|
||||
ac_data = result[0]
|
||||
ac = AirConditioner(id)
|
||||
ac.on = bool(ac_data["on"])
|
||||
ac.temperature = float(ac_data["temperature"])
|
||||
ac.burnt = bool(ac_data["burnt"])
|
||||
ac.frozen = bool(ac_data["frozen"])
|
||||
await cls.storage_ac(id, ac)
|
||||
return ac
|
||||
|
||||
@classmethod
|
||||
async def create_ac(cls, id: str) -> 'AirConditioner':
|
||||
ac = AirConditioner(id)
|
||||
await db_manager.execute_by_sql_file(
|
||||
ROOT_PATH / "sql" / "insert_ac.sql",
|
||||
[(id, ac.on, ac.temperature, ac.burnt, ac.frozen),(id,)]
|
||||
)
|
||||
await cls.storage_ac(id, ac)
|
||||
return ac
|
||||
|
||||
async def change_ac_temp(self, temperature_delta: float) -> None:
|
||||
'''
|
||||
改变空调的温度
|
||||
:param temperature_delta: float 温度变化量
|
||||
'''
|
||||
changed_temp = self.temperature + temperature_delta
|
||||
random_poss = random.random()
|
||||
if temperature_delta < 0 and changed_temp < 0:
|
||||
# 根据温度随机出是否冻结,0度开始,呈指数增长
|
||||
possibility = -math.e ** (changed_temp / 50) + 1
|
||||
if random_poss < possibility:
|
||||
await self.broke_ac(CrashType.FROZEN)
|
||||
elif temperature_delta > 0 and changed_temp > 40:
|
||||
# 根据温度随机出是否烧坏,40度开始,呈指数增长
|
||||
possibility = -math.e ** ((40-changed_temp) / 50) + 1
|
||||
if random_poss < possibility:
|
||||
await self.broke_ac(CrashType.BURNT)
|
||||
self.temperature = changed_temp
|
||||
|
||||
async def update_ac(self, state: bool = None, temperature_delta: float = None, burnt: bool = None, frozen: bool = None) -> 'AirConditioner':
|
||||
if state is not None:
|
||||
self.on = state
|
||||
if temperature_delta is not None:
|
||||
await self.change_ac_temp(temperature_delta)
|
||||
if burnt is not None:
|
||||
self.burnt = burnt
|
||||
if frozen is not None:
|
||||
self.frozen = frozen
|
||||
# await db_manager.execute_by_sql_file(
|
||||
# ROOT_PATH / "sql" / "update_ac.sql",
|
||||
# (self.on, self.temperature, self.burnt, self.frozen, self.id)
|
||||
# )
|
||||
return self
|
||||
|
||||
async def change_ac(self) -> 'AirConditioner':
|
||||
self.on = False
|
||||
self.temperature = 24
|
||||
self.burnt = False
|
||||
self.frozen = False
|
||||
# await db_manager.execute_by_sql_file(
|
||||
# ROOT_PATH / "sql" / "update_ac.sql",
|
||||
# (self.on, self.temperature, self.burnt, self.frozen, self.id)
|
||||
# )
|
||||
return self
|
||||
|
||||
def __init__(self, id: str) -> None:
|
||||
self.id = id
|
||||
@ -22,45 +195,42 @@ class AirConditioner:
|
||||
self.temperature = 24 # 默认温度
|
||||
self.burnt = False
|
||||
self.frozen = False
|
||||
AirConditioner.air_conditioners[id] = self
|
||||
|
||||
def change_ac(self):
|
||||
self.burnt = False
|
||||
self.frozen = False
|
||||
self.on = False
|
||||
self.temperature = 24 # 重置为默认温度
|
||||
self.instance_get_time = time.time()
|
||||
|
||||
def broke_ac(self, crash_type: CrashType):
|
||||
async def broke_ac(self, crash_type: CrashType):
|
||||
'''
|
||||
让空调坏掉,并保存数据
|
||||
|
||||
让空调坏掉
|
||||
:param crash_type: CrashType 枚举,表示空调坏掉的类型
|
||||
'''
|
||||
match crash_type:
|
||||
case CrashType.BURNT:
|
||||
self.burnt = True
|
||||
await self.update_ac(burnt=True)
|
||||
case CrashType.FROZEN:
|
||||
self.frozen = True
|
||||
self.save_crash_data(crash_type)
|
||||
await self.update_ac(frozen=True)
|
||||
await db_manager.execute_by_sql_file(
|
||||
ROOT_PATH / "sql" / "insert_crash.sql",
|
||||
(self.id, crash_type.value)
|
||||
)
|
||||
|
||||
def save_crash_data(self, crash_type: CrashType):
|
||||
'''
|
||||
如果空调爆炸了,就往本地的 ac_crash_data.json 里该 id 的记录加一
|
||||
'''
|
||||
data_file = DATA_PATH / "ac_crash_data.json"
|
||||
crash_data = {}
|
||||
if data_file.exists():
|
||||
with open(data_file, "r", encoding="utf-8") as f:
|
||||
crash_data = json.load(f)
|
||||
if self.id not in crash_data:
|
||||
crash_data[self.id] = {"burnt": 0, "frozen": 0}
|
||||
match crash_type:
|
||||
case CrashType.BURNT:
|
||||
crash_data[self.id]["burnt"] += 1
|
||||
case CrashType.FROZEN:
|
||||
crash_data[self.id]["frozen"] += 1
|
||||
with open(data_file, "w", encoding="utf-8") as f:
|
||||
json.dump(crash_data, f, ensure_ascii=False, indent=4)
|
||||
# def save_crash_data(self, crash_type: CrashType):
|
||||
# '''
|
||||
# 如果空调爆炸了,就往本地的 ac_crash_data.json 里该 id 的记录加一
|
||||
# '''
|
||||
# data_file = DATA_PATH / "ac_crash_data.json"
|
||||
# crash_data = {}
|
||||
# if data_file.exists():
|
||||
# with open(data_file, "r", encoding="utf-8") as f:
|
||||
# crash_data = json.load(f)
|
||||
# if self.id not in crash_data:
|
||||
# crash_data[self.id] = {"burnt": 0, "frozen": 0}
|
||||
# match crash_type:
|
||||
# case CrashType.BURNT:
|
||||
# crash_data[self.id]["burnt"] += 1
|
||||
# case CrashType.FROZEN:
|
||||
# crash_data[self.id]["frozen"] += 1
|
||||
# with open(data_file, "w", encoding="utf-8") as f:
|
||||
# json.dump(crash_data, f, ensure_ascii=False, indent=4)
|
||||
|
||||
def get_crashes_and_ranking(self) -> tuple[int, int]:
|
||||
'''
|
||||
|
||||
26
konabot/plugins/air_conditioner/sql/create_table.sql
Normal file
26
konabot/plugins/air_conditioner/sql/create_table.sql
Normal file
@ -0,0 +1,26 @@
|
||||
-- 创建所有表
|
||||
CREATE TABLE IF NOT EXISTS air_conditioner (
|
||||
id VARCHAR(128) PRIMARY KEY,
|
||||
"on" BOOLEAN NOT NULL,
|
||||
temperature REAL NOT NULL,
|
||||
burnt BOOLEAN NOT NULL,
|
||||
frozen BOOLEAN NOT NULL
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS air_conditioner_log (
|
||||
log_id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
log_time DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
id VARCHAR(128),
|
||||
"on" BOOLEAN NOT NULL,
|
||||
temperature REAL NOT NULL,
|
||||
burnt BOOLEAN NOT NULL,
|
||||
frozen BOOLEAN NOT NULL,
|
||||
FOREIGN KEY (id) REFERENCES air_conditioner(id)
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS air_conditioner_crash_log (
|
||||
id VARCHAR(128) NOT NULL,
|
||||
crash_type INT NOT NULL,
|
||||
timestamp DATETIME NOT NULL,
|
||||
FOREIGN KEY (id) REFERENCES air_conditioner(id)
|
||||
);
|
||||
8
konabot/plugins/air_conditioner/sql/insert_ac.sql
Normal file
8
konabot/plugins/air_conditioner/sql/insert_ac.sql
Normal file
@ -0,0 +1,8 @@
|
||||
-- 插入一台新空调
|
||||
INSERT INTO air_conditioner (id, "on", temperature, burnt, frozen)
|
||||
VALUES (?, ?, ?, ?, ?);
|
||||
-- 使用返回的数据插入日志
|
||||
INSERT INTO air_conditioner_log (id, "on", temperature, burnt, frozen)
|
||||
SELECT id, "on", temperature, burnt, frozen
|
||||
FROM air_conditioner
|
||||
WHERE id = ?;
|
||||
3
konabot/plugins/air_conditioner/sql/insert_crash.sql
Normal file
3
konabot/plugins/air_conditioner/sql/insert_crash.sql
Normal file
@ -0,0 +1,3 @@
|
||||
-- 插入一条空调爆炸记录
|
||||
INSERT INTO air_conditioner_crash_log (id, crash_type, timestamp)
|
||||
VALUES (?, ?, CURRENT_TIMESTAMP);
|
||||
4
konabot/plugins/air_conditioner/sql/query_ac.sql
Normal file
4
konabot/plugins/air_conditioner/sql/query_ac.sql
Normal file
@ -0,0 +1,4 @@
|
||||
-- 查询空调状态
|
||||
SELECT *
|
||||
FROM air_conditioner
|
||||
WHERE id = ?;
|
||||
23
konabot/plugins/air_conditioner/sql/query_crash_and_rank.sql
Normal file
23
konabot/plugins/air_conditioner/sql/query_crash_and_rank.sql
Normal file
@ -0,0 +1,23 @@
|
||||
-- 从 air_conditioner_crash_log 表中获取指定 id 损坏的次数以及损坏次数的排名
|
||||
SELECT crash_count, crash_rank
|
||||
FROM (
|
||||
SELECT id,
|
||||
COUNT(*) AS crash_count,
|
||||
RANK() OVER (ORDER BY COUNT(*) DESC) AS crash_rank
|
||||
FROM air_conditioner_crash_log
|
||||
GROUP BY id
|
||||
) AS ranked_data
|
||||
WHERE id = ?
|
||||
-- 如果该 id 没有损坏记录,则返回 0 次损坏和对应的最后一名
|
||||
UNION
|
||||
SELECT 0 AS crash_count,
|
||||
(SELECT COUNT(DISTINCT id) + 1 FROM air_conditioner_crash_log) AS crash_rank
|
||||
FROM (
|
||||
SELECT DISTINCT id
|
||||
FROM air_conditioner_crash_log
|
||||
) AS ranked_data
|
||||
WHERE NOT EXISTS (
|
||||
SELECT 1
|
||||
FROM air_conditioner_crash_log
|
||||
WHERE id = ?
|
||||
);
|
||||
13
konabot/plugins/air_conditioner/sql/query_peak.sql
Normal file
13
konabot/plugins/air_conditioner/sql/query_peak.sql
Normal file
@ -0,0 +1,13 @@
|
||||
-- 查询目前所有空调中的最高温度与最低温度与历史最高低温
|
||||
SELECT
|
||||
(SELECT MAX(temperature) FROM air_conditioner
|
||||
WHERE "on" = TRUE AND NOT frozen AND NOT burnt) AS max,
|
||||
|
||||
(SELECT MIN(temperature) FROM air_conditioner
|
||||
WHERE "on" = TRUE AND NOT frozen AND NOT burnt) AS min,
|
||||
|
||||
(SELECT MAX(temperature) FROM air_conditioner_log
|
||||
WHERE "on" = TRUE AND NOT frozen AND NOT burnt) AS his_max,
|
||||
|
||||
(SELECT MIN(temperature) FROM air_conditioner_log
|
||||
WHERE "on" = TRUE AND NOT frozen AND NOT burnt) AS his_min;
|
||||
10
konabot/plugins/air_conditioner/sql/update_ac.sql
Normal file
10
konabot/plugins/air_conditioner/sql/update_ac.sql
Normal file
@ -0,0 +1,10 @@
|
||||
-- 更新空调状态
|
||||
UPDATE air_conditioner
|
||||
SET "on" = ?, temperature = ?, burnt = ?, frozen = ?
|
||||
WHERE id = ?;
|
||||
|
||||
-- 插入日志记录(从更新后的数据获取)
|
||||
INSERT INTO air_conditioner_log (id, "on", temperature, burnt, frozen)
|
||||
SELECT id, "on", temperature, burnt, frozen
|
||||
FROM air_conditioner
|
||||
WHERE id = ?;
|
||||
@ -1,39 +1,39 @@
|
||||
import re
|
||||
|
||||
from nonebot import on_message
|
||||
from nonebot import get_plugin_config, on_message
|
||||
from nonebot.rule import Rule
|
||||
from nonebot_plugin_alconna import Reference, Reply, UniMsg
|
||||
|
||||
from nonebot.adapters import Event
|
||||
from pydantic import BaseModel
|
||||
|
||||
from konabot.common.permsys import require_permission
|
||||
|
||||
|
||||
matcher_fix = on_message()
|
||||
class Config(BaseModel):
|
||||
bilifetch_enabled_groups: list[int] = []
|
||||
|
||||
|
||||
config = get_plugin_config(Config)
|
||||
pattern = (
|
||||
r"^(?:(?:av|cv)\d+|BV[a-zA-Z0-9]{10})|"
|
||||
r"(?:b23\.tv|bili(?:22|23|33|2233)\.cn|\.bilibili\.com|QQ小程序(?:&#93;|]|\])哔哩哔哩).{0,500}"
|
||||
)
|
||||
|
||||
|
||||
@matcher_fix.handle()
|
||||
async def _(msg: UniMsg, event: Event):
|
||||
def _rule(msg: UniMsg) -> bool:
|
||||
to_search = msg.exclude(Reply, Reference).dump(json=True)
|
||||
to_search2 = msg.exclude(Reply, Reference).extract_plain_text()
|
||||
if not re.search(pattern, to_search) and not re.search(pattern, to_search2):
|
||||
return
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
matcher_fix = on_message(rule=Rule(_rule) & require_permission("bilifetch"))
|
||||
|
||||
|
||||
@matcher_fix.handle()
|
||||
async def _(event: Event):
|
||||
from nonebot_plugin_analysis_bilibili import handle_analysis
|
||||
|
||||
await handle_analysis(event)
|
||||
|
||||
# b_url: str
|
||||
# b_page: str | None
|
||||
# b_time: str | None
|
||||
#
|
||||
# from nonebot_plugin_analysis_bilibili.analysis_bilibili import extract as bilibili_extract
|
||||
#
|
||||
# b_url, b_page, b_time = bilibili_extract(to_search)
|
||||
# if b_url is None:
|
||||
# return
|
||||
#
|
||||
# await matcher_fix.send(await UniMessage().text(b_url).export())
|
||||
|
||||
|
||||
154
konabot/plugins/celeste_classic/__init__.py
Normal file
154
konabot/plugins/celeste_classic/__init__.py
Normal file
@ -0,0 +1,154 @@
|
||||
from pathlib import Path
|
||||
import subprocess
|
||||
import tempfile
|
||||
from typing import Any
|
||||
from loguru import logger
|
||||
from nonebot import on_message
|
||||
from pydantic import BaseModel
|
||||
|
||||
from nonebot.adapters import Event, Bot
|
||||
from nonebot_plugin_alconna import UniMessage, UniMsg
|
||||
from nonebot.adapters.onebot.v11.event import MessageEvent as OB11MessageEvent
|
||||
|
||||
from konabot.common.artifact import ArtifactDepends, ensure_artifact, register_artifacts
|
||||
from konabot.common.data_man import DataManager
|
||||
from konabot.common.path import BINARY_PATH, DATA_PATH
|
||||
|
||||
|
||||
arti_ccleste_wrap_linux = ArtifactDepends(
|
||||
url="https://github.com/Passthem-desu/pt-ccleste-wrap/releases/download/v0.1.5/ccleste-wrap",
|
||||
sha256="ba4118c6465d1ca1547cdd1bd11c6b9e6a6a98ea8967b55485aeb6b77bb7e921",
|
||||
target=BINARY_PATH / "ccleste-wrap",
|
||||
required_os="Linux",
|
||||
required_arch="x86_64",
|
||||
)
|
||||
arti_ccleste_wrap_windows = ArtifactDepends(
|
||||
url="https://github.com/Passthem-desu/pt-ccleste-wrap/releases/download/v0.1.5/ccleste-wrap.exe",
|
||||
sha256="7df382486a452485cdcf2115eabd7f772339ece470ab344074dc163fc7981feb",
|
||||
target=BINARY_PATH / "ccleste-wrap.exe",
|
||||
required_os="Windows",
|
||||
required_arch="AMD64",
|
||||
)
|
||||
|
||||
|
||||
register_artifacts(arti_ccleste_wrap_linux)
|
||||
register_artifacts(arti_ccleste_wrap_windows)
|
||||
|
||||
|
||||
class CelesteStatus(BaseModel):
|
||||
records: dict[str, str] = {}
|
||||
|
||||
|
||||
celeste_status = DataManager(CelesteStatus, DATA_PATH / "celeste-status.json")
|
||||
|
||||
|
||||
# ↓ 这里的 Type Hinting 是为了能 fit 进去 set[str | tuple[str, ...]]
|
||||
aliases: set[Any] = {"celeste", "蔚蓝", "爬山", "鳌太线"}
|
||||
ALLOW_CHARS = "wasdxc0123456789 \t\n\r"
|
||||
|
||||
|
||||
async def get_prev(evt: Event, bot: Bot) -> str | None:
|
||||
prev = None
|
||||
if isinstance(evt, OB11MessageEvent):
|
||||
if evt.reply is not None:
|
||||
prev = f"QQ:{bot.self_id}:" + str(evt.reply.message_id)
|
||||
else:
|
||||
for seg in evt.get_message():
|
||||
if seg.type == 'reply':
|
||||
msgid = seg.get('id')
|
||||
prev = f"QQ:{bot.self_id}:" + str(msgid)
|
||||
if prev is not None:
|
||||
async with celeste_status.get_data() as data:
|
||||
prev = data.records.get(prev)
|
||||
return prev
|
||||
|
||||
|
||||
async def match_celeste(evt: Event, bot: Bot, msg: UniMsg) -> bool:
|
||||
prev = await get_prev(evt, bot)
|
||||
text = msg.extract_plain_text().strip()
|
||||
if any(text.startswith(a) for a in aliases):
|
||||
return True
|
||||
if prev is not None:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
# cmd = on_command(cmd="celeste", aliases=aliases)
|
||||
cmd = on_message(rule=match_celeste)
|
||||
|
||||
|
||||
@cmd.handle()
|
||||
async def _(msg: UniMsg, evt: Event, bot: Bot):
|
||||
prev = await get_prev(evt, bot)
|
||||
actions = msg.extract_plain_text().strip()
|
||||
for alias in aliases:
|
||||
actions = actions.removeprefix(alias)
|
||||
actions = actions.strip()
|
||||
if len(actions) == 0:
|
||||
return
|
||||
if any((c not in ALLOW_CHARS) for c in actions):
|
||||
return
|
||||
|
||||
await ensure_artifact(arti_ccleste_wrap_linux)
|
||||
await ensure_artifact(arti_ccleste_wrap_windows)
|
||||
|
||||
bin: Path | None = None
|
||||
for arti in (
|
||||
arti_ccleste_wrap_linux,
|
||||
arti_ccleste_wrap_windows,
|
||||
):
|
||||
if not arti.is_corresponding_platform():
|
||||
continue
|
||||
bin = arti.target
|
||||
if not bin.exists():
|
||||
continue
|
||||
break
|
||||
|
||||
if bin is None:
|
||||
logger.warning("Celeste 模块没有找到该系统需要的二进制文件")
|
||||
return
|
||||
|
||||
if prev is not None:
|
||||
prev_append = ["-p", prev]
|
||||
else:
|
||||
prev_append = []
|
||||
|
||||
try:
|
||||
with tempfile.TemporaryDirectory() as _tempdir:
|
||||
tempdir = Path(_tempdir)
|
||||
gif_path = tempdir / "render.gif"
|
||||
cmd_celeste = [
|
||||
bin,
|
||||
"-a",
|
||||
actions,
|
||||
"-o",
|
||||
gif_path,
|
||||
] + prev_append
|
||||
logger.info(f"执行指令调用 celeste: CMD={cmd_celeste}")
|
||||
res = subprocess.run(cmd_celeste, timeout=5, capture_output=True)
|
||||
if res.returncode != 0:
|
||||
logger.warning(f"渲染 Celeste 时的输出不是 0 CODE={res.returncode} STDOUT={res.stdout} STDERR={res.stderr}")
|
||||
await UniMessage.text(f"渲染 Celeste 时出错啦!下面是输出:\n\n{res.stdout.decode()}{res.stderr.decode()}").send(evt, bot, at_sender=True)
|
||||
return
|
||||
if not gif_path.exists():
|
||||
logger.warning("没有找到 Celeste 渲染的文件")
|
||||
await UniMessage.text("渲染 Celeste 时出错啦!").send(evt, bot, at_sender=True)
|
||||
return
|
||||
gif_data = gif_path.read_bytes()
|
||||
except TimeoutError:
|
||||
logger.warning("在渲染 Celeste 时超时了")
|
||||
await UniMessage("渲染 Celeste 时超时了!请检查你的操作清单,不能太长").send(evt, bot, at_sender=True)
|
||||
return
|
||||
|
||||
receipt = await UniMessage.image(raw=gif_data).send(evt, bot)
|
||||
async with celeste_status.get_data() as data:
|
||||
if prev:
|
||||
actions = prev + "\n" + actions
|
||||
if isinstance(evt, OB11MessageEvent):
|
||||
for _msgid in receipt.msg_ids:
|
||||
msgid = _msgid["message_id"]
|
||||
data.records[f"QQ:{bot.self_id}:{msgid}"] = actions
|
||||
else:
|
||||
for msgid in receipt.msg_ids:
|
||||
data.records[f"DISCORD:{bot.self_id}:{msgid}"] = actions
|
||||
|
||||
277
konabot/plugins/fx_process/__init__.py
Normal file
277
konabot/plugins/fx_process/__init__.py
Normal file
@ -0,0 +1,277 @@
|
||||
import asyncio as asynkio
|
||||
from io import BytesIO
|
||||
|
||||
from inspect import signature
|
||||
import random
|
||||
|
||||
from konabot.common.longtask import DepLongTaskTarget
|
||||
from konabot.common.nb.exc import BotExceptionMessage
|
||||
from konabot.common.nb.extract_image import DepImageBytesOrNone
|
||||
from nonebot.adapters import Event as BaseEvent
|
||||
from nonebot import on_message, logger
|
||||
|
||||
from nonebot_plugin_alconna import (
|
||||
UniMessage,
|
||||
UniMsg
|
||||
)
|
||||
|
||||
from konabot.plugins.fx_process.fx_handle import ImageFilterStorage
|
||||
from konabot.plugins.fx_process.fx_manager import ImageFilterManager
|
||||
|
||||
from PIL import Image, ImageSequence
|
||||
|
||||
from konabot.plugins.fx_process.types import FilterItem, ImageRequireSignal, ImagesListRequireSignal, SenderInfo, StoredInfo
|
||||
|
||||
def try_convert_type(param_type, input_param, sender_info: SenderInfo = None) -> tuple[bool, any]:
|
||||
converted_value = None
|
||||
try:
|
||||
if param_type is float:
|
||||
converted_value = float(input_param)
|
||||
elif param_type is int:
|
||||
converted_value = int(input_param)
|
||||
elif param_type is bool:
|
||||
converted_value = input_param.lower() in ['true', '1', 'yes', '是', '开']
|
||||
elif param_type is Image.Image:
|
||||
converted_value = ImageRequireSignal()
|
||||
return False, converted_value
|
||||
elif param_type is SenderInfo:
|
||||
converted_value = sender_info
|
||||
return False, converted_value
|
||||
elif param_type == list[Image.Image]:
|
||||
converted_value = ImagesListRequireSignal()
|
||||
return False, converted_value
|
||||
elif param_type is str:
|
||||
if input_param is None:
|
||||
return False, None
|
||||
converted_value = str(input_param)
|
||||
else:
|
||||
return False, None
|
||||
except Exception:
|
||||
return False, None
|
||||
return True, converted_value
|
||||
|
||||
def prase_input_args(input_str: str, sender_info: SenderInfo = None) -> list[FilterItem]:
|
||||
# 按分号或换行符分割参数
|
||||
args = []
|
||||
for part in input_str.replace('\n', ';').split(';'):
|
||||
part = part.strip()
|
||||
if not part:
|
||||
continue
|
||||
split_part = part.split()
|
||||
filter_name = split_part[0]
|
||||
if not ImageFilterManager.has_filter(filter_name):
|
||||
continue
|
||||
filter_func = ImageFilterManager.get_filter(filter_name)
|
||||
input_filter_args = split_part[1:]
|
||||
# 获取函数最大参数数量
|
||||
sig = signature(filter_func)
|
||||
max_params = len(sig.parameters) - 1 # 减去第一个参数 image
|
||||
# 从 args 提取参数,并转换为适当类型
|
||||
func_args = []
|
||||
for i in range(0, max_params):
|
||||
# 尝试将参数转换为函数签名中对应的类型
|
||||
param = list(sig.parameters.values())[i + 1]
|
||||
param_type = param.annotation
|
||||
# 根据函数所需要的参数,从输入参数中提取,如果不匹配就使用默认值,将当前参数递交给下一个循环
|
||||
input_param = input_filter_args[0] if len(input_filter_args) > 0 else None
|
||||
state, converted_param = try_convert_type(param_type, input_param, sender_info)
|
||||
if state:
|
||||
input_filter_args.pop(0)
|
||||
if converted_param is None and param.default != param.empty:
|
||||
converted_param = param.default
|
||||
func_args.append(converted_param)
|
||||
args.append(FilterItem(name=filter_name,filter=filter_func, args=func_args))
|
||||
return args
|
||||
|
||||
def handle_filters_to_image(images: list[Image.Image], filters: list[FilterItem]) -> Image.Image:
|
||||
for filter_item in filters:
|
||||
logger.debug(f"{filter_item}")
|
||||
filter_func = filter_item.filter
|
||||
func_args = filter_item.args
|
||||
# 检测参数中是否有 ImageRequireSignal,如果有则传入对应数量的图像列表
|
||||
if any(isinstance(arg, ImageRequireSignal) for arg in func_args):
|
||||
# 替换 ImageRequireSignal 为 images 对应索引的图像
|
||||
actual_args = []
|
||||
img_signal_count = 1 # 从 images[1] 开始取图像
|
||||
for arg in func_args:
|
||||
if isinstance(arg, ImageRequireSignal):
|
||||
if img_signal_count >= len(images):
|
||||
raise BotExceptionMessage("图像数量不足,无法满足滤镜需求!")
|
||||
actual_args.append(images[img_signal_count])
|
||||
img_signal_count += 1
|
||||
else:
|
||||
actual_args.append(arg)
|
||||
func_args = actual_args
|
||||
# 检测参数中是否有 ImagesListRequireSignal,如果有则传入整个图像列表
|
||||
if any(isinstance(arg, ImagesListRequireSignal) for arg in func_args):
|
||||
actual_args = []
|
||||
for arg in func_args:
|
||||
if isinstance(arg, ImagesListRequireSignal):
|
||||
actual_args.append(images)
|
||||
else:
|
||||
actual_args.append(arg)
|
||||
func_args = actual_args
|
||||
|
||||
logger.debug(f"Applying filter: {filter_item.name} with args: {func_args}")
|
||||
|
||||
images[0] = filter_func(images[0], *func_args)
|
||||
return images[0]
|
||||
|
||||
def copy_images_by_index(images: list[Image.Image], index: int) -> list[Image.Image]:
|
||||
# 将导入图像列表复制为新的图像列表,如果是动图,那么就找到对应索引下的帧
|
||||
new_images = []
|
||||
for img in images:
|
||||
if getattr(img, "is_animated", False):
|
||||
frames = img.n_frames
|
||||
frame_idx = index % frames
|
||||
img.seek(frame_idx)
|
||||
new_images.append(img.copy())
|
||||
else:
|
||||
new_images.append(img.copy())
|
||||
|
||||
return new_images
|
||||
|
||||
def generate_image(images: list[Image.Image], filters: list[FilterItem]) -> Image.Image:
|
||||
# 处理位于最前面的生成类滤镜
|
||||
while filters and filters[0].name.strip() in ImageFilterManager.generate_filter_map:
|
||||
gen_filter = filters.pop(0)
|
||||
gen_func = gen_filter.filter
|
||||
func_args = gen_filter.args[1:] # 去掉第一个 list 参数
|
||||
gen_func(None, images, *func_args)
|
||||
|
||||
def save_or_load_image(images: list[Image.Image], filters: list[FilterItem], sender_info: SenderInfo) -> StoredInfo | None:
|
||||
stored_info = None
|
||||
# 处理位于最前面的“读取图像”、“存入图像”
|
||||
if not filters:
|
||||
return
|
||||
while filters and filters[0].name.strip() in ["读取图像", "存入图像"]:
|
||||
if filters[0].name.strip() == "读取图像":
|
||||
load_filter = filters.pop(0)
|
||||
path = load_filter.args[0] if load_filter.args else ""
|
||||
ImageFilterStorage.load_image(None, path, images, sender_info)
|
||||
elif filters[0].name.strip() == "存入图像":
|
||||
store_filter = filters.pop(0)
|
||||
name = store_filter.args[0] if store_filter.args[0] else str(random.randint(10000,99999))
|
||||
stored_info = ImageFilterStorage.store_image(images[0], name, sender_info)
|
||||
# 将剩下的“读取图像”或“存入图像”参数全部删除,避免后续非法操作
|
||||
filters[:] = [f for f in filters if f.name.strip() not in ["读取图像", "存入图像"]]
|
||||
return stored_info
|
||||
|
||||
async def apply_filters_to_images(images: list[Image.Image], filters: list[FilterItem], sender_info: SenderInfo) -> BytesIO | StoredInfo:
|
||||
# 先处理存取图像、生成图像的操作
|
||||
stored_info = save_or_load_image(images, filters, sender_info)
|
||||
generate_image(images, filters)
|
||||
|
||||
if stored_info and len(filters) <= 0:
|
||||
return stored_info
|
||||
|
||||
if len(images) <= 0:
|
||||
raise BotExceptionMessage("没有可处理的图像!")
|
||||
|
||||
# 检测是否需要将静态图视作动图处理
|
||||
frozen_to_move = any(
|
||||
filter_item.name == "动图"
|
||||
for filter_item in filters
|
||||
)
|
||||
static_fps = 10
|
||||
# 找到动图参数 fps
|
||||
if frozen_to_move:
|
||||
for filter_item in filters:
|
||||
if filter_item.name == "动图" and filter_item.args:
|
||||
try:
|
||||
static_fps = int(filter_item.args[0])
|
||||
except Exception:
|
||||
static_fps = 10
|
||||
break
|
||||
# 如果 image 是动图,则逐帧处理
|
||||
img = images[0]
|
||||
logger.debug("开始图像处理")
|
||||
output = BytesIO()
|
||||
if getattr(img, "is_animated", False) or frozen_to_move:
|
||||
frames = []
|
||||
append_images = []
|
||||
if getattr(img, "is_animated", False):
|
||||
logger.debug("处理动图帧")
|
||||
else:
|
||||
# 将静态图视作单帧动图处理,拷贝 10 帧
|
||||
logger.debug("处理静态图为多帧动图")
|
||||
append_images = [img.copy() for _ in range(10)]
|
||||
img.info['duration'] = int(1000 / static_fps)
|
||||
|
||||
async def process_single_frame(frame_images: list[Image.Image], frame_idx: int) -> Image.Image:
|
||||
"""处理单帧的异步函数"""
|
||||
logger.debug(f"开始处理帧 {frame_idx}")
|
||||
result = await asynkio.to_thread(handle_filters_to_image, frame_images, filters)
|
||||
logger.debug(f"完成处理帧 {frame_idx}")
|
||||
return result
|
||||
|
||||
# 并发处理所有帧
|
||||
tasks = []
|
||||
all_frames = []
|
||||
for i, frame in enumerate(list(ImageSequence.Iterator(img)) + append_images):
|
||||
all_frames.append(frame.copy())
|
||||
images_copy = copy_images_by_index(images, i)
|
||||
task = process_single_frame(images_copy, i)
|
||||
tasks.append(task)
|
||||
|
||||
frames = await asynkio.gather(*tasks, return_exceptions=False)
|
||||
|
||||
# 检查是否有处理失败的帧
|
||||
for i, result in enumerate(frames):
|
||||
if isinstance(result, Exception):
|
||||
logger.error(f"帧 {i} 处理失败: {result}")
|
||||
# 使用原始帧作为回退
|
||||
frames[i] = all_frames[i]
|
||||
|
||||
logger.debug("保存动图")
|
||||
frames[0].save(
|
||||
output,
|
||||
format="GIF",
|
||||
save_all=True,
|
||||
append_images=frames[1:],
|
||||
loop=img.info.get("loop", 0),
|
||||
disposal=img.info.get("disposal", 2),
|
||||
duration=img.info.get("duration", 100),
|
||||
)
|
||||
logger.debug("Animated image saved")
|
||||
else:
|
||||
img = handle_filters_to_image(images=images, filters=filters)
|
||||
img.save(output, format="PNG")
|
||||
logger.debug("Image processing completed")
|
||||
output.seek(0)
|
||||
return output
|
||||
|
||||
def is_fx_mentioned(evt: BaseEvent, msg: UniMsg) -> bool:
|
||||
txt = msg.extract_plain_text()
|
||||
if "fx" not in txt[:3].lower():
|
||||
return False
|
||||
return True
|
||||
|
||||
fx_on = on_message(rule=is_fx_mentioned)
|
||||
|
||||
@fx_on.handle()
|
||||
async def _(msg: UniMsg, event: BaseEvent, target: DepLongTaskTarget, image_data: DepImageBytesOrNone):
|
||||
preload_imgs = []
|
||||
# 提取图像
|
||||
try:
|
||||
preload_imgs.append(Image.open(BytesIO(image_data)))
|
||||
except Exception:
|
||||
logger.info("No image found in message for FX processing.")
|
||||
args = msg.extract_plain_text().split()
|
||||
if len(args) < 2:
|
||||
return
|
||||
|
||||
sender_info = SenderInfo(
|
||||
group_id=target.channel_id,
|
||||
qq_id=target.target_id
|
||||
)
|
||||
|
||||
filters = prase_input_args(msg.extract_plain_text()[2:], sender_info=sender_info)
|
||||
# if not filters:
|
||||
# return
|
||||
output = await apply_filters_to_images(preload_imgs, filters, sender_info)
|
||||
if isinstance(output,StoredInfo):
|
||||
await fx_on.send(await UniMessage().text(f"图像已存为「{output.name}」!").export())
|
||||
elif isinstance(output,BytesIO):
|
||||
await fx_on.send(await UniMessage().image(raw=output).export())
|
||||
|
||||
50
konabot/plugins/fx_process/color_handle.py
Normal file
50
konabot/plugins/fx_process/color_handle.py
Normal file
@ -0,0 +1,50 @@
|
||||
from typing import Optional
|
||||
from PIL import ImageColor
|
||||
|
||||
class ColorHandle:
|
||||
color_name_map = {
|
||||
"红": (255, 0, 0),
|
||||
"绿": (0, 255, 0),
|
||||
"蓝": (0, 0, 255),
|
||||
"黄": (255, 255, 0),
|
||||
"紫": (128, 0, 128),
|
||||
"黑": (0, 0, 0),
|
||||
"白": (255, 255, 255),
|
||||
"橙": (255, 165, 0),
|
||||
"粉": (255, 192, 203),
|
||||
"灰": (128, 128, 128),
|
||||
"青": (0, 255, 255),
|
||||
"靛": (75, 0, 130),
|
||||
"棕": (165, 42, 42),
|
||||
"浅": (200, 200, 200),
|
||||
"深": (50, 50, 50),
|
||||
"亮": (255, 255, 224),
|
||||
"暗": (47, 79, 79),
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def set_or_blend_color(ori_color: Optional[tuple], target_color: tuple) -> tuple:
|
||||
# 如果没有指定初始颜色,返回目标颜色
|
||||
if ori_color is None:
|
||||
return target_color
|
||||
# 混合颜色,取平均值
|
||||
blended_color = tuple((o + t) // 2 for o, t in zip(ori_color, target_color))
|
||||
return blended_color
|
||||
|
||||
@staticmethod
|
||||
def parse_color(color_str: str) -> tuple:
|
||||
# 如果是纯括号,则加上前缀 rgb
|
||||
if color_str.startswith('(') and color_str.endswith(')'):
|
||||
color_str = 'rgb' + color_str
|
||||
try:
|
||||
return ImageColor.getrgb(color_str)
|
||||
except ValueError:
|
||||
pass
|
||||
base_color = None
|
||||
color_str = color_str.replace('色', '')
|
||||
for name, rgb in ColorHandle.color_name_map.items():
|
||||
if name in color_str:
|
||||
base_color = ColorHandle.set_or_blend_color(base_color, rgb)
|
||||
if base_color is not None:
|
||||
return base_color
|
||||
return (255, 255, 255) # 默认白色
|
||||
1421
konabot/plugins/fx_process/fx_handle.py
Normal file
1421
konabot/plugins/fx_process/fx_handle.py
Normal file
File diff suppressed because it is too large
Load Diff
88
konabot/plugins/fx_process/fx_manager.py
Normal file
88
konabot/plugins/fx_process/fx_manager.py
Normal file
@ -0,0 +1,88 @@
|
||||
from typing import Optional
|
||||
from konabot.plugins.fx_process.fx_handle import ImageFilterEmpty, ImageFilterImplement, ImageFilterStorage
|
||||
|
||||
class ImageFilterManager:
|
||||
filter_map = {
|
||||
"模糊": ImageFilterImplement.apply_blur,
|
||||
"马赛克": ImageFilterImplement.apply_mosaic,
|
||||
"轮廓": ImageFilterImplement.apply_contour,
|
||||
"锐化": ImageFilterImplement.apply_sharpen,
|
||||
"边缘增强": ImageFilterImplement.apply_edge_enhance,
|
||||
"浮雕": ImageFilterImplement.apply_emboss,
|
||||
"查找边缘": ImageFilterImplement.apply_find_edges,
|
||||
"平滑": ImageFilterImplement.apply_smooth,
|
||||
"反色": ImageFilterImplement.apply_invert,
|
||||
"黑白": ImageFilterImplement.apply_black_white,
|
||||
"阈值": ImageFilterImplement.apply_threshold,
|
||||
"对比度": ImageFilterImplement.apply_contrast,
|
||||
"亮度": ImageFilterImplement.apply_brightness,
|
||||
"色彩": ImageFilterImplement.apply_color,
|
||||
"色调": ImageFilterImplement.apply_to_color,
|
||||
"缩放": ImageFilterImplement.apply_resize,
|
||||
"波纹": ImageFilterImplement.apply_wave,
|
||||
"色键": ImageFilterImplement.apply_color_key,
|
||||
"暗角": ImageFilterImplement.apply_vignette,
|
||||
"发光": ImageFilterImplement.apply_glow,
|
||||
"RGB分离": ImageFilterImplement.apply_rgb_split,
|
||||
"光学补偿": ImageFilterImplement.apply_optical_compensation,
|
||||
"球面化": ImageFilterImplement.apply_spherize,
|
||||
"旋转": ImageFilterImplement.apply_rotate,
|
||||
"透视变换": ImageFilterImplement.apply_perspective_transform,
|
||||
"裁剪": ImageFilterImplement.apply_crop,
|
||||
"噪点": ImageFilterImplement.apply_noise,
|
||||
"平移": ImageFilterImplement.apply_translate,
|
||||
"拓展边缘": ImageFilterImplement.apply_expand_edges,
|
||||
"素描": ImageFilterImplement.apply_sketch,
|
||||
"叠加颜色": ImageFilterImplement.apply_gradient_overlay,
|
||||
"阴影": ImageFilterImplement.apply_shadow,
|
||||
"径向模糊": ImageFilterImplement.apply_radial_blur,
|
||||
"旋转模糊": ImageFilterImplement.apply_spin_blur,
|
||||
"方向模糊": ImageFilterImplement.apply_directional_blur,
|
||||
"边缘模糊": ImageFilterImplement.apply_focus_blur,
|
||||
"缩放模糊": ImageFilterImplement.apply_zoom_blur,
|
||||
"镜像": ImageFilterImplement.apply_mirror_half,
|
||||
"水平翻转": ImageFilterImplement.apply_flip_horizontal,
|
||||
"垂直翻转": ImageFilterImplement.apply_flip_vertical,
|
||||
"复制": ImageFilterImplement.copy_area,
|
||||
"晃动": ImageFilterImplement.apply_random_wiggle,
|
||||
"动图": ImageFilterEmpty.empty_filter_param,
|
||||
"像素抖动": ImageFilterImplement.apply_pixel_jitter,
|
||||
"描边": ImageFilterImplement.apply_stroke,
|
||||
"形状描边": ImageFilterImplement.apply_shape_stroke,
|
||||
"半调": ImageFilterImplement.apply_halftone,
|
||||
"JPEG损坏": ImageFilterImplement.apply_jpeg_damage,
|
||||
"设置通道": ImageFilterImplement.apply_set_channel,
|
||||
"设置遮罩": ImageFilterImplement.apply_set_mask,
|
||||
# 图像处理
|
||||
"存入图像": ImageFilterStorage.store_image,
|
||||
"读取图像": ImageFilterStorage.load_image,
|
||||
"暂存图像": ImageFilterStorage.temp_store_image,
|
||||
"交换图像": ImageFilterStorage.swap_image_index,
|
||||
"删除图像": ImageFilterStorage.delete_image_by_index,
|
||||
"选择图像": ImageFilterStorage.select_image_by_index,
|
||||
# 多图像处理
|
||||
"混合图像": ImageFilterImplement.apply_blend,
|
||||
"覆盖图像": ImageFilterImplement.apply_overlay,
|
||||
# 生成式
|
||||
"覆加颜色": ImageFilterImplement.generate_solid,
|
||||
}
|
||||
|
||||
generate_filter_map = {
|
||||
"生成图层": ImageFilterImplement.generate_empty,
|
||||
"生成文本": ImageFilterImplement.generate_text
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def get_filter(cls, name: str) -> Optional[callable]:
|
||||
if name in cls.filter_map:
|
||||
return cls.filter_map[name]
|
||||
elif name in cls.generate_filter_map:
|
||||
return cls.generate_filter_map[name]
|
||||
else:
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def has_filter(cls, name: str) -> bool:
|
||||
return name in cls.filter_map or name in cls.generate_filter_map
|
||||
|
||||
|
||||
344
konabot/plugins/fx_process/gradient.py
Normal file
344
konabot/plugins/fx_process/gradient.py
Normal file
@ -0,0 +1,344 @@
|
||||
import re
|
||||
from konabot.plugins.fx_process.color_handle import ColorHandle
|
||||
import numpy as np
|
||||
from PIL import Image, ImageDraw
|
||||
from typing import List, Tuple, Dict, Optional
|
||||
|
||||
class GradientGenerator:
|
||||
"""渐变生成器类"""
|
||||
|
||||
def __init__(self):
|
||||
self.has_numpy = hasattr(np, '__version__')
|
||||
|
||||
def parse_color_list(self, color_list_str: str) -> List[Dict]:
|
||||
"""解析渐变颜色列表字符串
|
||||
|
||||
Args:
|
||||
color_list_str: 格式如 '[rgb(255,0,0)|(0,0)+rgb(0,255,0)|(0,100)+rgb(0,0,255)|(50,100)]'
|
||||
|
||||
Returns:
|
||||
list: 包含颜色和位置信息的字典列表
|
||||
"""
|
||||
color_nodes = []
|
||||
color_list_str = color_list_str.strip('[]').strip()
|
||||
matches = color_list_str.split('+')
|
||||
|
||||
for single_str in matches:
|
||||
color_str = single_str.split('|')[0]
|
||||
pos_str = single_str.split('|')[1] if '|' in single_str else '0,0'
|
||||
|
||||
color = ColorHandle.parse_color(color_str.strip())
|
||||
|
||||
try:
|
||||
pos_str = pos_str.replace('(', '').replace(')', '')
|
||||
x_str, y_str = pos_str.split(',')
|
||||
x_percent = float(x_str.strip().replace('%', ''))
|
||||
y_percent = float(y_str.strip().replace('%', ''))
|
||||
x_percent = max(0, min(100, x_percent))
|
||||
y_percent = max(0, min(100, y_percent))
|
||||
except:
|
||||
x_percent = 0
|
||||
y_percent = 0
|
||||
|
||||
color_nodes.append({
|
||||
'color': color,
|
||||
'position': (x_percent / 100.0, y_percent / 100.0)
|
||||
})
|
||||
|
||||
if not color_nodes:
|
||||
color_nodes = [
|
||||
{'color': (255, 0, 0), 'position': (0, 0)},
|
||||
{'color': (0, 0, 255), 'position': (1, 1)}
|
||||
]
|
||||
|
||||
return color_nodes
|
||||
|
||||
def create_gradient(self, width: int, height: int, color_nodes: List[Dict]) -> Image.Image:
|
||||
"""创建渐变图像
|
||||
|
||||
Args:
|
||||
width: 图像宽度
|
||||
height: 图像高度
|
||||
color_nodes: 颜色节点列表
|
||||
|
||||
Returns:
|
||||
Image.Image: 渐变图像
|
||||
"""
|
||||
if len(color_nodes) == 1:
|
||||
return Image.new('RGB', (width, height), color_nodes[0]['color'])
|
||||
elif len(color_nodes) == 2:
|
||||
return self._create_linear_gradient(width, height, color_nodes)
|
||||
else:
|
||||
return self._create_radial_gradient(width, height, color_nodes)
|
||||
|
||||
def _create_linear_gradient(self, width: int, height: int, color_nodes: List[Dict]) -> Image.Image:
|
||||
"""创建线性渐变"""
|
||||
color1 = color_nodes[0]['color']
|
||||
color2 = color_nodes[1]['color']
|
||||
pos1 = color_nodes[0]['position']
|
||||
pos2 = color_nodes[1]['position']
|
||||
|
||||
if self.has_numpy:
|
||||
return self._create_linear_gradient_numpy(width, height, color1, color2, pos1, pos2)
|
||||
else:
|
||||
return self._create_linear_gradient_pil(width, height, color1, color2, pos1, pos2)
|
||||
|
||||
def _create_linear_gradient_numpy(self, width: int, height: int,
|
||||
color1: Tuple, color2: Tuple,
|
||||
pos1: Tuple, pos2: Tuple) -> Image.Image:
|
||||
"""使用numpy创建线性渐变"""
|
||||
# 创建坐标网格
|
||||
x = np.linspace(0, 1, width)
|
||||
y = np.linspace(0, 1, height)
|
||||
xx, yy = np.meshgrid(x, y)
|
||||
|
||||
# 计算渐变方向
|
||||
dx = pos2[0] - pos1[0]
|
||||
dy = pos2[1] - pos1[1]
|
||||
length_sq = dx * dx + dy * dy
|
||||
|
||||
if length_sq > 0:
|
||||
# 计算投影参数
|
||||
t = ((xx - pos1[0]) * dx + (yy - pos1[1]) * dy) / length_sq
|
||||
t = np.clip(t, 0, 1)
|
||||
else:
|
||||
t = np.zeros_like(xx)
|
||||
|
||||
# 插值颜色
|
||||
r = color1[0] + (color2[0] - color1[0]) * t
|
||||
g = color1[1] + (color2[1] - color1[1]) * t
|
||||
b = color1[2] + (color2[2] - color1[2]) * t
|
||||
|
||||
# 创建图像
|
||||
gradient_array = np.stack([r, g, b], axis=-1).astype(np.uint8)
|
||||
return Image.fromarray(gradient_array)
|
||||
|
||||
def _create_linear_gradient_pil(self, width: int, height: int,
|
||||
color1: Tuple, color2: Tuple,
|
||||
pos1: Tuple, pos2: Tuple) -> Image.Image:
|
||||
"""使用PIL创建线性渐变(没有numpy时使用)"""
|
||||
gradient = Image.new('RGB', (width, height))
|
||||
draw = ImageDraw.Draw(gradient)
|
||||
|
||||
# 判断渐变方向
|
||||
if abs(pos1[0] - pos2[0]) < 0.01: # 垂直渐变
|
||||
y1 = int(pos1[1] * (height - 1))
|
||||
y2 = int(pos2[1] * (height - 1))
|
||||
|
||||
if y2 < y1:
|
||||
y1, y2 = y2, y1
|
||||
color1, color2 = color2, color1
|
||||
|
||||
if y2 > y1:
|
||||
for y in range(height):
|
||||
if y <= y1:
|
||||
fill_color = color1
|
||||
elif y >= y2:
|
||||
fill_color = color2
|
||||
else:
|
||||
ratio = (y - y1) / (y2 - y1)
|
||||
r = int(color1[0] + (color2[0] - color1[0]) * ratio)
|
||||
g = int(color1[1] + (color2[1] - color1[1]) * ratio)
|
||||
b = int(color1[2] + (color2[2] - color1[2]) * ratio)
|
||||
fill_color = (r, g, b)
|
||||
|
||||
draw.line([(0, y), (width, y)], fill=fill_color)
|
||||
else:
|
||||
draw.rectangle([0, 0, width, height], fill=color1)
|
||||
|
||||
elif abs(pos1[1] - pos2[1]) < 0.01: # 水平渐变
|
||||
x1 = int(pos1[0] * (width - 1))
|
||||
x2 = int(pos2[0] * (width - 1))
|
||||
|
||||
if x2 < x1:
|
||||
x1, x2 = x2, x1
|
||||
color1, color2 = color2, color1
|
||||
|
||||
if x2 > x1:
|
||||
for x in range(width):
|
||||
if x <= x1:
|
||||
fill_color = color1
|
||||
elif x >= x2:
|
||||
fill_color = color2
|
||||
else:
|
||||
ratio = (x - x1) / (x2 - x1)
|
||||
r = int(color1[0] + (color2[0] - color1[0]) * ratio)
|
||||
g = int(color1[1] + (color2[1] - color1[1]) * ratio)
|
||||
b = int(color1[2] + (color2[2] - color1[2]) * ratio)
|
||||
fill_color = (r, g, b)
|
||||
|
||||
draw.line([(x, 0), (x, height)], fill=fill_color)
|
||||
else:
|
||||
draw.rectangle([0, 0, width, height], fill=color1)
|
||||
|
||||
else: # 对角渐变(简化处理为左上到右下)
|
||||
for y in range(height):
|
||||
for x in range(width):
|
||||
distance = (x/width + y/height) / 2
|
||||
r = int(color1[0] + (color2[0] - color1[0]) * distance)
|
||||
g = int(color1[1] + (color2[1] - color1[1]) * distance)
|
||||
b = int(color1[2] + (color2[2] - color1[2]) * distance)
|
||||
draw.point((x, y), fill=(r, g, b))
|
||||
|
||||
return gradient
|
||||
|
||||
def _create_radial_gradient(self, width: int, height: int, color_nodes: List[Dict]) -> Image.Image:
|
||||
"""创建径向渐变"""
|
||||
if self.has_numpy and len(color_nodes) > 2:
|
||||
return self._create_radial_gradient_numpy(width, height, color_nodes)
|
||||
else:
|
||||
return self._create_simple_gradient(width, height, color_nodes)
|
||||
|
||||
def _create_radial_gradient_numpy(self, width: int, height: int, color_nodes: List[Dict]) -> Image.Image:
|
||||
"""使用numpy创建径向渐变(多色)"""
|
||||
# 创建坐标网格
|
||||
x = np.linspace(0, 1, width)
|
||||
y = np.linspace(0, 1, height)
|
||||
xx, yy = np.meshgrid(x, y)
|
||||
|
||||
# 提取颜色和位置
|
||||
positions = np.array([node['position'] for node in color_nodes])
|
||||
colors = np.array([node['color'] for node in color_nodes])
|
||||
|
||||
# 计算每个点到所有节点的距离
|
||||
distances = np.sqrt((xx[:, :, np.newaxis] - positions[np.newaxis, np.newaxis, :, 0]) ** 2 +
|
||||
(yy[:, :, np.newaxis] - positions[np.newaxis, np.newaxis, :, 1]) ** 2)
|
||||
|
||||
# 找到最近的两个节点
|
||||
sorted_indices = np.argsort(distances, axis=2)
|
||||
nearest_idx = sorted_indices[:, :, 0]
|
||||
second_idx = sorted_indices[:, :, 1]
|
||||
|
||||
# 获取对应的颜色
|
||||
nearest_colors = colors[nearest_idx]
|
||||
second_colors = colors[second_idx]
|
||||
|
||||
# 获取距离并计算权重
|
||||
nearest_dist = np.take_along_axis(distances, np.expand_dims(nearest_idx, axis=2), axis=2)[:, :, 0]
|
||||
second_dist = np.take_along_axis(distances, np.expand_dims(second_idx, axis=2), axis=2)[:, :, 0]
|
||||
|
||||
total_dist = nearest_dist + second_dist
|
||||
mask = total_dist > 0
|
||||
weight1 = np.zeros_like(nearest_dist)
|
||||
weight1[mask] = second_dist[mask] / total_dist[mask]
|
||||
weight2 = 1 - weight1
|
||||
|
||||
# 插值颜色
|
||||
r = nearest_colors[:, :, 0] * weight1 + second_colors[:, :, 0] * weight2
|
||||
g = nearest_colors[:, :, 1] * weight1 + second_colors[:, :, 1] * weight2
|
||||
b = nearest_colors[:, :, 2] * weight1 + second_colors[:, :, 2] * weight2
|
||||
|
||||
gradient_array = np.stack([r, g, b], axis=-1).astype(np.uint8)
|
||||
return Image.fromarray(gradient_array)
|
||||
|
||||
def _create_simple_gradient(self, width: int, height: int, color_nodes: List[Dict]) -> Image.Image:
|
||||
"""创建简化渐变(没有numpy或多色时使用)"""
|
||||
gradient = Image.new('RGB', (width, height))
|
||||
draw = ImageDraw.Draw(gradient)
|
||||
|
||||
if len(color_nodes) >= 2:
|
||||
# 使用第一个和最后一个颜色创建简单渐变
|
||||
color1 = color_nodes[0]['color']
|
||||
color2 = color_nodes[-1]['color']
|
||||
|
||||
# 判断节点分布
|
||||
x_positions = [node['position'][0] for node in color_nodes]
|
||||
y_positions = [node['position'][1] for node in color_nodes]
|
||||
|
||||
if all(abs(x - x_positions[0]) < 0.01 for x in x_positions):
|
||||
# 垂直渐变
|
||||
for y in range(height):
|
||||
ratio = y / (height - 1) if height > 1 else 0
|
||||
r = int(color1[0] + (color2[0] - color1[0]) * ratio)
|
||||
g = int(color1[1] + (color2[1] - color1[1]) * ratio)
|
||||
b = int(color1[2] + (color2[2] - color1[2]) * ratio)
|
||||
draw.line([(0, y), (width, y)], fill=(r, g, b))
|
||||
else:
|
||||
# 水平渐变
|
||||
for x in range(width):
|
||||
ratio = x / (width - 1) if width > 1 else 0
|
||||
r = int(color1[0] + (color2[0] - color1[0]) * ratio)
|
||||
g = int(color1[1] + (color2[1] - color1[1]) * ratio)
|
||||
b = int(color1[2] + (color2[2] - color1[2]) * ratio)
|
||||
draw.line([(x, 0), (x, height)], fill=(r, g, b))
|
||||
else:
|
||||
# 单色
|
||||
draw.rectangle([0, 0, width, height], fill=color_nodes[0]['color'])
|
||||
|
||||
return gradient
|
||||
|
||||
def create_simple_gradient(self, width: int, height: int,
|
||||
start_color: Tuple, end_color: Tuple,
|
||||
direction: str = 'vertical') -> Image.Image:
|
||||
"""创建简单双色渐变
|
||||
|
||||
Args:
|
||||
width: 图像宽度
|
||||
height: 图像高度
|
||||
start_color: 起始颜色
|
||||
end_color: 结束颜色
|
||||
direction: 渐变方向 'vertical', 'horizontal', 'diagonal'
|
||||
|
||||
Returns:
|
||||
Image.Image: 渐变图像
|
||||
"""
|
||||
if direction == 'vertical':
|
||||
return self._create_vertical_gradient(width, height, start_color, end_color)
|
||||
elif direction == 'horizontal':
|
||||
return self._create_horizontal_gradient(width, height, start_color, end_color)
|
||||
else: # diagonal
|
||||
return self._create_diagonal_gradient(width, height, start_color, end_color)
|
||||
|
||||
def _create_vertical_gradient(self, width: int, height: int,
|
||||
color1: Tuple, color2: Tuple) -> Image.Image:
|
||||
"""创建垂直渐变"""
|
||||
gradient = Image.new('RGB', (width, height))
|
||||
draw = ImageDraw.Draw(gradient)
|
||||
|
||||
for y in range(height):
|
||||
ratio = y / (height - 1) if height > 1 else 0
|
||||
r = int(color1[0] + (color2[0] - color1[0]) * ratio)
|
||||
g = int(color1[1] + (color2[1] - color1[1]) * ratio)
|
||||
b = int(color1[2] + (color2[2] - color1[2]) * ratio)
|
||||
draw.line([(0, y), (width, y)], fill=(r, g, b))
|
||||
|
||||
return gradient
|
||||
|
||||
def _create_horizontal_gradient(self, width: int, height: int,
|
||||
color1: Tuple, color2: Tuple) -> Image.Image:
|
||||
"""创建水平渐变"""
|
||||
gradient = Image.new('RGB', (width, height))
|
||||
draw = ImageDraw.Draw(gradient)
|
||||
|
||||
for x in range(width):
|
||||
ratio = x / (width - 1) if width > 1 else 0
|
||||
r = int(color1[0] + (color2[0] - color1[0]) * ratio)
|
||||
g = int(color1[1] + (color2[1] - color1[1]) * ratio)
|
||||
b = int(color1[2] + (color2[2] - color1[2]) * ratio)
|
||||
draw.line([(x, 0), (x, height)], fill=(r, g, b))
|
||||
|
||||
return gradient
|
||||
|
||||
def _create_diagonal_gradient(self, width: int, height: int,
|
||||
color1: Tuple, color2: Tuple) -> Image.Image:
|
||||
"""创建对角渐变"""
|
||||
if self.has_numpy:
|
||||
return self._create_diagonal_gradient_numpy(width, height, color1, color2)
|
||||
else:
|
||||
return self._create_horizontal_gradient(width, height, color1, color2) # 降级为水平渐变
|
||||
|
||||
def _create_diagonal_gradient_numpy(self, width: int, height: int,
|
||||
color1: Tuple, color2: Tuple) -> Image.Image:
|
||||
"""使用numpy创建对角渐变"""
|
||||
x = np.linspace(0, 1, width)
|
||||
y = np.linspace(0, 1, height)
|
||||
xx, yy = np.meshgrid(x, y)
|
||||
|
||||
distance = (xx + yy) / 2.0
|
||||
|
||||
r = color1[0] + (color2[0] - color1[0]) * distance
|
||||
g = color1[1] + (color2[1] - color1[1]) * distance
|
||||
b = color1[2] + (color2[2] - color1[2]) * distance
|
||||
|
||||
gradient_array = np.stack([r, g, b], axis=-1).astype(np.uint8)
|
||||
return Image.fromarray(gradient_array)
|
||||
182
konabot/plugins/fx_process/image_storage.py
Normal file
182
konabot/plugins/fx_process/image_storage.py
Normal file
@ -0,0 +1,182 @@
|
||||
import asyncio
|
||||
from dataclasses import dataclass
|
||||
from hashlib import md5
|
||||
import time
|
||||
|
||||
from nonebot import logger
|
||||
from nonebot_plugin_apscheduler import driver
|
||||
from konabot.common.path import DATA_PATH
|
||||
import os
|
||||
from PIL import Image
|
||||
from io import BytesIO
|
||||
|
||||
IMAGE_PATH = DATA_PATH / "temp" / "images"
|
||||
|
||||
@dataclass
|
||||
class ImageResource:
|
||||
filename: str
|
||||
expire: int
|
||||
|
||||
@dataclass
|
||||
class StorageImage:
|
||||
name: str
|
||||
resources: dict[str,
|
||||
dict[str,ImageResource]] # {群号: {QQ号: ImageResource}}
|
||||
|
||||
class ImageStorager:
|
||||
images_pool: dict[str,StorageImage] = {}
|
||||
|
||||
max_storage: int = 10 * 1024 * 1024 # 最大存储10MB
|
||||
max_image_count: int = 200 # 最大存储图片数量
|
||||
|
||||
@staticmethod
|
||||
def init():
|
||||
if not IMAGE_PATH.exists():
|
||||
IMAGE_PATH.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
@staticmethod
|
||||
def delete_path_image(name: str):
|
||||
resource_path = IMAGE_PATH / name
|
||||
if resource_path.exists():
|
||||
os.remove(resource_path)
|
||||
|
||||
@staticmethod
|
||||
async def clear_all_image():
|
||||
# 清理 temp 目录下的所有图片资源
|
||||
for file in os.listdir(IMAGE_PATH):
|
||||
file_path = IMAGE_PATH / file
|
||||
if file_path.is_file():
|
||||
os.remove(file_path)
|
||||
|
||||
@classmethod
|
||||
async def clear_expire_image(cls):
|
||||
# 清理过期的图片资源,将未被删除的放入列表中,如果超过最大数量则删除最早过期的
|
||||
remaining_images = []
|
||||
current_time = time.time()
|
||||
for name, storage_image in list(ImageStorager.images_pool.items()):
|
||||
for group_id, resources in list(storage_image.resources.items()):
|
||||
for qq_id, resource in list(resources.items()):
|
||||
if resource.expire < current_time:
|
||||
del storage_image.resources[group_id][qq_id]
|
||||
cls.delete_path_image(name)
|
||||
else:
|
||||
remaining_images.append((name, group_id, qq_id, resource.expire))
|
||||
if not storage_image.resources:
|
||||
del ImageStorager.images_pool[name]
|
||||
# 如果剩余图片超过最大数量,按过期时间排序并删除最早过期的
|
||||
if len(remaining_images) > ImageStorager.max_image_count:
|
||||
remaining_images.sort(key=lambda x: x[3]) # 按过期时间排序
|
||||
to_delete = len(remaining_images) - ImageStorager.max_image_count
|
||||
for i in range(to_delete):
|
||||
name, group_id, qq_id, _ = remaining_images[i]
|
||||
resource = ImageStorager.images_pool[name].resources[group_id][qq_id]
|
||||
del ImageStorager.images_pool[name].resources[group_id][qq_id]
|
||||
cls.delete_path_image(name)
|
||||
logger.info("过期图片清理完成")
|
||||
|
||||
@classmethod
|
||||
def _add_to_pool(cls, filename: str, name: str, group_id: str, qq_id: str, expire: int = 36000):
|
||||
expire_time = time.time() + expire
|
||||
if name not in cls.images_pool:
|
||||
cls.images_pool[name] = StorageImage(name=name,resources={})
|
||||
if group_id not in cls.images_pool[name].resources:
|
||||
cls.images_pool[name].resources[group_id] = {}
|
||||
cls.images_pool[name].resources[group_id][qq_id] = ImageResource(filename=filename, expire=expire_time)
|
||||
logger.debug(f"{cls.images_pool}")
|
||||
|
||||
@classmethod
|
||||
def save_image(cls, image: bytes, name: str, group_id: str, qq_id: str) -> None:
|
||||
"""
|
||||
以哈希值命名保存图片,并返回图片资源信息
|
||||
"""
|
||||
# 检测图像大小,不得超过 10 MB
|
||||
if len(image) > cls.max_storage:
|
||||
raise ValueError("图片大小超过 10 MB 限制")
|
||||
hash_name = md5(image).hexdigest()
|
||||
ext = os.path.splitext(name)[1]
|
||||
file_name = f"{hash_name}{ext}"
|
||||
full_path = IMAGE_PATH / file_name
|
||||
with open(full_path, "wb") as f:
|
||||
f.write(image)
|
||||
# 将文件写入 images_pool
|
||||
logger.debug(f"Image saved: {file_name} for group {group_id}, qq {qq_id}")
|
||||
cls._add_to_pool(file_name, name, group_id, qq_id)
|
||||
|
||||
@classmethod
|
||||
def save_image_by_pil(cls, image: Image.Image, name: str, group_id: str, qq_id: str) -> None:
|
||||
"""
|
||||
以哈希值命名保存图片,并返回图片资源信息
|
||||
"""
|
||||
img_byte_arr = BytesIO()
|
||||
# 如果图片是动图,保存为 GIF 格式
|
||||
if getattr(image, "is_animated", False):
|
||||
image.save(img_byte_arr, format="GIF", save_all=True, loop=0)
|
||||
else:
|
||||
image.save(img_byte_arr, format=image.format or "PNG")
|
||||
img_bytes = img_byte_arr.getvalue()
|
||||
cls.save_image(img_bytes, name, group_id, qq_id)
|
||||
|
||||
@classmethod
|
||||
def load_image(cls, name: str, group_id: str, qq_id: str) -> Image:
|
||||
logger.debug(f"Loading image: {name} for group {group_id}, qq {qq_id}")
|
||||
if name not in cls.images_pool:
|
||||
logger.debug(f"Image {name} not found in pool")
|
||||
return None
|
||||
if group_id not in cls.images_pool[name].resources:
|
||||
logger.debug(f"No resources for group {group_id} in image {name}")
|
||||
return None
|
||||
# 寻找对应 QQ 号 的资源,如果没有就返回相同群下的第一个资源
|
||||
if qq_id not in cls.images_pool[name].resources[group_id]:
|
||||
first_qq_id = next(iter(cls.images_pool[name].resources[group_id]))
|
||||
qq_id = first_qq_id
|
||||
resource = cls.images_pool[name].resources[group_id][qq_id]
|
||||
resource_path = IMAGE_PATH / resource.filename
|
||||
logger.debug(f"Image path: {resource_path}")
|
||||
return Image.open(resource_path)
|
||||
|
||||
class ImageStoragerManager:
|
||||
def __init__(self, interval: int = 300): # 默认 5 分钟执行一次
|
||||
self.interval = interval
|
||||
self._clear_task = None
|
||||
self._running = False
|
||||
|
||||
async def start_auto_clear(self):
|
||||
"""启动自动任务"""
|
||||
# 先清理一次
|
||||
await ImageStorager.clear_all_image()
|
||||
self._running = True
|
||||
self._clear_task = asyncio.create_task(self._auto_clear_loop())
|
||||
|
||||
logger.info(f"自动清理任务已启动,间隔: {self.interval}秒")
|
||||
|
||||
async def stop_auto_clear(self):
|
||||
"""停止自动清理任务"""
|
||||
if self._clear_task:
|
||||
self._running = False
|
||||
self._clear_task.cancel()
|
||||
try:
|
||||
await self._clear_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
logger.info("自动清理任务已停止")
|
||||
else:
|
||||
logger.warning("没有正在运行的自动清理任务")
|
||||
|
||||
async def _auto_clear_loop(self):
|
||||
"""自动清理循环"""
|
||||
while self._running:
|
||||
try:
|
||||
await asyncio.sleep(self.interval)
|
||||
await ImageStorager.clear_expire_image()
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"定时清理失败: {e}")
|
||||
|
||||
image_manager = ImageStoragerManager(interval=300) # 每5分钟清理一次
|
||||
|
||||
@driver.on_startup
|
||||
async def init_image_storage():
|
||||
ImageStorager.init()
|
||||
# 启用定时任务清理过期图片
|
||||
await image_manager.start_auto_clear()
|
||||
125
konabot/plugins/fx_process/math_helper.py
Normal file
125
konabot/plugins/fx_process/math_helper.py
Normal file
@ -0,0 +1,125 @@
|
||||
import cv2
|
||||
from nonebot import logger
|
||||
import numpy as np
|
||||
|
||||
from shapely.geometry import Polygon
|
||||
from shapely.ops import unary_union
|
||||
|
||||
def fix_with_shapely(contours: list) -> np.ndarray:
|
||||
"""
|
||||
使用Shapely库处理复杂自相交
|
||||
"""
|
||||
fixed_polygons = []
|
||||
for contour in contours:
|
||||
# 转换输入为正确的格式
|
||||
contour_array = contour.reshape(-1, 2)
|
||||
# 转换为Shapely多边形
|
||||
polygon = Polygon(contour_array)
|
||||
if not polygon.is_valid:
|
||||
polygon = polygon.buffer(0)
|
||||
fixed_polygons.append(polygon)
|
||||
# 接下来把所有轮廓合并为一个
|
||||
if len(fixed_polygons) >= 1:
|
||||
merged_polygon = unary_union(fixed_polygons)
|
||||
if merged_polygon.geom_type == 'Polygon':
|
||||
merged_points = np.array(merged_polygon.exterior.coords, dtype=np.int32)
|
||||
elif merged_polygon.geom_type == 'MultiPolygon':
|
||||
largest = max(merged_polygon.geoms, key=lambda p: p.area)
|
||||
merged_points = np.array(largest.exterior.coords, dtype=np.int32)
|
||||
return [merged_points.reshape(-1, 1, 2)]
|
||||
else:
|
||||
logger.warning("No valid contours found after fixing with Shapely.")
|
||||
return [np.array([], dtype=np.int32).reshape(0, 1, 2)]
|
||||
|
||||
def expand_contours(contours, stroke_width):
|
||||
"""
|
||||
将轮廓向外扩展指定宽度
|
||||
|
||||
参数:
|
||||
contours: OpenCV轮廓列表
|
||||
stroke_width: 扩展宽度(像素)
|
||||
|
||||
返回:
|
||||
扩展后的轮廓列表
|
||||
"""
|
||||
expanded_contours = []
|
||||
|
||||
for cnt in contours:
|
||||
# 将轮廓转换为点列表
|
||||
points = cnt.reshape(-1, 2).astype(np.float32)
|
||||
n = len(points)
|
||||
|
||||
if n < 3:
|
||||
continue # 至少需要3个点才能形成多边形
|
||||
|
||||
expanded_points = []
|
||||
|
||||
for i in range(n):
|
||||
# 获取当前点、前一个点和后一个点
|
||||
p_curr = points[i]
|
||||
p_prev = points[(i - 1) % n]
|
||||
p_next = points[(i + 1) % n]
|
||||
|
||||
# 计算两条边的向量
|
||||
v1 = p_curr - p_prev # 前一条边(从prev到curr)
|
||||
v2 = p_next - p_curr # 后一条边(从curr到next)
|
||||
|
||||
# 归一化
|
||||
norm1 = np.linalg.norm(v1)
|
||||
norm2 = np.linalg.norm(v2)
|
||||
|
||||
if norm1 == 0 or norm2 == 0:
|
||||
# 如果有零向量,直接沿着法线方向扩展
|
||||
edge_dir = np.array([0, 0])
|
||||
if norm1 > 0:
|
||||
edge_dir = v1 / norm1
|
||||
elif norm2 > 0:
|
||||
edge_dir = v2 / norm2
|
||||
normal = np.array([-edge_dir[1], edge_dir[0]])
|
||||
expanded_point = p_curr + normal * stroke_width
|
||||
else:
|
||||
# 归一化向量
|
||||
v1_norm = v1 / norm1
|
||||
v2_norm = v2 / norm2
|
||||
|
||||
# 计算两条边的单位法向量(指向多边形外部)
|
||||
n1 = np.array([-v1_norm[1], v1_norm[0]])
|
||||
n2 = np.array([-v2_norm[1], v2_norm[0]])
|
||||
|
||||
# 计算角平分线方向(两个法向量的和)
|
||||
bisector = n1 + n2
|
||||
|
||||
# 计算平分线的长度
|
||||
bisector_norm = np.linalg.norm(bisector)
|
||||
|
||||
if bisector_norm == 0:
|
||||
# 如果两条边平行(同向或反向),取任一法线方向
|
||||
expanded_point = p_curr + n1 * stroke_width
|
||||
else:
|
||||
# 归一化平分线
|
||||
bisector_normalized = bisector / bisector_norm
|
||||
|
||||
# 计算偏移距离(考虑夹角)
|
||||
# 使用余弦定理计算正确的偏移距离
|
||||
cos_angle = np.dot(v1_norm, v2_norm)
|
||||
angle = np.arccos(np.clip(cos_angle, -1.0, 1.0))
|
||||
|
||||
if abs(np.pi - angle) < 1e-6: # 近似平角
|
||||
# 接近直线的情况
|
||||
offset_distance = stroke_width
|
||||
else:
|
||||
# 计算正确的偏移距离
|
||||
offset_distance = stroke_width / np.sin(angle / 2)
|
||||
|
||||
# 计算扩展点
|
||||
expanded_point = p_curr + bisector_normalized * offset_distance
|
||||
|
||||
expanded_points.append(expanded_point)
|
||||
|
||||
# 将扩展后的点转换为整数坐标
|
||||
expanded_cnt = np.array(expanded_points, dtype=np.float32).reshape(-1, 1, 2)
|
||||
expanded_contours.append(expanded_cnt.astype(np.int32))
|
||||
|
||||
expanded_contours = fix_with_shapely(expanded_contours)
|
||||
|
||||
return expanded_contours
|
||||
23
konabot/plugins/fx_process/types.py
Normal file
23
konabot/plugins/fx_process/types.py
Normal file
@ -0,0 +1,23 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class FilterItem:
|
||||
name: str
|
||||
filter: callable
|
||||
args: list
|
||||
|
||||
class ImageRequireSignal:
|
||||
pass
|
||||
|
||||
class ImagesListRequireSignal:
|
||||
pass
|
||||
|
||||
@dataclass
|
||||
class StoredInfo:
|
||||
name: str
|
||||
|
||||
@dataclass
|
||||
class SenderInfo:
|
||||
group_id: str
|
||||
qq_id: str
|
||||
174
konabot/plugins/handle_text/__init__.py
Normal file
174
konabot/plugins/handle_text/__init__.py
Normal file
@ -0,0 +1,174 @@
|
||||
from typing import cast
|
||||
import asyncio
|
||||
from loguru import logger
|
||||
from nonebot import on_command
|
||||
import nonebot
|
||||
from nonebot.adapters import Event, Bot
|
||||
from nonebot_plugin_alconna import UniMessage, UniMsg
|
||||
from nonebot.adapters.onebot.v11.event import MessageEvent as OB11MessageEvent
|
||||
from nonebot.adapters.onebot.v11.bot import Bot as OB11Bot
|
||||
from nonebot.adapters.onebot.v11.message import Message as OB11Message
|
||||
|
||||
from konabot.common.apis.ali_content_safety import AlibabaGreen
|
||||
from konabot.common.longtask import DepLongTaskTarget
|
||||
from konabot.common.render_error_message import render_error_message
|
||||
from konabot.plugins.handle_text.base import (
|
||||
PipelineRunner,
|
||||
TextHandlerEnvironment,
|
||||
register_text_handlers,
|
||||
)
|
||||
from konabot.plugins.handle_text.handlers.ai_handlers import THQwen
|
||||
from konabot.plugins.handle_text.handlers.encoding_handlers import (
|
||||
THAlign,
|
||||
THAlphaConv,
|
||||
THB64Hex,
|
||||
THBase64,
|
||||
THBaseConv,
|
||||
THCaesar,
|
||||
THMorse,
|
||||
THReverse,
|
||||
)
|
||||
from konabot.plugins.handle_text.handlers.random_handlers import THShuffle, THSorted
|
||||
from konabot.plugins.handle_text.handlers.unix_handlers import (
|
||||
THCat,
|
||||
THEcho,
|
||||
THFalse,
|
||||
THReplace,
|
||||
THRm,
|
||||
THTest,
|
||||
THTrue,
|
||||
)
|
||||
from konabot.plugins.handle_text.handlers.whitespace_handlers import (
|
||||
THLines,
|
||||
THLTrim,
|
||||
THRTrim,
|
||||
THSqueeze,
|
||||
THTrim,
|
||||
)
|
||||
|
||||
|
||||
TEXTFX_MAX_RUNTIME_SECONDS = 60
|
||||
_textfx_running_users: set[str] = set()
|
||||
|
||||
|
||||
def _get_textfx_user_key(evt: Event) -> str:
|
||||
user_id = getattr(evt, "user_id", None)
|
||||
self_id = getattr(evt, "self_id", None)
|
||||
group_id = getattr(evt, "group_id", None)
|
||||
if user_id is not None:
|
||||
if group_id is not None:
|
||||
return f"{self_id}:{group_id}:{user_id}"
|
||||
return f"{self_id}:private:{user_id}"
|
||||
session_id = getattr(evt, "get_session_id", None)
|
||||
if callable(session_id):
|
||||
try:
|
||||
return f"session:{evt.get_session_id()}"
|
||||
except Exception:
|
||||
pass
|
||||
return f"event:{evt.__class__.__name__}:{id(evt)}"
|
||||
|
||||
|
||||
cmd = on_command(cmd="textfx", aliases={"处理文字", "处理文本"})
|
||||
|
||||
|
||||
@cmd.handle()
|
||||
async def _(msg: UniMsg, evt: Event, bot: Bot, target: DepLongTaskTarget):
|
||||
user_key = _get_textfx_user_key(evt)
|
||||
if user_key in _textfx_running_users:
|
||||
await target.send_message("你当前已有一个 textfx 脚本正在运行,请等待它结束后再试。")
|
||||
return
|
||||
|
||||
istream = ""
|
||||
if isinstance(evt, OB11MessageEvent):
|
||||
if evt.reply is not None:
|
||||
istream = evt.reply.message.extract_plain_text()
|
||||
else:
|
||||
for seg in evt.get_message():
|
||||
if seg.type == "reply":
|
||||
msgid = seg.get("id")
|
||||
if msgid is not None:
|
||||
msg2data = await cast(OB11Bot, bot).get_msg(message_id=msgid)
|
||||
istream = OB11Message(
|
||||
msg2data.get("message")
|
||||
).extract_plain_text()
|
||||
|
||||
script = msg.extract_plain_text().removeprefix("textfx").removeprefix("处理文字")
|
||||
runner = PipelineRunner.get_runner()
|
||||
res = runner.parse_pipeline(script)
|
||||
|
||||
if isinstance(res, str):
|
||||
await target.send_message(res)
|
||||
return
|
||||
|
||||
env = TextHandlerEnvironment(is_trusted=False, event=evt)
|
||||
|
||||
_textfx_running_users.add(user_key)
|
||||
try:
|
||||
results = await asyncio.wait_for(
|
||||
runner.run_pipeline(res, istream or None, env),
|
||||
timeout=TEXTFX_MAX_RUNTIME_SECONDS,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
rendered = await render_error_message(
|
||||
f"处理指令时出现问题:脚本执行超时(超过 {TEXTFX_MAX_RUNTIME_SECONDS} 秒)"
|
||||
)
|
||||
await target.send_message(rendered)
|
||||
return
|
||||
finally:
|
||||
_textfx_running_users.discard(user_key)
|
||||
|
||||
for r in results:
|
||||
if r.code != 0:
|
||||
message = f"处理指令时出现问题:{r.ostream}"
|
||||
rendered = await render_error_message(message)
|
||||
await target.send_message(rendered)
|
||||
return
|
||||
|
||||
ostreams = [r.ostream for r in results if r.ostream is not None]
|
||||
attachments = [r.attachment for r in results if r.attachment is not None]
|
||||
|
||||
if ostreams:
|
||||
txt = "\n".join(ostreams)
|
||||
err = await AlibabaGreen.detect(txt)
|
||||
if not err:
|
||||
await target.send_message(
|
||||
"处理指令时出现问题:内容被拦截!请你检查你的内容是否合理!"
|
||||
)
|
||||
return
|
||||
await target.send_message(txt, at=False)
|
||||
|
||||
for att in attachments:
|
||||
await target.send_message(UniMessage.image(raw=att), at=False)
|
||||
|
||||
|
||||
driver = nonebot.get_driver()
|
||||
|
||||
|
||||
@driver.on_startup
|
||||
async def _():
|
||||
register_text_handlers(
|
||||
THCat(),
|
||||
THEcho(),
|
||||
THRm(),
|
||||
THTrue(),
|
||||
THFalse(),
|
||||
THTest(),
|
||||
THShuffle(),
|
||||
THReplace(),
|
||||
THBase64(),
|
||||
THCaesar(),
|
||||
THReverse(),
|
||||
THBaseConv(),
|
||||
THAlphaConv(),
|
||||
THB64Hex(),
|
||||
THAlign(),
|
||||
THSorted(),
|
||||
THMorse(),
|
||||
THQwen(),
|
||||
THTrim(),
|
||||
THLTrim(),
|
||||
THRTrim(),
|
||||
THSqueeze(),
|
||||
THLines(),
|
||||
)
|
||||
logger.info(f"注册了 TextHandler:{PipelineRunner.get_runner().handlers}")
|
||||
587
konabot/plugins/handle_text/base.py
Normal file
587
konabot/plugins/handle_text/base.py
Normal file
@ -0,0 +1,587 @@
|
||||
import asyncio
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import cast
|
||||
|
||||
from loguru import logger
|
||||
from nonebot.adapters import Event
|
||||
|
||||
|
||||
MAX_WHILE_ITERATIONS = 100
|
||||
|
||||
|
||||
@dataclass
|
||||
class TextHandlerEnvironment:
|
||||
is_trusted: bool
|
||||
event: Event | None = None
|
||||
buffers: dict[str, str] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TextHandleResult:
|
||||
code: int
|
||||
ostream: str | None
|
||||
attachment: bytes | None = None
|
||||
|
||||
|
||||
class TextHandler(ABC):
|
||||
name: str = ""
|
||||
keywords: list[str] = []
|
||||
|
||||
@abstractmethod
|
||||
async def handle(
|
||||
self, env: TextHandlerEnvironment, istream: str | None, args: list[str]
|
||||
) -> TextHandleResult: ...
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<{self.__class__.__name__}: {self.name} [{''.join(self.keywords)}]>"
|
||||
|
||||
|
||||
class TextHandlerSync(TextHandler):
|
||||
@abstractmethod
|
||||
def handle_sync(
|
||||
self, env: TextHandlerEnvironment, istream: str | None, args: list[str]
|
||||
) -> TextHandleResult: ...
|
||||
|
||||
async def handle(
|
||||
self, env: TextHandlerEnvironment, istream: str | None, args: list[str]
|
||||
) -> TextHandleResult:
|
||||
def _hs():
|
||||
return self.handle_sync(env, istream, args)
|
||||
|
||||
return await asyncio.to_thread(_hs)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Redirect:
|
||||
target: str
|
||||
append: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class CommandNode:
|
||||
name: str
|
||||
handler: TextHandler
|
||||
args: list[str]
|
||||
redirects: list[Redirect] = field(default_factory=list)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PipelineNode:
|
||||
commands: list[CommandNode] = field(default_factory=list)
|
||||
negate: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConditionalPipeline:
|
||||
op: str | None
|
||||
pipeline: PipelineNode
|
||||
|
||||
|
||||
@dataclass
|
||||
class CommandGroup:
|
||||
chains: list[ConditionalPipeline] = field(default_factory=list)
|
||||
|
||||
|
||||
@dataclass
|
||||
class IfNode:
|
||||
condition: CommandGroup
|
||||
then_body: "Script"
|
||||
else_body: "Script | None" = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class WhileNode:
|
||||
condition: CommandGroup
|
||||
body: "Script"
|
||||
|
||||
|
||||
@dataclass
|
||||
class Script:
|
||||
statements: list[CommandGroup | IfNode | WhileNode] = field(default_factory=list)
|
||||
|
||||
|
||||
class TokenKind(Enum):
|
||||
WORD = "word"
|
||||
OP = "op"
|
||||
|
||||
|
||||
@dataclass
|
||||
class Token:
|
||||
kind: TokenKind
|
||||
value: str
|
||||
|
||||
|
||||
class PipelineRunner:
|
||||
handlers: list[TextHandler]
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.handlers = []
|
||||
|
||||
@staticmethod
|
||||
def get_runner():
|
||||
if "singleton" not in PipelineRunner.__annotations__:
|
||||
PipelineRunner.__annotations__["singleton"] = PipelineRunner()
|
||||
return cast(PipelineRunner, PipelineRunner.__annotations__.get("singleton"))
|
||||
|
||||
def register(self, handler: TextHandler):
|
||||
self.handlers.append(handler)
|
||||
|
||||
def _resolve_handler(self, cmd_name: str) -> TextHandler | str:
|
||||
matched = [
|
||||
h for h in self.handlers if cmd_name == h.name or cmd_name in h.keywords
|
||||
]
|
||||
if not matched:
|
||||
return f"不存在名为 {cmd_name} 的函数"
|
||||
if len(matched) > 1:
|
||||
logger.warning(
|
||||
f"指令能对应超过一个文本处理器 CMD={cmd_name} handlers={self.handlers}"
|
||||
)
|
||||
return matched[0]
|
||||
|
||||
def tokenize(self, script: str) -> list[Token] | str:
|
||||
tokens: list[Token] = []
|
||||
buf = ""
|
||||
quote: str | None = None
|
||||
escape = False
|
||||
i = 0
|
||||
operators = {"|", ";", ">", "&&", "||", ">>", "!"}
|
||||
escape_map = {
|
||||
"n": "\n",
|
||||
"r": "\r",
|
||||
"t": "\t",
|
||||
"0": "\0",
|
||||
"a": "\a",
|
||||
"b": "\b",
|
||||
"f": "\f",
|
||||
"v": "\v",
|
||||
"\\": "\\",
|
||||
'"': '"',
|
||||
"'": "'",
|
||||
}
|
||||
|
||||
def flush_word(force: bool = False):
|
||||
nonlocal buf
|
||||
if buf or force:
|
||||
tokens.append(Token(TokenKind.WORD, buf))
|
||||
buf = ""
|
||||
|
||||
while i < len(script):
|
||||
c = script[i]
|
||||
|
||||
if quote is not None:
|
||||
if escape:
|
||||
buf += escape_map.get(c, c)
|
||||
escape = False
|
||||
elif c == "\\":
|
||||
escape = True
|
||||
elif c == quote:
|
||||
quote = None
|
||||
flush_word(force=True) # 引号闭合时强制 flush,即使是空字符串
|
||||
else:
|
||||
buf += c
|
||||
i += 1
|
||||
continue
|
||||
|
||||
if c in "'\"":
|
||||
quote = c
|
||||
i += 1
|
||||
continue
|
||||
|
||||
if c.isspace():
|
||||
flush_word()
|
||||
i += 1
|
||||
continue
|
||||
|
||||
two = script[i : i + 2]
|
||||
if two in operators:
|
||||
flush_word()
|
||||
tokens.append(Token(TokenKind.OP, two))
|
||||
i += 2
|
||||
continue
|
||||
|
||||
if c in {"|", ";", ">", "!"}:
|
||||
flush_word()
|
||||
tokens.append(Token(TokenKind.OP, c))
|
||||
i += 1
|
||||
continue
|
||||
|
||||
if c == "\\":
|
||||
if i + 1 < len(script):
|
||||
i += 1
|
||||
buf += escape_map.get(script[i], script[i])
|
||||
else:
|
||||
buf += c
|
||||
i += 1
|
||||
continue
|
||||
|
||||
buf += c
|
||||
i += 1
|
||||
|
||||
if quote is not None:
|
||||
return "存在未闭合的引号"
|
||||
if escape:
|
||||
buf += "\\"
|
||||
|
||||
flush_word()
|
||||
return tokens
|
||||
|
||||
def parse_pipeline(self, script: str) -> Script | str:
|
||||
tokens = self.tokenize(script)
|
||||
if isinstance(tokens, str):
|
||||
return tokens
|
||||
if not tokens:
|
||||
return Script()
|
||||
|
||||
pos = 0
|
||||
|
||||
def peek(offset: int = 0) -> Token | None:
|
||||
idx = pos + offset
|
||||
return tokens[idx] if idx < len(tokens) else None
|
||||
|
||||
def consume() -> Token:
|
||||
nonlocal pos
|
||||
tok = tokens[pos]
|
||||
pos += 1
|
||||
return tok
|
||||
|
||||
def consume_if_op(value: str) -> bool:
|
||||
tok = peek()
|
||||
if tok is not None and tok.kind == TokenKind.OP and tok.value == value:
|
||||
consume()
|
||||
return True
|
||||
return False
|
||||
|
||||
def consume_if_word(value: str) -> bool:
|
||||
tok = peek()
|
||||
if tok is not None and tok.kind == TokenKind.WORD and tok.value == value:
|
||||
consume()
|
||||
return True
|
||||
return False
|
||||
|
||||
def expect_word(msg: str) -> Token | str:
|
||||
tok = peek()
|
||||
if tok is None or tok.kind != TokenKind.WORD:
|
||||
return msg
|
||||
return consume()
|
||||
|
||||
def parse_command() -> CommandNode | str:
|
||||
first = expect_word("缺少指令名")
|
||||
if isinstance(first, str):
|
||||
return first
|
||||
|
||||
handler = self._resolve_handler(first.value)
|
||||
if isinstance(handler, str):
|
||||
return handler
|
||||
|
||||
args: list[str] = []
|
||||
redirects: list[Redirect] = []
|
||||
|
||||
while True:
|
||||
tok = peek()
|
||||
if tok is None:
|
||||
break
|
||||
if tok.kind == TokenKind.OP and tok.value in {"|", ";", "&&", "||"}:
|
||||
break
|
||||
if tok.kind == TokenKind.OP and tok.value in {">", ">>"}:
|
||||
op_tok = consume()
|
||||
target = expect_word("重定向操作符后面需要缓存名")
|
||||
if isinstance(target, str):
|
||||
return target
|
||||
redirects.append(
|
||||
Redirect(target=target.value, append=op_tok.value == ">>")
|
||||
)
|
||||
continue
|
||||
if tok.kind != TokenKind.WORD:
|
||||
return f"无法解析的 token: {tok.value}"
|
||||
args.append(consume().value)
|
||||
|
||||
return CommandNode(
|
||||
name=first.value,
|
||||
handler=handler,
|
||||
args=args,
|
||||
redirects=redirects,
|
||||
)
|
||||
|
||||
def parse_pipe() -> PipelineNode | str:
|
||||
negate = False
|
||||
while consume_if_op("!"):
|
||||
negate = not negate
|
||||
|
||||
pipeline = PipelineNode(negate=negate)
|
||||
command = parse_command()
|
||||
if isinstance(command, str):
|
||||
return command
|
||||
pipeline.commands.append(command)
|
||||
|
||||
while True:
|
||||
tok = peek()
|
||||
if tok is None or tok.kind != TokenKind.OP or tok.value != "|":
|
||||
break
|
||||
consume()
|
||||
next_command = parse_command()
|
||||
if isinstance(next_command, str):
|
||||
return next_command
|
||||
pipeline.commands.append(next_command)
|
||||
|
||||
return pipeline
|
||||
|
||||
def parse_chain() -> CommandGroup | str:
|
||||
group = CommandGroup()
|
||||
first_pipeline = parse_pipe()
|
||||
if isinstance(first_pipeline, str):
|
||||
return first_pipeline
|
||||
group.chains.append(ConditionalPipeline(op=None, pipeline=first_pipeline))
|
||||
|
||||
while True:
|
||||
tok = peek()
|
||||
if tok is None or tok.kind != TokenKind.OP or tok.value not in {"&&", "||"}:
|
||||
break
|
||||
op = consume().value
|
||||
next_pipeline = parse_pipe()
|
||||
if isinstance(next_pipeline, str):
|
||||
return next_pipeline
|
||||
group.chains.append(ConditionalPipeline(op=op, pipeline=next_pipeline))
|
||||
|
||||
return group
|
||||
|
||||
def parse_if() -> IfNode | str:
|
||||
if not consume_if_word("if"):
|
||||
return "缺少 if"
|
||||
|
||||
condition = parse_chain()
|
||||
if isinstance(condition, str):
|
||||
return condition
|
||||
|
||||
consume_if_op(";")
|
||||
if not consume_if_word("then"):
|
||||
return "if 语句缺少 then"
|
||||
|
||||
then_body = parse_script(stop_words={"else", "fi"})
|
||||
if isinstance(then_body, str):
|
||||
return then_body
|
||||
|
||||
else_body: Script | None = None
|
||||
if consume_if_word("else"):
|
||||
else_body = parse_script(stop_words={"fi"})
|
||||
if isinstance(else_body, str):
|
||||
return else_body
|
||||
|
||||
if not consume_if_word("fi"):
|
||||
return "if 语句缺少 fi"
|
||||
|
||||
return IfNode(condition=condition, then_body=then_body, else_body=else_body)
|
||||
|
||||
def parse_while() -> WhileNode | str:
|
||||
if not consume_if_word("while"):
|
||||
return "缺少 while"
|
||||
|
||||
condition = parse_chain()
|
||||
if isinstance(condition, str):
|
||||
return condition
|
||||
|
||||
consume_if_op(";")
|
||||
if not consume_if_word("do"):
|
||||
return "while 语句缺少 do"
|
||||
|
||||
body = parse_script(stop_words={"done"})
|
||||
if isinstance(body, str):
|
||||
return body
|
||||
|
||||
if not consume_if_word("done"):
|
||||
return "while 语句缺少 done"
|
||||
|
||||
return WhileNode(condition=condition, body=body)
|
||||
|
||||
def parse_statement() -> CommandGroup | IfNode | WhileNode | str:
|
||||
tok = peek()
|
||||
if tok is not None and tok.kind == TokenKind.WORD:
|
||||
if tok.value == "if":
|
||||
return parse_if()
|
||||
if tok.value == "while":
|
||||
return parse_while()
|
||||
return parse_chain()
|
||||
|
||||
def parse_script(stop_words: set[str] | None = None) -> Script | str:
|
||||
parsed = Script()
|
||||
nonlocal pos
|
||||
|
||||
while pos < len(tokens):
|
||||
tok = peek()
|
||||
if tok is None:
|
||||
break
|
||||
|
||||
if stop_words and tok.kind == TokenKind.WORD and tok.value in stop_words:
|
||||
break
|
||||
|
||||
if tok.kind == TokenKind.OP and tok.value == ";":
|
||||
consume()
|
||||
continue
|
||||
|
||||
statement = parse_statement()
|
||||
if isinstance(statement, str):
|
||||
return statement
|
||||
parsed.statements.append(statement)
|
||||
|
||||
tok = peek()
|
||||
if tok is not None and tok.kind == TokenKind.OP and tok.value == ";":
|
||||
consume()
|
||||
|
||||
return parsed
|
||||
|
||||
parsed = parse_script()
|
||||
if isinstance(parsed, str):
|
||||
return parsed
|
||||
if pos != len(tokens):
|
||||
tok = tokens[pos]
|
||||
return f"无法解析的 token: {tok.value}"
|
||||
return parsed
|
||||
|
||||
async def _execute_command(
|
||||
self,
|
||||
command: CommandNode,
|
||||
istream: str | None,
|
||||
env: TextHandlerEnvironment,
|
||||
) -> TextHandleResult:
|
||||
logger.debug(
|
||||
f"Executing: {command.name} args={command.args} redirects={command.redirects}"
|
||||
)
|
||||
result = await command.handler.handle(env, istream, command.args)
|
||||
|
||||
if result.code != 0:
|
||||
return result
|
||||
|
||||
if command.redirects:
|
||||
content = result.ostream or ""
|
||||
for redirect in command.redirects:
|
||||
if redirect.append:
|
||||
old_content = env.buffers.get(redirect.target, "")
|
||||
env.buffers[redirect.target] = old_content + content
|
||||
else:
|
||||
env.buffers[redirect.target] = content
|
||||
return TextHandleResult(code=0, ostream=None, attachment=result.attachment)
|
||||
|
||||
return result
|
||||
|
||||
async def _execute_pipeline(
|
||||
self,
|
||||
pipeline: PipelineNode,
|
||||
istream: str | None,
|
||||
env: TextHandlerEnvironment,
|
||||
) -> TextHandleResult:
|
||||
current_stream = istream
|
||||
last_result = TextHandleResult(code=0, ostream=None)
|
||||
|
||||
for command in pipeline.commands:
|
||||
try:
|
||||
last_result = await self._execute_command(command, current_stream, env)
|
||||
except Exception as e:
|
||||
logger.error(f"Pipeline execution failed at {command.name}")
|
||||
logger.exception(e)
|
||||
return TextHandleResult(code=-1, ostream="处理流水线时出现 python 错误")
|
||||
|
||||
if last_result.code != 0:
|
||||
if pipeline.negate:
|
||||
return TextHandleResult(code=0, ostream=None)
|
||||
return last_result
|
||||
current_stream = last_result.ostream
|
||||
|
||||
if pipeline.negate:
|
||||
return TextHandleResult(code=1, ostream=None)
|
||||
return last_result
|
||||
|
||||
async def _execute_group(
|
||||
self,
|
||||
group: CommandGroup,
|
||||
istream: str | None,
|
||||
env: TextHandlerEnvironment,
|
||||
) -> TextHandleResult:
|
||||
last_result = TextHandleResult(code=0, ostream=None)
|
||||
|
||||
for chain in group.chains:
|
||||
should_run = True
|
||||
if chain.op == "&&":
|
||||
should_run = last_result.code == 0
|
||||
elif chain.op == "||":
|
||||
should_run = last_result.code != 0
|
||||
|
||||
if should_run:
|
||||
last_result = await self._execute_pipeline(chain.pipeline, istream, env)
|
||||
|
||||
return last_result
|
||||
|
||||
async def _execute_if(
|
||||
self,
|
||||
if_node: IfNode,
|
||||
istream: str | None,
|
||||
env: TextHandlerEnvironment,
|
||||
) -> TextHandleResult:
|
||||
condition_result = await self._execute_group(if_node.condition, istream, env)
|
||||
if condition_result.code == 0:
|
||||
results = await self.run_pipeline(if_node.then_body, istream, env)
|
||||
else:
|
||||
results = (
|
||||
await self.run_pipeline(if_node.else_body, istream, env)
|
||||
if if_node.else_body is not None
|
||||
else [TextHandleResult(code=0, ostream=None)]
|
||||
)
|
||||
return results[-1] if results else TextHandleResult(code=0, ostream=None)
|
||||
|
||||
async def _execute_while(
|
||||
self,
|
||||
while_node: WhileNode,
|
||||
istream: str | None,
|
||||
env: TextHandlerEnvironment,
|
||||
) -> TextHandleResult:
|
||||
last_result = TextHandleResult(code=0, ostream=None)
|
||||
|
||||
for _ in range(MAX_WHILE_ITERATIONS):
|
||||
condition_result = await self._execute_group(while_node.condition, istream, env)
|
||||
if condition_result.code != 0:
|
||||
return last_result
|
||||
|
||||
body_results = await self.run_pipeline(while_node.body, istream, env)
|
||||
if body_results:
|
||||
last_result = body_results[-1]
|
||||
if last_result.code != 0:
|
||||
return last_result
|
||||
|
||||
return TextHandleResult(
|
||||
code=2,
|
||||
ostream=f"while 循环超过最大迭代次数限制({MAX_WHILE_ITERATIONS})",
|
||||
)
|
||||
|
||||
async def run_pipeline(
|
||||
self,
|
||||
pipeline: Script,
|
||||
istream: str | None,
|
||||
env: TextHandlerEnvironment | None = None,
|
||||
) -> list[TextHandleResult]:
|
||||
if env is None:
|
||||
env = TextHandlerEnvironment(is_trusted=False, event=None, buffers={})
|
||||
|
||||
results: list[TextHandleResult] = []
|
||||
|
||||
for statement in pipeline.statements:
|
||||
try:
|
||||
if isinstance(statement, IfNode):
|
||||
results.append(await self._execute_if(statement, istream, env))
|
||||
elif isinstance(statement, WhileNode):
|
||||
results.append(await self._execute_while(statement, istream, env))
|
||||
else:
|
||||
results.append(await self._execute_group(statement, istream, env))
|
||||
except Exception as e:
|
||||
logger.error(f"Pipeline execution failed: {e}")
|
||||
logger.exception(e)
|
||||
results.append(
|
||||
TextHandleResult(code=-1, ostream="处理流水线时出现 python 错误")
|
||||
)
|
||||
return results
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def register_text_handlers(*handlers: TextHandler):
|
||||
for handler in handlers:
|
||||
PipelineRunner.get_runner().register(handler)
|
||||
61
konabot/plugins/handle_text/handlers/ai_handlers.py
Normal file
61
konabot/plugins/handle_text/handlers/ai_handlers.py
Normal file
@ -0,0 +1,61 @@
|
||||
from typing import Any, cast
|
||||
from konabot.common.llm import get_llm
|
||||
from konabot.common.permsys import perm_manager
|
||||
from konabot.plugins.handle_text.base import (
|
||||
TextHandler,
|
||||
TextHandlerEnvironment,
|
||||
TextHandleResult,
|
||||
)
|
||||
|
||||
|
||||
class THQwen(TextHandler):
|
||||
name = "qwen"
|
||||
|
||||
async def handle(
|
||||
self, env: TextHandlerEnvironment, istream: str | None, args: list[str]
|
||||
) -> TextHandleResult:
|
||||
pm = perm_manager()
|
||||
if env.event is None or not await pm.check_has_permission(
|
||||
env.event, "textfx.qwen"
|
||||
):
|
||||
return TextHandleResult(
|
||||
code=1,
|
||||
ostream="你或当前环境没有使用 qwen 的权限。如有疑问请联系管理员",
|
||||
)
|
||||
|
||||
llm = get_llm()
|
||||
messages = []
|
||||
|
||||
if istream is not None:
|
||||
messages.append({"role": "user", "content": istream})
|
||||
if len(args) > 0:
|
||||
message = " ".join(args)
|
||||
messages.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": message,
|
||||
}
|
||||
)
|
||||
if len(messages) == 0:
|
||||
return TextHandleResult(
|
||||
code=1,
|
||||
ostream="使用方法:qwen <提示词>",
|
||||
)
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "除非用户要求,请尽可能短点回答。另外,当前环境不支持 Markdown 语法,如果可以,请使用纯文本回答",
|
||||
}
|
||||
] + messages
|
||||
result = await llm.chat(cast(Any, messages))
|
||||
content = result.content
|
||||
if content is None:
|
||||
return TextHandleResult(
|
||||
code=500,
|
||||
ostream="问 AI 的时候发生了未知的错误",
|
||||
)
|
||||
return TextHandleResult(
|
||||
code=0,
|
||||
ostream=content,
|
||||
)
|
||||
346
konabot/plugins/handle_text/handlers/encoding_handlers.py
Normal file
346
konabot/plugins/handle_text/handlers/encoding_handlers.py
Normal file
@ -0,0 +1,346 @@
|
||||
import base64
|
||||
from konabot.plugins.handle_text.base import (
|
||||
TextHandleResult,
|
||||
TextHandler,
|
||||
TextHandlerEnvironment,
|
||||
)
|
||||
|
||||
|
||||
class THBase64(TextHandler):
|
||||
name = "b64"
|
||||
keywords = ["base64"]
|
||||
|
||||
async def handle(
|
||||
self, env: TextHandlerEnvironment, istream: str | None, args: list[str]
|
||||
) -> TextHandleResult:
|
||||
# 用法: b64 encode/decode [encoding] [text]
|
||||
if not args and istream is None:
|
||||
return TextHandleResult(
|
||||
1, "用法:b64 <encode|decode> [编码, 默认utf-8] [文本]"
|
||||
)
|
||||
|
||||
mode = args[0].lower() if args else "encode"
|
||||
encoding = args[1] if len(args) > 1 else "utf-8"
|
||||
|
||||
# 确定输入源
|
||||
text = (
|
||||
istream
|
||||
if istream is not None
|
||||
else (" ".join(args[2:]) if len(args) > 2 else "")
|
||||
)
|
||||
if not text:
|
||||
return TextHandleResult(1, "输入文本为空")
|
||||
|
||||
try:
|
||||
if mode == "encode":
|
||||
res = base64.b64encode(text.encode(encoding, "replace")).decode("ascii")
|
||||
else:
|
||||
res = base64.b64decode(text.encode("ascii")).decode(encoding, "replace")
|
||||
return TextHandleResult(0, res)
|
||||
except Exception as e:
|
||||
return TextHandleResult(1, f"Base64 转换失败: {str(e)}")
|
||||
|
||||
|
||||
class THCaesar(TextHandler):
|
||||
name = "caesar"
|
||||
keywords = ["凯撒", "rot"]
|
||||
|
||||
async def handle(
|
||||
self, env: TextHandlerEnvironment, istream: str | None, args: list[str]
|
||||
) -> TextHandleResult:
|
||||
# 用法: caesar <shift> [text]
|
||||
shift = int(args[0]) if args else 13
|
||||
text = (
|
||||
istream
|
||||
if istream is not None
|
||||
else (" ".join(args[1:]) if len(args) > 1 else "")
|
||||
)
|
||||
|
||||
def _shift(char):
|
||||
if not char.isalpha():
|
||||
return char
|
||||
start = ord("A") if char.isupper() else ord("a")
|
||||
return chr((ord(char) - start + shift) % 26 + start)
|
||||
|
||||
res = "".join(_shift(c) for c in text)
|
||||
return TextHandleResult(0, res)
|
||||
|
||||
|
||||
class THReverse(TextHandler):
|
||||
name = "reverse"
|
||||
keywords = ["rev", "反转"]
|
||||
|
||||
async def handle(
|
||||
self, env: TextHandlerEnvironment, istream: str | None, args: list[str]
|
||||
) -> TextHandleResult:
|
||||
text = istream if istream is not None else (" ".join(args) if args else "")
|
||||
return TextHandleResult(0, text[::-1])
|
||||
|
||||
|
||||
class THMorse(TextHandler):
|
||||
name = "morse"
|
||||
keywords = ["摩斯", "decode_morse"]
|
||||
|
||||
# 国际摩斯电码表 (部分)
|
||||
MORSE_EN = {
|
||||
".-": "A",
|
||||
"-...": "B",
|
||||
"-.-.": "C",
|
||||
"-..": "D",
|
||||
".": "E",
|
||||
"..-.": "F",
|
||||
"--.": "G",
|
||||
"....": "H",
|
||||
"..": "I",
|
||||
".---": "J",
|
||||
"-.-": "K",
|
||||
".-..": "L",
|
||||
"--": "M",
|
||||
"-.": "N",
|
||||
"---": "O",
|
||||
".--.": "P",
|
||||
"--.-": "Q",
|
||||
".-.": "R",
|
||||
"...": "S",
|
||||
"-": "T",
|
||||
"..-": "U",
|
||||
"...-": "V",
|
||||
".--": "W",
|
||||
"-..-": "X",
|
||||
"-.--": "Y",
|
||||
"--..": "Z",
|
||||
"-----": "0",
|
||||
".----": "1",
|
||||
"..---": "2",
|
||||
"...--": "3",
|
||||
"....-": "4",
|
||||
".....": "5",
|
||||
"-....": "6",
|
||||
"--...": "7",
|
||||
"---..": "8",
|
||||
"----.": "9",
|
||||
"/": " ",
|
||||
}
|
||||
|
||||
# 日文和文摩斯电码表 (Wabun Code)
|
||||
MORSE_JP = {
|
||||
"--.--": "ア",
|
||||
".-": "イ",
|
||||
"..-": "ウ",
|
||||
"-.---": "エ",
|
||||
".-...": "オ",
|
||||
".-..": "カ",
|
||||
"-.-..": "キ",
|
||||
"...-": "ク",
|
||||
"-.--": "ケ",
|
||||
"----": "コ",
|
||||
"-.-.-": "サ",
|
||||
"--.-.": "シ",
|
||||
"---.-": "ス",
|
||||
".---.": "セ",
|
||||
"---.": "ソ",
|
||||
"-.": "タ",
|
||||
"..-.": "チ",
|
||||
".--.": "ツ",
|
||||
".-.--": "テ",
|
||||
"..-..": "ト",
|
||||
".-.": "ナ",
|
||||
"-.-.": "ニ",
|
||||
"....": "ヌ",
|
||||
"--.-": "ネ",
|
||||
"..--": "ノ",
|
||||
"-...": "ハ",
|
||||
"--..-": "ヒ",
|
||||
"--..": "フ",
|
||||
".": "ヘ",
|
||||
"-..": "ホ",
|
||||
"-..-": "マ",
|
||||
"..-.-": "ミ",
|
||||
"-": "ム",
|
||||
"-...-": "メ",
|
||||
"-..-.": "モ",
|
||||
".--": "ヤ",
|
||||
"-..--": "ユ",
|
||||
"--": "ヨ",
|
||||
"...": "ラ",
|
||||
"--.": "リ",
|
||||
"-.--.": "ル",
|
||||
"---": "レ",
|
||||
".-.-": "ロ",
|
||||
"-.-": "ワ",
|
||||
".-..-": "ヰ",
|
||||
".--..": "ヱ",
|
||||
".---": "ヲ",
|
||||
".-.-.": "ン",
|
||||
"-..-.--.": "ッ",
|
||||
"-..-.--": "ャ",
|
||||
"-..--..--": "ュ",
|
||||
"-..---": "ョ",
|
||||
"-..---.--": "ァ",
|
||||
"-..-.-": "ィ",
|
||||
"-..-..-": "ゥ",
|
||||
"-..--.---": "ェ",
|
||||
"-..-.-...": "ォ",
|
||||
"-..-.-..": "ヵ",
|
||||
"-..--.--": "ヶ",
|
||||
"..": "゛",
|
||||
"..--.": "゜",
|
||||
".--.-": "ー",
|
||||
".-.-.-": "、",
|
||||
".-.-..": "。",
|
||||
"-.--.-": "(",
|
||||
".-..-.": ")",
|
||||
}
|
||||
|
||||
async def handle(
|
||||
self, env: TextHandlerEnvironment, istream: str | None, args: list[str]
|
||||
) -> TextHandleResult:
|
||||
"""
|
||||
用法: morse <mode: en|jp> [text]
|
||||
例子: morse en .... . .-.. .-.. ---
|
||||
"""
|
||||
if not args and istream is None:
|
||||
return TextHandleResult(
|
||||
1, "用法:morse <en|jp> <电码>。使用空格分隔字符,/ 分隔单词。"
|
||||
)
|
||||
|
||||
mode = args[0].lower() if args else "en"
|
||||
text = (
|
||||
istream
|
||||
if istream is not None
|
||||
else (" ".join(args[1:]) if len(args) > 1 else "")
|
||||
)
|
||||
|
||||
if not text:
|
||||
return TextHandleResult(1, "请输入电码内容")
|
||||
|
||||
# 选择词典
|
||||
mapping = self.MORSE_JP if mode == "jp" else self.MORSE_EN
|
||||
|
||||
try:
|
||||
# 按空格切分符号,过滤掉多余空位
|
||||
tokens = [t for t in text.split(" ") if t]
|
||||
decoded = []
|
||||
|
||||
for token in tokens:
|
||||
# 处理部分解谜中可能出现的换行或特殊斜杠
|
||||
token = token.strip()
|
||||
if token in mapping:
|
||||
decoded.append(mapping[token])
|
||||
else:
|
||||
decoded.append("[?]") # 无法识别的符号
|
||||
|
||||
return TextHandleResult(0, "".join(decoded))
|
||||
except Exception as e:
|
||||
return TextHandleResult(1, f"摩斯电码解析出错: {str(e)}")
|
||||
|
||||
|
||||
class THBaseConv(TextHandler):
|
||||
name = "baseconv"
|
||||
keywords = ["进制转换"]
|
||||
|
||||
async def handle(self, env: TextHandlerEnvironment, istream: str | None, args: list[str]) -> TextHandleResult:
|
||||
# 用法: baseconv <src_base> <dst_base> [text]
|
||||
if len(args) < 2 and istream is None:
|
||||
return TextHandleResult(1, "用法:baseconv <原进制> <目标进制> [文本]")
|
||||
|
||||
src_base = int(args[0])
|
||||
dst_base = int(args[1])
|
||||
val_str = istream if istream is not None else "".join(args[2:])
|
||||
|
||||
try:
|
||||
# 先转为 10 进制中间量,再转为目标进制
|
||||
decimal_val = int(val_str, src_base)
|
||||
|
||||
if dst_base == 10:
|
||||
res = str(decimal_val)
|
||||
elif dst_base == 16:
|
||||
res = hex(decimal_val)[2:]
|
||||
else:
|
||||
# 通用任意进制转换逻辑
|
||||
chars = "0123456789abcdefghijklmnopqrstuvwxyz"
|
||||
res = ""
|
||||
temp = decimal_val
|
||||
while temp > 0:
|
||||
res = chars[temp % dst_base] + res
|
||||
temp //= dst_base
|
||||
res = res or "0"
|
||||
|
||||
return TextHandleResult(0, res.upper() if dst_base == 16 else res)
|
||||
except Exception as e:
|
||||
return TextHandleResult(1, f"转换失败: {str(e)}")
|
||||
|
||||
|
||||
class THAlphaConv(TextHandler):
|
||||
name = "alphaconv"
|
||||
keywords = ["字母表转换"]
|
||||
|
||||
async def handle(self, env: TextHandlerEnvironment, istream: str | None, args: list[str]) -> TextHandleResult:
|
||||
# 用法: alphaconv <alphabet> <to_hex|from_hex> [text]
|
||||
if len(args) < 2:
|
||||
return TextHandleResult(1, "用法:alphaconv <字母表> <to_hex|from_hex> [文本]")
|
||||
|
||||
alphabet = args[0]
|
||||
mode = args[1].lower()
|
||||
base = len(alphabet)
|
||||
text = istream if istream is not None else "".join(args[2:])
|
||||
|
||||
try:
|
||||
if mode == "to_hex":
|
||||
# 自定义字母表 -> 10进制 -> 16进制
|
||||
val = 0
|
||||
for char in text:
|
||||
val = val * base + alphabet.index(char)
|
||||
return TextHandleResult(0, hex(val)[2:])
|
||||
else:
|
||||
# 16进制 -> 10进制 -> 自定义字母表
|
||||
val = int(text, 16)
|
||||
res = ""
|
||||
while val > 0:
|
||||
res = alphabet[val % base] + res
|
||||
val //= base
|
||||
return TextHandleResult(0, res or alphabet[0])
|
||||
except Exception as e:
|
||||
return TextHandleResult(1, f"字母表转换失败: {str(e)}")
|
||||
|
||||
|
||||
class THB64Hex(TextHandler):
|
||||
name = "b64hex"
|
||||
|
||||
async def handle(self, env: TextHandlerEnvironment, istream: str | None, args: list[str]) -> TextHandleResult:
|
||||
# 用法: b64hex <enc|dec> [text]
|
||||
mode = args[0] if args else "dec"
|
||||
text = istream if istream is not None else "".join(args[1:])
|
||||
|
||||
try:
|
||||
if mode == "enc": # Hex -> B64
|
||||
raw_bytes = bytes.fromhex(text)
|
||||
res = base64.b64encode(raw_bytes).decode()
|
||||
else: # B64 -> Hex
|
||||
raw_bytes = base64.b64decode(text)
|
||||
res = raw_bytes.hex()
|
||||
return TextHandleResult(0, res)
|
||||
except Exception as e:
|
||||
return TextHandleResult(1, f"Base64-Hex 转换失败: {str(e)}")
|
||||
|
||||
|
||||
class THAlign(TextHandler):
|
||||
name = "align"
|
||||
keywords = ["format", "排版"]
|
||||
|
||||
async def handle(self, env: TextHandlerEnvironment, istream: str | None, args: list[str]) -> TextHandleResult:
|
||||
# 用法: align <n:每组长度> <m:每行组数> [text]
|
||||
# 例子: align 2 8 (即 2个一组,8组一行,类似 0011 2233...)
|
||||
n = int(args[0]) if len(args) > 0 else 2
|
||||
m = int(args[1]) if len(args) > 1 else 8
|
||||
text = istream if istream is not None else "".join(args[2:])
|
||||
|
||||
# 移除现有空格换行以便重新排版
|
||||
text = "".join(text.split())
|
||||
|
||||
chunks = [text[i:i+n] for i in range(0, len(text), n)]
|
||||
lines = []
|
||||
for i in range(0, len(chunks), m):
|
||||
lines.append(" ".join(chunks[i:i+m]))
|
||||
|
||||
return TextHandleResult(0, "\n".join(lines))
|
||||
37
konabot/plugins/handle_text/handlers/random_handlers.py
Normal file
37
konabot/plugins/handle_text/handlers/random_handlers.py
Normal file
@ -0,0 +1,37 @@
|
||||
import random
|
||||
from konabot.plugins.handle_text.base import TextHandleResult, TextHandler, TextHandlerEnvironment, TextHandlerSync
|
||||
|
||||
|
||||
class THShuffle(TextHandler):
|
||||
name: str = "shuffle"
|
||||
keywords: list = ["打乱"]
|
||||
|
||||
async def handle(self, env: TextHandlerEnvironment, istream: str | None, args: list[str]) -> TextHandleResult:
|
||||
if istream is not None:
|
||||
w = istream
|
||||
elif len(args) == 0:
|
||||
return TextHandleResult(1, "使用方法:打乱 <待打乱的文本>,或者使用管道符传入待打乱的文本")
|
||||
else:
|
||||
w = args[0]
|
||||
args = args[1:]
|
||||
|
||||
w = [*w]
|
||||
random.shuffle(w)
|
||||
return TextHandleResult(0, ''.join(w))
|
||||
|
||||
|
||||
class THSorted(TextHandlerSync):
|
||||
name = "sort"
|
||||
keywords = ["排序"]
|
||||
|
||||
def handle_sync(self, env: TextHandlerEnvironment, istream: str | None, args: list[str]) -> TextHandleResult:
|
||||
if istream is not None:
|
||||
w = istream
|
||||
elif len(args) == 0:
|
||||
return TextHandleResult(1, "使用方法:排序 <待排序的文本>,或者使用管道符传入待打乱的文本")
|
||||
else:
|
||||
w = args[0]
|
||||
args = args[1:]
|
||||
|
||||
return TextHandleResult(0, ''.join(sorted([*w])))
|
||||
|
||||
161
konabot/plugins/handle_text/handlers/unix_handlers.py
Normal file
161
konabot/plugins/handle_text/handlers/unix_handlers.py
Normal file
@ -0,0 +1,161 @@
|
||||
import re
|
||||
|
||||
from konabot.plugins.handle_text.base import (
|
||||
TextHandleResult,
|
||||
TextHandler,
|
||||
TextHandlerEnvironment,
|
||||
)
|
||||
|
||||
|
||||
class THEcho(TextHandler):
|
||||
name = "echo"
|
||||
|
||||
async def handle(
|
||||
self, env: TextHandlerEnvironment, istream: str | None, args: list[str]
|
||||
) -> TextHandleResult:
|
||||
# echo 不读 stdin,只输出参数(Unix 语义)
|
||||
# 无参数时输出空行(与 Unix echo 行为一致)
|
||||
return TextHandleResult(0, "\n".join(args))
|
||||
|
||||
|
||||
class THCat(TextHandler):
|
||||
name = "cat"
|
||||
|
||||
async def handle(
|
||||
self, env: TextHandlerEnvironment, istream: str | None, args: list[str]
|
||||
) -> TextHandleResult:
|
||||
if len(args) == 0:
|
||||
if istream is None:
|
||||
return TextHandleResult(
|
||||
1,
|
||||
"cat 使用方法:cat [缓存名 ...]\n使用 - 代表标准输入,可拼接多个缓存",
|
||||
)
|
||||
return TextHandleResult(0, istream)
|
||||
|
||||
parts: list[str] = []
|
||||
for arg in args:
|
||||
if arg == "-":
|
||||
if istream is None:
|
||||
return TextHandleResult(2, "标准输入为空(没有管道输入或回复消息)")
|
||||
parts.append(istream)
|
||||
else:
|
||||
if arg not in env.buffers:
|
||||
return TextHandleResult(2, f"缓存 {arg} 不存在")
|
||||
parts.append(env.buffers[arg])
|
||||
|
||||
return TextHandleResult(0, "\n".join(parts))
|
||||
|
||||
|
||||
class THRm(TextHandler):
|
||||
name = "rm"
|
||||
|
||||
async def handle(
|
||||
self, env: TextHandlerEnvironment, istream: str | None, args: list[str]
|
||||
) -> TextHandleResult:
|
||||
if len(args) != 1:
|
||||
return TextHandleResult(1, "rm 使用方法:rm <缓存名>")
|
||||
buf = args[0]
|
||||
if buf == "-":
|
||||
buf = istream
|
||||
if buf not in env.buffers:
|
||||
return TextHandleResult(2, f"缓存 {buf} 不存在")
|
||||
del env.buffers[buf]
|
||||
return TextHandleResult(0, None)
|
||||
|
||||
|
||||
class THReplace(TextHandler):
|
||||
name = "replace"
|
||||
keywords = ["sed", "替换"]
|
||||
|
||||
async def handle(
|
||||
self, env: TextHandlerEnvironment, istream: str | None, args: list[str]
|
||||
) -> TextHandleResult:
|
||||
if len(args) < 2:
|
||||
return TextHandleResult(1, "用法:replace <正则> <替换内容> [文本]")
|
||||
|
||||
pattern, repl = args[0], args[1]
|
||||
text = (
|
||||
istream
|
||||
if istream is not None
|
||||
else (" ".join(args[2:]) if len(args) > 2 else "")
|
||||
)
|
||||
|
||||
try:
|
||||
res = re.sub(pattern, repl, text)
|
||||
return TextHandleResult(0, res)
|
||||
except Exception as e:
|
||||
return TextHandleResult(1, f"正则错误: {str(e)}")
|
||||
|
||||
|
||||
class THTrue(TextHandler):
|
||||
name = "true"
|
||||
|
||||
async def handle(
|
||||
self, env: TextHandlerEnvironment, istream: str | None, args: list[str]
|
||||
) -> TextHandleResult:
|
||||
return TextHandleResult(0, istream)
|
||||
|
||||
|
||||
class THFalse(TextHandler):
|
||||
name = "false"
|
||||
|
||||
async def handle(
|
||||
self, env: TextHandlerEnvironment, istream: str | None, args: list[str]
|
||||
) -> TextHandleResult:
|
||||
return TextHandleResult(1, None)
|
||||
|
||||
|
||||
class THTest(TextHandler):
|
||||
name = "test"
|
||||
keywords = ["["]
|
||||
|
||||
def _bool_result(self, value: bool) -> TextHandleResult:
|
||||
return TextHandleResult(0 if value else 1, None)
|
||||
|
||||
async def handle(
|
||||
self, env: TextHandlerEnvironment, istream: str | None, args: list[str]
|
||||
) -> TextHandleResult:
|
||||
expr = list(args)
|
||||
|
||||
# 支持方括号语法:[ expr ] 会自动移除末尾的 ]
|
||||
if expr and expr[-1] == "]":
|
||||
expr = expr[:-1]
|
||||
|
||||
if not expr:
|
||||
return TextHandleResult(1, None)
|
||||
|
||||
if len(expr) == 1:
|
||||
return self._bool_result(len(expr[0]) > 0)
|
||||
|
||||
if len(expr) == 2:
|
||||
op, value = expr
|
||||
if op == "-n":
|
||||
return self._bool_result(len(value) > 0)
|
||||
if op == "-z":
|
||||
return self._bool_result(len(value) == 0)
|
||||
return TextHandleResult(2, f"test 不支持的表达式: {' '.join(args)}")
|
||||
|
||||
if len(expr) == 3:
|
||||
left, op, right = expr
|
||||
if op == "=":
|
||||
return self._bool_result(left == right)
|
||||
if op == "!=":
|
||||
return self._bool_result(left != right)
|
||||
if op in {"-eq", "-ne", "-gt", "-ge", "-lt", "-le"}:
|
||||
try:
|
||||
li = int(left)
|
||||
ri = int(right)
|
||||
except ValueError:
|
||||
return TextHandleResult(2, "test 的数字比较参数必须是整数")
|
||||
mapping = {
|
||||
"-eq": li == ri,
|
||||
"-ne": li != ri,
|
||||
"-gt": li > ri,
|
||||
"-ge": li >= ri,
|
||||
"-lt": li < ri,
|
||||
"-le": li <= ri,
|
||||
}
|
||||
return self._bool_result(mapping[op])
|
||||
return TextHandleResult(2, f"test 不支持的操作符: {op}")
|
||||
|
||||
return TextHandleResult(2, f"test 不支持的表达式: {' '.join(args)}")
|
||||
126
konabot/plugins/handle_text/handlers/whitespace_handlers.py
Normal file
126
konabot/plugins/handle_text/handlers/whitespace_handlers.py
Normal file
@ -0,0 +1,126 @@
|
||||
import re
|
||||
|
||||
from konabot.plugins.handle_text.base import (
|
||||
TextHandleResult,
|
||||
TextHandler,
|
||||
TextHandlerEnvironment,
|
||||
)
|
||||
|
||||
|
||||
def _get_text(istream: str | None, args: list[str]) -> str | None:
|
||||
"""从 istream 或 args 中获取待处理文本"""
|
||||
if istream is not None:
|
||||
return istream
|
||||
if args:
|
||||
return " ".join(args)
|
||||
return None
|
||||
|
||||
|
||||
class THTrim(TextHandler):
|
||||
name = "trim"
|
||||
keywords = ["strip", "去空格"]
|
||||
|
||||
async def handle(
|
||||
self, env: TextHandlerEnvironment, istream: str | None, args: list[str]
|
||||
) -> TextHandleResult:
|
||||
text = _get_text(istream, args)
|
||||
if text is None:
|
||||
return TextHandleResult(1, "trim 使用方法:trim [文本]\n去除首尾空白字符")
|
||||
return TextHandleResult(0, text.strip())
|
||||
|
||||
|
||||
class THLTrim(TextHandler):
|
||||
name = "ltrim"
|
||||
keywords = ["lstrip"]
|
||||
|
||||
async def handle(
|
||||
self, env: TextHandlerEnvironment, istream: str | None, args: list[str]
|
||||
) -> TextHandleResult:
|
||||
text = _get_text(istream, args)
|
||||
if text is None:
|
||||
return TextHandleResult(1, "ltrim 使用方法:ltrim [文本]\n去除左侧空白字符")
|
||||
return TextHandleResult(0, text.lstrip())
|
||||
|
||||
|
||||
class THRTrim(TextHandler):
|
||||
name = "rtrim"
|
||||
keywords = ["rstrip"]
|
||||
|
||||
async def handle(
|
||||
self, env: TextHandlerEnvironment, istream: str | None, args: list[str]
|
||||
) -> TextHandleResult:
|
||||
text = _get_text(istream, args)
|
||||
if text is None:
|
||||
return TextHandleResult(1, "rtrim 使用方法:rtrim [文本]\n去除右侧空白字符")
|
||||
return TextHandleResult(0, text.rstrip())
|
||||
|
||||
|
||||
class THSqueeze(TextHandler):
|
||||
name = "squeeze"
|
||||
keywords = ["压缩空白"]
|
||||
|
||||
async def handle(
|
||||
self, env: TextHandlerEnvironment, istream: str | None, args: list[str]
|
||||
) -> TextHandleResult:
|
||||
text = _get_text(istream, args)
|
||||
if text is None:
|
||||
return TextHandleResult(
|
||||
1, "squeeze 使用方法:squeeze [文本]\n将连续空白字符压缩为单个空格"
|
||||
)
|
||||
return TextHandleResult(0, re.sub(r"[ \t]+", " ", text))
|
||||
|
||||
|
||||
class THLines(TextHandler):
|
||||
name = "lines"
|
||||
keywords = ["行处理"]
|
||||
|
||||
async def handle(
|
||||
self, env: TextHandlerEnvironment, istream: str | None, args: list[str]
|
||||
) -> TextHandleResult:
|
||||
# lines <子命令> [文本]
|
||||
# 子命令: trim | empty | squeeze
|
||||
if len(args) < 1:
|
||||
return TextHandleResult(
|
||||
1,
|
||||
"lines 使用方法:lines <子命令> [文本]\n"
|
||||
"子命令:\n"
|
||||
" trim - 去除每行首尾空白\n"
|
||||
" empty - 去除所有空行\n"
|
||||
" squeeze - 将连续空行压缩为一行",
|
||||
)
|
||||
|
||||
subcmd = args[0]
|
||||
text = (
|
||||
istream
|
||||
if istream is not None
|
||||
else (" ".join(args[1:]) if len(args) > 1 else None)
|
||||
)
|
||||
if text is None:
|
||||
return TextHandleResult(1, "请提供需要处理的文本(通过管道或参数)")
|
||||
|
||||
raw_lines = text.split("\n")
|
||||
|
||||
match subcmd:
|
||||
case "trim":
|
||||
result = "\n".join(line.strip() for line in raw_lines)
|
||||
case "empty":
|
||||
result = "\n".join(line for line in raw_lines if line.strip())
|
||||
case "squeeze":
|
||||
squeezed: list[str] = []
|
||||
prev_empty = False
|
||||
for line in raw_lines:
|
||||
is_empty = not line.strip()
|
||||
if is_empty:
|
||||
if not prev_empty:
|
||||
squeezed.append("")
|
||||
prev_empty = True
|
||||
else:
|
||||
squeezed.append(line)
|
||||
prev_empty = False
|
||||
result = "\n".join(squeezed)
|
||||
case _:
|
||||
return TextHandleResult(
|
||||
1, f"未知子命令:{subcmd}\n可用:trim, empty, squeeze"
|
||||
)
|
||||
|
||||
return TextHandleResult(0, result)
|
||||
@ -2,7 +2,6 @@ import random
|
||||
from typing import Optional
|
||||
import opencc
|
||||
|
||||
from nonebot import on_message
|
||||
from nonebot.adapters import Event as BaseEvent
|
||||
from nonebot.adapters.discord.event import MessageEvent as DiscordMessageEvent
|
||||
from nonebot_plugin_alconna import (
|
||||
@ -13,6 +12,10 @@ from nonebot_plugin_alconna import (
|
||||
on_alconna,
|
||||
)
|
||||
|
||||
from konabot.common.web_render import konaweb
|
||||
from konabot.common.web_render.core import WebRenderer
|
||||
from konabot.plugins.hanzi.er_data import ErFontData
|
||||
|
||||
convert_type = ["简","簡","繁","正","港","日"]
|
||||
|
||||
compiled_str = "|".join([f"{a}{mid}{b}" for mid in ["转","轉","転"] for a in convert_type for b in convert_type if a != b])
|
||||
@ -25,6 +28,7 @@ def hanzi_to_abbr(hanzi: str) -> str:
|
||||
"正": "t",
|
||||
"港": "hk",
|
||||
"日": "jp",
|
||||
"二": "er",
|
||||
}
|
||||
return mapping.get(hanzi, "")
|
||||
|
||||
@ -35,6 +39,9 @@ def check_valid_convert_type(convert_type: str) -> bool:
|
||||
return False
|
||||
|
||||
def convert(source, src_abbr, dst_abbr):
|
||||
if dst_abbr == "er":
|
||||
# 直接转换为二简
|
||||
return ErFontData.convert_text(source)
|
||||
convert_type_key = f"{src_abbr}2{dst_abbr}"
|
||||
if not check_valid_convert_type(convert_type_key):
|
||||
# 先转为繁体,再转为目标
|
||||
@ -98,12 +105,11 @@ async def _(msg: UniMsg, event: BaseEvent, source: Optional[str] = None):
|
||||
converted = convert(to_convert, src_abbr, dst_abbr)
|
||||
|
||||
converted_prefix = convert("转换结果", "s", dst_abbr)
|
||||
|
||||
await evt.send(await UniMessage().text(f"{converted_prefix}:{converted}").export())
|
||||
|
||||
shuo = ["说","說"]
|
||||
|
||||
full_name_type = ["简体","簡體","繁體","繁体","正體","正体","港話","港话","日文"]
|
||||
full_name_type = ["简体","簡體","繁體","繁体","正體","正体","港話","港话","日文","二简","二簡"]
|
||||
|
||||
combined_list = [f"{a}{b}" for a in shuo for b in full_name_type]
|
||||
|
||||
@ -151,20 +157,47 @@ async def _(msg: UniMsg, event: BaseEvent, source: Optional[str] = None):
|
||||
dst = "港"
|
||||
case "說日文" | "说日文":
|
||||
dst = "日"
|
||||
case "說二簡" | "说二简" | "說二簡" | "说二簡":
|
||||
dst = "二"
|
||||
dst_abbr = hanzi_to_abbr(dst)
|
||||
if not dst_abbr:
|
||||
notice = "不支持的转换类型,请使用“简体”、“繁體”、“正體”、“港話”、“日文”等。"
|
||||
notice = "不支持的转换类型,请使用“简体”、“繁體”、“正體”、“港話”、“日文”、“二简”等。"
|
||||
await evt.send(await UniMessage().text(notice).export())
|
||||
return
|
||||
# 循环,将源语言一次次转换为目标语言
|
||||
current_text = to_convert
|
||||
for src_abbr in ["s","hk","jp","tw","t"]:
|
||||
if src_abbr != dst_abbr:
|
||||
current_text = convert(current_text, src_abbr, dst_abbr)
|
||||
# 如果是二简,直接转换
|
||||
if dst_abbr == "er":
|
||||
current_text = ErFontData.convert_text(to_convert)
|
||||
else:
|
||||
# 循环,将源语言一次次转换为目标语言
|
||||
current_text = to_convert
|
||||
for src_abbr in ["s","hk","jp","tw","t"]:
|
||||
if src_abbr != dst_abbr:
|
||||
current_text = convert(current_text, src_abbr, dst_abbr)
|
||||
|
||||
converted_prefix = convert("转换结果", "s", dst_abbr)
|
||||
|
||||
await evt.send(await UniMessage().text(f"{converted_prefix}:{current_text}").export())
|
||||
if "span" in current_text:
|
||||
# 改为网页渲染
|
||||
render_result = await render_with_web_renderer(current_text)
|
||||
await evt.send(await UniMessage().image(raw=render_result).export())
|
||||
else:
|
||||
await evt.send(await UniMessage().text(f"{converted_prefix}:{current_text}").export())
|
||||
|
||||
async def render_with_web_renderer(text: str) -> bytes:
|
||||
async def page_function(page):
|
||||
# 找到id为content的文本框
|
||||
await page.wait_for_selector('textarea[name=content]')
|
||||
# 填入文本
|
||||
await page.locator('textarea[name=content]').fill(text)
|
||||
|
||||
out = await WebRenderer.render_with_persistent_page(
|
||||
"markdown_renderer",
|
||||
konaweb('old_font'),
|
||||
target='#main',
|
||||
other_function=page_function,
|
||||
)
|
||||
|
||||
return out
|
||||
|
||||
def random_char(char: str) -> str:
|
||||
dst_abbr = random.choice(["s","t","hk","jp","tw"])
|
||||
@ -214,4 +247,19 @@ async def _(msg: UniMsg, event: BaseEvent, source: Optional[str] = None):
|
||||
final_text = random_string(to_convert)
|
||||
converted_prefix = convert(random_string("转换结果"), "s", "s")
|
||||
|
||||
await evt.send(await UniMessage().text(f"{converted_prefix}:{final_text}").export())
|
||||
await evt.send(await UniMessage().text(f"{converted_prefix}:{final_text}").export())
|
||||
|
||||
def get_char(char: str, abbr: str) -> str:
|
||||
output = ""
|
||||
for src_abbr in ["s","hk","jp","tw","t"]:
|
||||
if src_abbr != abbr:
|
||||
output += convert(char, src_abbr, abbr)
|
||||
return output
|
||||
|
||||
def get_all_variants(char: str) -> str:
|
||||
output = ""
|
||||
for abbr in ["s","hk","jp","tw","t"]:
|
||||
for src_abbr in ["s","hk","jp","tw","t"]:
|
||||
if src_abbr != abbr:
|
||||
output += convert(char, src_abbr, abbr)
|
||||
return output
|
||||
45
konabot/plugins/hanzi/er_data.py
Normal file
45
konabot/plugins/hanzi/er_data.py
Normal file
@ -0,0 +1,45 @@
|
||||
import csv
|
||||
from nonebot import logger
|
||||
from nonebot_plugin_apscheduler import driver
|
||||
from konabot.common.path import ASSETS_PATH
|
||||
|
||||
FONT_ASSETS_PATH = ASSETS_PATH / "old_font"
|
||||
|
||||
class ErFontData:
|
||||
data = {}
|
||||
temp_featured_fonts = {}
|
||||
|
||||
@classmethod
|
||||
def init(cls):
|
||||
logger.info("加载二简字体数据...")
|
||||
path = FONT_ASSETS_PATH / "symtable.csv"
|
||||
if not path.exists():
|
||||
return
|
||||
with open(path, "r", encoding="utf-8-sig") as f:
|
||||
reader = csv.DictReader(f)
|
||||
for row in reader:
|
||||
if len(row["ss05"]) > 0:
|
||||
cls.data[row["trad"]] = {"char": row["ss05"][0], "type": "ss05", "render": False}
|
||||
if "er" in row["ss05"]:
|
||||
cls.data[row["trad"]]["render"] = True
|
||||
elif len(row["ss06"]) > 0:
|
||||
cls.data[row["trad"]] = {"char": row["ss06"][0], "type": "ss06", "render": False}
|
||||
if "er" in row["ss06"]:
|
||||
cls.data[row["trad"]]["render"] = True
|
||||
logger.info(f"二简字体数据加载完成,包含 {len(cls.data)} 个字。")
|
||||
|
||||
@classmethod
|
||||
def get(cls, char: str) -> str:
|
||||
if char not in cls.data:
|
||||
return char
|
||||
if cls.data[char]["render"]:
|
||||
return f"<span class={cls.data[char]['type']}>{cls.data[char]['char']}</span>"
|
||||
return cls.data[char]["char"]
|
||||
|
||||
@classmethod
|
||||
def convert_text(cls, text: str) -> str:
|
||||
return "".join([cls.get(c) for c in text])
|
||||
|
||||
@driver.on_startup
|
||||
async def load_er_font_data():
|
||||
ErFontData.init()
|
||||
@ -1,13 +1,15 @@
|
||||
import asyncio as asynkio
|
||||
import datetime
|
||||
from io import BytesIO
|
||||
import json
|
||||
import secrets
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from PIL import Image
|
||||
from loguru import logger
|
||||
from nonebot import on_message
|
||||
import nonebot
|
||||
from nonebot.adapters import Event as BaseEvent
|
||||
from nonebot.adapters.discord.event import MessageEvent as DiscordMessageEvent
|
||||
from nonebot_plugin_alconna import (
|
||||
@ -18,17 +20,23 @@ from nonebot_plugin_alconna import (
|
||||
on_alconna,
|
||||
)
|
||||
|
||||
from konabot.common.database import DatabaseManager
|
||||
from konabot.common.longtask import DepLongTaskTarget
|
||||
from konabot.common.path import ASSETS_PATH
|
||||
|
||||
from konabot.common.llm import get_llm
|
||||
|
||||
ROOT_PATH = Path(__file__).resolve().parent
|
||||
|
||||
DATA_DIR = Path(__file__).parent.parent.parent.parent / "data"
|
||||
|
||||
DATA_FILE_PATH = (
|
||||
DATA_DIR / "idiom_banned.json"
|
||||
)
|
||||
|
||||
# 创建全局数据库管理器实例
|
||||
db_manager = DatabaseManager()
|
||||
|
||||
def load_banned_ids() -> list[str]:
|
||||
if not DATA_FILE_PATH.exists():
|
||||
return []
|
||||
@ -58,6 +66,21 @@ def remove_banned_id(group_id: str):
|
||||
DATA_FILE_PATH.write_text(json.dumps(banned_ids, ensure_ascii=False, indent=4), "utf-8")
|
||||
|
||||
|
||||
driver = nonebot.get_driver()
|
||||
|
||||
|
||||
@driver.on_startup
|
||||
async def register_startup_hook():
|
||||
"""注册启动时需要执行的函数"""
|
||||
await IdiomGame.init_lexicon()
|
||||
|
||||
@driver.on_shutdown
|
||||
async def register_shutdown_hook():
|
||||
"""注册关闭时需要执行的函数"""
|
||||
# 关闭所有数据库连接
|
||||
await db_manager.close_all_connections()
|
||||
|
||||
|
||||
class TryStartState(Enum):
|
||||
STARTED = 0
|
||||
ALREADY_PLAYING = 1
|
||||
@ -94,6 +117,11 @@ class IdiomGameLLM:
|
||||
|
||||
@classmethod
|
||||
async def storage_idiom(cls, idiom: str):
|
||||
# 将 idiom 存入数据库
|
||||
# await db_manager.execute_by_sql_file(
|
||||
# ROOT_PATH / "sql" / "insert_custom_word.sql",
|
||||
# (idiom,)
|
||||
# )
|
||||
# 将 idiom 存入本地文件以备后续分析
|
||||
with open(DATA_DIR / "idiom_llm_storage.txt", "a", encoding="utf-8") as f:
|
||||
f.write(idiom + "\n")
|
||||
@ -126,7 +154,7 @@ class IdiomGame:
|
||||
IdiomGame.INSTANCE_LIST[group_id] = self
|
||||
|
||||
@classmethod
|
||||
def append_into_word_list(cls, word: str):
|
||||
async def append_into_word_list(cls, word: str):
|
||||
'''
|
||||
将一个新词加入到词语列表中
|
||||
'''
|
||||
@ -135,6 +163,10 @@ class IdiomGame:
|
||||
if word[0] not in cls.IDIOM_FIRST_CHAR:
|
||||
cls.IDIOM_FIRST_CHAR[word[0]] = []
|
||||
cls.IDIOM_FIRST_CHAR[word[0]].append(word)
|
||||
# await db_manager.execute_by_sql_file(
|
||||
# ROOT_PATH / "sql" / "insert_custom_word.sql",
|
||||
# (word,)
|
||||
# )
|
||||
|
||||
def be_able_to_play(self) -> bool:
|
||||
if self.last_play_date != datetime.date.today():
|
||||
@ -145,21 +177,29 @@ class IdiomGame:
|
||||
return True
|
||||
return False
|
||||
|
||||
def choose_start_idiom(self) -> str:
|
||||
@staticmethod
|
||||
async def random_idiom() -> str:
|
||||
# result = await db_manager.query_by_sql_file(
|
||||
# ROOT_PATH / "sql" / "random_choose_idiom.sql"
|
||||
# )
|
||||
# return result[0]["idiom"]
|
||||
return secrets.choice(IdiomGame.ALL_IDIOMS)
|
||||
|
||||
async def choose_start_idiom(self) -> str:
|
||||
"""
|
||||
随机选择一个成语作为起始成语
|
||||
"""
|
||||
self.last_idiom = secrets.choice(IdiomGame.ALL_IDIOMS)
|
||||
self.last_idiom = await IdiomGame.random_idiom()
|
||||
self.last_char = self.last_idiom[-1]
|
||||
if not self.is_nextable(self.last_char):
|
||||
self.choose_start_idiom()
|
||||
if not await self.is_nextable(self.last_char):
|
||||
await self.choose_start_idiom()
|
||||
else:
|
||||
self.add_history_idiom(self.last_idiom, new_chain=True)
|
||||
return self.last_idiom
|
||||
|
||||
@classmethod
|
||||
def try_start_game(cls, group_id: str, force: bool = False) -> TryStartState:
|
||||
cls.init_lexicon()
|
||||
async def try_start_game(cls, group_id: str, force: bool = False) -> TryStartState:
|
||||
await cls.init_lexicon()
|
||||
if not cls.INSTANCE_LIST.get(group_id):
|
||||
cls(group_id)
|
||||
instance = cls.INSTANCE_LIST[group_id]
|
||||
@ -170,10 +210,10 @@ class IdiomGame:
|
||||
instance.now_playing = True
|
||||
return TryStartState.STARTED
|
||||
|
||||
def start_game(self, rounds: int = 100):
|
||||
async def start_game(self, rounds: int = 100):
|
||||
self.now_playing = True
|
||||
self.remain_rounds = rounds
|
||||
self.choose_start_idiom()
|
||||
await self.choose_start_idiom()
|
||||
|
||||
@classmethod
|
||||
def try_stop_game(cls, group_id: str) -> TryStopState:
|
||||
@ -203,20 +243,20 @@ class IdiomGame:
|
||||
跳过当前成语,选择下一个成语
|
||||
"""
|
||||
async with self.lock:
|
||||
self._skip_idiom_async()
|
||||
await self._skip_idiom_async()
|
||||
self.add_buff_score(buff_score)
|
||||
return self.last_idiom
|
||||
|
||||
def _skip_idiom_async(self) -> str:
|
||||
self.last_idiom = secrets.choice(IdiomGame.ALL_IDIOMS)
|
||||
async def _skip_idiom_async(self) -> str:
|
||||
self.last_idiom = await IdiomGame.random_idiom()
|
||||
self.last_char = self.last_idiom[-1]
|
||||
if not self.is_nextable(self.last_char):
|
||||
self._skip_idiom_async()
|
||||
if not await self.is_nextable(self.last_char):
|
||||
await self._skip_idiom_async()
|
||||
else:
|
||||
self.add_history_idiom(self.last_idiom, new_chain=True)
|
||||
return self.last_idiom
|
||||
|
||||
async def try_verify_idiom(self, idiom: str, user_id: str) -> TryVerifyState:
|
||||
async def try_verify_idiom(self, idiom: str, user_id: str) -> list[TryVerifyState]:
|
||||
"""
|
||||
用户发送成语
|
||||
"""
|
||||
@ -224,12 +264,17 @@ class IdiomGame:
|
||||
state = await self._verify_idiom(idiom, user_id)
|
||||
return state
|
||||
|
||||
def is_nextable(self, last_char: str) -> bool:
|
||||
async def is_nextable(self, last_char: str) -> bool:
|
||||
"""
|
||||
判断是否有成语可以接
|
||||
"""
|
||||
# result = await db_manager.query_by_sql_file(
|
||||
# ROOT_PATH / "sql" / "is_nextable.sql",
|
||||
# (last_char,)
|
||||
# )
|
||||
# return result[0]["DEED"] == 1
|
||||
return last_char in IdiomGame.AVALIABLE_IDIOM_FIRST_CHAR
|
||||
|
||||
|
||||
def add_already_idiom(self, idiom: str):
|
||||
if idiom in self.already_idioms:
|
||||
self.already_idioms[idiom] += 1
|
||||
@ -259,6 +304,13 @@ class IdiomGame:
|
||||
if idiom[0] != self.last_char:
|
||||
state.append(TryVerifyState.WRONG_FIRST_CHAR)
|
||||
return state
|
||||
# 成语是否存在
|
||||
# result = await db_manager.query_by_sql_file(
|
||||
# ROOT_PATH / "sql" / "query_idiom.sql",
|
||||
# (idiom, idiom, idiom)
|
||||
# )
|
||||
# status_result = result[0]["status"]
|
||||
# if status_result == -1:
|
||||
if idiom not in IdiomGame.ALL_IDIOMS and idiom not in IdiomGame.ALL_WORDS:
|
||||
logger.info(f"用户 {user_id} 发送了未知词语 {idiom},正在使用 LLM 进行验证")
|
||||
try:
|
||||
@ -281,6 +333,7 @@ class IdiomGame:
|
||||
self.last_idiom = idiom
|
||||
self.last_char = idiom[-1]
|
||||
self.add_score(user_id, 1 * score_k) # 先加 1 分
|
||||
# if status_result == 1:
|
||||
if idiom in IdiomGame.ALL_IDIOMS:
|
||||
state.append(TryVerifyState.VERIFIED_AND_REAL)
|
||||
self.add_score(user_id, 4 * score_k) # 再加 4 分
|
||||
@ -288,9 +341,9 @@ class IdiomGame:
|
||||
if self.remain_rounds <= 0:
|
||||
self.now_playing = False
|
||||
state.append(TryVerifyState.GAME_END)
|
||||
if not self.is_nextable(self.last_char):
|
||||
if not await self.is_nextable(self.last_char):
|
||||
# 没有成语可以接了,自动跳过
|
||||
self._skip_idiom_async()
|
||||
await self._skip_idiom_async()
|
||||
self.add_buff_score(-100)
|
||||
state.append(TryVerifyState.BUT_NO_NEXT)
|
||||
return state
|
||||
@ -317,16 +370,27 @@ class IdiomGame:
|
||||
return self.last_char
|
||||
|
||||
@classmethod
|
||||
def random_idiom_starting_with(cls, first_char: str) -> Optional[str]:
|
||||
cls.init_lexicon()
|
||||
async def random_idiom_starting_with(cls, first_char: str) -> Optional[str]:
|
||||
# await cls.init_lexicon()
|
||||
# result = await db_manager.query_by_sql_file(
|
||||
# ROOT_PATH / "sql" / "query_idiom_start_with.sql",
|
||||
# (first_char,)
|
||||
# )
|
||||
# if len(result) == 0:
|
||||
# return None
|
||||
# return result[0]["idiom"]
|
||||
await cls.init_lexicon()
|
||||
if first_char not in cls.AVALIABLE_IDIOM_FIRST_CHAR:
|
||||
return None
|
||||
return secrets.choice(cls.AVALIABLE_IDIOM_FIRST_CHAR[first_char])
|
||||
|
||||
@classmethod
|
||||
def init_lexicon(cls):
|
||||
async def init_lexicon(cls):
|
||||
if cls.__inited:
|
||||
return
|
||||
# await db_manager.execute_by_sql_file(
|
||||
# ROOT_PATH / "sql" / "create_table.sql"
|
||||
# ) # 确保数据库初始化
|
||||
cls.__inited = True
|
||||
|
||||
# 成语大表
|
||||
@ -334,11 +398,12 @@ class IdiomGame:
|
||||
ALL_IDIOMS_INFOS = json.load(f)
|
||||
|
||||
# 词语大表
|
||||
ALL_WORDS = []
|
||||
with open(ASSETS_PATH / "lexicon" / "ci.json", "r", encoding="utf-8") as f:
|
||||
jsonData = json.load(f)
|
||||
cls.ALL_WORDS = [item["ci"] for item in jsonData]
|
||||
logger.debug(f"Loaded {len(cls.ALL_WORDS)} words from ci.json")
|
||||
logger.debug(f"Sample words: {cls.ALL_WORDS[:5]}")
|
||||
ALL_WORDS = [item["ci"] for item in jsonData]
|
||||
logger.debug(f"Loaded {len(ALL_WORDS)} words from ci.json")
|
||||
logger.debug(f"Sample words: {ALL_WORDS[:5]}")
|
||||
|
||||
COMMON_WORDS = []
|
||||
# 读取 COMMON 词语大表
|
||||
@ -389,17 +454,36 @@ class IdiomGame:
|
||||
logger.debug(f"Loaded additional {len(LOCAL_LLM_WORDS)} words from idiom_llm_storage.txt")
|
||||
|
||||
# 只有成语的大表
|
||||
cls.ALL_IDIOMS = [idiom["word"] for idiom in ALL_IDIOMS_INFOS] + THUOCL_IDIOMS
|
||||
cls.ALL_IDIOMS = list(set(cls.ALL_IDIOMS)) # 去重
|
||||
ALL_IDIOMS = [idiom["word"] for idiom in ALL_IDIOMS_INFOS] + THUOCL_IDIOMS
|
||||
ALL_IDIOMS = list(set(ALL_IDIOMS)) # 去重
|
||||
# 批量插入数据库
|
||||
# await db_manager.execute_many_values_by_sql_file(
|
||||
# ROOT_PATH / "sql" / "insert_idiom.sql",
|
||||
# [(idiom,) for idiom in ALL_IDIOMS]
|
||||
# )
|
||||
|
||||
|
||||
# 其他四字词语表,仅表示可以有这个词
|
||||
cls.ALL_WORDS = (
|
||||
[word for word in cls.ALL_WORDS if len(word) == 4]
|
||||
ALL_WORDS = (
|
||||
[word for word in ALL_WORDS if len(word) == 4]
|
||||
+ THUOCL_WORDS
|
||||
+ COMMON_WORDS
|
||||
+ LOCAL_LLM_WORDS
|
||||
)
|
||||
cls.ALL_WORDS = list(set(cls.ALL_WORDS)) # 去重
|
||||
|
||||
cls.ALL_WORDS = ALL_WORDS + LOCAL_LLM_WORDS
|
||||
cls.ALL_IDIOMS = ALL_IDIOMS
|
||||
|
||||
# 插入数据库
|
||||
# await db_manager.execute_many_values_by_sql_file(
|
||||
# ROOT_PATH / "sql" / "insert_word.sql",
|
||||
# [(word,) for word in ALL_WORDS]
|
||||
# )
|
||||
|
||||
# 自定义词语 LOCAL_LLM_WORDS 插入数据库,兼容用
|
||||
# await db_manager.execute_many_values_by_sql_file(
|
||||
# ROOT_PATH / "sql" / "insert_custom_word.sql",
|
||||
# [(word,) for word in LOCAL_LLM_WORDS]
|
||||
# )
|
||||
|
||||
# 根据成语大表,划分出成语首字字典
|
||||
for idiom in cls.ALL_IDIOMS + cls.ALL_WORDS:
|
||||
@ -443,7 +527,7 @@ async def play_game(
|
||||
if rounds <= 0:
|
||||
await evt.send(await UniMessage().text("干什么!你想玩负数局吗?").export())
|
||||
return
|
||||
state = IdiomGame.try_start_game(group_id, force)
|
||||
state = await IdiomGame.try_start_game(group_id, force)
|
||||
if state == TryStartState.ALREADY_PLAYING:
|
||||
await evt.send(
|
||||
await UniMessage()
|
||||
@ -462,7 +546,7 @@ async def play_game(
|
||||
.export()
|
||||
)
|
||||
instance = IdiomGame.INSTANCE_LIST[group_id]
|
||||
instance.start_game(rounds)
|
||||
await instance.start_game(rounds)
|
||||
# 发布成语
|
||||
await evt.send(
|
||||
await UniMessage()
|
||||
@ -514,7 +598,9 @@ async def end_game(event: BaseEvent, group_id: str):
|
||||
for line in history_lines:
|
||||
result_text += line + "\n"
|
||||
await evt.send(await result_text.export())
|
||||
instance.clear_score_board()
|
||||
# instance.clear_score_board()
|
||||
# 将实例删除
|
||||
del IdiomGame.INSTANCE_LIST[group_id]
|
||||
|
||||
|
||||
evt = on_alconna(
|
||||
@ -532,14 +618,23 @@ async def _(event: BaseEvent, target: DepLongTaskTarget):
|
||||
# 打开好吧狗本地文件
|
||||
with open(ASSETS_PATH / "img" / "dog" / "haoba_dog.jpg", "rb") as f:
|
||||
img_data = f.read()
|
||||
# 把好吧狗变成 GIF 格式以缩小尺寸
|
||||
img_data = await convert_image_to_gif(img_data)
|
||||
await evt.send(await UniMessage().image(raw=img_data).export())
|
||||
await end_game(event, group_id)
|
||||
else:
|
||||
await evt.send(
|
||||
await UniMessage().text("当前没有成语接龙游戏在进行中!").export()
|
||||
)
|
||||
# await evt.send(
|
||||
# await UniMessage().text("当前没有成语接龙游戏在进行中!").export()
|
||||
# )
|
||||
return
|
||||
|
||||
|
||||
async def convert_image_to_gif(image_data: bytes) -> bytes:
|
||||
with Image.open(BytesIO(image_data)) as img:
|
||||
with BytesIO() as output:
|
||||
img.save(output, format="GIF")
|
||||
return output.getvalue()
|
||||
|
||||
# 跳过
|
||||
evt = on_alconna(
|
||||
Alconna("跳过成语"), use_cmd_start=True, use_cmd_sep=False, skip_for_unmatch=True
|
||||
@ -553,10 +648,12 @@ async def _(target: DepLongTaskTarget):
|
||||
instance = IdiomGame.INSTANCE_LIST.get(group_id)
|
||||
if not instance or not instance.get_playing_state():
|
||||
return
|
||||
avaliable_idiom = IdiomGame.random_idiom_starting_with(instance.get_last_char())
|
||||
avaliable_idiom = await IdiomGame.random_idiom_starting_with(instance.get_last_char())
|
||||
# 发送哈哈狗图片
|
||||
with open(ASSETS_PATH / "img" / "dog" / "haha_dog.jpg", "rb") as f:
|
||||
img_data = f.read()
|
||||
# 把哈哈狗变成 GIF 格式以缩小尺寸
|
||||
img_data = await convert_image_to_gif(img_data)
|
||||
await evt.send(await UniMessage().image(raw=img_data).export())
|
||||
await evt.send(await UniMessage().text(f"你们太菜了,全部扣100分!明明还可以接「{avaliable_idiom}」的!").export())
|
||||
idiom = await instance.skip_idiom(-100)
|
||||
|
||||
15
konabot/plugins/idiomgame/sql/create_table.sql
Normal file
15
konabot/plugins/idiomgame/sql/create_table.sql
Normal file
@ -0,0 +1,15 @@
|
||||
-- 创建成语大表
|
||||
CREATE TABLE IF NOT EXISTS all_idioms (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
idiom VARCHAR(128) NOT NULL UNIQUE
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS all_words (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
word VARCHAR(128) NOT NULL UNIQUE
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS custom_words (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
word VARCHAR(128) NOT NULL UNIQUE
|
||||
);
|
||||
3
konabot/plugins/idiomgame/sql/insert_custom_word.sql
Normal file
3
konabot/plugins/idiomgame/sql/insert_custom_word.sql
Normal file
@ -0,0 +1,3 @@
|
||||
-- 插入自定义词
|
||||
INSERT OR IGNORE INTO custom_words (word)
|
||||
VALUES (?);
|
||||
3
konabot/plugins/idiomgame/sql/insert_idiom.sql
Normal file
3
konabot/plugins/idiomgame/sql/insert_idiom.sql
Normal file
@ -0,0 +1,3 @@
|
||||
-- 插入成语大表,避免重复插入
|
||||
INSERT OR IGNORE INTO all_idioms (idiom)
|
||||
VALUES (?);
|
||||
3
konabot/plugins/idiomgame/sql/insert_word.sql
Normal file
3
konabot/plugins/idiomgame/sql/insert_word.sql
Normal file
@ -0,0 +1,3 @@
|
||||
-- 插入词
|
||||
INSERT OR IGNORE INTO all_words (word)
|
||||
VALUES (?);
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user