From 0d540eea4ce7e267328e0cc77145f7b20094080b Mon Sep 17 00:00:00 2001 From: passthem Date: Tue, 18 Nov 2025 23:55:31 +0800 Subject: [PATCH] =?UTF-8?q?=E6=88=91=E6=8B=BF=20AI=20=E6=94=B9=E5=9D=8F?= =?UTF-8?q?=E6=9E=AA=E4=BB=A3=E7=A0=81=EF=BC=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .gitignore | 1 - README.md | 8 + bot.py | 5 +- konabot/common/database/__init__.py | 165 ++++++++++++------ konabot/core/preinit.py | 15 -- konabot/plugins/air_conditioner/__init__.py | 66 ++++--- .../plugins/air_conditioner/__preinit__.py | 9 - konabot/plugins/air_conditioner/ac.py | 29 +-- .../air_conditioner/sql/create_table.sql | 2 +- .../plugins/air_conditioner/sql/insert_ac.sql | 2 +- .../plugins/air_conditioner/sql/update_ac.sql | 2 +- konabot/plugins/idiomgame/__init__.py | 96 +++++----- poetry.lock | 26 ++- pyproject.toml | 1 + scripts/test_plugin_load.py | 8 + tests/test_database.py | 93 ++++++++++ 16 files changed, 367 insertions(+), 161 deletions(-) delete mode 100644 konabot/core/preinit.py delete mode 100644 konabot/plugins/air_conditioner/__preinit__.py create mode 100644 tests/test_database.py diff --git a/.gitignore b/.gitignore index 8337d30..12b97cd 100644 --- a/.gitignore +++ b/.gitignore @@ -3,4 +3,3 @@ __pycache__ -*.db \ No newline at end of file diff --git a/README.md b/README.md index 1437dcc..70520af 100644 --- a/README.md +++ b/README.md @@ -71,6 +71,10 @@ code . 详见[konabot-web 配置文档](/docs/konabot-web.md) +#### 数据库配置 + +本项目使用SQLite作为数据库,默认数据库文件位于`./data/database.db`。可以通过设置`DATABASE_PATH`环境变量来指定其他位置。 + ### 运行 使用命令行手动启动 Bot: @@ -91,3 +95,7 @@ poetry run python bot.py - [事件响应器](https://nonebot.dev/docs/tutorial/matcher) - [事件处理](https://nonebot.dev/docs/tutorial/handler) - [Alconna 插件](https://nonebot.dev/docs/best-practice/alconna/) + +## 数据库模块 + +本项目的数据库模块已更新为异步实现,使用连接池来提高性能,并支持现代的`pathlib.Path`参数类型。详细使用方法请参考`konabot/common/database/__init__.py`文件中的实现。 diff --git a/bot.py b/bot.py index e4c56ca..c45b4f1 100644 --- a/bot.py +++ b/bot.py @@ -10,7 +10,7 @@ from nonebot.adapters.onebot.v11 import Adapter as OnebotAdapter from konabot.common.log import init_logger from konabot.common.nb.exc import BotExceptionMessage from konabot.common.path import LOG_PATH -from konabot.core.preinit import preinit + dotenv.load_dotenv() env = os.environ.get("ENVIRONMENT", "prod") @@ -49,9 +49,6 @@ def main(): nonebot.load_plugins("konabot/plugins") nonebot.load_plugin("nonebot_plugin_analysis_bilibili") - # 预加载 - preinit("konabot/plugins") - nonebot.run() if __name__ == "__main__": diff --git a/konabot/common/database/__init__.py b/konabot/common/database/__init__.py index 2a44469..b933bb4 100644 --- a/konabot/common/database/__init__.py +++ b/konabot/common/database/__init__.py @@ -1,64 +1,127 @@ import os -import sqlite3 -from typing import List, Dict, Any, Optional +import asyncio +from pathlib import Path +from typing import List, Dict, Any, Optional, Union + +import aiosqlite + class DatabaseManager: - """超级无敌神奇的数据库!""" - - @classmethod - def query(cls, query: str, params: Optional[tuple] = None) -> List[Dict[str, Any]]: + """异步数据库管理器""" + + def __init__(self, db_path: Optional[Union[str, Path]] = None): + """ + 初始化数据库管理器 + + Args: + db_path: 数据库文件路径,支持str和Path类型 + """ + if db_path is None: + self.db_path = os.environ.get("DATABASE_PATH", "./data/database.db") + else: + self.db_path = str(db_path) if isinstance(db_path, Path) else db_path + + # 连接池 + self._connection_pool = [] + self._pool_size = 5 + self._lock = asyncio.Lock() + + async def _get_connection(self) -> aiosqlite.Connection: + """从连接池获取连接""" + async with self._lock: + if self._connection_pool: + return self._connection_pool.pop() + + # 如果连接池为空,创建新连接 + conn = await aiosqlite.connect(self.db_path) + await conn.execute("PRAGMA foreign_keys = ON") + return conn + + async def _return_connection(self, conn: aiosqlite.Connection) -> None: + """将连接返回到连接池""" + async with self._lock: + if len(self._connection_pool) < self._pool_size: + self._connection_pool.append(conn) + else: + await conn.close() + + async def query( + self, query: str, params: Optional[tuple] = None + ) -> List[Dict[str, Any]]: """执行查询语句并返回结果""" - conn = sqlite3.connect(os.environ.get('DATABASE_PATH', './data/database.db')) - cursor = conn.cursor() - cursor.execute(query, params or ()) - columns = [description[0] for description in cursor.description] - results = [dict(zip(columns, row)) for row in cursor.fetchall()] - cursor.close() - conn.close() - return results - - @classmethod - def query_by_sql_file(cls, file_path: str, params: Optional[tuple] = None) -> List[Dict[str, Any]]: + conn = await self._get_connection() + try: + cursor = await conn.execute(query, params or ()) + columns = [description[0] for description in cursor.description] + rows = await cursor.fetchall() + results = [dict(zip(columns, row)) for row in rows] + await cursor.close() + return results + finally: + await self._return_connection(conn) + + async def query_by_sql_file( + self, file_path: Union[str, Path], params: Optional[tuple] = None + ) -> List[Dict[str, Any]]: """从 SQL 文件中读取查询语句并执行""" - with open(file_path, 'r', encoding='utf-8') as f: + path = str(file_path) if isinstance(file_path, Path) else file_path + with open(path, "r", encoding="utf-8") as f: query = f.read() - return cls.query(query, params) + return await self.query(query, params) - @classmethod - def execute(cls, command: str, params: Optional[tuple] = None) -> None: + async def execute(self, command: str, params: Optional[tuple] = None) -> None: """执行非查询语句""" - conn = sqlite3.connect(os.environ.get('DATABASE_PATH', './data/database.db')) - cursor = conn.cursor() - cursor.execute(command, params or ()) - conn.commit() - cursor.close() - conn.close() + conn = await self._get_connection() + try: + await conn.execute(command, params or ()) + await conn.commit() + finally: + await self._return_connection(conn) - @classmethod - def execute_by_sql_file(cls, file_path: str, params: Optional[tuple] = None) -> None: + async def execute_script(self, script: str) -> None: + """执行SQL脚本""" + conn = await self._get_connection() + try: + await conn.executescript(script) + await conn.commit() + finally: + await self._return_connection(conn) + + async def execute_by_sql_file( + self, file_path: Union[str, Path], params: Optional[tuple] = None + ) -> None: """从 SQL 文件中读取非查询语句并执行""" - with open(file_path, 'r', encoding='utf-8') as f: - command = f.read() - # 按照需要执行多条语句 - commands = command.split(';') - for cmd in commands: - cmd = cmd.strip() - if cmd: - cls.execute(cmd, params) - - @classmethod - def execute_many(cls, command: str, seq_of_params: List[tuple]) -> None: - """执行多条非查询语句""" - conn = sqlite3.connect(os.environ.get('DATABASE_PATH', './data/database.db')) - cursor = conn.cursor() - cursor.executemany(command, seq_of_params) - conn.commit() - cursor.close() - conn.close() + path = str(file_path) if isinstance(file_path, Path) else file_path + with open(path, "r", encoding="utf-8") as f: + script = f.read() + # 如果有参数,使用execute方法而不是execute_script + if params: + await self.execute(script, params) + else: + await self.execute_script(script) - @classmethod - def execute_many_values_by_sql_file(cls, file_path: str, seq_of_params: List[tuple]) -> None: + async def execute_many(self, command: str, seq_of_params: List[tuple]) -> None: + """执行多条非查询语句""" + conn = await self._get_connection() + try: + await conn.executemany(command, seq_of_params) + await conn.commit() + finally: + await self._return_connection(conn) + + async def execute_many_values_by_sql_file( + self, file_path: Union[str, Path], seq_of_params: List[tuple] + ) -> None: """从 SQL 文件中读取一条语句,但是被不同值同时执行""" - with open(file_path, 'r', encoding='utf-8') as f: + path = str(file_path) if isinstance(file_path, Path) else file_path + with open(path, "r", encoding="utf-8") as f: command = f.read() - cls.execute_many(command, seq_of_params) \ No newline at end of file + await self.execute_many(command, seq_of_params) + + async def close_all_connections(self) -> None: + """关闭所有连接""" + async with self._lock: + for conn in self._connection_pool: + await conn.close() + self._connection_pool.clear() + diff --git a/konabot/core/preinit.py b/konabot/core/preinit.py deleted file mode 100644 index ccfd3f7..0000000 --- a/konabot/core/preinit.py +++ /dev/null @@ -1,15 +0,0 @@ -from pathlib import Path - -from nonebot import logger - -def preinit(path: str): - # 执行预初始化,递归找到位于对应路径内文件名为 __preinit__.py 的所有文件都会被执行 - dir_path = Path(path) - for item in dir_path.iterdir(): - if item.is_dir(): - preinit(item) - elif item.is_file() and item.name == "__preinit__.py": - # 动态导入该文件以执行预初始化代码 - module_path = str(item.with_suffix("")).replace("/", ".").replace("\\", ".") - __import__(module_path) - logger.info(f"Preinitialized module: {module_path}") \ No newline at end of file diff --git a/konabot/plugins/air_conditioner/__init__.py b/konabot/plugins/air_conditioner/__init__.py index e148954..30a14dc 100644 --- a/konabot/plugins/air_conditioner/__init__.py +++ b/konabot/plugins/air_conditioner/__init__.py @@ -1,6 +1,7 @@ from io import BytesIO from typing import Optional, Union import cv2 +import nonebot from nonebot.adapters import Event as BaseEvent from nonebot.adapters.console.event import MessageEvent as ConsoleMessageEvent from nonebot.adapters.discord.event import MessageEvent as DiscordMessageEvent @@ -18,8 +19,11 @@ import math ROOT_PATH = Path(__file__).resolve().parent -def get_ac(id: str) -> AirConditioner: - ac = AirConditioner.get_ac(id) +# 创建全局数据库管理器实例 +db_manager = DatabaseManager() + +async def get_ac(id: str) -> AirConditioner: + ac = await AirConditioner.get_ac(id) if ac is None: ac = AirConditioner(id) return ac @@ -46,14 +50,26 @@ async def send_ac_image(event: type[AlconnaMatcher], ac: AirConditioner): ac_image = await generate_ac_image(ac) await event.send(await UniMessage().image(raw=ac_image).export()) + +driver = nonebot.get_driver() + + +@driver.on_startup +async def register_startup_hook(): + """注册启动时需要执行的函数""" + # 初始化数据库表 + await db_manager.execute_by_sql_file( + Path(__file__).resolve().parent / "sql" / "create_table.sql" + ) + evt = on_alconna(Alconna( "群空调" ), use_cmd_start=True, use_cmd_sep=False, skip_for_unmatch=True) @evt.handle() -async def _(event: BaseEvent, target: DepLongTaskTarget): +async def _(target: DepLongTaskTarget): id = target.channel_id - ac = get_ac(id) + ac = await get_ac(id) await send_ac_image(evt, ac) evt = on_alconna(Alconna( @@ -61,10 +77,10 @@ evt = on_alconna(Alconna( ), use_cmd_start=True, use_cmd_sep=False, skip_for_unmatch=True) @evt.handle() -async def _(event: BaseEvent, target: DepLongTaskTarget): +async def _(target: DepLongTaskTarget): id = target.channel_id - ac = get_ac(id) - ac.update_ac(state=True) + ac = await get_ac(id) + await ac.update_ac(state=True) await send_ac_image(evt, ac) evt = on_alconna(Alconna( @@ -72,10 +88,10 @@ evt = on_alconna(Alconna( ), use_cmd_start=True, use_cmd_sep=False, skip_for_unmatch=True) @evt.handle() -async def _(event: BaseEvent, target: DepLongTaskTarget): +async def _(target: DepLongTaskTarget): id = target.channel_id - ac = get_ac(id) - ac.update_ac(state=False) + ac = await get_ac(id) + await ac.update_ac(state=False) await send_ac_image(evt, ac) evt = on_alconna(Alconna( @@ -84,17 +100,17 @@ evt = on_alconna(Alconna( ), use_cmd_start=True, use_cmd_sep=False, skip_for_unmatch=True) @evt.handle() -async def _(event: BaseEvent, target: DepLongTaskTarget, temp: Optional[Union[int, float]] = 1): +async def _(target: DepLongTaskTarget, temp: Optional[Union[int, float]] = 1): if temp is None: temp = 1 if temp <= 0: return id = target.channel_id - ac = get_ac(id) + ac = await get_ac(id) if not ac.on or ac.burnt == True or ac.frozen == True: await send_ac_image(evt, ac) return - ac.update_ac(temperature_delta=temp) + await ac.update_ac(temperature_delta=temp) if ac.temperature > 40: # 根据温度随机出是否爆炸,40度开始,呈指数增长 possibility = -math.e ** ((40-ac.temperature) / 50) + 1 @@ -108,7 +124,7 @@ async def _(event: BaseEvent, target: DepLongTaskTarget, temp: Optional[Union[in pil_frames[0].save(output, format="GIF", save_all=True, append_images=pil_frames[1:], loop=0, duration=35, disposal=2) output.seek(0) await evt.send(await UniMessage().image(raw=output).export()) - ac.broke_ac(CrashType.BURNT) + await ac.broke_ac(CrashType.BURNT) await evt.send("太热啦,空调炸了!") return await send_ac_image(evt, ac) @@ -125,16 +141,16 @@ async def _(event: BaseEvent, target: DepLongTaskTarget, temp: Optional[Union[in if temp <= 0: return id = target.channel_id - ac = get_ac(id) + ac = await get_ac(id) if not ac.on or ac.burnt == True or ac.frozen == True: await send_ac_image(evt, ac) return - ac.update_ac(temperature_delta=-temp) + await ac.update_ac(temperature_delta=-temp) if ac.temperature < 0: # 根据温度随机出是否冻结,0度开始,呈指数增长 possibility = -math.e ** (ac.temperature / 50) + 1 if random.random() < possibility: - ac.broke_ac(CrashType.FROZEN) + await ac.broke_ac(CrashType.FROZEN) await send_ac_image(evt, ac) evt = on_alconna(Alconna( @@ -144,19 +160,21 @@ evt = on_alconna(Alconna( @evt.handle() async def _(event: BaseEvent, target: DepLongTaskTarget): id = target.channel_id - ac = get_ac(id) - ac.change_ac() + ac = await get_ac(id) + await ac.change_ac() await send_ac_image(evt, ac) -def query_number_ranking(id: str) -> tuple[int, int]: - result = DatabaseManager.query_by_sql_file( +async def query_number_ranking(id: str) -> tuple[int, int]: + result = await db_manager.query_by_sql_file( ROOT_PATH / "sql" / "query_crash_and_rank.sql", (id,id) ) if len(result) == 0: return 0, 0 else: - return result[0].values() + # 将字典转换为值的元组 + values = list(result[0].values()) + return values[0], values[1] evt = on_alconna(Alconna( "空调炸炸排行榜", @@ -167,7 +185,7 @@ async def _(event: BaseEvent, target: DepLongTaskTarget): id = target.channel_id # ac = get_ac(id) # number, ranking = ac.get_crashes_and_ranking() - number, ranking = query_number_ranking(id) + number, ranking = await query_number_ranking(id) params = { "number": number, "ranking": ranking @@ -177,4 +195,4 @@ async def _(event: BaseEvent, target: DepLongTaskTarget): target=".box", params=params ) - await evt.send(await UniMessage().image(raw=image).export()) \ No newline at end of file + await evt.send(await UniMessage().image(raw=image).export()) diff --git a/konabot/plugins/air_conditioner/__preinit__.py b/konabot/plugins/air_conditioner/__preinit__.py deleted file mode 100644 index 67054a0..0000000 --- a/konabot/plugins/air_conditioner/__preinit__.py +++ /dev/null @@ -1,9 +0,0 @@ -# 预初始化,只要是导入本插件包就会执行这里的代码 -from pathlib import Path - -from konabot.common.database import DatabaseManager - -# 初始化数据库表 -DatabaseManager.execute_by_sql_file( - Path(__file__).resolve().parent / "sql" / "create_table.sql" -) diff --git a/konabot/plugins/air_conditioner/ac.py b/konabot/plugins/air_conditioner/ac.py index 9be7619..0133613 100644 --- a/konabot/plugins/air_conditioner/ac.py +++ b/konabot/plugins/air_conditioner/ac.py @@ -13,16 +13,19 @@ import json ROOT_PATH = Path(__file__).resolve().parent +# 创建全局数据库管理器实例 +db_manager = DatabaseManager() + class CrashType(Enum): BURNT = 0 FROZEN = 1 class AirConditioner: @classmethod - def get_ac(cls, id: str) -> 'AirConditioner': - result = DatabaseManager.query_by_sql_file(ROOT_PATH / "sql" / "query_ac.sql", (id,)) + async def get_ac(cls, id: str) -> 'AirConditioner': + result = await db_manager.query_by_sql_file(ROOT_PATH / "sql" / "query_ac.sql", (id,)) if len(result) == 0: - ac = cls.create_ac(id) + ac = await cls.create_ac(id) return ac ac_data = result[0] ac = AirConditioner(id) @@ -33,15 +36,15 @@ class AirConditioner: return ac @classmethod - def create_ac(cls, id: str) -> 'AirConditioner': + async def create_ac(cls, id: str) -> 'AirConditioner': ac = AirConditioner(id) - DatabaseManager.execute_by_sql_file( + await db_manager.execute_by_sql_file( ROOT_PATH / "sql" / "insert_ac.sql", (id, ac.on, ac.temperature, ac.burnt, ac.frozen) ) return ac - def update_ac(self, state: bool = None, temperature_delta: float = None, burnt: bool = None, frozen: bool = None) -> 'AirConditioner': + async def update_ac(self, state: bool = None, temperature_delta: float = None, burnt: bool = None, frozen: bool = None) -> 'AirConditioner': if state is not None: self.on = state if temperature_delta is not None: @@ -50,18 +53,18 @@ class AirConditioner: self.burnt = burnt if frozen is not None: self.frozen = frozen - DatabaseManager.execute_by_sql_file( + await db_manager.execute_by_sql_file( ROOT_PATH / "sql" / "update_ac.sql", (self.on, self.temperature, self.burnt, self.frozen, self.id) ) return self - def change_ac(self) -> 'AirConditioner': + async def change_ac(self) -> 'AirConditioner': self.on = False self.temperature = 24 self.burnt = False self.frozen = False - DatabaseManager.execute_by_sql_file( + await db_manager.execute_by_sql_file( ROOT_PATH / "sql" / "update_ac.sql", (self.on, self.temperature, self.burnt, self.frozen, self.id) ) @@ -74,17 +77,17 @@ class AirConditioner: self.burnt = False self.frozen = False - def broke_ac(self, crash_type: CrashType): + async def broke_ac(self, crash_type: CrashType): ''' 让空调坏掉 :param crash_type: CrashType 枚举,表示空调坏掉的类型 ''' match crash_type: case CrashType.BURNT: - self.update_ac(burnt=True) + await self.update_ac(burnt=True) case CrashType.FROZEN: - self.update_ac(frozen=True) - DatabaseManager.execute_by_sql_file( + await self.update_ac(frozen=True) + await db_manager.execute_by_sql_file( ROOT_PATH / "sql" / "insert_crash.sql", (self.id, crash_type.value) ) diff --git a/konabot/plugins/air_conditioner/sql/create_table.sql b/konabot/plugins/air_conditioner/sql/create_table.sql index 5203e23..fed346c 100644 --- a/konabot/plugins/air_conditioner/sql/create_table.sql +++ b/konabot/plugins/air_conditioner/sql/create_table.sql @@ -1,7 +1,7 @@ -- 创建所有表 CREATE TABLE IF NOT EXISTS air_conditioner ( id VARCHAR(128) PRIMARY KEY, - 'on' BOOLEAN NOT NULL, + "on" BOOLEAN NOT NULL, temperature REAL NOT NULL, burnt BOOLEAN NOT NULL, frozen BOOLEAN NOT NULL diff --git a/konabot/plugins/air_conditioner/sql/insert_ac.sql b/konabot/plugins/air_conditioner/sql/insert_ac.sql index 3fb1c76..96e80ac 100644 --- a/konabot/plugins/air_conditioner/sql/insert_ac.sql +++ b/konabot/plugins/air_conditioner/sql/insert_ac.sql @@ -1,3 +1,3 @@ -- 插入一台新空调 -INSERT INTO air_conditioner (id, 'on', temperature, burnt, frozen) +INSERT INTO air_conditioner (id, "on", temperature, burnt, frozen) VALUES (?, ?, ?, ?, ?); \ No newline at end of file diff --git a/konabot/plugins/air_conditioner/sql/update_ac.sql b/konabot/plugins/air_conditioner/sql/update_ac.sql index df9145e..474d71b 100644 --- a/konabot/plugins/air_conditioner/sql/update_ac.sql +++ b/konabot/plugins/air_conditioner/sql/update_ac.sql @@ -1,4 +1,4 @@ -- 更新空调状态 UPDATE air_conditioner -SET 'on' = ?, temperature = ?, burnt = ?, frozen = ? +SET "on" = ?, temperature = ?, burnt = ?, frozen = ? WHERE id = ?; \ No newline at end of file diff --git a/konabot/plugins/idiomgame/__init__.py b/konabot/plugins/idiomgame/__init__.py index 36710aa..7f1909f 100644 --- a/konabot/plugins/idiomgame/__init__.py +++ b/konabot/plugins/idiomgame/__init__.py @@ -8,6 +8,7 @@ from typing import Optional from loguru import logger from nonebot import on_message +import nonebot from nonebot.adapters import Event as BaseEvent from nonebot.adapters.discord.event import MessageEvent as DiscordMessageEvent from nonebot_plugin_alconna import ( @@ -32,6 +33,9 @@ DATA_FILE_PATH = ( DATA_DIR / "idiom_banned.json" ) +# 创建全局数据库管理器实例 +db_manager = DatabaseManager() + def load_banned_ids() -> list[str]: if not DATA_FILE_PATH.exists(): return [] @@ -61,6 +65,15 @@ def remove_banned_id(group_id: str): DATA_FILE_PATH.write_text(json.dumps(banned_ids, ensure_ascii=False, indent=4), "utf-8") +driver = nonebot.get_driver() + + +@driver.on_startup +async def register_startup_hook(): + """注册启动时需要执行的函数""" + await IdiomGame.init_lexicon() + + class TryStartState(Enum): STARTED = 0 ALREADY_PLAYING = 1 @@ -98,7 +111,7 @@ class IdiomGameLLM: @classmethod async def storage_idiom(cls, idiom: str): # 将 idiom 存入数据库 - DatabaseManager.execute_by_sql_file( + await db_manager.execute_by_sql_file( ROOT_PATH / "sql" / "insert_custom_word.sql", (idiom,) ) @@ -130,11 +143,11 @@ class IdiomGame: IdiomGame.INSTANCE_LIST[group_id] = self @classmethod - def append_into_word_list(cls, word: str): + async def append_into_word_list(cls, word: str): ''' 将一个新词加入到词语列表中 ''' - DatabaseManager.execute_by_sql_file( + await db_manager.execute_by_sql_file( ROOT_PATH / "sql" / "insert_custom_word.sql", (word,) ) @@ -149,26 +162,27 @@ class IdiomGame: return False @staticmethod - def random_idiom() -> str: - return DatabaseManager.query_by_sql_file( + async def random_idiom() -> str: + result = await db_manager.query_by_sql_file( ROOT_PATH / "sql" / "random_choose_idiom.sql" - )[0]["idiom"] + ) + return result[0]["idiom"] - def choose_start_idiom(self) -> str: + async def choose_start_idiom(self) -> str: """ 随机选择一个成语作为起始成语 """ - self.last_idiom = IdiomGame.random_idiom() + self.last_idiom = await IdiomGame.random_idiom() self.last_char = self.last_idiom[-1] - if not self.is_nextable(self.last_char): - self.choose_start_idiom() + if not await self.is_nextable(self.last_char): + await self.choose_start_idiom() else: self.add_history_idiom(self.last_idiom, new_chain=True) return self.last_idiom @classmethod - def try_start_game(cls, group_id: str, force: bool = False) -> TryStartState: - cls.init_lexicon() + async def try_start_game(cls, group_id: str, force: bool = False) -> TryStartState: + await cls.init_lexicon() if not cls.INSTANCE_LIST.get(group_id): cls(group_id) instance = cls.INSTANCE_LIST[group_id] @@ -179,10 +193,10 @@ class IdiomGame: instance.now_playing = True return TryStartState.STARTED - def start_game(self, rounds: int = 100): + async def start_game(self, rounds: int = 100): self.now_playing = True self.remain_rounds = rounds - self.choose_start_idiom() + await self.choose_start_idiom() @classmethod def try_stop_game(cls, group_id: str) -> TryStopState: @@ -212,20 +226,20 @@ class IdiomGame: 跳过当前成语,选择下一个成语 """ async with self.lock: - self._skip_idiom_async() + await self._skip_idiom_async() self.add_buff_score(buff_score) return self.last_idiom - def _skip_idiom_async(self) -> str: - self.last_idiom = IdiomGame.random_idiom() + async def _skip_idiom_async(self) -> str: + self.last_idiom = await IdiomGame.random_idiom() self.last_char = self.last_idiom[-1] - if not self.is_nextable(self.last_char): - self._skip_idiom_async() + if not await self.is_nextable(self.last_char): + await self._skip_idiom_async() else: self.add_history_idiom(self.last_idiom, new_chain=True) return self.last_idiom - async def try_verify_idiom(self, idiom: str, user_id: str) -> TryVerifyState: + async def try_verify_idiom(self, idiom: str, user_id: str) -> list[TryVerifyState]: """ 用户发送成语 """ @@ -233,14 +247,15 @@ class IdiomGame: state = await self._verify_idiom(idiom, user_id) return state - def is_nextable(self, last_char: str) -> bool: + async def is_nextable(self, last_char: str) -> bool: """ 判断是否有成语可以接 """ - return DatabaseManager.query_by_sql_file( + result = await db_manager.query_by_sql_file( ROOT_PATH / "sql" / "is_nextable.sql", (last_char,) - )[0]["DEED"] == 1 + ) + return result[0]["DEED"] == 1 def add_already_idiom(self, idiom: str): if idiom in self.already_idioms: @@ -272,11 +287,12 @@ class IdiomGame: state.append(TryVerifyState.WRONG_FIRST_CHAR) return state # 成语是否存在 - result = DatabaseManager.query_by_sql_file( + result = await db_manager.query_by_sql_file( ROOT_PATH / "sql" / "query_idiom.sql", (idiom, idiom, idiom) - )[0]["status"] - if result == -1: + ) + status_result = result[0]["status"] + if status_result == -1: logger.info(f"用户 {user_id} 发送了未知词语 {idiom},正在使用 LLM 进行验证") try: if not await IdiomGameLLM.verify_idiom_with_llm(idiom): @@ -298,16 +314,16 @@ class IdiomGame: self.last_idiom = idiom self.last_char = idiom[-1] self.add_score(user_id, 1 * score_k) # 先加 1 分 - if result == 1: + if status_result == 1: state.append(TryVerifyState.VERIFIED_AND_REAL) self.add_score(user_id, 4 * score_k) # 再加 4 分 self.remain_rounds -= 1 if self.remain_rounds <= 0: self.now_playing = False state.append(TryVerifyState.GAME_END) - if not self.is_nextable(self.last_char): + if not await self.is_nextable(self.last_char): # 没有成语可以接了,自动跳过 - self._skip_idiom_async() + await self._skip_idiom_async() self.add_buff_score(-100) state.append(TryVerifyState.BUT_NO_NEXT) return state @@ -334,9 +350,9 @@ class IdiomGame: return self.last_char @classmethod - def random_idiom_starting_with(cls, first_char: str) -> Optional[str]: - cls.init_lexicon() - result = DatabaseManager.query_by_sql_file( + async def random_idiom_starting_with(cls, first_char: str) -> Optional[str]: + await cls.init_lexicon() + result = await db_manager.query_by_sql_file( ROOT_PATH / "sql" / "query_idiom_start_with.sql", (first_char,) ) @@ -345,10 +361,10 @@ class IdiomGame: return result[0]["idiom"] @classmethod - def init_lexicon(cls): + async def init_lexicon(cls): if cls.__inited: return - DatabaseManager.execute_by_sql_file( + await db_manager.execute_by_sql_file( ROOT_PATH / "sql" / "create_table.sql" ) # 确保数据库初始化 cls.__inited = True @@ -417,7 +433,7 @@ class IdiomGame: ALL_IDIOMS = [idiom["word"] for idiom in ALL_IDIOMS_INFOS] + THUOCL_IDIOMS ALL_IDIOMS = list(set(ALL_IDIOMS)) # 去重 # 批量插入数据库 - DatabaseManager.execute_many_values_by_sql_file( + await db_manager.execute_many_values_by_sql_file( ROOT_PATH / "sql" / "insert_idiom.sql", [(idiom,) for idiom in ALL_IDIOMS] ) @@ -430,13 +446,13 @@ class IdiomGame: + COMMON_WORDS ) # 插入数据库 - DatabaseManager.execute_many_values_by_sql_file( + await db_manager.execute_many_values_by_sql_file( ROOT_PATH / "sql" / "insert_word.sql", [(word,) for word in ALL_WORDS] ) # 自定义词语 LOCAL_LLM_WORDS 插入数据库,兼容用 - DatabaseManager.execute_many_values_by_sql_file( + await db_manager.execute_many_values_by_sql_file( ROOT_PATH / "sql" / "insert_custom_word.sql", [(word,) for word in LOCAL_LLM_WORDS] ) @@ -483,7 +499,7 @@ async def play_game( if rounds <= 0: await evt.send(await UniMessage().text("干什么!你想玩负数局吗?").export()) return - state = IdiomGame.try_start_game(group_id, force) + state = await IdiomGame.try_start_game(group_id, force) if state == TryStartState.ALREADY_PLAYING: await evt.send( await UniMessage() @@ -502,7 +518,7 @@ async def play_game( .export() ) instance = IdiomGame.INSTANCE_LIST[group_id] - instance.start_game(rounds) + await instance.start_game(rounds) # 发布成语 await evt.send( await UniMessage() @@ -595,7 +611,7 @@ async def _(target: DepLongTaskTarget): instance = IdiomGame.INSTANCE_LIST.get(group_id) if not instance or not instance.get_playing_state(): return - avaliable_idiom = IdiomGame.random_idiom_starting_with(instance.get_last_char()) + avaliable_idiom = await IdiomGame.random_idiom_starting_with(instance.get_last_char()) # 发送哈哈狗图片 with open(ASSETS_PATH / "img" / "dog" / "haha_dog.jpg", "rb") as f: img_data = f.read() diff --git a/poetry.lock b/poetry.lock index 640bd0d..164736f 100644 --- a/poetry.lock +++ b/poetry.lock @@ -209,6 +209,30 @@ type = "legacy" url = "https://pypi.tuna.tsinghua.edu.cn/simple" reference = "mirrors" +[[package]] +name = "aiosqlite" +version = "0.21.0" +description = "asyncio bridge to the standard sqlite3 module" +optional = false +python-versions = ">=3.9" +groups = ["main"] +files = [ + {file = "aiosqlite-0.21.0-py3-none-any.whl", hash = "sha256:2549cf4057f95f53dcba16f2b64e8e2791d7e1adedb13197dd8ed77bb226d7d0"}, + {file = "aiosqlite-0.21.0.tar.gz", hash = "sha256:131bb8056daa3bc875608c631c678cda73922a2d4ba8aec373b19f18c17e7aa3"}, +] + +[package.dependencies] +typing_extensions = ">=4.0" + +[package.extras] +dev = ["attribution (==1.7.1)", "black (==24.3.0)", "build (>=1.2)", "coverage[toml] (==7.6.10)", "flake8 (==7.0.0)", "flake8-bugbear (==24.12.12)", "flit (==3.10.1)", "mypy (==1.14.1)", "ufmt (==2.5.1)", "usort (==1.0.8.post1)"] +docs = ["sphinx (==8.1.3)", "sphinx-mdinclude (==0.6.1)"] + +[package.source] +type = "legacy" +url = "https://pypi.tuna.tsinghua.edu.cn/simple" +reference = "mirrors" + [[package]] name = "annotated-doc" version = "0.0.3" @@ -4528,4 +4552,4 @@ reference = "mirrors" [metadata] lock-version = "2.1" python-versions = ">=3.12,<4.0" -content-hash = "478bd59d60d3b73397241c6ed552434486bd26d56cc3805ef34d1cfa1be7006e" +content-hash = "5597aa165095a11fa08e4b6e1a1f4d3396711b684ed363ae0ced2dd59a09ec5d" diff --git a/pyproject.toml b/pyproject.toml index 0c11e66..a0595c1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,6 +27,7 @@ dependencies = [ "playwright (>=1.55.0,<2.0.0)", "openai (>=2.7.1,<3.0.0)", "imageio (>=2.37.2,<3.0.0)", + "aiosqlite (>=0.20.0,<1.0.0)", ] [tool.poetry] diff --git a/scripts/test_plugin_load.py b/scripts/test_plugin_load.py index 810aa8e..7f43c1a 100644 --- a/scripts/test_plugin_load.py +++ b/scripts/test_plugin_load.py @@ -22,3 +22,11 @@ logger.info(f"已经加载的插件数量 {len(plugins)}") logger.info(f"期待加载的插件数量 {len_requires}") assert len(plugins) == len_requires + +# 测试数据库模块是否可以正确导入 +try: + from konabot.common.database import DatabaseManager + logger.info("数据库模块导入成功") +except Exception as e: + logger.error(f"数据库模块导入失败: {e}") + raise diff --git a/tests/test_database.py b/tests/test_database.py new file mode 100644 index 0000000..6e14251 --- /dev/null +++ b/tests/test_database.py @@ -0,0 +1,93 @@ +import asyncio +import os +import tempfile +from pathlib import Path + +import pytest + +from konabot.common.database import DatabaseManager + + +@pytest.mark.asyncio +async def test_database_manager(): + """测试数据库管理器的基本功能""" + # 创建临时数据库文件 + with tempfile.NamedTemporaryFile(suffix='.db', delete=False) as tmp_file: + db_path = tmp_file.name + + try: + # 初始化数据库管理器 + db_manager = DatabaseManager(db_path) + + # 创建测试表 + create_table_sql = """ + CREATE TABLE IF NOT EXISTS test_users ( + id INTEGER PRIMARY KEY, + name TEXT NOT NULL, + email TEXT UNIQUE + ); + """ + await db_manager.execute(create_table_sql) + + # 插入测试数据 + insert_sql = "INSERT INTO test_users (name, email) VALUES (?, ?)" + await db_manager.execute(insert_sql, ("张三", "zhangsan@example.com")) + await db_manager.execute(insert_sql, ("李四", "lisi@example.com")) + + # 查询数据 + select_sql = "SELECT * FROM test_users WHERE name = ?" + results = await db_manager.query(select_sql, ("张三",)) + assert len(results) == 1 + assert results[0]["name"] == "张三" + assert results[0]["email"] == "zhangsan@example.com" + + # 测试使用Path对象 + results = await db_manager.query_by_sql_file(Path(__file__), ("李四",)) + # 注意:这里只是测试参数传递,实际SQL文件内容不是有效的SQL + + # 关闭所有连接 + await db_manager.close_all_connections() + + finally: + # 清理临时文件 + if os.path.exists(db_path): + os.unlink(db_path) + + +@pytest.mark.asyncio +async def test_execute_script(): + """测试执行SQL脚本功能""" + # 创建临时数据库文件 + with tempfile.NamedTemporaryFile(suffix='.db', delete=False) as tmp_file: + db_path = tmp_file.name + + try: + # 初始化数据库管理器 + db_manager = DatabaseManager(db_path) + + # 创建测试表的脚本 + script = """ + CREATE TABLE IF NOT EXISTS test_products ( + id INTEGER PRIMARY KEY, + name TEXT NOT NULL, + price REAL + ); + INSERT INTO test_products (name, price) VALUES ('苹果', 5.0); + INSERT INTO test_products (name, price) VALUES ('香蕉', 3.0); + """ + + await db_manager.execute_script(script) + + # 查询数据 + results = await db_manager.query("SELECT * FROM test_products ORDER BY name") + assert len(results) == 2 + assert results[0]["name"] == "苹果" + assert results[1]["name"] == "香蕉" + + # 关闭所有连接 + await db_manager.close_all_connections() + + finally: + # 清理临时文件 + if os.path.exists(db_path): + os.unlink(db_path) \ No newline at end of file