Merge pull request '为此方 Bot 接入数据库' (#49) from database into master
All checks were successful
continuous-integration/drone/push Build is passing

Reviewed-on: #49
This commit is contained in:
2025-11-19 00:52:01 +08:00
27 changed files with 1095 additions and 129 deletions

View File

@ -1,4 +1,4 @@
ENVIRONMENT=dev
PORT=21333
DATABASE_PATH="./data/database.db"
ENABLE_CONSOLE=true

3
.gitignore vendored
View File

@ -1,4 +1,5 @@
/.env
/data
__pycache__
__pycache__
/*.diff

View File

@ -71,6 +71,10 @@ code .
详见[konabot-web 配置文档](/docs/konabot-web.md)
#### 数据库配置
本项目使用SQLite作为数据库默认数据库文件位于`./data/database.db`。可以通过设置`DATABASE_PATH`环境变量来指定其他位置。
### 运行
使用命令行手动启动 Bot
@ -91,3 +95,7 @@ poetry run python bot.py
- [事件响应器](https://nonebot.dev/docs/tutorial/matcher)
- [事件处理](https://nonebot.dev/docs/tutorial/handler)
- [Alconna 插件](https://nonebot.dev/docs/best-practice/alconna/)
## 数据库模块
本项目的数据库模块已更新为异步实现,使用连接池来提高性能,并支持现代的`pathlib.Path`参数类型。详细使用方法请参考[数据库使用文档](/docs/database.md)。

9
bot.py
View File

@ -10,6 +10,8 @@ from nonebot.adapters.onebot.v11 import Adapter as OnebotAdapter
from konabot.common.log import init_logger
from konabot.common.nb.exc import BotExceptionMessage
from konabot.common.path import LOG_PATH
from konabot.common.database import get_global_db_manager
dotenv.load_dotenv()
env = os.environ.get("ENVIRONMENT", "prod")
@ -48,6 +50,13 @@ def main():
nonebot.load_plugins("konabot/plugins")
nonebot.load_plugin("nonebot_plugin_analysis_bilibili")
# 注册关闭钩子
@driver.on_shutdown
async def shutdown_handler():
# 关闭全局数据库管理器
db_manager = get_global_db_manager()
await db_manager.close_all_connections()
nonebot.run()
if __name__ == "__main__":

223
docs/database.md Normal file
View File

@ -0,0 +1,223 @@
# 数据库系统使用文档
本文档详细介绍了本项目中使用的异步数据库系统,包括其架构设计、使用方法和最佳实践。
## 系统概述
本项目的数据库系统基于 `aiosqlite` 库构建,提供了异步的 SQLite 数据库访问接口。系统主要特性包括:
1. **异步操作**:完全支持异步/await模式适配NoneBot2框架
2. **连接池**:内置连接池机制,提高数据库访问性能
3. **参数化查询**支持安全的参数化查询防止SQL注入
4. **SQL文件支持**可以直接执行SQL文件中的脚本
5. **类型支持**:支持 `pathlib.Path``str` 类型的路径参数
## 核心类和方法
### DatabaseManager 类
`DatabaseManager` 是数据库操作的核心类,提供了以下主要方法:
#### 初始化
```python
from konabot.common.database import DatabaseManager
from pathlib import Path
# 使用默认数据库路径
db = DatabaseManager()
# 指定了义数据库路径
db = DatabaseManager("./data/myapp.db")
db = DatabaseManager(Path("./data/myapp.db"))
```
#### 查询操作
```python
# 执行查询语句并返回结果
results = await db.query("SELECT * FROM users WHERE age > ?", (18,))
# 从SQL文件执行查询
results = await db.query_by_sql_file("./sql/get_users.sql", (18,))
```
#### 执行操作
```python
# 执行非查询语句
await db.execute("INSERT INTO users (name, email) VALUES (?, ?)", ("张三", "zhangsan@example.com"))
# 执行SQL脚本不带参数
await db.execute_script("""
CREATE TABLE IF NOT EXISTS users (
id INTEGER PRIMARY KEY,
name TEXT NOT NULL,
email TEXT UNIQUE
);
INSERT INTO users (name, email) VALUES ('测试用户', 'test@example.com');
""")
# 从SQL文件执行非查询语句
await db.execute_by_sql_file("./sql/create_tables.sql")
# 带参数执行SQL文件
await db.execute_by_sql_file("./sql/insert_user.sql", ("张三", "zhangsan@example.com"))
# 执行多条语句(每条语句使用相同参数)
await db.execute_many("INSERT INTO users (name, email) VALUES (?, ?)", [
("张三", "zhangsan@example.com"),
("李四", "lisi@example.com"),
("王五", "wangwu@example.com")
])
# 从SQL文件执行多条语句每条语句使用相同参数
await db.execute_many_values_by_sql_file("./sql/batch_insert.sql", [
("张三", "zhangsan@example.com"),
("李四", "lisi@example.com")
])
```
## SQL文件处理机制
### 单语句SQL文件
```sql
-- insert_user.sql
INSERT INTO users (name, email) VALUES (?, ?);
```
```python
# 使用方式
await db.execute_by_sql_file("./sql/insert_user.sql", ("张三", "zhangsan@example.com"))
```
### 多语句SQL文件
```sql
-- setup.sql
CREATE TABLE IF NOT EXISTS users (
id INTEGER PRIMARY KEY,
name TEXT NOT NULL,
email TEXT UNIQUE
);
CREATE TABLE IF NOT EXISTS profiles (
user_id INTEGER,
age INTEGER,
FOREIGN KEY (user_id) REFERENCES users(id)
);
```
```python
# 使用方式
await db.execute_by_sql_file("./sql/setup.sql")
```
### 多语句带不同参数的SQL文件
```sql
-- batch_operations.sql
INSERT INTO users (name, email) VALUES (?, ?);
INSERT INTO profiles (user_id, age) VALUES (?, ?);
```
```python
# 使用方式
await db.execute_by_sql_file("./sql/batch_operations.sql", [
("张三", "zhangsan@example.com"), # 第一条语句的参数
(1, 25) # 第二条语句的参数
])
```
## 最佳实践
### 1. 数据库表设计
```sql
-- 推荐的表设计实践
CREATE TABLE IF NOT EXISTS example_table (
id INTEGER PRIMARY KEY AUTOINCREMENT,
name TEXT NOT NULL,
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
updated_at DATETIME DEFAULT CURRENT_TIMESTAMP
);
```
### 2. SQL文件组织
建议按照功能模块组织SQL文件
```
plugin/
├── sql/
│ ├── create_tables.sql
│ ├── insert_data.sql
│ ├── update_data.sql
│ └── query_data.sql
└── __init__.py
```
### 3. 错误处理
```python
try:
results = await db.query("SELECT * FROM users WHERE id = ?", (user_id,))
except Exception as e:
logger.error(f"数据库查询失败: {e}")
# 处理错误情况
```
### 4. 连接管理
```python
# 在应用启动时初始化
db_manager = DatabaseManager()
# 在应用关闭时清理连接
async def shutdown():
await db_manager.close_all_connections()
```
## 高级特性
### 连接池配置
```python
class DatabaseManager:
def __init__(self, db_path: Optional[Union[str, Path]] = None):
# 连接池大小配置
self._pool_size = 5 # 可根据需要调整
```
### 事务支持
```python
# 通过execute方法的自动提交机制支持事务
await db.execute("BEGIN TRANSACTION")
try:
await db.execute("INSERT INTO users (name) VALUES (?)", ("张三",))
await db.execute("INSERT INTO profiles (user_id, age) VALUES (?, ?)", (1, 25))
await db.execute("COMMIT")
except Exception:
await db.execute("ROLLBACK")
raise
```
## 注意事项
1. **异步环境**:所有数据库操作都必须在异步环境中执行
2. **参数安全**始终使用参数化查询避免SQL注入
3. **资源管理**:确保在应用关闭时调用 `close_all_connections()`
4. **SQL解析**:使用 `sqlparse` 库准确解析SQL语句正确处理包含分号的字符串和注释
5. **错误处理**:适当处理数据库操作可能抛出的异常
## 常见问题
### Q: 如何处理数据库约束错误?
A: 确保SQL语句中的字段名正确引用特别是保留字需要使用双引号包围
```sql
CREATE TABLE air_conditioner (
id VARCHAR(128) PRIMARY KEY,
"on" BOOLEAN NOT NULL, -- 使用双引号包围保留字
temperature REAL NOT NULL
);
```
### Q: 如何处理多个语句和参数的匹配?
A: 当SQL文件包含多个语句时参数应该是参数列表每个语句对应一个参数元组
```python
await db.execute_by_sql_file("./sql/batch.sql", [
("参数1", "参数2"), # 第一个语句的参数
("参数3", "参数4") # 第二个语句的参数
])
```
通过遵循这些指南和最佳实践,您可以充分利用本项目的异步数据库系统,构建高性能、安全的数据库应用。

View File

@ -0,0 +1,218 @@
import os
import asyncio
import sqlparse
from pathlib import Path
from typing import List, Dict, Any, Optional, Union, TYPE_CHECKING
import aiosqlite
if TYPE_CHECKING:
from . import DatabaseManager
# 全局数据库管理器实例
_global_db_manager: Optional['DatabaseManager'] = None
def get_global_db_manager() -> 'DatabaseManager':
"""获取全局数据库管理器实例"""
global _global_db_manager
if _global_db_manager is None:
from . import DatabaseManager
_global_db_manager = DatabaseManager()
return _global_db_manager
def close_global_db_manager() -> None:
"""关闭全局数据库管理器实例"""
global _global_db_manager
if _global_db_manager is not None:
# 注意这个函数应该在async环境中调用close_all_connections
_global_db_manager = None
class DatabaseManager:
"""异步数据库管理器"""
def __init__(self, db_path: Optional[Union[str, Path]] = None, pool_size: int = 5):
"""
初始化数据库管理器
Args:
db_path: 数据库文件路径支持str和Path类型
pool_size: 连接池大小
"""
if db_path is None:
self.db_path = os.environ.get("DATABASE_PATH", "./data/database.db")
else:
self.db_path = str(db_path) if isinstance(db_path, Path) else db_path
# 连接池
self._connection_pool = []
self._pool_size = pool_size
self._lock = asyncio.Lock()
self._in_use = set() # 跟踪正在使用的连接
async def _get_connection(self) -> aiosqlite.Connection:
"""从连接池获取连接"""
async with self._lock:
# 尝试从池中获取现有连接
while self._connection_pool:
conn = self._connection_pool.pop()
# 检查连接是否仍然有效
try:
await conn.execute("SELECT 1")
self._in_use.add(conn)
return conn
except:
# 连接已失效,关闭它
try:
await conn.close()
except:
pass
# 如果连接池为空,创建新连接
conn = await aiosqlite.connect(self.db_path)
await conn.execute("PRAGMA foreign_keys = ON")
self._in_use.add(conn)
return conn
async def _return_connection(self, conn: aiosqlite.Connection) -> None:
"""将连接返回到连接池"""
async with self._lock:
self._in_use.discard(conn)
if len(self._connection_pool) < self._pool_size:
self._connection_pool.append(conn)
else:
# 池已满,直接关闭连接
try:
await conn.close()
except:
pass
async def query(
self, query: str, params: Optional[tuple] = None
) -> List[Dict[str, Any]]:
"""执行查询语句并返回结果"""
conn = await self._get_connection()
try:
cursor = await conn.execute(query, params or ())
columns = [description[0] for description in cursor.description]
rows = await cursor.fetchall()
results = [dict(zip(columns, row)) for row in rows]
await cursor.close()
return results
except Exception as e:
# 记录错误但重新抛出,让调用者处理
raise Exception(f"数据库查询失败: {str(e)}") from e
finally:
await self._return_connection(conn)
async def query_by_sql_file(
self, file_path: Union[str, Path], params: Optional[tuple] = None
) -> List[Dict[str, Any]]:
"""从 SQL 文件中读取查询语句并执行"""
path = str(file_path) if isinstance(file_path, Path) else file_path
with open(path, "r", encoding="utf-8") as f:
query = f.read()
return await self.query(query, params)
async def execute(self, command: str, params: Optional[tuple] = None) -> None:
"""执行非查询语句"""
conn = await self._get_connection()
try:
await conn.execute(command, params or ())
await conn.commit()
except Exception as e:
# 记录错误但重新抛出,让调用者处理
raise Exception(f"数据库执行失败: {str(e)}") from e
finally:
await self._return_connection(conn)
async def execute_script(self, script: str) -> None:
"""执行SQL脚本"""
conn = await self._get_connection()
try:
await conn.executescript(script)
await conn.commit()
except Exception as e:
# 记录错误但重新抛出,让调用者处理
raise Exception(f"数据库脚本执行失败: {str(e)}") from e
finally:
await self._return_connection(conn)
def _parse_sql_statements(self, script: str) -> List[str]:
"""解析SQL脚本分割成独立的语句"""
# 使用sqlparse库更准确地分割SQL语句
parsed = sqlparse.split(script)
statements = []
for statement in parsed:
statement = statement.strip()
if statement:
statements.append(statement)
return statements
async def execute_by_sql_file(
self, file_path: Union[str, Path], params: Optional[Union[tuple, List[tuple]]] = None
) -> None:
"""从 SQL 文件中读取非查询语句并执行"""
path = str(file_path) if isinstance(file_path, Path) else file_path
with open(path, "r", encoding="utf-8") as f:
script = f.read()
# 如果有参数且是元组使用execute执行整个脚本
if params is not None and isinstance(params, tuple):
await self.execute(script, params)
# 如果有参数且是列表,分别执行每个语句
elif params is not None and isinstance(params, list):
# 使用sqlparse准确分割SQL语句
statements = self._parse_sql_statements(script)
if len(statements) != len(params):
raise ValueError(f"语句数量({len(statements)})与参数组数量({len(params)})不匹配")
for statement, stmt_params in zip(statements, params):
if statement:
await self.execute(statement, stmt_params)
# 如果无参数使用executescript
else:
await self.execute_script(script)
async def execute_many(self, command: str, seq_of_params: List[tuple]) -> None:
"""执行多条非查询语句"""
conn = await self._get_connection()
try:
await conn.executemany(command, seq_of_params)
await conn.commit()
except Exception as e:
# 记录错误但重新抛出,让调用者处理
raise Exception(f"数据库批量执行失败: {str(e)}") from e
finally:
await self._return_connection(conn)
async def execute_many_values_by_sql_file(
self, file_path: Union[str, Path], seq_of_params: List[tuple]
) -> None:
"""从 SQL 文件中读取一条语句,但是被不同值同时执行"""
path = str(file_path) if isinstance(file_path, Path) else file_path
with open(path, "r", encoding="utf-8") as f:
command = f.read()
await self.execute_many(command, seq_of_params)
async def close_all_connections(self) -> None:
"""关闭所有连接"""
async with self._lock:
# 关闭池中的连接
for conn in self._connection_pool:
try:
await conn.close()
except:
pass
self._connection_pool.clear()
# 关闭正在使用的连接
for conn in self._in_use.copy():
try:
await conn.close()
except:
pass
self._in_use.clear()

View File

@ -1,22 +1,29 @@
from io import BytesIO
from typing import Optional, Union
import cv2
import nonebot
from nonebot.adapters import Event as BaseEvent
from nonebot.adapters.console.event import MessageEvent as ConsoleMessageEvent
from nonebot.adapters.discord.event import MessageEvent as DiscordMessageEvent
from nonebot_plugin_alconna import Alconna, AlconnaMatcher, Args, UniMessage, on_alconna
from PIL import Image
import numpy as np
from konabot.common.database import DatabaseManager
from konabot.common.longtask import DepLongTaskTarget
from konabot.common.path import ASSETS_PATH
from konabot.common.web_render import WebRenderer
from konabot.plugins.air_conditioner.ac import AirConditioner, CrashType, generate_ac_image, wiggle_transform
from pathlib import Path
import random
import math
def get_ac(id: str) -> AirConditioner:
ac = AirConditioner.air_conditioners.get(id)
ROOT_PATH = Path(__file__).resolve().parent
# 创建全局数据库管理器实例
db_manager = DatabaseManager()
async def get_ac(id: str) -> AirConditioner:
ac = await AirConditioner.get_ac(id)
if ac is None:
ac = AirConditioner(id)
return ac
@ -43,14 +50,32 @@ async def send_ac_image(event: type[AlconnaMatcher], ac: AirConditioner):
ac_image = await generate_ac_image(ac)
await event.send(await UniMessage().image(raw=ac_image).export())
driver = nonebot.get_driver()
@driver.on_startup
async def register_startup_hook():
"""注册启动时需要执行的函数"""
# 初始化数据库表
await db_manager.execute_by_sql_file(
Path(__file__).resolve().parent / "sql" / "create_table.sql"
)
@driver.on_shutdown
async def register_shutdown_hook():
"""注册关闭时需要执行的函数"""
# 关闭所有数据库连接
await db_manager.close_all_connections()
evt = on_alconna(Alconna(
"群空调"
), use_cmd_start=True, use_cmd_sep=False, skip_for_unmatch=True)
@evt.handle()
async def _(event: BaseEvent, target: DepLongTaskTarget):
async def _(target: DepLongTaskTarget):
id = target.channel_id
ac = get_ac(id)
ac = await get_ac(id)
await send_ac_image(evt, ac)
evt = on_alconna(Alconna(
@ -58,10 +83,10 @@ evt = on_alconna(Alconna(
), use_cmd_start=True, use_cmd_sep=False, skip_for_unmatch=True)
@evt.handle()
async def _(event: BaseEvent, target: DepLongTaskTarget):
async def _(target: DepLongTaskTarget):
id = target.channel_id
ac = get_ac(id)
ac.on = True
ac = await get_ac(id)
await ac.update_ac(state=True)
await send_ac_image(evt, ac)
evt = on_alconna(Alconna(
@ -69,10 +94,10 @@ evt = on_alconna(Alconna(
), use_cmd_start=True, use_cmd_sep=False, skip_for_unmatch=True)
@evt.handle()
async def _(event: BaseEvent, target: DepLongTaskTarget):
async def _(target: DepLongTaskTarget):
id = target.channel_id
ac = get_ac(id)
ac.on = False
ac = await get_ac(id)
await ac.update_ac(state=False)
await send_ac_image(evt, ac)
evt = on_alconna(Alconna(
@ -81,15 +106,17 @@ evt = on_alconna(Alconna(
), use_cmd_start=True, use_cmd_sep=False, skip_for_unmatch=True)
@evt.handle()
async def _(event: BaseEvent, target: DepLongTaskTarget, temp: Optional[Union[int, float]] = 1):
async def _(target: DepLongTaskTarget, temp: Optional[Union[int, float]] = 1):
if temp is None:
temp = 1
if temp <= 0:
return
id = target.channel_id
ac = get_ac(id)
ac = await get_ac(id)
if not ac.on or ac.burnt == True or ac.frozen == True:
await send_ac_image(evt, ac)
return
ac.temperature += temp
await ac.update_ac(temperature_delta=temp)
if ac.temperature > 40:
# 根据温度随机出是否爆炸40度开始呈指数增长
possibility = -math.e ** ((40-ac.temperature) / 50) + 1
@ -103,7 +130,7 @@ async def _(event: BaseEvent, target: DepLongTaskTarget, temp: Optional[Union[in
pil_frames[0].save(output, format="GIF", save_all=True, append_images=pil_frames[1:], loop=0, duration=35, disposal=2)
output.seek(0)
await evt.send(await UniMessage().image(raw=output).export())
ac.broke_ac(CrashType.BURNT)
await ac.broke_ac(CrashType.BURNT)
await evt.send("太热啦,空调炸了!")
return
await send_ac_image(evt, ac)
@ -114,20 +141,22 @@ evt = on_alconna(Alconna(
), use_cmd_start=True, use_cmd_sep=False, skip_for_unmatch=True)
@evt.handle()
async def _(event: BaseEvent, target: DepLongTaskTarget, temp: Optional[Union[int, float]] = 1):
async def _(target: DepLongTaskTarget, temp: Optional[Union[int, float]] = 1):
if temp is None:
temp = 1
if temp <= 0:
return
id = target.channel_id
ac = get_ac(id)
ac = await get_ac(id)
if not ac.on or ac.burnt == True or ac.frozen == True:
await send_ac_image(evt, ac)
return
ac.temperature -= temp
await ac.update_ac(temperature_delta=-temp)
if ac.temperature < 0:
# 根据温度随机出是否冻结0度开始呈指数增长
possibility = -math.e ** (ac.temperature / 50) + 1
if random.random() < possibility:
ac.broke_ac(CrashType.FROZEN)
await ac.broke_ac(CrashType.FROZEN)
await send_ac_image(evt, ac)
evt = on_alconna(Alconna(
@ -135,21 +164,34 @@ evt = on_alconna(Alconna(
), use_cmd_start=True, use_cmd_sep=False, skip_for_unmatch=True)
@evt.handle()
async def _(event: BaseEvent, target: DepLongTaskTarget):
async def _(target: DepLongTaskTarget):
id = target.channel_id
ac = get_ac(id)
ac.change_ac()
ac = await get_ac(id)
await ac.change_ac()
await send_ac_image(evt, ac)
async def query_number_ranking(id: str) -> tuple[int, int]:
result = await db_manager.query_by_sql_file(
ROOT_PATH / "sql" / "query_crash_and_rank.sql",
(id,id)
)
if len(result) == 0:
return 0, 0
else:
# 将字典转换为值的元组
values = list(result[0].values())
return values[0], values[1]
evt = on_alconna(Alconna(
"空调炸炸排行榜",
), use_cmd_start=True, use_cmd_sep=False, skip_for_unmatch=True)
@evt.handle()
async def _(event: BaseEvent, target: DepLongTaskTarget):
async def _(target: DepLongTaskTarget):
id = target.channel_id
ac = get_ac(id)
number, ranking = ac.get_crashes_and_ranking()
# ac = get_ac(id)
# number, ranking = ac.get_crashes_and_ranking()
number, ranking = await query_number_ranking(id)
params = {
"number": number,
"ranking": ranking
@ -159,4 +201,4 @@ async def _(event: BaseEvent, target: DepLongTaskTarget):
target=".box",
params=params
)
await evt.send(await UniMessage().image(raw=image).export())
await evt.send(await UniMessage().image(raw=image).export())

View File

@ -1,20 +1,74 @@
from enum import Enum
from io import BytesIO
from pathlib import Path
import cv2
import numpy as np
from PIL import Image, ImageDraw, ImageFont
from konabot.common.database import DatabaseManager
from konabot.common.path import ASSETS_PATH, FONTS_PATH
from konabot.common.path import DATA_PATH
import json
ROOT_PATH = Path(__file__).resolve().parent
# 创建全局数据库管理器实例
db_manager = DatabaseManager()
class CrashType(Enum):
BURNT = 0
FROZEN = 1
class AirConditioner:
air_conditioners: dict[str, "AirConditioner"] = {}
@classmethod
async def get_ac(cls, id: str) -> 'AirConditioner':
result = await db_manager.query_by_sql_file(ROOT_PATH / "sql" / "query_ac.sql", (id,))
if len(result) == 0:
ac = await cls.create_ac(id)
return ac
ac_data = result[0]
ac = AirConditioner(id)
ac.on = bool(ac_data["on"])
ac.temperature = float(ac_data["temperature"])
ac.burnt = bool(ac_data["burnt"])
ac.frozen = bool(ac_data["frozen"])
return ac
@classmethod
async def create_ac(cls, id: str) -> 'AirConditioner':
ac = AirConditioner(id)
await db_manager.execute_by_sql_file(
ROOT_PATH / "sql" / "insert_ac.sql",
(id, ac.on, ac.temperature, ac.burnt, ac.frozen)
)
return ac
async def update_ac(self, state: bool = None, temperature_delta: float = None, burnt: bool = None, frozen: bool = None) -> 'AirConditioner':
if state is not None:
self.on = state
if temperature_delta is not None:
self.temperature += temperature_delta
if burnt is not None:
self.burnt = burnt
if frozen is not None:
self.frozen = frozen
await db_manager.execute_by_sql_file(
ROOT_PATH / "sql" / "update_ac.sql",
(self.on, self.temperature, self.burnt, self.frozen, self.id)
)
return self
async def change_ac(self) -> 'AirConditioner':
self.on = False
self.temperature = 24
self.burnt = False
self.frozen = False
await db_manager.execute_by_sql_file(
ROOT_PATH / "sql" / "update_ac.sql",
(self.on, self.temperature, self.burnt, self.frozen, self.id)
)
return self
def __init__(self, id: str) -> None:
self.id = id
@ -22,45 +76,40 @@ class AirConditioner:
self.temperature = 24 # 默认温度
self.burnt = False
self.frozen = False
AirConditioner.air_conditioners[id] = self
def change_ac(self):
self.burnt = False
self.frozen = False
self.on = False
self.temperature = 24 # 重置为默认温度
def broke_ac(self, crash_type: CrashType):
async def broke_ac(self, crash_type: CrashType):
'''
让空调坏掉,并保存数据
让空调坏掉
:param crash_type: CrashType 枚举,表示空调坏掉的类型
'''
match crash_type:
case CrashType.BURNT:
self.burnt = True
await self.update_ac(burnt=True)
case CrashType.FROZEN:
self.frozen = True
self.save_crash_data(crash_type)
await self.update_ac(frozen=True)
await db_manager.execute_by_sql_file(
ROOT_PATH / "sql" / "insert_crash.sql",
(self.id, crash_type.value)
)
def save_crash_data(self, crash_type: CrashType):
'''
如果空调爆炸了,就往本地的 ac_crash_data.json 里该 id 的记录加一
'''
data_file = DATA_PATH / "ac_crash_data.json"
crash_data = {}
if data_file.exists():
with open(data_file, "r", encoding="utf-8") as f:
crash_data = json.load(f)
if self.id not in crash_data:
crash_data[self.id] = {"burnt": 0, "frozen": 0}
match crash_type:
case CrashType.BURNT:
crash_data[self.id]["burnt"] += 1
case CrashType.FROZEN:
crash_data[self.id]["frozen"] += 1
with open(data_file, "w", encoding="utf-8") as f:
json.dump(crash_data, f, ensure_ascii=False, indent=4)
# def save_crash_data(self, crash_type: CrashType):
# '''
# 如果空调爆炸了,就往本地的 ac_crash_data.json 里该 id 的记录加一
# '''
# data_file = DATA_PATH / "ac_crash_data.json"
# crash_data = {}
# if data_file.exists():
# with open(data_file, "r", encoding="utf-8") as f:
# crash_data = json.load(f)
# if self.id not in crash_data:
# crash_data[self.id] = {"burnt": 0, "frozen": 0}
# match crash_type:
# case CrashType.BURNT:
# crash_data[self.id]["burnt"] += 1
# case CrashType.FROZEN:
# crash_data[self.id]["frozen"] += 1
# with open(data_file, "w", encoding="utf-8") as f:
# json.dump(crash_data, f, ensure_ascii=False, indent=4)
def get_crashes_and_ranking(self) -> tuple[int, int]:
'''

View File

@ -0,0 +1,15 @@
-- 创建所有表
CREATE TABLE IF NOT EXISTS air_conditioner (
id VARCHAR(128) PRIMARY KEY,
"on" BOOLEAN NOT NULL,
temperature REAL NOT NULL,
burnt BOOLEAN NOT NULL,
frozen BOOLEAN NOT NULL
);
CREATE TABLE IF NOT EXISTS air_conditioner_crash_log (
id VARCHAR(128) NOT NULL,
crash_type INT NOT NULL,
timestamp DATETIME NOT NULL,
FOREIGN KEY (id) REFERENCES air_conditioner(id)
);

View File

@ -0,0 +1,3 @@
-- 插入一台新空调
INSERT INTO air_conditioner (id, "on", temperature, burnt, frozen)
VALUES (?, ?, ?, ?, ?);

View File

@ -0,0 +1,3 @@
-- 插入一条空调爆炸记录
INSERT INTO air_conditioner_crash_log (id, crash_type, timestamp)
VALUES (?, ?, CURRENT_TIMESTAMP);

View File

@ -0,0 +1,4 @@
-- 查询空调状态,如果没有就插入一条新的记录
SELECT *
FROM air_conditioner
WHERE id = ?;

View File

@ -0,0 +1,23 @@
-- 从 air_conditioner_crash_log 表中获取指定 id 损坏的次数以及损坏次数的排名
SELECT crash_count, crash_rank
FROM (
SELECT id,
COUNT(*) AS crash_count,
RANK() OVER (ORDER BY COUNT(*) DESC) AS crash_rank
FROM air_conditioner_crash_log
GROUP BY id
) AS ranked_data
WHERE id = ?
-- 如果该 id 没有损坏记录,则返回 0 次损坏和对应的最后一名
UNION
SELECT 0 AS crash_count,
(SELECT COUNT(DISTINCT id) + 1 FROM air_conditioner_crash_log) AS crash_rank
FROM (
SELECT DISTINCT id
FROM air_conditioner_crash_log
) AS ranked_data
WHERE NOT EXISTS (
SELECT 1
FROM air_conditioner_crash_log
WHERE id = ?
);

View File

@ -0,0 +1,4 @@
-- 更新空调状态
UPDATE air_conditioner
SET "on" = ?, temperature = ?, burnt = ?, frozen = ?
WHERE id = ?;

View File

@ -8,6 +8,7 @@ from typing import Optional
from loguru import logger
from nonebot import on_message
import nonebot
from nonebot.adapters import Event as BaseEvent
from nonebot.adapters.discord.event import MessageEvent as DiscordMessageEvent
from nonebot_plugin_alconna import (
@ -18,17 +19,23 @@ from nonebot_plugin_alconna import (
on_alconna,
)
from konabot.common.database import DatabaseManager
from konabot.common.longtask import DepLongTaskTarget
from konabot.common.path import ASSETS_PATH
from konabot.common.llm import get_llm
ROOT_PATH = Path(__file__).resolve().parent
DATA_DIR = Path(__file__).parent.parent.parent.parent / "data"
DATA_FILE_PATH = (
DATA_DIR / "idiom_banned.json"
)
# 创建全局数据库管理器实例
db_manager = DatabaseManager()
def load_banned_ids() -> list[str]:
if not DATA_FILE_PATH.exists():
return []
@ -58,6 +65,21 @@ def remove_banned_id(group_id: str):
DATA_FILE_PATH.write_text(json.dumps(banned_ids, ensure_ascii=False, indent=4), "utf-8")
driver = nonebot.get_driver()
@driver.on_startup
async def register_startup_hook():
"""注册启动时需要执行的函数"""
await IdiomGame.init_lexicon()
@driver.on_shutdown
async def register_shutdown_hook():
"""注册关闭时需要执行的函数"""
# 关闭所有数据库连接
await db_manager.close_all_connections()
class TryStartState(Enum):
STARTED = 0
ALREADY_PLAYING = 1
@ -94,18 +116,19 @@ class IdiomGameLLM:
@classmethod
async def storage_idiom(cls, idiom: str):
# 将 idiom 存入本地文件以备后续分析
with open(DATA_DIR / "idiom_llm_storage.txt", "a", encoding="utf-8") as f:
f.write(idiom + "\n")
IdiomGame.append_into_word_list(idiom)
# 将 idiom 存入数据库
await db_manager.execute_by_sql_file(
ROOT_PATH / "sql" / "insert_custom_word.sql",
(idiom,)
)
class IdiomGame:
ALL_WORDS = [] # 所有四字词语
ALL_IDIOMS = [] # 所有成语
# ALL_WORDS = [] # 所有四字词语
# ALL_IDIOMS = [] # 所有成语
INSTANCE_LIST: dict[str, "IdiomGame"] = {} # 群号对应的游戏实例
IDIOM_FIRST_CHAR = {} # 所有成语包括词语的首字字典
AVALIABLE_IDIOM_FIRST_CHAR = {} # 真正有效的成语首字字典
# IDIOM_FIRST_CHAR = {} # 所有成语包括词语的首字字典
# AVALIABLE_IDIOM_FIRST_CHAR = {} # 真正有效的成语首字字典
__inited = False
@ -126,15 +149,14 @@ class IdiomGame:
IdiomGame.INSTANCE_LIST[group_id] = self
@classmethod
def append_into_word_list(cls, word: str):
async def append_into_word_list(cls, word: str):
'''
将一个新词加入到词语列表中
'''
if word not in cls.ALL_WORDS:
cls.ALL_WORDS.append(word)
if word[0] not in cls.IDIOM_FIRST_CHAR:
cls.IDIOM_FIRST_CHAR[word[0]] = []
cls.IDIOM_FIRST_CHAR[word[0]].append(word)
await db_manager.execute_by_sql_file(
ROOT_PATH / "sql" / "insert_custom_word.sql",
(word,)
)
def be_able_to_play(self) -> bool:
if self.last_play_date != datetime.date.today():
@ -145,21 +167,28 @@ class IdiomGame:
return True
return False
def choose_start_idiom(self) -> str:
@staticmethod
async def random_idiom() -> str:
result = await db_manager.query_by_sql_file(
ROOT_PATH / "sql" / "random_choose_idiom.sql"
)
return result[0]["idiom"]
async def choose_start_idiom(self) -> str:
"""
随机选择一个成语作为起始成语
"""
self.last_idiom = secrets.choice(IdiomGame.ALL_IDIOMS)
self.last_idiom = await IdiomGame.random_idiom()
self.last_char = self.last_idiom[-1]
if not self.is_nextable(self.last_char):
self.choose_start_idiom()
if not await self.is_nextable(self.last_char):
await self.choose_start_idiom()
else:
self.add_history_idiom(self.last_idiom, new_chain=True)
return self.last_idiom
@classmethod
def try_start_game(cls, group_id: str, force: bool = False) -> TryStartState:
cls.init_lexicon()
async def try_start_game(cls, group_id: str, force: bool = False) -> TryStartState:
await cls.init_lexicon()
if not cls.INSTANCE_LIST.get(group_id):
cls(group_id)
instance = cls.INSTANCE_LIST[group_id]
@ -170,10 +199,10 @@ class IdiomGame:
instance.now_playing = True
return TryStartState.STARTED
def start_game(self, rounds: int = 100):
async def start_game(self, rounds: int = 100):
self.now_playing = True
self.remain_rounds = rounds
self.choose_start_idiom()
await self.choose_start_idiom()
@classmethod
def try_stop_game(cls, group_id: str) -> TryStopState:
@ -203,20 +232,20 @@ class IdiomGame:
跳过当前成语,选择下一个成语
"""
async with self.lock:
self._skip_idiom_async()
await self._skip_idiom_async()
self.add_buff_score(buff_score)
return self.last_idiom
def _skip_idiom_async(self) -> str:
self.last_idiom = secrets.choice(IdiomGame.ALL_IDIOMS)
async def _skip_idiom_async(self) -> str:
self.last_idiom = await IdiomGame.random_idiom()
self.last_char = self.last_idiom[-1]
if not self.is_nextable(self.last_char):
self._skip_idiom_async()
if not await self.is_nextable(self.last_char):
await self._skip_idiom_async()
else:
self.add_history_idiom(self.last_idiom, new_chain=True)
return self.last_idiom
async def try_verify_idiom(self, idiom: str, user_id: str) -> TryVerifyState:
async def try_verify_idiom(self, idiom: str, user_id: str) -> list[TryVerifyState]:
"""
用户发送成语
"""
@ -224,12 +253,16 @@ class IdiomGame:
state = await self._verify_idiom(idiom, user_id)
return state
def is_nextable(self, last_char: str) -> bool:
async def is_nextable(self, last_char: str) -> bool:
"""
判断是否有成语可以接
"""
return last_char in IdiomGame.AVALIABLE_IDIOM_FIRST_CHAR
result = await db_manager.query_by_sql_file(
ROOT_PATH / "sql" / "is_nextable.sql",
(last_char,)
)
return result[0]["DEED"] == 1
def add_already_idiom(self, idiom: str):
if idiom in self.already_idioms:
self.already_idioms[idiom] += 1
@ -259,7 +292,13 @@ class IdiomGame:
if idiom[0] != self.last_char:
state.append(TryVerifyState.WRONG_FIRST_CHAR)
return state
if idiom not in IdiomGame.ALL_IDIOMS and idiom not in IdiomGame.ALL_WORDS:
# 成语是否存在
result = await db_manager.query_by_sql_file(
ROOT_PATH / "sql" / "query_idiom.sql",
(idiom, idiom, idiom)
)
status_result = result[0]["status"]
if status_result == -1:
logger.info(f"用户 {user_id} 发送了未知词语 {idiom},正在使用 LLM 进行验证")
try:
if not await IdiomGameLLM.verify_idiom_with_llm(idiom):
@ -281,16 +320,16 @@ class IdiomGame:
self.last_idiom = idiom
self.last_char = idiom[-1]
self.add_score(user_id, 1 * score_k) # 先加 1 分
if idiom in IdiomGame.ALL_IDIOMS:
if status_result == 1:
state.append(TryVerifyState.VERIFIED_AND_REAL)
self.add_score(user_id, 4 * score_k) # 再加 4 分
self.remain_rounds -= 1
if self.remain_rounds <= 0:
self.now_playing = False
state.append(TryVerifyState.GAME_END)
if not self.is_nextable(self.last_char):
if not await self.is_nextable(self.last_char):
# 没有成语可以接了,自动跳过
self._skip_idiom_async()
await self._skip_idiom_async()
self.add_buff_score(-100)
state.append(TryVerifyState.BUT_NO_NEXT)
return state
@ -317,16 +356,23 @@ class IdiomGame:
return self.last_char
@classmethod
def random_idiom_starting_with(cls, first_char: str) -> Optional[str]:
cls.init_lexicon()
if first_char not in cls.AVALIABLE_IDIOM_FIRST_CHAR:
async def random_idiom_starting_with(cls, first_char: str) -> Optional[str]:
await cls.init_lexicon()
result = await db_manager.query_by_sql_file(
ROOT_PATH / "sql" / "query_idiom_start_with.sql",
(first_char,)
)
if len(result) == 0:
return None
return secrets.choice(cls.AVALIABLE_IDIOM_FIRST_CHAR[first_char])
return result[0]["idiom"]
@classmethod
def init_lexicon(cls):
async def init_lexicon(cls):
if cls.__inited:
return
await db_manager.execute_by_sql_file(
ROOT_PATH / "sql" / "create_table.sql"
) # 确保数据库初始化
cls.__inited = True
# 成语大表
@ -334,11 +380,12 @@ class IdiomGame:
ALL_IDIOMS_INFOS = json.load(f)
# 词语大表
ALL_WORDS = []
with open(ASSETS_PATH / "lexicon" / "ci.json", "r", encoding="utf-8") as f:
jsonData = json.load(f)
cls.ALL_WORDS = [item["ci"] for item in jsonData]
logger.debug(f"Loaded {len(cls.ALL_WORDS)} words from ci.json")
logger.debug(f"Sample words: {cls.ALL_WORDS[:5]}")
ALL_WORDS = [item["ci"] for item in jsonData]
logger.debug(f"Loaded {len(ALL_WORDS)} words from ci.json")
logger.debug(f"Sample words: {ALL_WORDS[:5]}")
COMMON_WORDS = []
# 读取 COMMON 词语大表
@ -389,29 +436,44 @@ class IdiomGame:
logger.debug(f"Loaded additional {len(LOCAL_LLM_WORDS)} words from idiom_llm_storage.txt")
# 只有成语的大表
cls.ALL_IDIOMS = [idiom["word"] for idiom in ALL_IDIOMS_INFOS] + THUOCL_IDIOMS
cls.ALL_IDIOMS = list(set(cls.ALL_IDIOMS)) # 去重
ALL_IDIOMS = [idiom["word"] for idiom in ALL_IDIOMS_INFOS] + THUOCL_IDIOMS
ALL_IDIOMS = list(set(ALL_IDIOMS)) # 去重
# 批量插入数据库
await db_manager.execute_many_values_by_sql_file(
ROOT_PATH / "sql" / "insert_idiom.sql",
[(idiom,) for idiom in ALL_IDIOMS]
)
# 其他四字词语表,仅表示可以有这个词
cls.ALL_WORDS = (
[word for word in cls.ALL_WORDS if len(word) == 4]
ALL_WORDS = (
[word for word in ALL_WORDS if len(word) == 4]
+ THUOCL_WORDS
+ COMMON_WORDS
+ LOCAL_LLM_WORDS
)
cls.ALL_WORDS = list(set(cls.ALL_WORDS)) # 去重
# 插入数据库
await db_manager.execute_many_values_by_sql_file(
ROOT_PATH / "sql" / "insert_word.sql",
[(word,) for word in ALL_WORDS]
)
# 根据成语大表,划分出成语首字字典
for idiom in cls.ALL_IDIOMS + cls.ALL_WORDS:
if idiom[0] not in cls.IDIOM_FIRST_CHAR:
cls.IDIOM_FIRST_CHAR[idiom[0]] = []
cls.IDIOM_FIRST_CHAR[idiom[0]].append(idiom)
# 自定义词语 LOCAL_LLM_WORDS 插入数据库,兼容用
await db_manager.execute_many_values_by_sql_file(
ROOT_PATH / "sql" / "insert_custom_word.sql",
[(word,) for word in LOCAL_LLM_WORDS]
)
# 根据真正的成语大表,划分出有效成语首字字典
for idiom in cls.ALL_IDIOMS:
if idiom[0] not in cls.AVALIABLE_IDIOM_FIRST_CHAR:
cls.AVALIABLE_IDIOM_FIRST_CHAR[idiom[0]] = []
cls.AVALIABLE_IDIOM_FIRST_CHAR[idiom[0]].append(idiom)
# # 根据成语大表,划分出成语首字字典
# for idiom in cls.ALL_IDIOMS + cls.ALL_WORDS:
# if idiom[0] not in cls.IDIOM_FIRST_CHAR:
# cls.IDIOM_FIRST_CHAR[idiom[0]] = []
# cls.IDIOM_FIRST_CHAR[idiom[0]].append(idiom)
# # 根据真正的成语大表,划分出有效成语首字字典
# for idiom in cls.ALL_IDIOMS:
# if idiom[0] not in cls.AVALIABLE_IDIOM_FIRST_CHAR:
# cls.AVALIABLE_IDIOM_FIRST_CHAR[idiom[0]] = []
# cls.AVALIABLE_IDIOM_FIRST_CHAR[idiom[0]].append(idiom)
evt = on_alconna(
@ -443,7 +505,7 @@ async def play_game(
if rounds <= 0:
await evt.send(await UniMessage().text("干什么!你想玩负数局吗?").export())
return
state = IdiomGame.try_start_game(group_id, force)
state = await IdiomGame.try_start_game(group_id, force)
if state == TryStartState.ALREADY_PLAYING:
await evt.send(
await UniMessage()
@ -462,7 +524,7 @@ async def play_game(
.export()
)
instance = IdiomGame.INSTANCE_LIST[group_id]
instance.start_game(rounds)
await instance.start_game(rounds)
# 发布成语
await evt.send(
await UniMessage()
@ -514,7 +576,9 @@ async def end_game(event: BaseEvent, group_id: str):
for line in history_lines:
result_text += line + "\n"
await evt.send(await result_text.export())
instance.clear_score_board()
# instance.clear_score_board()
# 将实例删除
del IdiomGame.INSTANCE_LIST[group_id]
evt = on_alconna(
@ -553,7 +617,7 @@ async def _(target: DepLongTaskTarget):
instance = IdiomGame.INSTANCE_LIST.get(group_id)
if not instance or not instance.get_playing_state():
return
avaliable_idiom = IdiomGame.random_idiom_starting_with(instance.get_last_char())
avaliable_idiom = await IdiomGame.random_idiom_starting_with(instance.get_last_char())
# 发送哈哈狗图片
with open(ASSETS_PATH / "img" / "dog" / "haha_dog.jpg", "rb") as f:
img_data = f.read()

View File

@ -0,0 +1,15 @@
-- 创建成语大表
CREATE TABLE IF NOT EXISTS all_idioms (
id INTEGER PRIMARY KEY AUTOINCREMENT,
idiom VARCHAR(128) NOT NULL UNIQUE
);
CREATE TABLE IF NOT EXISTS all_words (
id INTEGER PRIMARY KEY AUTOINCREMENT,
word VARCHAR(128) NOT NULL UNIQUE
);
CREATE TABLE IF NOT EXISTS custom_words (
id INTEGER PRIMARY KEY AUTOINCREMENT,
word VARCHAR(128) NOT NULL UNIQUE
);

View File

@ -0,0 +1,3 @@
-- 插入自定义词
INSERT OR IGNORE INTO custom_words (word)
VALUES (?);

View File

@ -0,0 +1,3 @@
-- 插入成语大表,避免重复插入
INSERT OR IGNORE INTO all_idioms (idiom)
VALUES (?);

View File

@ -0,0 +1,3 @@
-- 插入词
INSERT OR IGNORE INTO all_words (word)
VALUES (?);

View File

@ -0,0 +1,5 @@
-- 查询是否有以 xx 开头的成语,有则返回真,否则假
SELECT EXISTS(
SELECT 1 FROM all_idioms
WHERE idiom LIKE ? || '%'
) AS DEED;

View File

@ -0,0 +1,7 @@
-- 查询成语是否在 all_idioms 中,如果存在则返回 1否则再判断是否在 custom_words 或 all_words 中,存在则返回 0否则返回 -1
SELECT
CASE
WHEN EXISTS (SELECT 1 FROM all_idioms WHERE idiom = ?) THEN 1
WHEN EXISTS (SELECT 1 FROM custom_words WHERE word = ?) OR EXISTS (SELECT 1 FROM all_words WHERE word = ?) THEN 0
ELSE -1
END AS status;

View File

@ -0,0 +1,4 @@
-- 查询以 xx 开头的成语,随机打乱后只取第一个
SELECT idiom FROM all_idioms
WHERE idiom LIKE ? || '%'
ORDER BY RANDOM() LIMIT 1;

View File

@ -0,0 +1,2 @@
-- 随机从 all_idioms 表中选择一个成语
SELECT idiom FROM all_idioms ORDER BY RANDOM() LIMIT 1;

163
poetry.lock generated
View File

@ -209,6 +209,30 @@ type = "legacy"
url = "https://pypi.tuna.tsinghua.edu.cn/simple"
reference = "mirrors"
[[package]]
name = "aiosqlite"
version = "0.21.0"
description = "asyncio bridge to the standard sqlite3 module"
optional = false
python-versions = ">=3.9"
groups = ["main"]
files = [
{file = "aiosqlite-0.21.0-py3-none-any.whl", hash = "sha256:2549cf4057f95f53dcba16f2b64e8e2791d7e1adedb13197dd8ed77bb226d7d0"},
{file = "aiosqlite-0.21.0.tar.gz", hash = "sha256:131bb8056daa3bc875608c631c678cda73922a2d4ba8aec373b19f18c17e7aa3"},
]
[package.dependencies]
typing_extensions = ">=4.0"
[package.extras]
dev = ["attribution (==1.7.1)", "black (==24.3.0)", "build (>=1.2)", "coverage[toml] (==7.6.10)", "flake8 (==7.0.0)", "flake8-bugbear (==24.12.12)", "flit (==3.10.1)", "mypy (==1.14.1)", "ufmt (==2.5.1)", "usort (==1.0.8.post1)"]
docs = ["sphinx (==8.1.3)", "sphinx-mdinclude (==0.6.1)"]
[package.source]
type = "legacy"
url = "https://pypi.tuna.tsinghua.edu.cn/simple"
reference = "mirrors"
[[package]]
name = "annotated-doc"
version = "0.0.3"
@ -946,12 +970,12 @@ version = "0.4.6"
description = "Cross-platform colored terminal text."
optional = false
python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,>=2.7"
groups = ["main"]
markers = "sys_platform == \"win32\" or platform_system == \"Windows\""
groups = ["main", "dev"]
files = [
{file = "colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6"},
{file = "colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44"},
]
markers = {main = "sys_platform == \"win32\" or platform_system == \"Windows\"", dev = "sys_platform == \"win32\""}
[package.source]
type = "legacy"
@ -1568,6 +1592,23 @@ type = "legacy"
url = "https://pypi.tuna.tsinghua.edu.cn/simple"
reference = "mirrors"
[[package]]
name = "iniconfig"
version = "2.3.0"
description = "brain-dead simple config-ini parsing"
optional = false
python-versions = ">=3.10"
groups = ["dev"]
files = [
{file = "iniconfig-2.3.0-py3-none-any.whl", hash = "sha256:f631c04d2c48c52b84d0d0549c99ff3859c98df65b3101406327ecc7d53fbf12"},
{file = "iniconfig-2.3.0.tar.gz", hash = "sha256:c76315c77db068650d49c5b56314774a7804df16fee4402c1f19d6d15d8c4730"},
]
[package.source]
type = "legacy"
url = "https://pypi.tuna.tsinghua.edu.cn/simple"
reference = "mirrors"
[[package]]
name = "jiter"
version = "0.11.1"
@ -2679,6 +2720,23 @@ type = "legacy"
url = "https://pypi.tuna.tsinghua.edu.cn/simple"
reference = "mirrors"
[[package]]
name = "packaging"
version = "25.0"
description = "Core utilities for Python packages"
optional = false
python-versions = ">=3.8"
groups = ["dev"]
files = [
{file = "packaging-25.0-py3-none-any.whl", hash = "sha256:29572ef2b1f17581046b3a2227d5c611fb25ec70ca1ba8554b24b0e69331a484"},
{file = "packaging-25.0.tar.gz", hash = "sha256:d443872c98d677bf60f6a1f2f8c1cb748e8fe762d2bf9d3148b5599295b0fc4f"},
]
[package.source]
type = "legacy"
url = "https://pypi.tuna.tsinghua.edu.cn/simple"
reference = "mirrors"
[[package]]
name = "pillow"
version = "11.3.0"
@ -2858,6 +2916,27 @@ type = "legacy"
url = "https://pypi.tuna.tsinghua.edu.cn/simple"
reference = "mirrors"
[[package]]
name = "pluggy"
version = "1.6.0"
description = "plugin and hook calling mechanisms for python"
optional = false
python-versions = ">=3.9"
groups = ["dev"]
files = [
{file = "pluggy-1.6.0-py3-none-any.whl", hash = "sha256:e920276dd6813095e9377c0bc5566d94c932c33b27a3e3945d8389c374dd4746"},
{file = "pluggy-1.6.0.tar.gz", hash = "sha256:7dcc130b76258d33b90f61b658791dede3486c3e6bfb003ee5c9bfb396dd22f3"},
]
[package.extras]
dev = ["pre-commit", "tox"]
testing = ["coverage", "pytest", "pytest-benchmark"]
[package.source]
type = "legacy"
url = "https://pypi.tuna.tsinghua.edu.cn/simple"
reference = "mirrors"
[[package]]
name = "propcache"
version = "0.4.1"
@ -3344,7 +3423,7 @@ version = "2.19.2"
description = "Pygments is a syntax highlighting package written in Python."
optional = false
python-versions = ">=3.8"
groups = ["main"]
groups = ["main", "dev"]
files = [
{file = "pygments-2.19.2-py3-none-any.whl", hash = "sha256:86540386c03d588bb81d44bc3928634ff26449851e99741617ecb9037ee5ec0b"},
{file = "pygments-2.19.2.tar.gz", hash = "sha256:636cb2477cec7f8952536970bc533bc43743542f70392ae026374600add5b887"},
@ -3375,6 +3454,58 @@ type = "legacy"
url = "https://pypi.tuna.tsinghua.edu.cn/simple"
reference = "mirrors"
[[package]]
name = "pytest"
version = "9.0.1"
description = "pytest: simple powerful testing with Python"
optional = false
python-versions = ">=3.10"
groups = ["dev"]
files = [
{file = "pytest-9.0.1-py3-none-any.whl", hash = "sha256:67be0030d194df2dfa7b556f2e56fb3c3315bd5c8822c6951162b92b32ce7dad"},
{file = "pytest-9.0.1.tar.gz", hash = "sha256:3e9c069ea73583e255c3b21cf46b8d3c56f6e3a1a8f6da94ccb0fcf57b9d73c8"},
]
[package.dependencies]
colorama = {version = ">=0.4", markers = "sys_platform == \"win32\""}
iniconfig = ">=1.0.1"
packaging = ">=22"
pluggy = ">=1.5,<2"
pygments = ">=2.7.2"
[package.extras]
dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "requests", "setuptools", "xmlschema"]
[package.source]
type = "legacy"
url = "https://pypi.tuna.tsinghua.edu.cn/simple"
reference = "mirrors"
[[package]]
name = "pytest-asyncio"
version = "1.3.0"
description = "Pytest support for asyncio"
optional = false
python-versions = ">=3.10"
groups = ["dev"]
files = [
{file = "pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5"},
{file = "pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5"},
]
[package.dependencies]
pytest = ">=8.2,<10"
typing-extensions = {version = ">=4.12", markers = "python_version < \"3.13\""}
[package.extras]
docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"]
testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"]
[package.source]
type = "legacy"
url = "https://pypi.tuna.tsinghua.edu.cn/simple"
reference = "mirrors"
[[package]]
name = "python-dotenv"
version = "1.2.1"
@ -3699,6 +3830,27 @@ type = "legacy"
url = "https://pypi.tuna.tsinghua.edu.cn/simple"
reference = "mirrors"
[[package]]
name = "sqlparse"
version = "0.5.3"
description = "A non-validating SQL parser."
optional = false
python-versions = ">=3.8"
groups = ["main"]
files = [
{file = "sqlparse-0.5.3-py3-none-any.whl", hash = "sha256:cf2196ed3418f3ba5de6af7e82c694a9fbdbfecccdfc72e281548517081f16ca"},
{file = "sqlparse-0.5.3.tar.gz", hash = "sha256:09f67787f56a0b16ecdbde1bfc7f5d9c3371ca683cfeaa8e6ff60b4807ec9272"},
]
[package.extras]
dev = ["build", "hatch"]
doc = ["sphinx"]
[package.source]
type = "legacy"
url = "https://pypi.tuna.tsinghua.edu.cn/simple"
reference = "mirrors"
[[package]]
name = "starlette"
version = "0.49.3"
@ -3902,11 +4054,12 @@ version = "4.15.0"
description = "Backported and Experimental Type Hints for Python 3.9+"
optional = false
python-versions = ">=3.9"
groups = ["main"]
groups = ["main", "dev"]
files = [
{file = "typing_extensions-4.15.0-py3-none-any.whl", hash = "sha256:f0fa19c6845758ab08074a0cfa8b7aecb71c999ca73d62883bc25cc018c4e548"},
{file = "typing_extensions-4.15.0.tar.gz", hash = "sha256:0cea48d173cc12fa28ecabc3b837ea3cf6f38c6d1136f85cbaaf598984861466"},
]
markers = {dev = "python_version == \"3.12\""}
[package.source]
type = "legacy"
@ -4528,4 +4681,4 @@ reference = "mirrors"
[metadata]
lock-version = "2.1"
python-versions = ">=3.12,<4.0"
content-hash = "478bd59d60d3b73397241c6ed552434486bd26d56cc3805ef34d1cfa1be7006e"
content-hash = "2c341fdc0d5b29ad3b24516c46e036b2eff4c11e244047d114971039255c2ac4"

View File

@ -27,6 +27,8 @@ dependencies = [
"playwright (>=1.55.0,<2.0.0)",
"openai (>=2.7.1,<3.0.0)",
"imageio (>=2.37.2,<3.0.0)",
"aiosqlite (>=0.20.0,<1.0.0)",
"sqlparse (>=0.5.0,<1.0.0)",
]
[tool.poetry]
@ -46,5 +48,7 @@ priority = "primary"
[dependency-groups]
dev = [
"rust-just (>=1.43.0,<2.0.0)"
"rust-just (>=1.43.0,<2.0.0)",
"pytest (>=9.0.1,<10.0.0)",
"pytest-asyncio (>=1.3.0,<2.0.0)"
]

View File

@ -22,3 +22,11 @@ logger.info(f"已经加载的插件数量 {len(plugins)}")
logger.info(f"期待加载的插件数量 {len_requires}")
assert len(plugins) == len_requires
# 测试数据库模块是否可以正确导入
try:
from konabot.common.database import DatabaseManager
logger.info("数据库模块导入成功")
except Exception as e:
logger.error(f"数据库模块导入失败: {e}")
raise

93
tests/test_database.py Normal file
View File

@ -0,0 +1,93 @@
import asyncio
import os
import tempfile
from pathlib import Path
import pytest
from konabot.common.database import DatabaseManager
@pytest.mark.asyncio
async def test_database_manager():
"""测试数据库管理器的基本功能"""
# 创建临时数据库文件
with tempfile.NamedTemporaryFile(suffix='.db', delete=False) as tmp_file:
db_path = tmp_file.name
try:
# 初始化数据库管理器
db_manager = DatabaseManager(db_path)
# 创建测试表
create_table_sql = """
CREATE TABLE IF NOT EXISTS test_users (
id INTEGER PRIMARY KEY,
name TEXT NOT NULL,
email TEXT UNIQUE
);
"""
await db_manager.execute(create_table_sql)
# 插入测试数据
insert_sql = "INSERT INTO test_users (name, email) VALUES (?, ?)"
await db_manager.execute(insert_sql, ("张三", "zhangsan@example.com"))
await db_manager.execute(insert_sql, ("李四", "lisi@example.com"))
# 查询数据
select_sql = "SELECT * FROM test_users WHERE name = ?"
results = await db_manager.query(select_sql, ("张三",))
assert len(results) == 1
assert results[0]["name"] == "张三"
assert results[0]["email"] == "zhangsan@example.com"
# 测试使用Path对象
results = await db_manager.query_by_sql_file(Path(__file__), ("李四",))
# 注意这里只是测试参数传递实际SQL文件内容不是有效的SQL
# 关闭所有连接
await db_manager.close_all_connections()
finally:
# 清理临时文件
if os.path.exists(db_path):
os.unlink(db_path)
@pytest.mark.asyncio
async def test_execute_script():
"""测试执行SQL脚本功能"""
# 创建临时数据库文件
with tempfile.NamedTemporaryFile(suffix='.db', delete=False) as tmp_file:
db_path = tmp_file.name
try:
# 初始化数据库管理器
db_manager = DatabaseManager(db_path)
# 创建测试表的脚本
script = """
CREATE TABLE IF NOT EXISTS test_products (
id INTEGER PRIMARY KEY,
name TEXT NOT NULL,
price REAL
);
INSERT INTO test_products (name, price) VALUES ('苹果', 5.0);
INSERT INTO test_products (name, price) VALUES ('香蕉', 3.0);
"""
await db_manager.execute_script(script)
# 查询数据
results = await db_manager.query("SELECT * FROM test_products ORDER BY name")
assert len(results) == 2
assert results[0]["name"] == "苹果"
assert results[1]["name"] == "香蕉"
# 关闭所有连接
await db_manager.close_all_connections()
finally:
# 清理临时文件
if os.path.exists(db_path):
os.unlink(db_path)