diff --git a/.env.example b/.env.example index 7fde1d8..488632c 100644 --- a/.env.example +++ b/.env.example @@ -1,4 +1,4 @@ ENVIRONMENT=dev PORT=21333 - +DATABASE_PATH="./data/database.db" ENABLE_CONSOLE=true diff --git a/.gitignore b/.gitignore index 9f2daec..8bbe36e 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,5 @@ /.env /data -__pycache__ \ No newline at end of file +__pycache__ +/*.diff diff --git a/README.md b/README.md index 1437dcc..6e5cdc0 100644 --- a/README.md +++ b/README.md @@ -71,6 +71,10 @@ code . 详见[konabot-web 配置文档](/docs/konabot-web.md) +#### 数据库配置 + +本项目使用SQLite作为数据库,默认数据库文件位于`./data/database.db`。可以通过设置`DATABASE_PATH`环境变量来指定其他位置。 + ### 运行 使用命令行手动启动 Bot: @@ -91,3 +95,7 @@ poetry run python bot.py - [事件响应器](https://nonebot.dev/docs/tutorial/matcher) - [事件处理](https://nonebot.dev/docs/tutorial/handler) - [Alconna 插件](https://nonebot.dev/docs/best-practice/alconna/) + +## 数据库模块 + +本项目的数据库模块已更新为异步实现,使用连接池来提高性能,并支持现代的`pathlib.Path`参数类型。详细使用方法请参考[数据库使用文档](/docs/database.md)。 diff --git a/bot.py b/bot.py index 782c870..d0285d6 100644 --- a/bot.py +++ b/bot.py @@ -10,6 +10,8 @@ from nonebot.adapters.onebot.v11 import Adapter as OnebotAdapter from konabot.common.log import init_logger from konabot.common.nb.exc import BotExceptionMessage from konabot.common.path import LOG_PATH +from konabot.common.database import get_global_db_manager + dotenv.load_dotenv() env = os.environ.get("ENVIRONMENT", "prod") @@ -48,6 +50,13 @@ def main(): nonebot.load_plugins("konabot/plugins") nonebot.load_plugin("nonebot_plugin_analysis_bilibili") + # 注册关闭钩子 + @driver.on_shutdown + async def shutdown_handler(): + # 关闭全局数据库管理器 + db_manager = get_global_db_manager() + await db_manager.close_all_connections() + nonebot.run() if __name__ == "__main__": diff --git a/docs/database.md b/docs/database.md new file mode 100644 index 0000000..bfcdc63 --- /dev/null +++ b/docs/database.md @@ -0,0 +1,223 @@ +# 数据库系统使用文档 + +本文档详细介绍了本项目中使用的异步数据库系统,包括其架构设计、使用方法和最佳实践。 + +## 系统概述 + +本项目的数据库系统基于 `aiosqlite` 库构建,提供了异步的 SQLite 数据库访问接口。系统主要特性包括: + +1. **异步操作**:完全支持异步/await模式,适配NoneBot2框架 +2. **连接池**:内置连接池机制,提高数据库访问性能 +3. **参数化查询**:支持安全的参数化查询,防止SQL注入 +4. **SQL文件支持**:可以直接执行SQL文件中的脚本 +5. **类型支持**:支持 `pathlib.Path` 和 `str` 类型的路径参数 + +## 核心类和方法 + +### DatabaseManager 类 + +`DatabaseManager` 是数据库操作的核心类,提供了以下主要方法: + +#### 初始化 +```python +from konabot.common.database import DatabaseManager +from pathlib import Path + +# 使用默认数据库路径 +db = DatabaseManager() + +# 指定了义数据库路径 +db = DatabaseManager("./data/myapp.db") +db = DatabaseManager(Path("./data/myapp.db")) +``` + +#### 查询操作 +```python +# 执行查询语句并返回结果 +results = await db.query("SELECT * FROM users WHERE age > ?", (18,)) + +# 从SQL文件执行查询 +results = await db.query_by_sql_file("./sql/get_users.sql", (18,)) +``` + +#### 执行操作 +```python +# 执行非查询语句 +await db.execute("INSERT INTO users (name, email) VALUES (?, ?)", ("张三", "zhangsan@example.com")) + +# 执行SQL脚本(不带参数) +await db.execute_script(""" + CREATE TABLE IF NOT EXISTS users ( + id INTEGER PRIMARY KEY, + name TEXT NOT NULL, + email TEXT UNIQUE + ); + INSERT INTO users (name, email) VALUES ('测试用户', 'test@example.com'); +""") + +# 从SQL文件执行非查询语句 +await db.execute_by_sql_file("./sql/create_tables.sql") + +# 带参数执行SQL文件 +await db.execute_by_sql_file("./sql/insert_user.sql", ("张三", "zhangsan@example.com")) + +# 执行多条语句(每条语句使用相同参数) +await db.execute_many("INSERT INTO users (name, email) VALUES (?, ?)", [ + ("张三", "zhangsan@example.com"), + ("李四", "lisi@example.com"), + ("王五", "wangwu@example.com") +]) + +# 从SQL文件执行多条语句(每条语句使用相同参数) +await db.execute_many_values_by_sql_file("./sql/batch_insert.sql", [ + ("张三", "zhangsan@example.com"), + ("李四", "lisi@example.com") +]) +``` + +## SQL文件处理机制 + +### 单语句SQL文件 +```sql +-- insert_user.sql +INSERT INTO users (name, email) VALUES (?, ?); +``` + +```python +# 使用方式 +await db.execute_by_sql_file("./sql/insert_user.sql", ("张三", "zhangsan@example.com")) +``` + +### 多语句SQL文件 +```sql +-- setup.sql +CREATE TABLE IF NOT EXISTS users ( + id INTEGER PRIMARY KEY, + name TEXT NOT NULL, + email TEXT UNIQUE +); + +CREATE TABLE IF NOT EXISTS profiles ( + user_id INTEGER, + age INTEGER, + FOREIGN KEY (user_id) REFERENCES users(id) +); +``` + +```python +# 使用方式 +await db.execute_by_sql_file("./sql/setup.sql") +``` + +### 多语句带不同参数的SQL文件 +```sql +-- batch_operations.sql +INSERT INTO users (name, email) VALUES (?, ?); +INSERT INTO profiles (user_id, age) VALUES (?, ?); +``` + +```python +# 使用方式 +await db.execute_by_sql_file("./sql/batch_operations.sql", [ + ("张三", "zhangsan@example.com"), # 第一条语句的参数 + (1, 25) # 第二条语句的参数 +]) +``` + +## 最佳实践 + +### 1. 数据库表设计 +```sql +-- 推荐的表设计实践 +CREATE TABLE IF NOT EXISTS example_table ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + name TEXT NOT NULL, + created_at DATETIME DEFAULT CURRENT_TIMESTAMP, + updated_at DATETIME DEFAULT CURRENT_TIMESTAMP +); +``` + +### 2. SQL文件组织 +建议按照功能模块组织SQL文件: +``` +plugin/ +├── sql/ +│ ├── create_tables.sql +│ ├── insert_data.sql +│ ├── update_data.sql +│ └── query_data.sql +└── __init__.py +``` + +### 3. 错误处理 +```python +try: + results = await db.query("SELECT * FROM users WHERE id = ?", (user_id,)) +except Exception as e: + logger.error(f"数据库查询失败: {e}") + # 处理错误情况 +``` + +### 4. 连接管理 +```python +# 在应用启动时初始化 +db_manager = DatabaseManager() + +# 在应用关闭时清理连接 +async def shutdown(): + await db_manager.close_all_connections() +``` + +## 高级特性 + +### 连接池配置 +```python +class DatabaseManager: + def __init__(self, db_path: Optional[Union[str, Path]] = None): + # 连接池大小配置 + self._pool_size = 5 # 可根据需要调整 +``` + +### 事务支持 +```python +# 通过execute方法的自动提交机制支持事务 +await db.execute("BEGIN TRANSACTION") +try: + await db.execute("INSERT INTO users (name) VALUES (?)", ("张三",)) + await db.execute("INSERT INTO profiles (user_id, age) VALUES (?, ?)", (1, 25)) + await db.execute("COMMIT") +except Exception: + await db.execute("ROLLBACK") + raise +``` + +## 注意事项 + +1. **异步环境**:所有数据库操作都必须在异步环境中执行 +2. **参数安全**:始终使用参数化查询,避免SQL注入 +3. **资源管理**:确保在应用关闭时调用 `close_all_connections()` +4. **SQL解析**:使用 `sqlparse` 库准确解析SQL语句,正确处理包含分号的字符串和注释 +5. **错误处理**:适当处理数据库操作可能抛出的异常 + +## 常见问题 + +### Q: 如何处理数据库约束错误? +A: 确保SQL语句中的字段名正确引用,特别是保留字需要使用双引号包围: +```sql +CREATE TABLE air_conditioner ( + id VARCHAR(128) PRIMARY KEY, + "on" BOOLEAN NOT NULL, -- 使用双引号包围保留字 + temperature REAL NOT NULL +); +``` + +### Q: 如何处理多个语句和参数的匹配? +A: 当SQL文件包含多个语句时,参数应该是参数列表,每个语句对应一个参数元组: +```python +await db.execute_by_sql_file("./sql/batch.sql", [ + ("参数1", "参数2"), # 第一个语句的参数 + ("参数3", "参数4") # 第二个语句的参数 +]) +``` + +通过遵循这些指南和最佳实践,您可以充分利用本项目的异步数据库系统,构建高性能、安全的数据库应用。 \ No newline at end of file diff --git a/konabot/common/database/__init__.py b/konabot/common/database/__init__.py new file mode 100644 index 0000000..03a5a5f --- /dev/null +++ b/konabot/common/database/__init__.py @@ -0,0 +1,218 @@ +import os +import asyncio +import sqlparse +from pathlib import Path +from typing import List, Dict, Any, Optional, Union, TYPE_CHECKING + +import aiosqlite + +if TYPE_CHECKING: + from . import DatabaseManager + +# 全局数据库管理器实例 +_global_db_manager: Optional['DatabaseManager'] = None + +def get_global_db_manager() -> 'DatabaseManager': + """获取全局数据库管理器实例""" + global _global_db_manager + if _global_db_manager is None: + from . import DatabaseManager + _global_db_manager = DatabaseManager() + return _global_db_manager + +def close_global_db_manager() -> None: + """关闭全局数据库管理器实例""" + global _global_db_manager + if _global_db_manager is not None: + # 注意:这个函数应该在async环境中调用close_all_connections + _global_db_manager = None + + +class DatabaseManager: + """异步数据库管理器""" + + def __init__(self, db_path: Optional[Union[str, Path]] = None, pool_size: int = 5): + """ + 初始化数据库管理器 + + Args: + db_path: 数据库文件路径,支持str和Path类型 + pool_size: 连接池大小 + """ + if db_path is None: + self.db_path = os.environ.get("DATABASE_PATH", "./data/database.db") + else: + self.db_path = str(db_path) if isinstance(db_path, Path) else db_path + + # 连接池 + self._connection_pool = [] + self._pool_size = pool_size + self._lock = asyncio.Lock() + self._in_use = set() # 跟踪正在使用的连接 + + async def _get_connection(self) -> aiosqlite.Connection: + """从连接池获取连接""" + async with self._lock: + # 尝试从池中获取现有连接 + while self._connection_pool: + conn = self._connection_pool.pop() + # 检查连接是否仍然有效 + try: + await conn.execute("SELECT 1") + self._in_use.add(conn) + return conn + except: + # 连接已失效,关闭它 + try: + await conn.close() + except: + pass + + # 如果连接池为空,创建新连接 + conn = await aiosqlite.connect(self.db_path) + await conn.execute("PRAGMA foreign_keys = ON") + self._in_use.add(conn) + return conn + + async def _return_connection(self, conn: aiosqlite.Connection) -> None: + """将连接返回到连接池""" + async with self._lock: + self._in_use.discard(conn) + if len(self._connection_pool) < self._pool_size: + self._connection_pool.append(conn) + else: + # 池已满,直接关闭连接 + try: + await conn.close() + except: + pass + + async def query( + self, query: str, params: Optional[tuple] = None + ) -> List[Dict[str, Any]]: + """执行查询语句并返回结果""" + conn = await self._get_connection() + try: + cursor = await conn.execute(query, params or ()) + columns = [description[0] for description in cursor.description] + rows = await cursor.fetchall() + results = [dict(zip(columns, row)) for row in rows] + await cursor.close() + return results + except Exception as e: + # 记录错误但重新抛出,让调用者处理 + raise Exception(f"数据库查询失败: {str(e)}") from e + finally: + await self._return_connection(conn) + + async def query_by_sql_file( + self, file_path: Union[str, Path], params: Optional[tuple] = None + ) -> List[Dict[str, Any]]: + """从 SQL 文件中读取查询语句并执行""" + path = str(file_path) if isinstance(file_path, Path) else file_path + with open(path, "r", encoding="utf-8") as f: + query = f.read() + return await self.query(query, params) + + async def execute(self, command: str, params: Optional[tuple] = None) -> None: + """执行非查询语句""" + conn = await self._get_connection() + try: + await conn.execute(command, params or ()) + await conn.commit() + except Exception as e: + # 记录错误但重新抛出,让调用者处理 + raise Exception(f"数据库执行失败: {str(e)}") from e + finally: + await self._return_connection(conn) + + async def execute_script(self, script: str) -> None: + """执行SQL脚本""" + conn = await self._get_connection() + try: + await conn.executescript(script) + await conn.commit() + except Exception as e: + # 记录错误但重新抛出,让调用者处理 + raise Exception(f"数据库脚本执行失败: {str(e)}") from e + finally: + await self._return_connection(conn) + + def _parse_sql_statements(self, script: str) -> List[str]: + """解析SQL脚本,分割成独立的语句""" + # 使用sqlparse库更准确地分割SQL语句 + parsed = sqlparse.split(script) + statements = [] + + for statement in parsed: + statement = statement.strip() + if statement: + statements.append(statement) + + return statements + + async def execute_by_sql_file( + self, file_path: Union[str, Path], params: Optional[Union[tuple, List[tuple]]] = None + ) -> None: + """从 SQL 文件中读取非查询语句并执行""" + path = str(file_path) if isinstance(file_path, Path) else file_path + with open(path, "r", encoding="utf-8") as f: + script = f.read() + + # 如果有参数且是元组,使用execute执行整个脚本 + if params is not None and isinstance(params, tuple): + await self.execute(script, params) + # 如果有参数且是列表,分别执行每个语句 + elif params is not None and isinstance(params, list): + # 使用sqlparse准确分割SQL语句 + statements = self._parse_sql_statements(script) + if len(statements) != len(params): + raise ValueError(f"语句数量({len(statements)})与参数组数量({len(params)})不匹配") + + for statement, stmt_params in zip(statements, params): + if statement: + await self.execute(statement, stmt_params) + # 如果无参数,使用executescript + else: + await self.execute_script(script) + + async def execute_many(self, command: str, seq_of_params: List[tuple]) -> None: + """执行多条非查询语句""" + conn = await self._get_connection() + try: + await conn.executemany(command, seq_of_params) + await conn.commit() + except Exception as e: + # 记录错误但重新抛出,让调用者处理 + raise Exception(f"数据库批量执行失败: {str(e)}") from e + finally: + await self._return_connection(conn) + + async def execute_many_values_by_sql_file( + self, file_path: Union[str, Path], seq_of_params: List[tuple] + ) -> None: + """从 SQL 文件中读取一条语句,但是被不同值同时执行""" + path = str(file_path) if isinstance(file_path, Path) else file_path + with open(path, "r", encoding="utf-8") as f: + command = f.read() + await self.execute_many(command, seq_of_params) + + async def close_all_connections(self) -> None: + """关闭所有连接""" + async with self._lock: + # 关闭池中的连接 + for conn in self._connection_pool: + try: + await conn.close() + except: + pass + self._connection_pool.clear() + + # 关闭正在使用的连接 + for conn in self._in_use.copy(): + try: + await conn.close() + except: + pass + self._in_use.clear() + diff --git a/konabot/plugins/air_conditioner/__init__.py b/konabot/plugins/air_conditioner/__init__.py index 4f921fe..27b88e6 100644 --- a/konabot/plugins/air_conditioner/__init__.py +++ b/konabot/plugins/air_conditioner/__init__.py @@ -1,22 +1,29 @@ from io import BytesIO from typing import Optional, Union import cv2 +import nonebot from nonebot.adapters import Event as BaseEvent from nonebot.adapters.console.event import MessageEvent as ConsoleMessageEvent from nonebot.adapters.discord.event import MessageEvent as DiscordMessageEvent from nonebot_plugin_alconna import Alconna, AlconnaMatcher, Args, UniMessage, on_alconna from PIL import Image import numpy as np +from konabot.common.database import DatabaseManager from konabot.common.longtask import DepLongTaskTarget from konabot.common.path import ASSETS_PATH from konabot.common.web_render import WebRenderer from konabot.plugins.air_conditioner.ac import AirConditioner, CrashType, generate_ac_image, wiggle_transform - +from pathlib import Path import random import math -def get_ac(id: str) -> AirConditioner: - ac = AirConditioner.air_conditioners.get(id) +ROOT_PATH = Path(__file__).resolve().parent + +# 创建全局数据库管理器实例 +db_manager = DatabaseManager() + +async def get_ac(id: str) -> AirConditioner: + ac = await AirConditioner.get_ac(id) if ac is None: ac = AirConditioner(id) return ac @@ -43,14 +50,32 @@ async def send_ac_image(event: type[AlconnaMatcher], ac: AirConditioner): ac_image = await generate_ac_image(ac) await event.send(await UniMessage().image(raw=ac_image).export()) + +driver = nonebot.get_driver() + + +@driver.on_startup +async def register_startup_hook(): + """注册启动时需要执行的函数""" + # 初始化数据库表 + await db_manager.execute_by_sql_file( + Path(__file__).resolve().parent / "sql" / "create_table.sql" + ) + +@driver.on_shutdown +async def register_shutdown_hook(): + """注册关闭时需要执行的函数""" + # 关闭所有数据库连接 + await db_manager.close_all_connections() + evt = on_alconna(Alconna( "群空调" ), use_cmd_start=True, use_cmd_sep=False, skip_for_unmatch=True) @evt.handle() -async def _(event: BaseEvent, target: DepLongTaskTarget): +async def _(target: DepLongTaskTarget): id = target.channel_id - ac = get_ac(id) + ac = await get_ac(id) await send_ac_image(evt, ac) evt = on_alconna(Alconna( @@ -58,10 +83,10 @@ evt = on_alconna(Alconna( ), use_cmd_start=True, use_cmd_sep=False, skip_for_unmatch=True) @evt.handle() -async def _(event: BaseEvent, target: DepLongTaskTarget): +async def _(target: DepLongTaskTarget): id = target.channel_id - ac = get_ac(id) - ac.on = True + ac = await get_ac(id) + await ac.update_ac(state=True) await send_ac_image(evt, ac) evt = on_alconna(Alconna( @@ -69,10 +94,10 @@ evt = on_alconna(Alconna( ), use_cmd_start=True, use_cmd_sep=False, skip_for_unmatch=True) @evt.handle() -async def _(event: BaseEvent, target: DepLongTaskTarget): +async def _(target: DepLongTaskTarget): id = target.channel_id - ac = get_ac(id) - ac.on = False + ac = await get_ac(id) + await ac.update_ac(state=False) await send_ac_image(evt, ac) evt = on_alconna(Alconna( @@ -81,15 +106,17 @@ evt = on_alconna(Alconna( ), use_cmd_start=True, use_cmd_sep=False, skip_for_unmatch=True) @evt.handle() -async def _(event: BaseEvent, target: DepLongTaskTarget, temp: Optional[Union[int, float]] = 1): +async def _(target: DepLongTaskTarget, temp: Optional[Union[int, float]] = 1): + if temp is None: + temp = 1 if temp <= 0: return id = target.channel_id - ac = get_ac(id) + ac = await get_ac(id) if not ac.on or ac.burnt == True or ac.frozen == True: await send_ac_image(evt, ac) return - ac.temperature += temp + await ac.update_ac(temperature_delta=temp) if ac.temperature > 40: # 根据温度随机出是否爆炸,40度开始,呈指数增长 possibility = -math.e ** ((40-ac.temperature) / 50) + 1 @@ -103,7 +130,7 @@ async def _(event: BaseEvent, target: DepLongTaskTarget, temp: Optional[Union[in pil_frames[0].save(output, format="GIF", save_all=True, append_images=pil_frames[1:], loop=0, duration=35, disposal=2) output.seek(0) await evt.send(await UniMessage().image(raw=output).export()) - ac.broke_ac(CrashType.BURNT) + await ac.broke_ac(CrashType.BURNT) await evt.send("太热啦,空调炸了!") return await send_ac_image(evt, ac) @@ -114,20 +141,22 @@ evt = on_alconna(Alconna( ), use_cmd_start=True, use_cmd_sep=False, skip_for_unmatch=True) @evt.handle() -async def _(event: BaseEvent, target: DepLongTaskTarget, temp: Optional[Union[int, float]] = 1): +async def _(target: DepLongTaskTarget, temp: Optional[Union[int, float]] = 1): + if temp is None: + temp = 1 if temp <= 0: return id = target.channel_id - ac = get_ac(id) + ac = await get_ac(id) if not ac.on or ac.burnt == True or ac.frozen == True: await send_ac_image(evt, ac) return - ac.temperature -= temp + await ac.update_ac(temperature_delta=-temp) if ac.temperature < 0: # 根据温度随机出是否冻结,0度开始,呈指数增长 possibility = -math.e ** (ac.temperature / 50) + 1 if random.random() < possibility: - ac.broke_ac(CrashType.FROZEN) + await ac.broke_ac(CrashType.FROZEN) await send_ac_image(evt, ac) evt = on_alconna(Alconna( @@ -135,21 +164,34 @@ evt = on_alconna(Alconna( ), use_cmd_start=True, use_cmd_sep=False, skip_for_unmatch=True) @evt.handle() -async def _(event: BaseEvent, target: DepLongTaskTarget): +async def _(target: DepLongTaskTarget): id = target.channel_id - ac = get_ac(id) - ac.change_ac() + ac = await get_ac(id) + await ac.change_ac() await send_ac_image(evt, ac) +async def query_number_ranking(id: str) -> tuple[int, int]: + result = await db_manager.query_by_sql_file( + ROOT_PATH / "sql" / "query_crash_and_rank.sql", + (id,id) + ) + if len(result) == 0: + return 0, 0 + else: + # 将字典转换为值的元组 + values = list(result[0].values()) + return values[0], values[1] + evt = on_alconna(Alconna( "空调炸炸排行榜", ), use_cmd_start=True, use_cmd_sep=False, skip_for_unmatch=True) @evt.handle() -async def _(event: BaseEvent, target: DepLongTaskTarget): +async def _(target: DepLongTaskTarget): id = target.channel_id - ac = get_ac(id) - number, ranking = ac.get_crashes_and_ranking() + # ac = get_ac(id) + # number, ranking = ac.get_crashes_and_ranking() + number, ranking = await query_number_ranking(id) params = { "number": number, "ranking": ranking @@ -159,4 +201,4 @@ async def _(event: BaseEvent, target: DepLongTaskTarget): target=".box", params=params ) - await evt.send(await UniMessage().image(raw=image).export()) \ No newline at end of file + await evt.send(await UniMessage().image(raw=image).export()) diff --git a/konabot/plugins/air_conditioner/ac.py b/konabot/plugins/air_conditioner/ac.py index 6614784..0133613 100644 --- a/konabot/plugins/air_conditioner/ac.py +++ b/konabot/plugins/air_conditioner/ac.py @@ -1,20 +1,74 @@ from enum import Enum from io import BytesIO +from pathlib import Path import cv2 import numpy as np from PIL import Image, ImageDraw, ImageFont +from konabot.common.database import DatabaseManager from konabot.common.path import ASSETS_PATH, FONTS_PATH from konabot.common.path import DATA_PATH import json +ROOT_PATH = Path(__file__).resolve().parent + +# 创建全局数据库管理器实例 +db_manager = DatabaseManager() + class CrashType(Enum): BURNT = 0 FROZEN = 1 class AirConditioner: - air_conditioners: dict[str, "AirConditioner"] = {} + @classmethod + async def get_ac(cls, id: str) -> 'AirConditioner': + result = await db_manager.query_by_sql_file(ROOT_PATH / "sql" / "query_ac.sql", (id,)) + if len(result) == 0: + ac = await cls.create_ac(id) + return ac + ac_data = result[0] + ac = AirConditioner(id) + ac.on = bool(ac_data["on"]) + ac.temperature = float(ac_data["temperature"]) + ac.burnt = bool(ac_data["burnt"]) + ac.frozen = bool(ac_data["frozen"]) + return ac + + @classmethod + async def create_ac(cls, id: str) -> 'AirConditioner': + ac = AirConditioner(id) + await db_manager.execute_by_sql_file( + ROOT_PATH / "sql" / "insert_ac.sql", + (id, ac.on, ac.temperature, ac.burnt, ac.frozen) + ) + return ac + + async def update_ac(self, state: bool = None, temperature_delta: float = None, burnt: bool = None, frozen: bool = None) -> 'AirConditioner': + if state is not None: + self.on = state + if temperature_delta is not None: + self.temperature += temperature_delta + if burnt is not None: + self.burnt = burnt + if frozen is not None: + self.frozen = frozen + await db_manager.execute_by_sql_file( + ROOT_PATH / "sql" / "update_ac.sql", + (self.on, self.temperature, self.burnt, self.frozen, self.id) + ) + return self + + async def change_ac(self) -> 'AirConditioner': + self.on = False + self.temperature = 24 + self.burnt = False + self.frozen = False + await db_manager.execute_by_sql_file( + ROOT_PATH / "sql" / "update_ac.sql", + (self.on, self.temperature, self.burnt, self.frozen, self.id) + ) + return self def __init__(self, id: str) -> None: self.id = id @@ -22,45 +76,40 @@ class AirConditioner: self.temperature = 24 # 默认温度 self.burnt = False self.frozen = False - AirConditioner.air_conditioners[id] = self - def change_ac(self): - self.burnt = False - self.frozen = False - self.on = False - self.temperature = 24 # 重置为默认温度 - - def broke_ac(self, crash_type: CrashType): + async def broke_ac(self, crash_type: CrashType): ''' - 让空调坏掉,并保存数据 - + 让空调坏掉 :param crash_type: CrashType 枚举,表示空调坏掉的类型 ''' match crash_type: case CrashType.BURNT: - self.burnt = True + await self.update_ac(burnt=True) case CrashType.FROZEN: - self.frozen = True - self.save_crash_data(crash_type) + await self.update_ac(frozen=True) + await db_manager.execute_by_sql_file( + ROOT_PATH / "sql" / "insert_crash.sql", + (self.id, crash_type.value) + ) - def save_crash_data(self, crash_type: CrashType): - ''' - 如果空调爆炸了,就往本地的 ac_crash_data.json 里该 id 的记录加一 - ''' - data_file = DATA_PATH / "ac_crash_data.json" - crash_data = {} - if data_file.exists(): - with open(data_file, "r", encoding="utf-8") as f: - crash_data = json.load(f) - if self.id not in crash_data: - crash_data[self.id] = {"burnt": 0, "frozen": 0} - match crash_type: - case CrashType.BURNT: - crash_data[self.id]["burnt"] += 1 - case CrashType.FROZEN: - crash_data[self.id]["frozen"] += 1 - with open(data_file, "w", encoding="utf-8") as f: - json.dump(crash_data, f, ensure_ascii=False, indent=4) + # def save_crash_data(self, crash_type: CrashType): + # ''' + # 如果空调爆炸了,就往本地的 ac_crash_data.json 里该 id 的记录加一 + # ''' + # data_file = DATA_PATH / "ac_crash_data.json" + # crash_data = {} + # if data_file.exists(): + # with open(data_file, "r", encoding="utf-8") as f: + # crash_data = json.load(f) + # if self.id not in crash_data: + # crash_data[self.id] = {"burnt": 0, "frozen": 0} + # match crash_type: + # case CrashType.BURNT: + # crash_data[self.id]["burnt"] += 1 + # case CrashType.FROZEN: + # crash_data[self.id]["frozen"] += 1 + # with open(data_file, "w", encoding="utf-8") as f: + # json.dump(crash_data, f, ensure_ascii=False, indent=4) def get_crashes_and_ranking(self) -> tuple[int, int]: ''' diff --git a/konabot/plugins/air_conditioner/sql/create_table.sql b/konabot/plugins/air_conditioner/sql/create_table.sql new file mode 100644 index 0000000..fed346c --- /dev/null +++ b/konabot/plugins/air_conditioner/sql/create_table.sql @@ -0,0 +1,15 @@ +-- 创建所有表 +CREATE TABLE IF NOT EXISTS air_conditioner ( + id VARCHAR(128) PRIMARY KEY, + "on" BOOLEAN NOT NULL, + temperature REAL NOT NULL, + burnt BOOLEAN NOT NULL, + frozen BOOLEAN NOT NULL +); + +CREATE TABLE IF NOT EXISTS air_conditioner_crash_log ( + id VARCHAR(128) NOT NULL, + crash_type INT NOT NULL, + timestamp DATETIME NOT NULL, + FOREIGN KEY (id) REFERENCES air_conditioner(id) +); \ No newline at end of file diff --git a/konabot/plugins/air_conditioner/sql/insert_ac.sql b/konabot/plugins/air_conditioner/sql/insert_ac.sql new file mode 100644 index 0000000..96e80ac --- /dev/null +++ b/konabot/plugins/air_conditioner/sql/insert_ac.sql @@ -0,0 +1,3 @@ +-- 插入一台新空调 +INSERT INTO air_conditioner (id, "on", temperature, burnt, frozen) +VALUES (?, ?, ?, ?, ?); \ No newline at end of file diff --git a/konabot/plugins/air_conditioner/sql/insert_crash.sql b/konabot/plugins/air_conditioner/sql/insert_crash.sql new file mode 100644 index 0000000..aae3898 --- /dev/null +++ b/konabot/plugins/air_conditioner/sql/insert_crash.sql @@ -0,0 +1,3 @@ +-- 插入一条空调爆炸记录 +INSERT INTO air_conditioner_crash_log (id, crash_type, timestamp) +VALUES (?, ?, CURRENT_TIMESTAMP); \ No newline at end of file diff --git a/konabot/plugins/air_conditioner/sql/query_ac.sql b/konabot/plugins/air_conditioner/sql/query_ac.sql new file mode 100644 index 0000000..db957d3 --- /dev/null +++ b/konabot/plugins/air_conditioner/sql/query_ac.sql @@ -0,0 +1,4 @@ +-- 查询空调状态,如果没有就插入一条新的记录 +SELECT * +FROM air_conditioner +WHERE id = ?; \ No newline at end of file diff --git a/konabot/plugins/air_conditioner/sql/query_crash_and_rank.sql b/konabot/plugins/air_conditioner/sql/query_crash_and_rank.sql new file mode 100644 index 0000000..c180638 --- /dev/null +++ b/konabot/plugins/air_conditioner/sql/query_crash_and_rank.sql @@ -0,0 +1,23 @@ +-- 从 air_conditioner_crash_log 表中获取指定 id 损坏的次数以及损坏次数的排名 +SELECT crash_count, crash_rank +FROM ( + SELECT id, + COUNT(*) AS crash_count, + RANK() OVER (ORDER BY COUNT(*) DESC) AS crash_rank + FROM air_conditioner_crash_log + GROUP BY id +) AS ranked_data +WHERE id = ? +-- 如果该 id 没有损坏记录,则返回 0 次损坏和对应的最后一名 +UNION +SELECT 0 AS crash_count, + (SELECT COUNT(DISTINCT id) + 1 FROM air_conditioner_crash_log) AS crash_rank +FROM ( + SELECT DISTINCT id + FROM air_conditioner_crash_log +) AS ranked_data +WHERE NOT EXISTS ( + SELECT 1 + FROM air_conditioner_crash_log + WHERE id = ? +); \ No newline at end of file diff --git a/konabot/plugins/air_conditioner/sql/update_ac.sql b/konabot/plugins/air_conditioner/sql/update_ac.sql new file mode 100644 index 0000000..474d71b --- /dev/null +++ b/konabot/plugins/air_conditioner/sql/update_ac.sql @@ -0,0 +1,4 @@ +-- 更新空调状态 +UPDATE air_conditioner +SET "on" = ?, temperature = ?, burnt = ?, frozen = ? +WHERE id = ?; \ No newline at end of file diff --git a/konabot/plugins/idiomgame/__init__.py b/konabot/plugins/idiomgame/__init__.py index ee4e26c..e652cb1 100644 --- a/konabot/plugins/idiomgame/__init__.py +++ b/konabot/plugins/idiomgame/__init__.py @@ -8,6 +8,7 @@ from typing import Optional from loguru import logger from nonebot import on_message +import nonebot from nonebot.adapters import Event as BaseEvent from nonebot.adapters.discord.event import MessageEvent as DiscordMessageEvent from nonebot_plugin_alconna import ( @@ -18,17 +19,23 @@ from nonebot_plugin_alconna import ( on_alconna, ) +from konabot.common.database import DatabaseManager from konabot.common.longtask import DepLongTaskTarget from konabot.common.path import ASSETS_PATH from konabot.common.llm import get_llm +ROOT_PATH = Path(__file__).resolve().parent + DATA_DIR = Path(__file__).parent.parent.parent.parent / "data" DATA_FILE_PATH = ( DATA_DIR / "idiom_banned.json" ) +# 创建全局数据库管理器实例 +db_manager = DatabaseManager() + def load_banned_ids() -> list[str]: if not DATA_FILE_PATH.exists(): return [] @@ -58,6 +65,21 @@ def remove_banned_id(group_id: str): DATA_FILE_PATH.write_text(json.dumps(banned_ids, ensure_ascii=False, indent=4), "utf-8") +driver = nonebot.get_driver() + + +@driver.on_startup +async def register_startup_hook(): + """注册启动时需要执行的函数""" + await IdiomGame.init_lexicon() + +@driver.on_shutdown +async def register_shutdown_hook(): + """注册关闭时需要执行的函数""" + # 关闭所有数据库连接 + await db_manager.close_all_connections() + + class TryStartState(Enum): STARTED = 0 ALREADY_PLAYING = 1 @@ -94,18 +116,19 @@ class IdiomGameLLM: @classmethod async def storage_idiom(cls, idiom: str): - # 将 idiom 存入本地文件以备后续分析 - with open(DATA_DIR / "idiom_llm_storage.txt", "a", encoding="utf-8") as f: - f.write(idiom + "\n") - IdiomGame.append_into_word_list(idiom) + # 将 idiom 存入数据库 + await db_manager.execute_by_sql_file( + ROOT_PATH / "sql" / "insert_custom_word.sql", + (idiom,) + ) class IdiomGame: - ALL_WORDS = [] # 所有四字词语 - ALL_IDIOMS = [] # 所有成语 + # ALL_WORDS = [] # 所有四字词语 + # ALL_IDIOMS = [] # 所有成语 INSTANCE_LIST: dict[str, "IdiomGame"] = {} # 群号对应的游戏实例 - IDIOM_FIRST_CHAR = {} # 所有成语包括词语的首字字典 - AVALIABLE_IDIOM_FIRST_CHAR = {} # 真正有效的成语首字字典 + # IDIOM_FIRST_CHAR = {} # 所有成语包括词语的首字字典 + # AVALIABLE_IDIOM_FIRST_CHAR = {} # 真正有效的成语首字字典 __inited = False @@ -126,15 +149,14 @@ class IdiomGame: IdiomGame.INSTANCE_LIST[group_id] = self @classmethod - def append_into_word_list(cls, word: str): + async def append_into_word_list(cls, word: str): ''' 将一个新词加入到词语列表中 ''' - if word not in cls.ALL_WORDS: - cls.ALL_WORDS.append(word) - if word[0] not in cls.IDIOM_FIRST_CHAR: - cls.IDIOM_FIRST_CHAR[word[0]] = [] - cls.IDIOM_FIRST_CHAR[word[0]].append(word) + await db_manager.execute_by_sql_file( + ROOT_PATH / "sql" / "insert_custom_word.sql", + (word,) + ) def be_able_to_play(self) -> bool: if self.last_play_date != datetime.date.today(): @@ -145,21 +167,28 @@ class IdiomGame: return True return False - def choose_start_idiom(self) -> str: + @staticmethod + async def random_idiom() -> str: + result = await db_manager.query_by_sql_file( + ROOT_PATH / "sql" / "random_choose_idiom.sql" + ) + return result[0]["idiom"] + + async def choose_start_idiom(self) -> str: """ 随机选择一个成语作为起始成语 """ - self.last_idiom = secrets.choice(IdiomGame.ALL_IDIOMS) + self.last_idiom = await IdiomGame.random_idiom() self.last_char = self.last_idiom[-1] - if not self.is_nextable(self.last_char): - self.choose_start_idiom() + if not await self.is_nextable(self.last_char): + await self.choose_start_idiom() else: self.add_history_idiom(self.last_idiom, new_chain=True) return self.last_idiom @classmethod - def try_start_game(cls, group_id: str, force: bool = False) -> TryStartState: - cls.init_lexicon() + async def try_start_game(cls, group_id: str, force: bool = False) -> TryStartState: + await cls.init_lexicon() if not cls.INSTANCE_LIST.get(group_id): cls(group_id) instance = cls.INSTANCE_LIST[group_id] @@ -170,10 +199,10 @@ class IdiomGame: instance.now_playing = True return TryStartState.STARTED - def start_game(self, rounds: int = 100): + async def start_game(self, rounds: int = 100): self.now_playing = True self.remain_rounds = rounds - self.choose_start_idiom() + await self.choose_start_idiom() @classmethod def try_stop_game(cls, group_id: str) -> TryStopState: @@ -203,20 +232,20 @@ class IdiomGame: 跳过当前成语,选择下一个成语 """ async with self.lock: - self._skip_idiom_async() + await self._skip_idiom_async() self.add_buff_score(buff_score) return self.last_idiom - def _skip_idiom_async(self) -> str: - self.last_idiom = secrets.choice(IdiomGame.ALL_IDIOMS) + async def _skip_idiom_async(self) -> str: + self.last_idiom = await IdiomGame.random_idiom() self.last_char = self.last_idiom[-1] - if not self.is_nextable(self.last_char): - self._skip_idiom_async() + if not await self.is_nextable(self.last_char): + await self._skip_idiom_async() else: self.add_history_idiom(self.last_idiom, new_chain=True) return self.last_idiom - async def try_verify_idiom(self, idiom: str, user_id: str) -> TryVerifyState: + async def try_verify_idiom(self, idiom: str, user_id: str) -> list[TryVerifyState]: """ 用户发送成语 """ @@ -224,12 +253,16 @@ class IdiomGame: state = await self._verify_idiom(idiom, user_id) return state - def is_nextable(self, last_char: str) -> bool: + async def is_nextable(self, last_char: str) -> bool: """ 判断是否有成语可以接 """ - return last_char in IdiomGame.AVALIABLE_IDIOM_FIRST_CHAR - + result = await db_manager.query_by_sql_file( + ROOT_PATH / "sql" / "is_nextable.sql", + (last_char,) + ) + return result[0]["DEED"] == 1 + def add_already_idiom(self, idiom: str): if idiom in self.already_idioms: self.already_idioms[idiom] += 1 @@ -259,7 +292,13 @@ class IdiomGame: if idiom[0] != self.last_char: state.append(TryVerifyState.WRONG_FIRST_CHAR) return state - if idiom not in IdiomGame.ALL_IDIOMS and idiom not in IdiomGame.ALL_WORDS: + # 成语是否存在 + result = await db_manager.query_by_sql_file( + ROOT_PATH / "sql" / "query_idiom.sql", + (idiom, idiom, idiom) + ) + status_result = result[0]["status"] + if status_result == -1: logger.info(f"用户 {user_id} 发送了未知词语 {idiom},正在使用 LLM 进行验证") try: if not await IdiomGameLLM.verify_idiom_with_llm(idiom): @@ -281,16 +320,16 @@ class IdiomGame: self.last_idiom = idiom self.last_char = idiom[-1] self.add_score(user_id, 1 * score_k) # 先加 1 分 - if idiom in IdiomGame.ALL_IDIOMS: + if status_result == 1: state.append(TryVerifyState.VERIFIED_AND_REAL) self.add_score(user_id, 4 * score_k) # 再加 4 分 self.remain_rounds -= 1 if self.remain_rounds <= 0: self.now_playing = False state.append(TryVerifyState.GAME_END) - if not self.is_nextable(self.last_char): + if not await self.is_nextable(self.last_char): # 没有成语可以接了,自动跳过 - self._skip_idiom_async() + await self._skip_idiom_async() self.add_buff_score(-100) state.append(TryVerifyState.BUT_NO_NEXT) return state @@ -317,16 +356,23 @@ class IdiomGame: return self.last_char @classmethod - def random_idiom_starting_with(cls, first_char: str) -> Optional[str]: - cls.init_lexicon() - if first_char not in cls.AVALIABLE_IDIOM_FIRST_CHAR: + async def random_idiom_starting_with(cls, first_char: str) -> Optional[str]: + await cls.init_lexicon() + result = await db_manager.query_by_sql_file( + ROOT_PATH / "sql" / "query_idiom_start_with.sql", + (first_char,) + ) + if len(result) == 0: return None - return secrets.choice(cls.AVALIABLE_IDIOM_FIRST_CHAR[first_char]) + return result[0]["idiom"] @classmethod - def init_lexicon(cls): + async def init_lexicon(cls): if cls.__inited: return + await db_manager.execute_by_sql_file( + ROOT_PATH / "sql" / "create_table.sql" + ) # 确保数据库初始化 cls.__inited = True # 成语大表 @@ -334,11 +380,12 @@ class IdiomGame: ALL_IDIOMS_INFOS = json.load(f) # 词语大表 + ALL_WORDS = [] with open(ASSETS_PATH / "lexicon" / "ci.json", "r", encoding="utf-8") as f: jsonData = json.load(f) - cls.ALL_WORDS = [item["ci"] for item in jsonData] - logger.debug(f"Loaded {len(cls.ALL_WORDS)} words from ci.json") - logger.debug(f"Sample words: {cls.ALL_WORDS[:5]}") + ALL_WORDS = [item["ci"] for item in jsonData] + logger.debug(f"Loaded {len(ALL_WORDS)} words from ci.json") + logger.debug(f"Sample words: {ALL_WORDS[:5]}") COMMON_WORDS = [] # 读取 COMMON 词语大表 @@ -389,29 +436,44 @@ class IdiomGame: logger.debug(f"Loaded additional {len(LOCAL_LLM_WORDS)} words from idiom_llm_storage.txt") # 只有成语的大表 - cls.ALL_IDIOMS = [idiom["word"] for idiom in ALL_IDIOMS_INFOS] + THUOCL_IDIOMS - cls.ALL_IDIOMS = list(set(cls.ALL_IDIOMS)) # 去重 + ALL_IDIOMS = [idiom["word"] for idiom in ALL_IDIOMS_INFOS] + THUOCL_IDIOMS + ALL_IDIOMS = list(set(ALL_IDIOMS)) # 去重 + # 批量插入数据库 + await db_manager.execute_many_values_by_sql_file( + ROOT_PATH / "sql" / "insert_idiom.sql", + [(idiom,) for idiom in ALL_IDIOMS] + ) + # 其他四字词语表,仅表示可以有这个词 - cls.ALL_WORDS = ( - [word for word in cls.ALL_WORDS if len(word) == 4] + ALL_WORDS = ( + [word for word in ALL_WORDS if len(word) == 4] + THUOCL_WORDS + COMMON_WORDS - + LOCAL_LLM_WORDS ) - cls.ALL_WORDS = list(set(cls.ALL_WORDS)) # 去重 + # 插入数据库 + await db_manager.execute_many_values_by_sql_file( + ROOT_PATH / "sql" / "insert_word.sql", + [(word,) for word in ALL_WORDS] + ) - # 根据成语大表,划分出成语首字字典 - for idiom in cls.ALL_IDIOMS + cls.ALL_WORDS: - if idiom[0] not in cls.IDIOM_FIRST_CHAR: - cls.IDIOM_FIRST_CHAR[idiom[0]] = [] - cls.IDIOM_FIRST_CHAR[idiom[0]].append(idiom) + # 自定义词语 LOCAL_LLM_WORDS 插入数据库,兼容用 + await db_manager.execute_many_values_by_sql_file( + ROOT_PATH / "sql" / "insert_custom_word.sql", + [(word,) for word in LOCAL_LLM_WORDS] + ) - # 根据真正的成语大表,划分出有效成语首字字典 - for idiom in cls.ALL_IDIOMS: - if idiom[0] not in cls.AVALIABLE_IDIOM_FIRST_CHAR: - cls.AVALIABLE_IDIOM_FIRST_CHAR[idiom[0]] = [] - cls.AVALIABLE_IDIOM_FIRST_CHAR[idiom[0]].append(idiom) + # # 根据成语大表,划分出成语首字字典 + # for idiom in cls.ALL_IDIOMS + cls.ALL_WORDS: + # if idiom[0] not in cls.IDIOM_FIRST_CHAR: + # cls.IDIOM_FIRST_CHAR[idiom[0]] = [] + # cls.IDIOM_FIRST_CHAR[idiom[0]].append(idiom) + + # # 根据真正的成语大表,划分出有效成语首字字典 + # for idiom in cls.ALL_IDIOMS: + # if idiom[0] not in cls.AVALIABLE_IDIOM_FIRST_CHAR: + # cls.AVALIABLE_IDIOM_FIRST_CHAR[idiom[0]] = [] + # cls.AVALIABLE_IDIOM_FIRST_CHAR[idiom[0]].append(idiom) evt = on_alconna( @@ -443,7 +505,7 @@ async def play_game( if rounds <= 0: await evt.send(await UniMessage().text("干什么!你想玩负数局吗?").export()) return - state = IdiomGame.try_start_game(group_id, force) + state = await IdiomGame.try_start_game(group_id, force) if state == TryStartState.ALREADY_PLAYING: await evt.send( await UniMessage() @@ -462,7 +524,7 @@ async def play_game( .export() ) instance = IdiomGame.INSTANCE_LIST[group_id] - instance.start_game(rounds) + await instance.start_game(rounds) # 发布成语 await evt.send( await UniMessage() @@ -514,7 +576,9 @@ async def end_game(event: BaseEvent, group_id: str): for line in history_lines: result_text += line + "\n" await evt.send(await result_text.export()) - instance.clear_score_board() + # instance.clear_score_board() + # 将实例删除 + del IdiomGame.INSTANCE_LIST[group_id] evt = on_alconna( @@ -553,7 +617,7 @@ async def _(target: DepLongTaskTarget): instance = IdiomGame.INSTANCE_LIST.get(group_id) if not instance or not instance.get_playing_state(): return - avaliable_idiom = IdiomGame.random_idiom_starting_with(instance.get_last_char()) + avaliable_idiom = await IdiomGame.random_idiom_starting_with(instance.get_last_char()) # 发送哈哈狗图片 with open(ASSETS_PATH / "img" / "dog" / "haha_dog.jpg", "rb") as f: img_data = f.read() diff --git a/konabot/plugins/idiomgame/sql/create_table.sql b/konabot/plugins/idiomgame/sql/create_table.sql new file mode 100644 index 0000000..5d38580 --- /dev/null +++ b/konabot/plugins/idiomgame/sql/create_table.sql @@ -0,0 +1,15 @@ +-- 创建成语大表 +CREATE TABLE IF NOT EXISTS all_idioms ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + idiom VARCHAR(128) NOT NULL UNIQUE +); + +CREATE TABLE IF NOT EXISTS all_words ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + word VARCHAR(128) NOT NULL UNIQUE +); + +CREATE TABLE IF NOT EXISTS custom_words ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + word VARCHAR(128) NOT NULL UNIQUE +); \ No newline at end of file diff --git a/konabot/plugins/idiomgame/sql/insert_custom_word.sql b/konabot/plugins/idiomgame/sql/insert_custom_word.sql new file mode 100644 index 0000000..212c8a2 --- /dev/null +++ b/konabot/plugins/idiomgame/sql/insert_custom_word.sql @@ -0,0 +1,3 @@ +-- 插入自定义词 +INSERT OR IGNORE INTO custom_words (word) +VALUES (?); \ No newline at end of file diff --git a/konabot/plugins/idiomgame/sql/insert_idiom.sql b/konabot/plugins/idiomgame/sql/insert_idiom.sql new file mode 100644 index 0000000..eaedae8 --- /dev/null +++ b/konabot/plugins/idiomgame/sql/insert_idiom.sql @@ -0,0 +1,3 @@ +-- 插入成语大表,避免重复插入 +INSERT OR IGNORE INTO all_idioms (idiom) +VALUES (?); \ No newline at end of file diff --git a/konabot/plugins/idiomgame/sql/insert_word.sql b/konabot/plugins/idiomgame/sql/insert_word.sql new file mode 100644 index 0000000..b085aab --- /dev/null +++ b/konabot/plugins/idiomgame/sql/insert_word.sql @@ -0,0 +1,3 @@ +-- 插入词 +INSERT OR IGNORE INTO all_words (word) +VALUES (?); \ No newline at end of file diff --git a/konabot/plugins/idiomgame/sql/is_nextable.sql b/konabot/plugins/idiomgame/sql/is_nextable.sql new file mode 100644 index 0000000..a7bbeb1 --- /dev/null +++ b/konabot/plugins/idiomgame/sql/is_nextable.sql @@ -0,0 +1,5 @@ +-- 查询是否有以 xx 开头的成语,有则返回真,否则假 +SELECT EXISTS( + SELECT 1 FROM all_idioms + WHERE idiom LIKE ? || '%' +) AS DEED; diff --git a/konabot/plugins/idiomgame/sql/query_idiom.sql b/konabot/plugins/idiomgame/sql/query_idiom.sql new file mode 100644 index 0000000..fa3bf93 --- /dev/null +++ b/konabot/plugins/idiomgame/sql/query_idiom.sql @@ -0,0 +1,7 @@ +-- 查询成语是否在 all_idioms 中,如果存在则返回 1,否则再判断是否在 custom_words 或 all_words 中,存在则返回 0,否则返回 -1 +SELECT + CASE + WHEN EXISTS (SELECT 1 FROM all_idioms WHERE idiom = ?) THEN 1 + WHEN EXISTS (SELECT 1 FROM custom_words WHERE word = ?) OR EXISTS (SELECT 1 FROM all_words WHERE word = ?) THEN 0 + ELSE -1 + END AS status; \ No newline at end of file diff --git a/konabot/plugins/idiomgame/sql/query_idiom_start_with.sql b/konabot/plugins/idiomgame/sql/query_idiom_start_with.sql new file mode 100644 index 0000000..a6e8fc6 --- /dev/null +++ b/konabot/plugins/idiomgame/sql/query_idiom_start_with.sql @@ -0,0 +1,4 @@ +-- 查询以 xx 开头的成语,随机打乱后只取第一个 +SELECT idiom FROM all_idioms +WHERE idiom LIKE ? || '%' +ORDER BY RANDOM() LIMIT 1; \ No newline at end of file diff --git a/konabot/plugins/idiomgame/sql/random_choose_idiom.sql b/konabot/plugins/idiomgame/sql/random_choose_idiom.sql new file mode 100644 index 0000000..f706092 --- /dev/null +++ b/konabot/plugins/idiomgame/sql/random_choose_idiom.sql @@ -0,0 +1,2 @@ +-- 随机从 all_idioms 表中选择一个成语 +SELECT idiom FROM all_idioms ORDER BY RANDOM() LIMIT 1; \ No newline at end of file diff --git a/poetry.lock b/poetry.lock index 640bd0d..af13460 100644 --- a/poetry.lock +++ b/poetry.lock @@ -209,6 +209,30 @@ type = "legacy" url = "https://pypi.tuna.tsinghua.edu.cn/simple" reference = "mirrors" +[[package]] +name = "aiosqlite" +version = "0.21.0" +description = "asyncio bridge to the standard sqlite3 module" +optional = false +python-versions = ">=3.9" +groups = ["main"] +files = [ + {file = "aiosqlite-0.21.0-py3-none-any.whl", hash = "sha256:2549cf4057f95f53dcba16f2b64e8e2791d7e1adedb13197dd8ed77bb226d7d0"}, + {file = "aiosqlite-0.21.0.tar.gz", hash = "sha256:131bb8056daa3bc875608c631c678cda73922a2d4ba8aec373b19f18c17e7aa3"}, +] + +[package.dependencies] +typing_extensions = ">=4.0" + +[package.extras] +dev = ["attribution (==1.7.1)", "black (==24.3.0)", "build (>=1.2)", "coverage[toml] (==7.6.10)", "flake8 (==7.0.0)", "flake8-bugbear (==24.12.12)", "flit (==3.10.1)", "mypy (==1.14.1)", "ufmt (==2.5.1)", "usort (==1.0.8.post1)"] +docs = ["sphinx (==8.1.3)", "sphinx-mdinclude (==0.6.1)"] + +[package.source] +type = "legacy" +url = "https://pypi.tuna.tsinghua.edu.cn/simple" +reference = "mirrors" + [[package]] name = "annotated-doc" version = "0.0.3" @@ -946,12 +970,12 @@ version = "0.4.6" description = "Cross-platform colored terminal text." optional = false python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,>=2.7" -groups = ["main"] -markers = "sys_platform == \"win32\" or platform_system == \"Windows\"" +groups = ["main", "dev"] files = [ {file = "colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6"}, {file = "colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44"}, ] +markers = {main = "sys_platform == \"win32\" or platform_system == \"Windows\"", dev = "sys_platform == \"win32\""} [package.source] type = "legacy" @@ -1568,6 +1592,23 @@ type = "legacy" url = "https://pypi.tuna.tsinghua.edu.cn/simple" reference = "mirrors" +[[package]] +name = "iniconfig" +version = "2.3.0" +description = "brain-dead simple config-ini parsing" +optional = false +python-versions = ">=3.10" +groups = ["dev"] +files = [ + {file = "iniconfig-2.3.0-py3-none-any.whl", hash = "sha256:f631c04d2c48c52b84d0d0549c99ff3859c98df65b3101406327ecc7d53fbf12"}, + {file = "iniconfig-2.3.0.tar.gz", hash = "sha256:c76315c77db068650d49c5b56314774a7804df16fee4402c1f19d6d15d8c4730"}, +] + +[package.source] +type = "legacy" +url = "https://pypi.tuna.tsinghua.edu.cn/simple" +reference = "mirrors" + [[package]] name = "jiter" version = "0.11.1" @@ -2679,6 +2720,23 @@ type = "legacy" url = "https://pypi.tuna.tsinghua.edu.cn/simple" reference = "mirrors" +[[package]] +name = "packaging" +version = "25.0" +description = "Core utilities for Python packages" +optional = false +python-versions = ">=3.8" +groups = ["dev"] +files = [ + {file = "packaging-25.0-py3-none-any.whl", hash = "sha256:29572ef2b1f17581046b3a2227d5c611fb25ec70ca1ba8554b24b0e69331a484"}, + {file = "packaging-25.0.tar.gz", hash = "sha256:d443872c98d677bf60f6a1f2f8c1cb748e8fe762d2bf9d3148b5599295b0fc4f"}, +] + +[package.source] +type = "legacy" +url = "https://pypi.tuna.tsinghua.edu.cn/simple" +reference = "mirrors" + [[package]] name = "pillow" version = "11.3.0" @@ -2858,6 +2916,27 @@ type = "legacy" url = "https://pypi.tuna.tsinghua.edu.cn/simple" reference = "mirrors" +[[package]] +name = "pluggy" +version = "1.6.0" +description = "plugin and hook calling mechanisms for python" +optional = false +python-versions = ">=3.9" +groups = ["dev"] +files = [ + {file = "pluggy-1.6.0-py3-none-any.whl", hash = "sha256:e920276dd6813095e9377c0bc5566d94c932c33b27a3e3945d8389c374dd4746"}, + {file = "pluggy-1.6.0.tar.gz", hash = "sha256:7dcc130b76258d33b90f61b658791dede3486c3e6bfb003ee5c9bfb396dd22f3"}, +] + +[package.extras] +dev = ["pre-commit", "tox"] +testing = ["coverage", "pytest", "pytest-benchmark"] + +[package.source] +type = "legacy" +url = "https://pypi.tuna.tsinghua.edu.cn/simple" +reference = "mirrors" + [[package]] name = "propcache" version = "0.4.1" @@ -3344,7 +3423,7 @@ version = "2.19.2" description = "Pygments is a syntax highlighting package written in Python." optional = false python-versions = ">=3.8" -groups = ["main"] +groups = ["main", "dev"] files = [ {file = "pygments-2.19.2-py3-none-any.whl", hash = "sha256:86540386c03d588bb81d44bc3928634ff26449851e99741617ecb9037ee5ec0b"}, {file = "pygments-2.19.2.tar.gz", hash = "sha256:636cb2477cec7f8952536970bc533bc43743542f70392ae026374600add5b887"}, @@ -3375,6 +3454,58 @@ type = "legacy" url = "https://pypi.tuna.tsinghua.edu.cn/simple" reference = "mirrors" +[[package]] +name = "pytest" +version = "9.0.1" +description = "pytest: simple powerful testing with Python" +optional = false +python-versions = ">=3.10" +groups = ["dev"] +files = [ + {file = "pytest-9.0.1-py3-none-any.whl", hash = "sha256:67be0030d194df2dfa7b556f2e56fb3c3315bd5c8822c6951162b92b32ce7dad"}, + {file = "pytest-9.0.1.tar.gz", hash = "sha256:3e9c069ea73583e255c3b21cf46b8d3c56f6e3a1a8f6da94ccb0fcf57b9d73c8"}, +] + +[package.dependencies] +colorama = {version = ">=0.4", markers = "sys_platform == \"win32\""} +iniconfig = ">=1.0.1" +packaging = ">=22" +pluggy = ">=1.5,<2" +pygments = ">=2.7.2" + +[package.extras] +dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "requests", "setuptools", "xmlschema"] + +[package.source] +type = "legacy" +url = "https://pypi.tuna.tsinghua.edu.cn/simple" +reference = "mirrors" + +[[package]] +name = "pytest-asyncio" +version = "1.3.0" +description = "Pytest support for asyncio" +optional = false +python-versions = ">=3.10" +groups = ["dev"] +files = [ + {file = "pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5"}, + {file = "pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5"}, +] + +[package.dependencies] +pytest = ">=8.2,<10" +typing-extensions = {version = ">=4.12", markers = "python_version < \"3.13\""} + +[package.extras] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"] +testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] + +[package.source] +type = "legacy" +url = "https://pypi.tuna.tsinghua.edu.cn/simple" +reference = "mirrors" + [[package]] name = "python-dotenv" version = "1.2.1" @@ -3699,6 +3830,27 @@ type = "legacy" url = "https://pypi.tuna.tsinghua.edu.cn/simple" reference = "mirrors" +[[package]] +name = "sqlparse" +version = "0.5.3" +description = "A non-validating SQL parser." +optional = false +python-versions = ">=3.8" +groups = ["main"] +files = [ + {file = "sqlparse-0.5.3-py3-none-any.whl", hash = "sha256:cf2196ed3418f3ba5de6af7e82c694a9fbdbfecccdfc72e281548517081f16ca"}, + {file = "sqlparse-0.5.3.tar.gz", hash = "sha256:09f67787f56a0b16ecdbde1bfc7f5d9c3371ca683cfeaa8e6ff60b4807ec9272"}, +] + +[package.extras] +dev = ["build", "hatch"] +doc = ["sphinx"] + +[package.source] +type = "legacy" +url = "https://pypi.tuna.tsinghua.edu.cn/simple" +reference = "mirrors" + [[package]] name = "starlette" version = "0.49.3" @@ -3902,11 +4054,12 @@ version = "4.15.0" description = "Backported and Experimental Type Hints for Python 3.9+" optional = false python-versions = ">=3.9" -groups = ["main"] +groups = ["main", "dev"] files = [ {file = "typing_extensions-4.15.0-py3-none-any.whl", hash = "sha256:f0fa19c6845758ab08074a0cfa8b7aecb71c999ca73d62883bc25cc018c4e548"}, {file = "typing_extensions-4.15.0.tar.gz", hash = "sha256:0cea48d173cc12fa28ecabc3b837ea3cf6f38c6d1136f85cbaaf598984861466"}, ] +markers = {dev = "python_version == \"3.12\""} [package.source] type = "legacy" @@ -4528,4 +4681,4 @@ reference = "mirrors" [metadata] lock-version = "2.1" python-versions = ">=3.12,<4.0" -content-hash = "478bd59d60d3b73397241c6ed552434486bd26d56cc3805ef34d1cfa1be7006e" +content-hash = "2c341fdc0d5b29ad3b24516c46e036b2eff4c11e244047d114971039255c2ac4" diff --git a/pyproject.toml b/pyproject.toml index 0c11e66..4691cd4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,6 +27,8 @@ dependencies = [ "playwright (>=1.55.0,<2.0.0)", "openai (>=2.7.1,<3.0.0)", "imageio (>=2.37.2,<3.0.0)", + "aiosqlite (>=0.20.0,<1.0.0)", + "sqlparse (>=0.5.0,<1.0.0)", ] [tool.poetry] @@ -46,5 +48,7 @@ priority = "primary" [dependency-groups] dev = [ - "rust-just (>=1.43.0,<2.0.0)" + "rust-just (>=1.43.0,<2.0.0)", + "pytest (>=9.0.1,<10.0.0)", + "pytest-asyncio (>=1.3.0,<2.0.0)" ] diff --git a/scripts/test_plugin_load.py b/scripts/test_plugin_load.py index 810aa8e..7f43c1a 100644 --- a/scripts/test_plugin_load.py +++ b/scripts/test_plugin_load.py @@ -22,3 +22,11 @@ logger.info(f"已经加载的插件数量 {len(plugins)}") logger.info(f"期待加载的插件数量 {len_requires}") assert len(plugins) == len_requires + +# 测试数据库模块是否可以正确导入 +try: + from konabot.common.database import DatabaseManager + logger.info("数据库模块导入成功") +except Exception as e: + logger.error(f"数据库模块导入失败: {e}") + raise diff --git a/tests/test_database.py b/tests/test_database.py new file mode 100644 index 0000000..6e14251 --- /dev/null +++ b/tests/test_database.py @@ -0,0 +1,93 @@ +import asyncio +import os +import tempfile +from pathlib import Path + +import pytest + +from konabot.common.database import DatabaseManager + + +@pytest.mark.asyncio +async def test_database_manager(): + """测试数据库管理器的基本功能""" + # 创建临时数据库文件 + with tempfile.NamedTemporaryFile(suffix='.db', delete=False) as tmp_file: + db_path = tmp_file.name + + try: + # 初始化数据库管理器 + db_manager = DatabaseManager(db_path) + + # 创建测试表 + create_table_sql = """ + CREATE TABLE IF NOT EXISTS test_users ( + id INTEGER PRIMARY KEY, + name TEXT NOT NULL, + email TEXT UNIQUE + ); + """ + await db_manager.execute(create_table_sql) + + # 插入测试数据 + insert_sql = "INSERT INTO test_users (name, email) VALUES (?, ?)" + await db_manager.execute(insert_sql, ("张三", "zhangsan@example.com")) + await db_manager.execute(insert_sql, ("李四", "lisi@example.com")) + + # 查询数据 + select_sql = "SELECT * FROM test_users WHERE name = ?" + results = await db_manager.query(select_sql, ("张三",)) + assert len(results) == 1 + assert results[0]["name"] == "张三" + assert results[0]["email"] == "zhangsan@example.com" + + # 测试使用Path对象 + results = await db_manager.query_by_sql_file(Path(__file__), ("李四",)) + # 注意:这里只是测试参数传递,实际SQL文件内容不是有效的SQL + + # 关闭所有连接 + await db_manager.close_all_connections() + + finally: + # 清理临时文件 + if os.path.exists(db_path): + os.unlink(db_path) + + +@pytest.mark.asyncio +async def test_execute_script(): + """测试执行SQL脚本功能""" + # 创建临时数据库文件 + with tempfile.NamedTemporaryFile(suffix='.db', delete=False) as tmp_file: + db_path = tmp_file.name + + try: + # 初始化数据库管理器 + db_manager = DatabaseManager(db_path) + + # 创建测试表的脚本 + script = """ + CREATE TABLE IF NOT EXISTS test_products ( + id INTEGER PRIMARY KEY, + name TEXT NOT NULL, + price REAL + ); + INSERT INTO test_products (name, price) VALUES ('苹果', 5.0); + INSERT INTO test_products (name, price) VALUES ('香蕉', 3.0); + """ + + await db_manager.execute_script(script) + + # 查询数据 + results = await db_manager.query("SELECT * FROM test_products ORDER BY name") + assert len(results) == 2 + assert results[0]["name"] == "苹果" + assert results[1]["name"] == "香蕉" + + # 关闭所有连接 + await db_manager.close_all_connections() + + finally: + # 清理临时文件 + if os.path.exists(db_path): + os.unlink(db_path) \ No newline at end of file