diff --git a/README.md b/README.md index 70520af..6e5cdc0 100644 --- a/README.md +++ b/README.md @@ -98,4 +98,4 @@ poetry run python bot.py ## 数据库模块 -本项目的数据库模块已更新为异步实现,使用连接池来提高性能,并支持现代的`pathlib.Path`参数类型。详细使用方法请参考`konabot/common/database/__init__.py`文件中的实现。 +本项目的数据库模块已更新为异步实现,使用连接池来提高性能,并支持现代的`pathlib.Path`参数类型。详细使用方法请参考[数据库使用文档](/docs/database.md)。 diff --git a/bot.py b/bot.py index c45b4f1..d0285d6 100644 --- a/bot.py +++ b/bot.py @@ -10,6 +10,7 @@ from nonebot.adapters.onebot.v11 import Adapter as OnebotAdapter from konabot.common.log import init_logger from konabot.common.nb.exc import BotExceptionMessage from konabot.common.path import LOG_PATH +from konabot.common.database import get_global_db_manager dotenv.load_dotenv() @@ -49,6 +50,13 @@ def main(): nonebot.load_plugins("konabot/plugins") nonebot.load_plugin("nonebot_plugin_analysis_bilibili") + # 注册关闭钩子 + @driver.on_shutdown + async def shutdown_handler(): + # 关闭全局数据库管理器 + db_manager = get_global_db_manager() + await db_manager.close_all_connections() + nonebot.run() if __name__ == "__main__": diff --git a/commit_f21da65.diff b/commit_f21da65.diff new file mode 100644 index 0000000..22e627d --- /dev/null +++ b/commit_f21da65.diff @@ -0,0 +1,848 @@ +commit f21da657dbc79c2d139265a69696e5ad213f5c53 +Author: MixBadGun <1059129006@qq.com> +Date: Tue Nov 18 19:36:05 2025 +0800 + + database 接入 + +diff --git a/.env.example b/.env.example +index 7fde1d8..488632c 100644 +--- a/.env.example ++++ b/.env.example +@@ -1,4 +1,4 @@ + ENVIRONMENT=dev + PORT=21333 +- ++DATABASE_PATH="./data/database.db" + ENABLE_CONSOLE=true +diff --git a/.gitignore b/.gitignore +index 9f2daec..8337d30 100644 +--- a/.gitignore ++++ b/.gitignore +@@ -1,4 +1,6 @@ + /.env + /data + +-__pycache__ +\ No newline at end of file ++__pycache__ ++ ++*.db +\ No newline at end of file +diff --git a/bot.py b/bot.py +index 782c870..e4c56ca 100644 +--- a/bot.py ++++ b/bot.py +@@ -10,6 +10,7 @@ from nonebot.adapters.onebot.v11 import Adapter as OnebotAdapter + from konabot.common.log import init_logger + from konabot.common.nb.exc import BotExceptionMessage + from konabot.common.path import LOG_PATH ++from konabot.core.preinit import preinit + + dotenv.load_dotenv() + env = os.environ.get("ENVIRONMENT", "prod") +@@ -48,6 +49,9 @@ def main(): + nonebot.load_plugins("konabot/plugins") + nonebot.load_plugin("nonebot_plugin_analysis_bilibili") + ++ # 预加载 ++ preinit("konabot/plugins") ++ + nonebot.run() + + if __name__ == "__main__": +diff --git a/konabot/common/database/__init__.py b/konabot/common/database/__init__.py +new file mode 100644 +index 0000000..2a44469 +--- /dev/null ++++ b/konabot/common/database/__init__.py +@@ -0,0 +1,64 @@ ++import os ++import sqlite3 ++from typing import List, Dict, Any, Optional ++ ++class DatabaseManager: ++ """超级无敌神奇的数据库!""" ++ ++ @classmethod ++ def query(cls, query: str, params: Optional[tuple] = None) -> List[Dict[str, Any]]: ++ """执行查询语句并返回结果""" ++ conn = sqlite3.connect(os.environ.get('DATABASE_PATH', './data/database.db')) ++ cursor = conn.cursor() ++ cursor.execute(query, params or ()) ++ columns = [description[0] for description in cursor.description] ++ results = [dict(zip(columns, row)) for row in cursor.fetchall()] ++ cursor.close() ++ conn.close() ++ return results ++ ++ @classmethod ++ def query_by_sql_file(cls, file_path: str, params: Optional[tuple] = None) -> List[Dict[str, Any]]: ++ """从 SQL 文件中读取查询语句并执行""" ++ with open(file_path, 'r', encoding='utf-8') as f: ++ query = f.read() ++ return cls.query(query, params) ++ ++ @classmethod ++ def execute(cls, command: str, params: Optional[tuple] = None) -> None: ++ """执行非查询语句""" ++ conn = sqlite3.connect(os.environ.get('DATABASE_PATH', './data/database.db')) ++ cursor = conn.cursor() ++ cursor.execute(command, params or ()) ++ conn.commit() ++ cursor.close() ++ conn.close() ++ ++ @classmethod ++ def execute_by_sql_file(cls, file_path: str, params: Optional[tuple] = None) -> None: ++ """从 SQL 文件中读取非查询语句并执行""" ++ with open(file_path, 'r', encoding='utf-8') as f: ++ command = f.read() ++ # 按照需要执行多条语句 ++ commands = command.split(';') ++ for cmd in commands: ++ cmd = cmd.strip() ++ if cmd: ++ cls.execute(cmd, params) ++ ++ @classmethod ++ def execute_many(cls, command: str, seq_of_params: List[tuple]) -> None: ++ """执行多条非查询语句""" ++ conn = sqlite3.connect(os.environ.get('DATABASE_PATH', './data/database.db')) ++ cursor = conn.cursor() ++ cursor.executemany(command, seq_of_params) ++ conn.commit() ++ cursor.close() ++ conn.close() ++ ++ @classmethod ++ def execute_many_values_by_sql_file(cls, file_path: str, seq_of_params: List[tuple]) -> None: ++ """从 SQL 文件中读取一条语句,但是被不同值同时执行""" ++ with open(file_path, 'r', encoding='utf-8') as f: ++ command = f.read() ++ cls.execute_many(command, seq_of_params) +\ No newline at end of file +diff --git a/konabot/core/preinit.py b/konabot/core/preinit.py +new file mode 100644 +index 0000000..ccfd3f7 +--- /dev/null ++++ b/konabot/core/preinit.py +@@ -0,0 +1,15 @@ ++from pathlib import Path ++ ++from nonebot import logger ++ ++def preinit(path: str): ++ # 执行预初始化,递归找到位于对应路径内文件名为 __preinit__.py 的所有文件都会被执行 ++ dir_path = Path(path) ++ for item in dir_path.iterdir(): ++ if item.is_dir(): ++ preinit(item) ++ elif item.is_file() and item.name == "__preinit__.py": ++ # 动态导入该文件以执行预初始化代码 ++ module_path = str(item.with_suffix("")).replace("/", ".").replace("\\", ".") ++ __import__(module_path) ++ logger.info(f"Preinitialized module: {module_path}") +\ No newline at end of file +diff --git a/konabot/plugins/air_conditioner/__init__.py b/konabot/plugins/air_conditioner/__init__.py +index 4f921fe..e148954 100644 +--- a/konabot/plugins/air_conditioner/__init__.py ++++ b/konabot/plugins/air_conditioner/__init__.py +@@ -7,16 +7,19 @@ from nonebot.adapters.discord.event import MessageEvent as DiscordMessageEvent + from nonebot_plugin_alconna import Alconna, AlconnaMatcher, Args, UniMessage, on_alconna + from PIL import Image + import numpy as np ++from konabot.common.database import DatabaseManager + from konabot.common.longtask import DepLongTaskTarget + from konabot.common.path import ASSETS_PATH + from konabot.common.web_render import WebRenderer + from konabot.plugins.air_conditioner.ac import AirConditioner, CrashType, generate_ac_image, wiggle_transform +- ++from pathlib import Path + import random + import math + ++ROOT_PATH = Path(__file__).resolve().parent ++ + def get_ac(id: str) -> AirConditioner: +- ac = AirConditioner.air_conditioners.get(id) ++ ac = AirConditioner.get_ac(id) + if ac is None: + ac = AirConditioner(id) + return ac +@@ -61,7 +64,7 @@ evt = on_alconna(Alconna( + async def _(event: BaseEvent, target: DepLongTaskTarget): + id = target.channel_id + ac = get_ac(id) +- ac.on = True ++ ac.update_ac(state=True) + await send_ac_image(evt, ac) + + evt = on_alconna(Alconna( +@@ -72,7 +75,7 @@ evt = on_alconna(Alconna( + async def _(event: BaseEvent, target: DepLongTaskTarget): + id = target.channel_id + ac = get_ac(id) +- ac.on = False ++ ac.update_ac(state=False) + await send_ac_image(evt, ac) + + evt = on_alconna(Alconna( +@@ -82,6 +85,8 @@ evt = on_alconna(Alconna( + + @evt.handle() + async def _(event: BaseEvent, target: DepLongTaskTarget, temp: Optional[Union[int, float]] = 1): ++ if temp is None: ++ temp = 1 + if temp <= 0: + return + id = target.channel_id +@@ -89,7 +94,7 @@ async def _(event: BaseEvent, target: DepLongTaskTarget, temp: Optional[Union[in + if not ac.on or ac.burnt == True or ac.frozen == True: + await send_ac_image(evt, ac) + return +- ac.temperature += temp ++ ac.update_ac(temperature_delta=temp) + if ac.temperature > 40: + # 根据温度随机出是否爆炸,40度开始,呈指数增长 + possibility = -math.e ** ((40-ac.temperature) / 50) + 1 +@@ -115,6 +120,8 @@ evt = on_alconna(Alconna( + + @evt.handle() + async def _(event: BaseEvent, target: DepLongTaskTarget, temp: Optional[Union[int, float]] = 1): ++ if temp is None: ++ temp = 1 + if temp <= 0: + return + id = target.channel_id +@@ -122,7 +129,7 @@ async def _(event: BaseEvent, target: DepLongTaskTarget, temp: Optional[Union[in + if not ac.on or ac.burnt == True or ac.frozen == True: + await send_ac_image(evt, ac) + return +- ac.temperature -= temp ++ ac.update_ac(temperature_delta=-temp) + if ac.temperature < 0: + # 根据温度随机出是否冻结,0度开始,呈指数增长 + possibility = -math.e ** (ac.temperature / 50) + 1 +@@ -141,6 +148,16 @@ async def _(event: BaseEvent, target: DepLongTaskTarget): + ac.change_ac() + await send_ac_image(evt, ac) + ++def query_number_ranking(id: str) -> tuple[int, int]: ++ result = DatabaseManager.query_by_sql_file( ++ ROOT_PATH / "sql" / "query_crash_and_rank.sql", ++ (id,id) ++ ) ++ if len(result) == 0: ++ return 0, 0 ++ else: ++ return result[0].values() ++ + evt = on_alconna(Alconna( + "空调炸炸排行榜", + ), use_cmd_start=True, use_cmd_sep=False, skip_for_unmatch=True) +@@ -148,8 +165,9 @@ evt = on_alconna(Alconna( + @evt.handle() + async def _(event: BaseEvent, target: DepLongTaskTarget): + id = target.channel_id +- ac = get_ac(id) +- number, ranking = ac.get_crashes_and_ranking() ++ # ac = get_ac(id) ++ # number, ranking = ac.get_crashes_and_ranking() ++ number, ranking = query_number_ranking(id) + params = { + "number": number, + "ranking": ranking +diff --git a/konabot/plugins/air_conditioner/__preinit__.py b/konabot/plugins/air_conditioner/__preinit__.py +new file mode 100644 +index 0000000..67054a0 +--- /dev/null ++++ b/konabot/plugins/air_conditioner/__preinit__.py +@@ -0,0 +1,9 @@ ++# 预初始化,只要是导入本插件包就会执行这里的代码 ++from pathlib import Path ++ ++from konabot.common.database import DatabaseManager ++ ++# 初始化数据库表 ++DatabaseManager.execute_by_sql_file( ++ Path(__file__).resolve().parent / "sql" / "create_table.sql" ++) +diff --git a/konabot/plugins/air_conditioner/ac.py b/konabot/plugins/air_conditioner/ac.py +index 6614784..9be7619 100644 +--- a/konabot/plugins/air_conditioner/ac.py ++++ b/konabot/plugins/air_conditioner/ac.py +@@ -1,66 +1,112 @@ + from enum import Enum + from io import BytesIO ++from pathlib import Path + + import cv2 + import numpy as np + from PIL import Image, ImageDraw, ImageFont + ++from konabot.common.database import DatabaseManager + from konabot.common.path import ASSETS_PATH, FONTS_PATH + from konabot.common.path import DATA_PATH + import json + ++ROOT_PATH = Path(__file__).resolve().parent ++ + class CrashType(Enum): + BURNT = 0 + FROZEN = 1 + + class AirConditioner: +- air_conditioners: dict[str, "AirConditioner"] = {} ++ @classmethod ++ def get_ac(cls, id: str) -> 'AirConditioner': ++ result = DatabaseManager.query_by_sql_file(ROOT_PATH / "sql" / "query_ac.sql", (id,)) ++ if len(result) == 0: ++ ac = cls.create_ac(id) ++ return ac ++ ac_data = result[0] ++ ac = AirConditioner(id) ++ ac.on = bool(ac_data["on"]) ++ ac.temperature = float(ac_data["temperature"]) ++ ac.burnt = bool(ac_data["burnt"]) ++ ac.frozen = bool(ac_data["frozen"]) ++ return ac + +- def __init__(self, id: str) -> None: +- self.id = id ++ @classmethod ++ def create_ac(cls, id: str) -> 'AirConditioner': ++ ac = AirConditioner(id) ++ DatabaseManager.execute_by_sql_file( ++ ROOT_PATH / "sql" / "insert_ac.sql", ++ (id, ac.on, ac.temperature, ac.burnt, ac.frozen) ++ ) ++ return ac ++ ++ def update_ac(self, state: bool = None, temperature_delta: float = None, burnt: bool = None, frozen: bool = None) -> 'AirConditioner': ++ if state is not None: ++ self.on = state ++ if temperature_delta is not None: ++ self.temperature += temperature_delta ++ if burnt is not None: ++ self.burnt = burnt ++ if frozen is not None: ++ self.frozen = frozen ++ DatabaseManager.execute_by_sql_file( ++ ROOT_PATH / "sql" / "update_ac.sql", ++ (self.on, self.temperature, self.burnt, self.frozen, self.id) ++ ) ++ return self ++ ++ def change_ac(self) -> 'AirConditioner': + self.on = False +- self.temperature = 24 # 默认温度 ++ self.temperature = 24 + self.burnt = False + self.frozen = False +- AirConditioner.air_conditioners[id] = self ++ DatabaseManager.execute_by_sql_file( ++ ROOT_PATH / "sql" / "update_ac.sql", ++ (self.on, self.temperature, self.burnt, self.frozen, self.id) ++ ) ++ return self + +- def change_ac(self): ++ def __init__(self, id: str) -> None: ++ self.id = id ++ self.on = False ++ self.temperature = 24 # 默认温度 + self.burnt = False + self.frozen = False +- self.on = False +- self.temperature = 24 # 重置为默认温度 + + def broke_ac(self, crash_type: CrashType): + ''' +- 让空调坏掉,并保存数据 +- ++ 让空调坏掉 + :param crash_type: CrashType 枚举,表示空调坏掉的类型 + ''' + match crash_type: + case CrashType.BURNT: +- self.burnt = True ++ self.update_ac(burnt=True) + case CrashType.FROZEN: +- self.frozen = True +- self.save_crash_data(crash_type) ++ self.update_ac(frozen=True) ++ DatabaseManager.execute_by_sql_file( ++ ROOT_PATH / "sql" / "insert_crash.sql", ++ (self.id, crash_type.value) ++ ) + +- def save_crash_data(self, crash_type: CrashType): +- ''' +- 如果空调爆炸了,就往本地的 ac_crash_data.json 里该 id 的记录加一 +- ''' +- data_file = DATA_PATH / "ac_crash_data.json" +- crash_data = {} +- if data_file.exists(): +- with open(data_file, "r", encoding="utf-8") as f: +- crash_data = json.load(f) +- if self.id not in crash_data: +- crash_data[self.id] = {"burnt": 0, "frozen": 0} +- match crash_type: +- case CrashType.BURNT: +- crash_data[self.id]["burnt"] += 1 +- case CrashType.FROZEN: +- crash_data[self.id]["frozen"] += 1 +- with open(data_file, "w", encoding="utf-8") as f: +- json.dump(crash_data, f, ensure_ascii=False, indent=4) ++ # def save_crash_data(self, crash_type: CrashType): ++ # ''' ++ # 如果空调爆炸了,就往本地的 ac_crash_data.json 里该 id 的记录加一 ++ # ''' ++ # data_file = DATA_PATH / "ac_crash_data.json" ++ # crash_data = {} ++ # if data_file.exists(): ++ # with open(data_file, "r", encoding="utf-8") as f: ++ # crash_data = json.load(f) ++ # if self.id not in crash_data: ++ # crash_data[self.id] = {"burnt": 0, "frozen": 0} ++ # match crash_type: ++ # case CrashType.BURNT: ++ # crash_data[self.id]["burnt"] += 1 ++ # case CrashType.FROZEN: ++ # crash_data[self.id]["frozen"] += 1 ++ # with open(data_file, "w", encoding="utf-8") as f: ++ # json.dump(crash_data, f, ensure_ascii=False, indent=4) + + def get_crashes_and_ranking(self) -> tuple[int, int]: + ''' +diff --git a/konabot/plugins/air_conditioner/sql/create_table.sql b/konabot/plugins/air_conditioner/sql/create_table.sql +new file mode 100644 +index 0000000..5203e23 +--- /dev/null ++++ b/konabot/plugins/air_conditioner/sql/create_table.sql +@@ -0,0 +1,15 @@ ++-- 创建所有表 ++CREATE TABLE IF NOT EXISTS air_conditioner ( ++ id VARCHAR(128) PRIMARY KEY, ++ 'on' BOOLEAN NOT NULL, ++ temperature REAL NOT NULL, ++ burnt BOOLEAN NOT NULL, ++ frozen BOOLEAN NOT NULL ++); ++ ++CREATE TABLE IF NOT EXISTS air_conditioner_crash_log ( ++ id VARCHAR(128) NOT NULL, ++ crash_type INT NOT NULL, ++ timestamp DATETIME NOT NULL, ++ FOREIGN KEY (id) REFERENCES air_conditioner(id) ++); +\ No newline at end of file +diff --git a/konabot/plugins/air_conditioner/sql/insert_ac.sql b/konabot/plugins/air_conditioner/sql/insert_ac.sql +new file mode 100644 +index 0000000..3fb1c76 +--- /dev/null ++++ b/konabot/plugins/air_conditioner/sql/insert_ac.sql +@@ -0,0 +1,3 @@ ++-- 插入一台新空调 ++INSERT INTO air_conditioner (id, 'on', temperature, burnt, frozen) ++VALUES (?, ?, ?, ?, ?); +\ No newline at end of file +diff --git a/konabot/plugins/air_conditioner/sql/insert_crash.sql b/konabot/plugins/air_conditioner/sql/insert_crash.sql +new file mode 100644 +index 0000000..aae3898 +--- /dev/null ++++ b/konabot/plugins/air_conditioner/sql/insert_crash.sql +@@ -0,0 +1,3 @@ ++-- 插入一条空调爆炸记录 ++INSERT INTO air_conditioner_crash_log (id, crash_type, timestamp) ++VALUES (?, ?, CURRENT_TIMESTAMP); +\ No newline at end of file +diff --git a/konabot/plugins/air_conditioner/sql/query_ac.sql b/konabot/plugins/air_conditioner/sql/query_ac.sql +new file mode 100644 +index 0000000..db957d3 +--- /dev/null ++++ b/konabot/plugins/air_conditioner/sql/query_ac.sql +@@ -0,0 +1,4 @@ ++-- 查询空调状态,如果没有就插入一条新的记录 ++SELECT * ++FROM air_conditioner ++WHERE id = ?; +\ No newline at end of file +diff --git a/konabot/plugins/air_conditioner/sql/query_crash_and_rank.sql b/konabot/plugins/air_conditioner/sql/query_crash_and_rank.sql +new file mode 100644 +index 0000000..c180638 +--- /dev/null ++++ b/konabot/plugins/air_conditioner/sql/query_crash_and_rank.sql +@@ -0,0 +1,23 @@ ++-- 从 air_conditioner_crash_log 表中获取指定 id 损坏的次数以及损坏次数的排名 ++SELECT crash_count, crash_rank ++FROM ( ++ SELECT id, ++ COUNT(*) AS crash_count, ++ RANK() OVER (ORDER BY COUNT(*) DESC) AS crash_rank ++ FROM air_conditioner_crash_log ++ GROUP BY id ++) AS ranked_data ++WHERE id = ? ++-- 如果该 id 没有损坏记录,则返回 0 次损坏和对应的最后一名 ++UNION ++SELECT 0 AS crash_count, ++ (SELECT COUNT(DISTINCT id) + 1 FROM air_conditioner_crash_log) AS crash_rank ++FROM ( ++ SELECT DISTINCT id ++ FROM air_conditioner_crash_log ++) AS ranked_data ++WHERE NOT EXISTS ( ++ SELECT 1 ++ FROM air_conditioner_crash_log ++ WHERE id = ? ++); +\ No newline at end of file +diff --git a/konabot/plugins/air_conditioner/sql/update_ac.sql b/konabot/plugins/air_conditioner/sql/update_ac.sql +new file mode 100644 +index 0000000..df9145e +--- /dev/null ++++ b/konabot/plugins/air_conditioner/sql/update_ac.sql +@@ -0,0 +1,4 @@ ++-- 更新空调状态 ++UPDATE air_conditioner ++SET 'on' = ?, temperature = ?, burnt = ?, frozen = ? ++WHERE id = ?; +\ No newline at end of file +diff --git a/konabot/plugins/idiomgame/__init__.py b/konabot/plugins/idiomgame/__init__.py +index ee4e26c..36710aa 100644 +--- a/konabot/plugins/idiomgame/__init__.py ++++ b/konabot/plugins/idiomgame/__init__.py +@@ -18,11 +18,14 @@ from nonebot_plugin_alconna import ( + on_alconna, + ) + ++from konabot.common.database import DatabaseManager + from konabot.common.longtask import DepLongTaskTarget + from konabot.common.path import ASSETS_PATH + + from konabot.common.llm import get_llm + ++ROOT_PATH = Path(__file__).resolve().parent ++ + DATA_DIR = Path(__file__).parent.parent.parent.parent / "data" + + DATA_FILE_PATH = ( +@@ -94,18 +97,19 @@ class IdiomGameLLM: + + @classmethod + async def storage_idiom(cls, idiom: str): +- # 将 idiom 存入本地文件以备后续分析 +- with open(DATA_DIR / "idiom_llm_storage.txt", "a", encoding="utf-8") as f: +- f.write(idiom + "\n") +- IdiomGame.append_into_word_list(idiom) ++ # 将 idiom 存入数据库 ++ DatabaseManager.execute_by_sql_file( ++ ROOT_PATH / "sql" / "insert_custom_word.sql", ++ (idiom,) ++ ) + + + class IdiomGame: +- ALL_WORDS = [] # 所有四字词语 +- ALL_IDIOMS = [] # 所有成语 ++ # ALL_WORDS = [] # 所有四字词语 ++ # ALL_IDIOMS = [] # 所有成语 + INSTANCE_LIST: dict[str, "IdiomGame"] = {} # 群号对应的游戏实例 +- IDIOM_FIRST_CHAR = {} # 所有成语包括词语的首字字典 +- AVALIABLE_IDIOM_FIRST_CHAR = {} # 真正有效的成语首字字典 ++ # IDIOM_FIRST_CHAR = {} # 所有成语包括词语的首字字典 ++ # AVALIABLE_IDIOM_FIRST_CHAR = {} # 真正有效的成语首字字典 + + __inited = False + +@@ -130,11 +134,10 @@ class IdiomGame: + ''' + 将一个新词加入到词语列表中 + ''' +- if word not in cls.ALL_WORDS: +- cls.ALL_WORDS.append(word) +- if word[0] not in cls.IDIOM_FIRST_CHAR: +- cls.IDIOM_FIRST_CHAR[word[0]] = [] +- cls.IDIOM_FIRST_CHAR[word[0]].append(word) ++ DatabaseManager.execute_by_sql_file( ++ ROOT_PATH / "sql" / "insert_custom_word.sql", ++ (word,) ++ ) + + def be_able_to_play(self) -> bool: + if self.last_play_date != datetime.date.today(): +@@ -145,11 +148,17 @@ class IdiomGame: + return True + return False + ++ @staticmethod ++ def random_idiom() -> str: ++ return DatabaseManager.query_by_sql_file( ++ ROOT_PATH / "sql" / "random_choose_idiom.sql" ++ )[0]["idiom"] ++ + def choose_start_idiom(self) -> str: + """ + 随机选择一个成语作为起始成语 + """ +- self.last_idiom = secrets.choice(IdiomGame.ALL_IDIOMS) ++ self.last_idiom = IdiomGame.random_idiom() + self.last_char = self.last_idiom[-1] + if not self.is_nextable(self.last_char): + self.choose_start_idiom() +@@ -208,7 +217,7 @@ class IdiomGame: + return self.last_idiom + + def _skip_idiom_async(self) -> str: +- self.last_idiom = secrets.choice(IdiomGame.ALL_IDIOMS) ++ self.last_idiom = IdiomGame.random_idiom() + self.last_char = self.last_idiom[-1] + if not self.is_nextable(self.last_char): + self._skip_idiom_async() +@@ -228,8 +237,11 @@ class IdiomGame: + """ + 判断是否有成语可以接 + """ +- return last_char in IdiomGame.AVALIABLE_IDIOM_FIRST_CHAR +- ++ return DatabaseManager.query_by_sql_file( ++ ROOT_PATH / "sql" / "is_nextable.sql", ++ (last_char,) ++ )[0]["DEED"] == 1 ++ + def add_already_idiom(self, idiom: str): + if idiom in self.already_idioms: + self.already_idioms[idiom] += 1 +@@ -259,7 +271,12 @@ class IdiomGame: + if idiom[0] != self.last_char: + state.append(TryVerifyState.WRONG_FIRST_CHAR) + return state +- if idiom not in IdiomGame.ALL_IDIOMS and idiom not in IdiomGame.ALL_WORDS: ++ # 成语是否存在 ++ result = DatabaseManager.query_by_sql_file( ++ ROOT_PATH / "sql" / "query_idiom.sql", ++ (idiom, idiom, idiom) ++ )[0]["status"] ++ if result == -1: + logger.info(f"用户 {user_id} 发送了未知词语 {idiom},正在使用 LLM 进行验证") + try: + if not await IdiomGameLLM.verify_idiom_with_llm(idiom): +@@ -281,7 +298,7 @@ class IdiomGame: + self.last_idiom = idiom + self.last_char = idiom[-1] + self.add_score(user_id, 1 * score_k) # 先加 1 分 +- if idiom in IdiomGame.ALL_IDIOMS: ++ if result == 1: + state.append(TryVerifyState.VERIFIED_AND_REAL) + self.add_score(user_id, 4 * score_k) # 再加 4 分 + self.remain_rounds -= 1 +@@ -319,14 +336,21 @@ class IdiomGame: + @classmethod + def random_idiom_starting_with(cls, first_char: str) -> Optional[str]: + cls.init_lexicon() +- if first_char not in cls.AVALIABLE_IDIOM_FIRST_CHAR: ++ result = DatabaseManager.query_by_sql_file( ++ ROOT_PATH / "sql" / "query_idiom_start_with.sql", ++ (first_char,) ++ ) ++ if len(result) == 0: + return None +- return secrets.choice(cls.AVALIABLE_IDIOM_FIRST_CHAR[first_char]) ++ return result[0]["idiom"] + + @classmethod + def init_lexicon(cls): + if cls.__inited: + return ++ DatabaseManager.execute_by_sql_file( ++ ROOT_PATH / "sql" / "create_table.sql" ++ ) # 确保数据库初始化 + cls.__inited = True + + # 成语大表 +@@ -334,11 +358,12 @@ class IdiomGame: + ALL_IDIOMS_INFOS = json.load(f) + + # 词语大表 ++ ALL_WORDS = [] + with open(ASSETS_PATH / "lexicon" / "ci.json", "r", encoding="utf-8") as f: + jsonData = json.load(f) +- cls.ALL_WORDS = [item["ci"] for item in jsonData] +- logger.debug(f"Loaded {len(cls.ALL_WORDS)} words from ci.json") +- logger.debug(f"Sample words: {cls.ALL_WORDS[:5]}") ++ ALL_WORDS = [item["ci"] for item in jsonData] ++ logger.debug(f"Loaded {len(ALL_WORDS)} words from ci.json") ++ logger.debug(f"Sample words: {ALL_WORDS[:5]}") + + COMMON_WORDS = [] + # 读取 COMMON 词语大表 +@@ -389,29 +414,44 @@ class IdiomGame: + logger.debug(f"Loaded additional {len(LOCAL_LLM_WORDS)} words from idiom_llm_storage.txt") + + # 只有成语的大表 +- cls.ALL_IDIOMS = [idiom["word"] for idiom in ALL_IDIOMS_INFOS] + THUOCL_IDIOMS +- cls.ALL_IDIOMS = list(set(cls.ALL_IDIOMS)) # 去重 ++ ALL_IDIOMS = [idiom["word"] for idiom in ALL_IDIOMS_INFOS] + THUOCL_IDIOMS ++ ALL_IDIOMS = list(set(ALL_IDIOMS)) # 去重 ++ # 批量插入数据库 ++ DatabaseManager.execute_many_values_by_sql_file( ++ ROOT_PATH / "sql" / "insert_idiom.sql", ++ [(idiom,) for idiom in ALL_IDIOMS] ++ ) ++ + + # 其他四字词语表,仅表示可以有这个词 +- cls.ALL_WORDS = ( +- [word for word in cls.ALL_WORDS if len(word) == 4] ++ ALL_WORDS = ( ++ [word for word in ALL_WORDS if len(word) == 4] + + THUOCL_WORDS + + COMMON_WORDS +- + LOCAL_LLM_WORDS + ) +- cls.ALL_WORDS = list(set(cls.ALL_WORDS)) # 去重 ++ # 插入数据库 ++ DatabaseManager.execute_many_values_by_sql_file( ++ ROOT_PATH / "sql" / "insert_word.sql", ++ [(word,) for word in ALL_WORDS] ++ ) ++ ++ # 自定义词语 LOCAL_LLM_WORDS 插入数据库,兼容用 ++ DatabaseManager.execute_many_values_by_sql_file( ++ ROOT_PATH / "sql" / "insert_custom_word.sql", ++ [(word,) for word in LOCAL_LLM_WORDS] ++ ) + +- # 根据成语大表,划分出成语首字字典 +- for idiom in cls.ALL_IDIOMS + cls.ALL_WORDS: +- if idiom[0] not in cls.IDIOM_FIRST_CHAR: +- cls.IDIOM_FIRST_CHAR[idiom[0]] = [] +- cls.IDIOM_FIRST_CHAR[idiom[0]].append(idiom) ++ # # 根据成语大表,划分出成语首字字典 ++ # for idiom in cls.ALL_IDIOMS + cls.ALL_WORDS: ++ # if idiom[0] not in cls.IDIOM_FIRST_CHAR: ++ # cls.IDIOM_FIRST_CHAR[idiom[0]] = [] ++ # cls.IDIOM_FIRST_CHAR[idiom[0]].append(idiom) + +- # 根据真正的成语大表,划分出有效成语首字字典 +- for idiom in cls.ALL_IDIOMS: +- if idiom[0] not in cls.AVALIABLE_IDIOM_FIRST_CHAR: +- cls.AVALIABLE_IDIOM_FIRST_CHAR[idiom[0]] = [] +- cls.AVALIABLE_IDIOM_FIRST_CHAR[idiom[0]].append(idiom) ++ # # 根据真正的成语大表,划分出有效成语首字字典 ++ # for idiom in cls.ALL_IDIOMS: ++ # if idiom[0] not in cls.AVALIABLE_IDIOM_FIRST_CHAR: ++ # cls.AVALIABLE_IDIOM_FIRST_CHAR[idiom[0]] = [] ++ # cls.AVALIABLE_IDIOM_FIRST_CHAR[idiom[0]].append(idiom) + + + evt = on_alconna( +@@ -514,7 +554,9 @@ async def end_game(event: BaseEvent, group_id: str): + for line in history_lines: + result_text += line + "\n" + await evt.send(await result_text.export()) +- instance.clear_score_board() ++ # instance.clear_score_board() ++ # 将实例删除 ++ del IdiomGame.INSTANCE_LIST[group_id] + + + evt = on_alconna( +diff --git a/konabot/plugins/idiomgame/sql/create_table.sql b/konabot/plugins/idiomgame/sql/create_table.sql +new file mode 100644 +index 0000000..5d38580 +--- /dev/null ++++ b/konabot/plugins/idiomgame/sql/create_table.sql +@@ -0,0 +1,15 @@ ++-- 创建成语大表 ++CREATE TABLE IF NOT EXISTS all_idioms ( ++ id INTEGER PRIMARY KEY AUTOINCREMENT, ++ idiom VARCHAR(128) NOT NULL UNIQUE ++); ++ ++CREATE TABLE IF NOT EXISTS all_words ( ++ id INTEGER PRIMARY KEY AUTOINCREMENT, ++ word VARCHAR(128) NOT NULL UNIQUE ++); ++ ++CREATE TABLE IF NOT EXISTS custom_words ( ++ id INTEGER PRIMARY KEY AUTOINCREMENT, ++ word VARCHAR(128) NOT NULL UNIQUE ++); +\ No newline at end of file +diff --git a/konabot/plugins/idiomgame/sql/insert_custom_word.sql b/konabot/plugins/idiomgame/sql/insert_custom_word.sql +new file mode 100644 +index 0000000..212c8a2 +--- /dev/null ++++ b/konabot/plugins/idiomgame/sql/insert_custom_word.sql +@@ -0,0 +1,3 @@ ++-- 插入自定义词 ++INSERT OR IGNORE INTO custom_words (word) ++VALUES (?); +\ No newline at end of file +diff --git a/konabot/plugins/idiomgame/sql/insert_idiom.sql b/konabot/plugins/idiomgame/sql/insert_idiom.sql +new file mode 100644 +index 0000000..eaedae8 +--- /dev/null ++++ b/konabot/plugins/idiomgame/sql/insert_idiom.sql +@@ -0,0 +1,3 @@ ++-- 插入成语大表,避免重复插入 ++INSERT OR IGNORE INTO all_idioms (idiom) ++VALUES (?); +\ No newline at end of file +diff --git a/konabot/plugins/idiomgame/sql/insert_word.sql b/konabot/plugins/idiomgame/sql/insert_word.sql +new file mode 100644 +index 0000000..b085aab +--- /dev/null ++++ b/konabot/plugins/idiomgame/sql/insert_word.sql +@@ -0,0 +1,3 @@ ++-- 插入词 ++INSERT OR IGNORE INTO all_words (word) ++VALUES (?); +\ No newline at end of file +diff --git a/konabot/plugins/idiomgame/sql/is_nextable.sql b/konabot/plugins/idiomgame/sql/is_nextable.sql +new file mode 100644 +index 0000000..a7bbeb1 +--- /dev/null ++++ b/konabot/plugins/idiomgame/sql/is_nextable.sql +@@ -0,0 +1,5 @@ ++-- 查询是否有以 xx 开头的成语,有则返回真,否则假 ++SELECT EXISTS( ++ SELECT 1 FROM all_idioms ++ WHERE idiom LIKE ? || '%' ++) AS DEED; +diff --git a/konabot/plugins/idiomgame/sql/query_idiom.sql b/konabot/plugins/idiomgame/sql/query_idiom.sql +new file mode 100644 +index 0000000..fa3bf93 +--- /dev/null ++++ b/konabot/plugins/idiomgame/sql/query_idiom.sql +@@ -0,0 +1,7 @@ ++-- 查询成语是否在 all_idioms 中,如果存在则返回 1,否则再判断是否在 custom_words 或 all_words 中,存在则返回 0,否则返回 -1 ++SELECT ++ CASE ++ WHEN EXISTS (SELECT 1 FROM all_idioms WHERE idiom = ?) THEN 1 ++ WHEN EXISTS (SELECT 1 FROM custom_words WHERE word = ?) OR EXISTS (SELECT 1 FROM all_words WHERE word = ?) THEN 0 ++ ELSE -1 ++ END AS status; +\ No newline at end of file +diff --git a/konabot/plugins/idiomgame/sql/query_idiom_start_with.sql b/konabot/plugins/idiomgame/sql/query_idiom_start_with.sql +new file mode 100644 +index 0000000..a6e8fc6 +--- /dev/null ++++ b/konabot/plugins/idiomgame/sql/query_idiom_start_with.sql +@@ -0,0 +1,4 @@ ++-- 查询以 xx 开头的成语,随机打乱后只取第一个 ++SELECT idiom FROM all_idioms ++WHERE idiom LIKE ? || '%' ++ORDER BY RANDOM() LIMIT 1; +\ No newline at end of file +diff --git a/konabot/plugins/idiomgame/sql/random_choose_idiom.sql b/konabot/plugins/idiomgame/sql/random_choose_idiom.sql +new file mode 100644 +index 0000000..f706092 +--- /dev/null ++++ b/konabot/plugins/idiomgame/sql/random_choose_idiom.sql +@@ -0,0 +1,2 @@ ++-- 随机从 all_idioms 表中选择一个成语 ++SELECT idiom FROM all_idioms ORDER BY RANDOM() LIMIT 1; +\ No newline at end of file diff --git a/docs/database.md b/docs/database.md new file mode 100644 index 0000000..bfcdc63 --- /dev/null +++ b/docs/database.md @@ -0,0 +1,223 @@ +# 数据库系统使用文档 + +本文档详细介绍了本项目中使用的异步数据库系统,包括其架构设计、使用方法和最佳实践。 + +## 系统概述 + +本项目的数据库系统基于 `aiosqlite` 库构建,提供了异步的 SQLite 数据库访问接口。系统主要特性包括: + +1. **异步操作**:完全支持异步/await模式,适配NoneBot2框架 +2. **连接池**:内置连接池机制,提高数据库访问性能 +3. **参数化查询**:支持安全的参数化查询,防止SQL注入 +4. **SQL文件支持**:可以直接执行SQL文件中的脚本 +5. **类型支持**:支持 `pathlib.Path` 和 `str` 类型的路径参数 + +## 核心类和方法 + +### DatabaseManager 类 + +`DatabaseManager` 是数据库操作的核心类,提供了以下主要方法: + +#### 初始化 +```python +from konabot.common.database import DatabaseManager +from pathlib import Path + +# 使用默认数据库路径 +db = DatabaseManager() + +# 指定了义数据库路径 +db = DatabaseManager("./data/myapp.db") +db = DatabaseManager(Path("./data/myapp.db")) +``` + +#### 查询操作 +```python +# 执行查询语句并返回结果 +results = await db.query("SELECT * FROM users WHERE age > ?", (18,)) + +# 从SQL文件执行查询 +results = await db.query_by_sql_file("./sql/get_users.sql", (18,)) +``` + +#### 执行操作 +```python +# 执行非查询语句 +await db.execute("INSERT INTO users (name, email) VALUES (?, ?)", ("张三", "zhangsan@example.com")) + +# 执行SQL脚本(不带参数) +await db.execute_script(""" + CREATE TABLE IF NOT EXISTS users ( + id INTEGER PRIMARY KEY, + name TEXT NOT NULL, + email TEXT UNIQUE + ); + INSERT INTO users (name, email) VALUES ('测试用户', 'test@example.com'); +""") + +# 从SQL文件执行非查询语句 +await db.execute_by_sql_file("./sql/create_tables.sql") + +# 带参数执行SQL文件 +await db.execute_by_sql_file("./sql/insert_user.sql", ("张三", "zhangsan@example.com")) + +# 执行多条语句(每条语句使用相同参数) +await db.execute_many("INSERT INTO users (name, email) VALUES (?, ?)", [ + ("张三", "zhangsan@example.com"), + ("李四", "lisi@example.com"), + ("王五", "wangwu@example.com") +]) + +# 从SQL文件执行多条语句(每条语句使用相同参数) +await db.execute_many_values_by_sql_file("./sql/batch_insert.sql", [ + ("张三", "zhangsan@example.com"), + ("李四", "lisi@example.com") +]) +``` + +## SQL文件处理机制 + +### 单语句SQL文件 +```sql +-- insert_user.sql +INSERT INTO users (name, email) VALUES (?, ?); +``` + +```python +# 使用方式 +await db.execute_by_sql_file("./sql/insert_user.sql", ("张三", "zhangsan@example.com")) +``` + +### 多语句SQL文件 +```sql +-- setup.sql +CREATE TABLE IF NOT EXISTS users ( + id INTEGER PRIMARY KEY, + name TEXT NOT NULL, + email TEXT UNIQUE +); + +CREATE TABLE IF NOT EXISTS profiles ( + user_id INTEGER, + age INTEGER, + FOREIGN KEY (user_id) REFERENCES users(id) +); +``` + +```python +# 使用方式 +await db.execute_by_sql_file("./sql/setup.sql") +``` + +### 多语句带不同参数的SQL文件 +```sql +-- batch_operations.sql +INSERT INTO users (name, email) VALUES (?, ?); +INSERT INTO profiles (user_id, age) VALUES (?, ?); +``` + +```python +# 使用方式 +await db.execute_by_sql_file("./sql/batch_operations.sql", [ + ("张三", "zhangsan@example.com"), # 第一条语句的参数 + (1, 25) # 第二条语句的参数 +]) +``` + +## 最佳实践 + +### 1. 数据库表设计 +```sql +-- 推荐的表设计实践 +CREATE TABLE IF NOT EXISTS example_table ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + name TEXT NOT NULL, + created_at DATETIME DEFAULT CURRENT_TIMESTAMP, + updated_at DATETIME DEFAULT CURRENT_TIMESTAMP +); +``` + +### 2. SQL文件组织 +建议按照功能模块组织SQL文件: +``` +plugin/ +├── sql/ +│ ├── create_tables.sql +│ ├── insert_data.sql +│ ├── update_data.sql +│ └── query_data.sql +└── __init__.py +``` + +### 3. 错误处理 +```python +try: + results = await db.query("SELECT * FROM users WHERE id = ?", (user_id,)) +except Exception as e: + logger.error(f"数据库查询失败: {e}") + # 处理错误情况 +``` + +### 4. 连接管理 +```python +# 在应用启动时初始化 +db_manager = DatabaseManager() + +# 在应用关闭时清理连接 +async def shutdown(): + await db_manager.close_all_connections() +``` + +## 高级特性 + +### 连接池配置 +```python +class DatabaseManager: + def __init__(self, db_path: Optional[Union[str, Path]] = None): + # 连接池大小配置 + self._pool_size = 5 # 可根据需要调整 +``` + +### 事务支持 +```python +# 通过execute方法的自动提交机制支持事务 +await db.execute("BEGIN TRANSACTION") +try: + await db.execute("INSERT INTO users (name) VALUES (?)", ("张三",)) + await db.execute("INSERT INTO profiles (user_id, age) VALUES (?, ?)", (1, 25)) + await db.execute("COMMIT") +except Exception: + await db.execute("ROLLBACK") + raise +``` + +## 注意事项 + +1. **异步环境**:所有数据库操作都必须在异步环境中执行 +2. **参数安全**:始终使用参数化查询,避免SQL注入 +3. **资源管理**:确保在应用关闭时调用 `close_all_connections()` +4. **SQL解析**:使用 `sqlparse` 库准确解析SQL语句,正确处理包含分号的字符串和注释 +5. **错误处理**:适当处理数据库操作可能抛出的异常 + +## 常见问题 + +### Q: 如何处理数据库约束错误? +A: 确保SQL语句中的字段名正确引用,特别是保留字需要使用双引号包围: +```sql +CREATE TABLE air_conditioner ( + id VARCHAR(128) PRIMARY KEY, + "on" BOOLEAN NOT NULL, -- 使用双引号包围保留字 + temperature REAL NOT NULL +); +``` + +### Q: 如何处理多个语句和参数的匹配? +A: 当SQL文件包含多个语句时,参数应该是参数列表,每个语句对应一个参数元组: +```python +await db.execute_by_sql_file("./sql/batch.sql", [ + ("参数1", "参数2"), # 第一个语句的参数 + ("参数3", "参数4") # 第二个语句的参数 +]) +``` + +通过遵循这些指南和最佳实践,您可以充分利用本项目的异步数据库系统,构建高性能、安全的数据库应用。 \ No newline at end of file diff --git a/konabot/common/database/__init__.py b/konabot/common/database/__init__.py index b933bb4..03a5a5f 100644 --- a/konabot/common/database/__init__.py +++ b/konabot/common/database/__init__.py @@ -1,20 +1,43 @@ import os import asyncio +import sqlparse from pathlib import Path -from typing import List, Dict, Any, Optional, Union +from typing import List, Dict, Any, Optional, Union, TYPE_CHECKING import aiosqlite +if TYPE_CHECKING: + from . import DatabaseManager + +# 全局数据库管理器实例 +_global_db_manager: Optional['DatabaseManager'] = None + +def get_global_db_manager() -> 'DatabaseManager': + """获取全局数据库管理器实例""" + global _global_db_manager + if _global_db_manager is None: + from . import DatabaseManager + _global_db_manager = DatabaseManager() + return _global_db_manager + +def close_global_db_manager() -> None: + """关闭全局数据库管理器实例""" + global _global_db_manager + if _global_db_manager is not None: + # 注意:这个函数应该在async环境中调用close_all_connections + _global_db_manager = None + class DatabaseManager: """异步数据库管理器""" - def __init__(self, db_path: Optional[Union[str, Path]] = None): + def __init__(self, db_path: Optional[Union[str, Path]] = None, pool_size: int = 5): """ 初始化数据库管理器 Args: db_path: 数据库文件路径,支持str和Path类型 + pool_size: 连接池大小 """ if db_path is None: self.db_path = os.environ.get("DATABASE_PATH", "./data/database.db") @@ -23,27 +46,46 @@ class DatabaseManager: # 连接池 self._connection_pool = [] - self._pool_size = 5 + self._pool_size = pool_size self._lock = asyncio.Lock() + self._in_use = set() # 跟踪正在使用的连接 async def _get_connection(self) -> aiosqlite.Connection: """从连接池获取连接""" async with self._lock: - if self._connection_pool: - return self._connection_pool.pop() + # 尝试从池中获取现有连接 + while self._connection_pool: + conn = self._connection_pool.pop() + # 检查连接是否仍然有效 + try: + await conn.execute("SELECT 1") + self._in_use.add(conn) + return conn + except: + # 连接已失效,关闭它 + try: + await conn.close() + except: + pass - # 如果连接池为空,创建新连接 - conn = await aiosqlite.connect(self.db_path) - await conn.execute("PRAGMA foreign_keys = ON") - return conn + # 如果连接池为空,创建新连接 + conn = await aiosqlite.connect(self.db_path) + await conn.execute("PRAGMA foreign_keys = ON") + self._in_use.add(conn) + return conn async def _return_connection(self, conn: aiosqlite.Connection) -> None: """将连接返回到连接池""" async with self._lock: + self._in_use.discard(conn) if len(self._connection_pool) < self._pool_size: self._connection_pool.append(conn) else: - await conn.close() + # 池已满,直接关闭连接 + try: + await conn.close() + except: + pass async def query( self, query: str, params: Optional[tuple] = None @@ -57,6 +99,9 @@ class DatabaseManager: results = [dict(zip(columns, row)) for row in rows] await cursor.close() return results + except Exception as e: + # 记录错误但重新抛出,让调用者处理 + raise Exception(f"数据库查询失败: {str(e)}") from e finally: await self._return_connection(conn) @@ -75,6 +120,9 @@ class DatabaseManager: try: await conn.execute(command, params or ()) await conn.commit() + except Exception as e: + # 记录错误但重新抛出,让调用者处理 + raise Exception(f"数据库执行失败: {str(e)}") from e finally: await self._return_connection(conn) @@ -84,19 +132,47 @@ class DatabaseManager: try: await conn.executescript(script) await conn.commit() + except Exception as e: + # 记录错误但重新抛出,让调用者处理 + raise Exception(f"数据库脚本执行失败: {str(e)}") from e finally: await self._return_connection(conn) + def _parse_sql_statements(self, script: str) -> List[str]: + """解析SQL脚本,分割成独立的语句""" + # 使用sqlparse库更准确地分割SQL语句 + parsed = sqlparse.split(script) + statements = [] + + for statement in parsed: + statement = statement.strip() + if statement: + statements.append(statement) + + return statements + async def execute_by_sql_file( - self, file_path: Union[str, Path], params: Optional[tuple] = None + self, file_path: Union[str, Path], params: Optional[Union[tuple, List[tuple]]] = None ) -> None: """从 SQL 文件中读取非查询语句并执行""" path = str(file_path) if isinstance(file_path, Path) else file_path with open(path, "r", encoding="utf-8") as f: script = f.read() - # 如果有参数,使用execute方法而不是execute_script - if params: + + # 如果有参数且是元组,使用execute执行整个脚本 + if params is not None and isinstance(params, tuple): await self.execute(script, params) + # 如果有参数且是列表,分别执行每个语句 + elif params is not None and isinstance(params, list): + # 使用sqlparse准确分割SQL语句 + statements = self._parse_sql_statements(script) + if len(statements) != len(params): + raise ValueError(f"语句数量({len(statements)})与参数组数量({len(params)})不匹配") + + for statement, stmt_params in zip(statements, params): + if statement: + await self.execute(statement, stmt_params) + # 如果无参数,使用executescript else: await self.execute_script(script) @@ -106,6 +182,9 @@ class DatabaseManager: try: await conn.executemany(command, seq_of_params) await conn.commit() + except Exception as e: + # 记录错误但重新抛出,让调用者处理 + raise Exception(f"数据库批量执行失败: {str(e)}") from e finally: await self._return_connection(conn) @@ -121,7 +200,19 @@ class DatabaseManager: async def close_all_connections(self) -> None: """关闭所有连接""" async with self._lock: + # 关闭池中的连接 for conn in self._connection_pool: - await conn.close() + try: + await conn.close() + except: + pass self._connection_pool.clear() + # 关闭正在使用的连接 + for conn in self._in_use.copy(): + try: + await conn.close() + except: + pass + self._in_use.clear() + diff --git a/konabot/plugins/air_conditioner/__init__.py b/konabot/plugins/air_conditioner/__init__.py index 30a14dc..27b88e6 100644 --- a/konabot/plugins/air_conditioner/__init__.py +++ b/konabot/plugins/air_conditioner/__init__.py @@ -62,6 +62,12 @@ async def register_startup_hook(): Path(__file__).resolve().parent / "sql" / "create_table.sql" ) +@driver.on_shutdown +async def register_shutdown_hook(): + """注册关闭时需要执行的函数""" + # 关闭所有数据库连接 + await db_manager.close_all_connections() + evt = on_alconna(Alconna( "群空调" ), use_cmd_start=True, use_cmd_sep=False, skip_for_unmatch=True) @@ -135,7 +141,7 @@ evt = on_alconna(Alconna( ), use_cmd_start=True, use_cmd_sep=False, skip_for_unmatch=True) @evt.handle() -async def _(event: BaseEvent, target: DepLongTaskTarget, temp: Optional[Union[int, float]] = 1): +async def _(target: DepLongTaskTarget, temp: Optional[Union[int, float]] = 1): if temp is None: temp = 1 if temp <= 0: @@ -158,7 +164,7 @@ evt = on_alconna(Alconna( ), use_cmd_start=True, use_cmd_sep=False, skip_for_unmatch=True) @evt.handle() -async def _(event: BaseEvent, target: DepLongTaskTarget): +async def _(target: DepLongTaskTarget): id = target.channel_id ac = await get_ac(id) await ac.change_ac() @@ -181,7 +187,7 @@ evt = on_alconna(Alconna( ), use_cmd_start=True, use_cmd_sep=False, skip_for_unmatch=True) @evt.handle() -async def _(event: BaseEvent, target: DepLongTaskTarget): +async def _(target: DepLongTaskTarget): id = target.channel_id # ac = get_ac(id) # number, ranking = ac.get_crashes_and_ranking() diff --git a/konabot/plugins/idiomgame/__init__.py b/konabot/plugins/idiomgame/__init__.py index 7f1909f..e652cb1 100644 --- a/konabot/plugins/idiomgame/__init__.py +++ b/konabot/plugins/idiomgame/__init__.py @@ -73,6 +73,12 @@ async def register_startup_hook(): """注册启动时需要执行的函数""" await IdiomGame.init_lexicon() +@driver.on_shutdown +async def register_shutdown_hook(): + """注册关闭时需要执行的函数""" + # 关闭所有数据库连接 + await db_manager.close_all_connections() + class TryStartState(Enum): STARTED = 0 diff --git a/poetry.lock b/poetry.lock index 164736f..af13460 100644 --- a/poetry.lock +++ b/poetry.lock @@ -970,12 +970,12 @@ version = "0.4.6" description = "Cross-platform colored terminal text." optional = false python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,>=2.7" -groups = ["main"] -markers = "sys_platform == \"win32\" or platform_system == \"Windows\"" +groups = ["main", "dev"] files = [ {file = "colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6"}, {file = "colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44"}, ] +markers = {main = "sys_platform == \"win32\" or platform_system == \"Windows\"", dev = "sys_platform == \"win32\""} [package.source] type = "legacy" @@ -1592,6 +1592,23 @@ type = "legacy" url = "https://pypi.tuna.tsinghua.edu.cn/simple" reference = "mirrors" +[[package]] +name = "iniconfig" +version = "2.3.0" +description = "brain-dead simple config-ini parsing" +optional = false +python-versions = ">=3.10" +groups = ["dev"] +files = [ + {file = "iniconfig-2.3.0-py3-none-any.whl", hash = "sha256:f631c04d2c48c52b84d0d0549c99ff3859c98df65b3101406327ecc7d53fbf12"}, + {file = "iniconfig-2.3.0.tar.gz", hash = "sha256:c76315c77db068650d49c5b56314774a7804df16fee4402c1f19d6d15d8c4730"}, +] + +[package.source] +type = "legacy" +url = "https://pypi.tuna.tsinghua.edu.cn/simple" +reference = "mirrors" + [[package]] name = "jiter" version = "0.11.1" @@ -2703,6 +2720,23 @@ type = "legacy" url = "https://pypi.tuna.tsinghua.edu.cn/simple" reference = "mirrors" +[[package]] +name = "packaging" +version = "25.0" +description = "Core utilities for Python packages" +optional = false +python-versions = ">=3.8" +groups = ["dev"] +files = [ + {file = "packaging-25.0-py3-none-any.whl", hash = "sha256:29572ef2b1f17581046b3a2227d5c611fb25ec70ca1ba8554b24b0e69331a484"}, + {file = "packaging-25.0.tar.gz", hash = "sha256:d443872c98d677bf60f6a1f2f8c1cb748e8fe762d2bf9d3148b5599295b0fc4f"}, +] + +[package.source] +type = "legacy" +url = "https://pypi.tuna.tsinghua.edu.cn/simple" +reference = "mirrors" + [[package]] name = "pillow" version = "11.3.0" @@ -2882,6 +2916,27 @@ type = "legacy" url = "https://pypi.tuna.tsinghua.edu.cn/simple" reference = "mirrors" +[[package]] +name = "pluggy" +version = "1.6.0" +description = "plugin and hook calling mechanisms for python" +optional = false +python-versions = ">=3.9" +groups = ["dev"] +files = [ + {file = "pluggy-1.6.0-py3-none-any.whl", hash = "sha256:e920276dd6813095e9377c0bc5566d94c932c33b27a3e3945d8389c374dd4746"}, + {file = "pluggy-1.6.0.tar.gz", hash = "sha256:7dcc130b76258d33b90f61b658791dede3486c3e6bfb003ee5c9bfb396dd22f3"}, +] + +[package.extras] +dev = ["pre-commit", "tox"] +testing = ["coverage", "pytest", "pytest-benchmark"] + +[package.source] +type = "legacy" +url = "https://pypi.tuna.tsinghua.edu.cn/simple" +reference = "mirrors" + [[package]] name = "propcache" version = "0.4.1" @@ -3368,7 +3423,7 @@ version = "2.19.2" description = "Pygments is a syntax highlighting package written in Python." optional = false python-versions = ">=3.8" -groups = ["main"] +groups = ["main", "dev"] files = [ {file = "pygments-2.19.2-py3-none-any.whl", hash = "sha256:86540386c03d588bb81d44bc3928634ff26449851e99741617ecb9037ee5ec0b"}, {file = "pygments-2.19.2.tar.gz", hash = "sha256:636cb2477cec7f8952536970bc533bc43743542f70392ae026374600add5b887"}, @@ -3399,6 +3454,58 @@ type = "legacy" url = "https://pypi.tuna.tsinghua.edu.cn/simple" reference = "mirrors" +[[package]] +name = "pytest" +version = "9.0.1" +description = "pytest: simple powerful testing with Python" +optional = false +python-versions = ">=3.10" +groups = ["dev"] +files = [ + {file = "pytest-9.0.1-py3-none-any.whl", hash = "sha256:67be0030d194df2dfa7b556f2e56fb3c3315bd5c8822c6951162b92b32ce7dad"}, + {file = "pytest-9.0.1.tar.gz", hash = "sha256:3e9c069ea73583e255c3b21cf46b8d3c56f6e3a1a8f6da94ccb0fcf57b9d73c8"}, +] + +[package.dependencies] +colorama = {version = ">=0.4", markers = "sys_platform == \"win32\""} +iniconfig = ">=1.0.1" +packaging = ">=22" +pluggy = ">=1.5,<2" +pygments = ">=2.7.2" + +[package.extras] +dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "requests", "setuptools", "xmlschema"] + +[package.source] +type = "legacy" +url = "https://pypi.tuna.tsinghua.edu.cn/simple" +reference = "mirrors" + +[[package]] +name = "pytest-asyncio" +version = "1.3.0" +description = "Pytest support for asyncio" +optional = false +python-versions = ">=3.10" +groups = ["dev"] +files = [ + {file = "pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5"}, + {file = "pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5"}, +] + +[package.dependencies] +pytest = ">=8.2,<10" +typing-extensions = {version = ">=4.12", markers = "python_version < \"3.13\""} + +[package.extras] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"] +testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] + +[package.source] +type = "legacy" +url = "https://pypi.tuna.tsinghua.edu.cn/simple" +reference = "mirrors" + [[package]] name = "python-dotenv" version = "1.2.1" @@ -3723,6 +3830,27 @@ type = "legacy" url = "https://pypi.tuna.tsinghua.edu.cn/simple" reference = "mirrors" +[[package]] +name = "sqlparse" +version = "0.5.3" +description = "A non-validating SQL parser." +optional = false +python-versions = ">=3.8" +groups = ["main"] +files = [ + {file = "sqlparse-0.5.3-py3-none-any.whl", hash = "sha256:cf2196ed3418f3ba5de6af7e82c694a9fbdbfecccdfc72e281548517081f16ca"}, + {file = "sqlparse-0.5.3.tar.gz", hash = "sha256:09f67787f56a0b16ecdbde1bfc7f5d9c3371ca683cfeaa8e6ff60b4807ec9272"}, +] + +[package.extras] +dev = ["build", "hatch"] +doc = ["sphinx"] + +[package.source] +type = "legacy" +url = "https://pypi.tuna.tsinghua.edu.cn/simple" +reference = "mirrors" + [[package]] name = "starlette" version = "0.49.3" @@ -3926,11 +4054,12 @@ version = "4.15.0" description = "Backported and Experimental Type Hints for Python 3.9+" optional = false python-versions = ">=3.9" -groups = ["main"] +groups = ["main", "dev"] files = [ {file = "typing_extensions-4.15.0-py3-none-any.whl", hash = "sha256:f0fa19c6845758ab08074a0cfa8b7aecb71c999ca73d62883bc25cc018c4e548"}, {file = "typing_extensions-4.15.0.tar.gz", hash = "sha256:0cea48d173cc12fa28ecabc3b837ea3cf6f38c6d1136f85cbaaf598984861466"}, ] +markers = {dev = "python_version == \"3.12\""} [package.source] type = "legacy" @@ -4552,4 +4681,4 @@ reference = "mirrors" [metadata] lock-version = "2.1" python-versions = ">=3.12,<4.0" -content-hash = "5597aa165095a11fa08e4b6e1a1f4d3396711b684ed363ae0ced2dd59a09ec5d" +content-hash = "2c341fdc0d5b29ad3b24516c46e036b2eff4c11e244047d114971039255c2ac4" diff --git a/pyproject.toml b/pyproject.toml index a0595c1..4691cd4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,6 +28,7 @@ dependencies = [ "openai (>=2.7.1,<3.0.0)", "imageio (>=2.37.2,<3.0.0)", "aiosqlite (>=0.20.0,<1.0.0)", + "sqlparse (>=0.5.0,<1.0.0)", ] [tool.poetry] @@ -47,5 +48,7 @@ priority = "primary" [dependency-groups] dev = [ - "rust-just (>=1.43.0,<2.0.0)" + "rust-just (>=1.43.0,<2.0.0)", + "pytest (>=9.0.1,<10.0.0)", + "pytest-asyncio (>=1.3.0,<2.0.0)" ]