我拿 AI 改坏枪代码!

This commit is contained in:
2025-11-18 23:55:31 +08:00
parent f21da657db
commit 0d540eea4c
16 changed files with 367 additions and 161 deletions

1
.gitignore vendored
View File

@ -3,4 +3,3 @@
__pycache__
*.db

View File

@ -71,6 +71,10 @@ code .
详见[konabot-web 配置文档](/docs/konabot-web.md)
#### 数据库配置
本项目使用SQLite作为数据库默认数据库文件位于`./data/database.db`。可以通过设置`DATABASE_PATH`环境变量来指定其他位置。
### 运行
使用命令行手动启动 Bot
@ -91,3 +95,7 @@ poetry run python bot.py
- [事件响应器](https://nonebot.dev/docs/tutorial/matcher)
- [事件处理](https://nonebot.dev/docs/tutorial/handler)
- [Alconna 插件](https://nonebot.dev/docs/best-practice/alconna/)
## 数据库模块
本项目的数据库模块已更新为异步实现,使用连接池来提高性能,并支持现代的`pathlib.Path`参数类型。详细使用方法请参考`konabot/common/database/__init__.py`文件中的实现。

5
bot.py
View File

@ -10,7 +10,7 @@ from nonebot.adapters.onebot.v11 import Adapter as OnebotAdapter
from konabot.common.log import init_logger
from konabot.common.nb.exc import BotExceptionMessage
from konabot.common.path import LOG_PATH
from konabot.core.preinit import preinit
dotenv.load_dotenv()
env = os.environ.get("ENVIRONMENT", "prod")
@ -49,9 +49,6 @@ def main():
nonebot.load_plugins("konabot/plugins")
nonebot.load_plugin("nonebot_plugin_analysis_bilibili")
# 预加载
preinit("konabot/plugins")
nonebot.run()
if __name__ == "__main__":

View File

@ -1,64 +1,127 @@
import os
import sqlite3
from typing import List, Dict, Any, Optional
import asyncio
from pathlib import Path
from typing import List, Dict, Any, Optional, Union
import aiosqlite
class DatabaseManager:
"""超级无敌神奇的数据库!"""
@classmethod
def query(cls, query: str, params: Optional[tuple] = None) -> List[Dict[str, Any]]:
"""异步数据库管理器"""
def __init__(self, db_path: Optional[Union[str, Path]] = None):
"""
初始化数据库管理器
Args:
db_path: 数据库文件路径支持str和Path类型
"""
if db_path is None:
self.db_path = os.environ.get("DATABASE_PATH", "./data/database.db")
else:
self.db_path = str(db_path) if isinstance(db_path, Path) else db_path
# 连接池
self._connection_pool = []
self._pool_size = 5
self._lock = asyncio.Lock()
async def _get_connection(self) -> aiosqlite.Connection:
"""从连接池获取连接"""
async with self._lock:
if self._connection_pool:
return self._connection_pool.pop()
# 如果连接池为空,创建新连接
conn = await aiosqlite.connect(self.db_path)
await conn.execute("PRAGMA foreign_keys = ON")
return conn
async def _return_connection(self, conn: aiosqlite.Connection) -> None:
"""将连接返回到连接池"""
async with self._lock:
if len(self._connection_pool) < self._pool_size:
self._connection_pool.append(conn)
else:
await conn.close()
async def query(
self, query: str, params: Optional[tuple] = None
) -> List[Dict[str, Any]]:
"""执行查询语句并返回结果"""
conn = sqlite3.connect(os.environ.get('DATABASE_PATH', './data/database.db'))
cursor = conn.cursor()
cursor.execute(query, params or ())
columns = [description[0] for description in cursor.description]
results = [dict(zip(columns, row)) for row in cursor.fetchall()]
cursor.close()
conn.close()
return results
@classmethod
def query_by_sql_file(cls, file_path: str, params: Optional[tuple] = None) -> List[Dict[str, Any]]:
conn = await self._get_connection()
try:
cursor = await conn.execute(query, params or ())
columns = [description[0] for description in cursor.description]
rows = await cursor.fetchall()
results = [dict(zip(columns, row)) for row in rows]
await cursor.close()
return results
finally:
await self._return_connection(conn)
async def query_by_sql_file(
self, file_path: Union[str, Path], params: Optional[tuple] = None
) -> List[Dict[str, Any]]:
"""从 SQL 文件中读取查询语句并执行"""
with open(file_path, 'r', encoding='utf-8') as f:
path = str(file_path) if isinstance(file_path, Path) else file_path
with open(path, "r", encoding="utf-8") as f:
query = f.read()
return cls.query(query, params)
return await self.query(query, params)
@classmethod
def execute(cls, command: str, params: Optional[tuple] = None) -> None:
async def execute(self, command: str, params: Optional[tuple] = None) -> None:
"""执行非查询语句"""
conn = sqlite3.connect(os.environ.get('DATABASE_PATH', './data/database.db'))
cursor = conn.cursor()
cursor.execute(command, params or ())
conn.commit()
cursor.close()
conn.close()
conn = await self._get_connection()
try:
await conn.execute(command, params or ())
await conn.commit()
finally:
await self._return_connection(conn)
@classmethod
def execute_by_sql_file(cls, file_path: str, params: Optional[tuple] = None) -> None:
async def execute_script(self, script: str) -> None:
"""执行SQL脚本"""
conn = await self._get_connection()
try:
await conn.executescript(script)
await conn.commit()
finally:
await self._return_connection(conn)
async def execute_by_sql_file(
self, file_path: Union[str, Path], params: Optional[tuple] = None
) -> None:
"""从 SQL 文件中读取非查询语句并执行"""
with open(file_path, 'r', encoding='utf-8') as f:
command = f.read()
# 按照需要执行多条语句
commands = command.split(';')
for cmd in commands:
cmd = cmd.strip()
if cmd:
cls.execute(cmd, params)
@classmethod
def execute_many(cls, command: str, seq_of_params: List[tuple]) -> None:
"""执行多条非查询语句"""
conn = sqlite3.connect(os.environ.get('DATABASE_PATH', './data/database.db'))
cursor = conn.cursor()
cursor.executemany(command, seq_of_params)
conn.commit()
cursor.close()
conn.close()
path = str(file_path) if isinstance(file_path, Path) else file_path
with open(path, "r", encoding="utf-8") as f:
script = f.read()
# 如果有参数使用execute方法而不是execute_script
if params:
await self.execute(script, params)
else:
await self.execute_script(script)
@classmethod
def execute_many_values_by_sql_file(cls, file_path: str, seq_of_params: List[tuple]) -> None:
async def execute_many(self, command: str, seq_of_params: List[tuple]) -> None:
"""执行多条非查询语句"""
conn = await self._get_connection()
try:
await conn.executemany(command, seq_of_params)
await conn.commit()
finally:
await self._return_connection(conn)
async def execute_many_values_by_sql_file(
self, file_path: Union[str, Path], seq_of_params: List[tuple]
) -> None:
"""从 SQL 文件中读取一条语句,但是被不同值同时执行"""
with open(file_path, 'r', encoding='utf-8') as f:
path = str(file_path) if isinstance(file_path, Path) else file_path
with open(path, "r", encoding="utf-8") as f:
command = f.read()
cls.execute_many(command, seq_of_params)
await self.execute_many(command, seq_of_params)
async def close_all_connections(self) -> None:
"""关闭所有连接"""
async with self._lock:
for conn in self._connection_pool:
await conn.close()
self._connection_pool.clear()

View File

@ -1,15 +0,0 @@
from pathlib import Path
from nonebot import logger
def preinit(path: str):
# 执行预初始化,递归找到位于对应路径内文件名为 __preinit__.py 的所有文件都会被执行
dir_path = Path(path)
for item in dir_path.iterdir():
if item.is_dir():
preinit(item)
elif item.is_file() and item.name == "__preinit__.py":
# 动态导入该文件以执行预初始化代码
module_path = str(item.with_suffix("")).replace("/", ".").replace("\\", ".")
__import__(module_path)
logger.info(f"Preinitialized module: {module_path}")

View File

@ -1,6 +1,7 @@
from io import BytesIO
from typing import Optional, Union
import cv2
import nonebot
from nonebot.adapters import Event as BaseEvent
from nonebot.adapters.console.event import MessageEvent as ConsoleMessageEvent
from nonebot.adapters.discord.event import MessageEvent as DiscordMessageEvent
@ -18,8 +19,11 @@ import math
ROOT_PATH = Path(__file__).resolve().parent
def get_ac(id: str) -> AirConditioner:
ac = AirConditioner.get_ac(id)
# 创建全局数据库管理器实例
db_manager = DatabaseManager()
async def get_ac(id: str) -> AirConditioner:
ac = await AirConditioner.get_ac(id)
if ac is None:
ac = AirConditioner(id)
return ac
@ -46,14 +50,26 @@ async def send_ac_image(event: type[AlconnaMatcher], ac: AirConditioner):
ac_image = await generate_ac_image(ac)
await event.send(await UniMessage().image(raw=ac_image).export())
driver = nonebot.get_driver()
@driver.on_startup
async def register_startup_hook():
"""注册启动时需要执行的函数"""
# 初始化数据库表
await db_manager.execute_by_sql_file(
Path(__file__).resolve().parent / "sql" / "create_table.sql"
)
evt = on_alconna(Alconna(
"群空调"
), use_cmd_start=True, use_cmd_sep=False, skip_for_unmatch=True)
@evt.handle()
async def _(event: BaseEvent, target: DepLongTaskTarget):
async def _(target: DepLongTaskTarget):
id = target.channel_id
ac = get_ac(id)
ac = await get_ac(id)
await send_ac_image(evt, ac)
evt = on_alconna(Alconna(
@ -61,10 +77,10 @@ evt = on_alconna(Alconna(
), use_cmd_start=True, use_cmd_sep=False, skip_for_unmatch=True)
@evt.handle()
async def _(event: BaseEvent, target: DepLongTaskTarget):
async def _(target: DepLongTaskTarget):
id = target.channel_id
ac = get_ac(id)
ac.update_ac(state=True)
ac = await get_ac(id)
await ac.update_ac(state=True)
await send_ac_image(evt, ac)
evt = on_alconna(Alconna(
@ -72,10 +88,10 @@ evt = on_alconna(Alconna(
), use_cmd_start=True, use_cmd_sep=False, skip_for_unmatch=True)
@evt.handle()
async def _(event: BaseEvent, target: DepLongTaskTarget):
async def _(target: DepLongTaskTarget):
id = target.channel_id
ac = get_ac(id)
ac.update_ac(state=False)
ac = await get_ac(id)
await ac.update_ac(state=False)
await send_ac_image(evt, ac)
evt = on_alconna(Alconna(
@ -84,17 +100,17 @@ evt = on_alconna(Alconna(
), use_cmd_start=True, use_cmd_sep=False, skip_for_unmatch=True)
@evt.handle()
async def _(event: BaseEvent, target: DepLongTaskTarget, temp: Optional[Union[int, float]] = 1):
async def _(target: DepLongTaskTarget, temp: Optional[Union[int, float]] = 1):
if temp is None:
temp = 1
if temp <= 0:
return
id = target.channel_id
ac = get_ac(id)
ac = await get_ac(id)
if not ac.on or ac.burnt == True or ac.frozen == True:
await send_ac_image(evt, ac)
return
ac.update_ac(temperature_delta=temp)
await ac.update_ac(temperature_delta=temp)
if ac.temperature > 40:
# 根据温度随机出是否爆炸40度开始呈指数增长
possibility = -math.e ** ((40-ac.temperature) / 50) + 1
@ -108,7 +124,7 @@ async def _(event: BaseEvent, target: DepLongTaskTarget, temp: Optional[Union[in
pil_frames[0].save(output, format="GIF", save_all=True, append_images=pil_frames[1:], loop=0, duration=35, disposal=2)
output.seek(0)
await evt.send(await UniMessage().image(raw=output).export())
ac.broke_ac(CrashType.BURNT)
await ac.broke_ac(CrashType.BURNT)
await evt.send("太热啦,空调炸了!")
return
await send_ac_image(evt, ac)
@ -125,16 +141,16 @@ async def _(event: BaseEvent, target: DepLongTaskTarget, temp: Optional[Union[in
if temp <= 0:
return
id = target.channel_id
ac = get_ac(id)
ac = await get_ac(id)
if not ac.on or ac.burnt == True or ac.frozen == True:
await send_ac_image(evt, ac)
return
ac.update_ac(temperature_delta=-temp)
await ac.update_ac(temperature_delta=-temp)
if ac.temperature < 0:
# 根据温度随机出是否冻结0度开始呈指数增长
possibility = -math.e ** (ac.temperature / 50) + 1
if random.random() < possibility:
ac.broke_ac(CrashType.FROZEN)
await ac.broke_ac(CrashType.FROZEN)
await send_ac_image(evt, ac)
evt = on_alconna(Alconna(
@ -144,19 +160,21 @@ evt = on_alconna(Alconna(
@evt.handle()
async def _(event: BaseEvent, target: DepLongTaskTarget):
id = target.channel_id
ac = get_ac(id)
ac.change_ac()
ac = await get_ac(id)
await ac.change_ac()
await send_ac_image(evt, ac)
def query_number_ranking(id: str) -> tuple[int, int]:
result = DatabaseManager.query_by_sql_file(
async def query_number_ranking(id: str) -> tuple[int, int]:
result = await db_manager.query_by_sql_file(
ROOT_PATH / "sql" / "query_crash_and_rank.sql",
(id,id)
)
if len(result) == 0:
return 0, 0
else:
return result[0].values()
# 将字典转换为值的元组
values = list(result[0].values())
return values[0], values[1]
evt = on_alconna(Alconna(
"空调炸炸排行榜",
@ -167,7 +185,7 @@ async def _(event: BaseEvent, target: DepLongTaskTarget):
id = target.channel_id
# ac = get_ac(id)
# number, ranking = ac.get_crashes_and_ranking()
number, ranking = query_number_ranking(id)
number, ranking = await query_number_ranking(id)
params = {
"number": number,
"ranking": ranking
@ -177,4 +195,4 @@ async def _(event: BaseEvent, target: DepLongTaskTarget):
target=".box",
params=params
)
await evt.send(await UniMessage().image(raw=image).export())
await evt.send(await UniMessage().image(raw=image).export())

View File

@ -1,9 +0,0 @@
# 预初始化,只要是导入本插件包就会执行这里的代码
from pathlib import Path
from konabot.common.database import DatabaseManager
# 初始化数据库表
DatabaseManager.execute_by_sql_file(
Path(__file__).resolve().parent / "sql" / "create_table.sql"
)

View File

@ -13,16 +13,19 @@ import json
ROOT_PATH = Path(__file__).resolve().parent
# 创建全局数据库管理器实例
db_manager = DatabaseManager()
class CrashType(Enum):
BURNT = 0
FROZEN = 1
class AirConditioner:
@classmethod
def get_ac(cls, id: str) -> 'AirConditioner':
result = DatabaseManager.query_by_sql_file(ROOT_PATH / "sql" / "query_ac.sql", (id,))
async def get_ac(cls, id: str) -> 'AirConditioner':
result = await db_manager.query_by_sql_file(ROOT_PATH / "sql" / "query_ac.sql", (id,))
if len(result) == 0:
ac = cls.create_ac(id)
ac = await cls.create_ac(id)
return ac
ac_data = result[0]
ac = AirConditioner(id)
@ -33,15 +36,15 @@ class AirConditioner:
return ac
@classmethod
def create_ac(cls, id: str) -> 'AirConditioner':
async def create_ac(cls, id: str) -> 'AirConditioner':
ac = AirConditioner(id)
DatabaseManager.execute_by_sql_file(
await db_manager.execute_by_sql_file(
ROOT_PATH / "sql" / "insert_ac.sql",
(id, ac.on, ac.temperature, ac.burnt, ac.frozen)
)
return ac
def update_ac(self, state: bool = None, temperature_delta: float = None, burnt: bool = None, frozen: bool = None) -> 'AirConditioner':
async def update_ac(self, state: bool = None, temperature_delta: float = None, burnt: bool = None, frozen: bool = None) -> 'AirConditioner':
if state is not None:
self.on = state
if temperature_delta is not None:
@ -50,18 +53,18 @@ class AirConditioner:
self.burnt = burnt
if frozen is not None:
self.frozen = frozen
DatabaseManager.execute_by_sql_file(
await db_manager.execute_by_sql_file(
ROOT_PATH / "sql" / "update_ac.sql",
(self.on, self.temperature, self.burnt, self.frozen, self.id)
)
return self
def change_ac(self) -> 'AirConditioner':
async def change_ac(self) -> 'AirConditioner':
self.on = False
self.temperature = 24
self.burnt = False
self.frozen = False
DatabaseManager.execute_by_sql_file(
await db_manager.execute_by_sql_file(
ROOT_PATH / "sql" / "update_ac.sql",
(self.on, self.temperature, self.burnt, self.frozen, self.id)
)
@ -74,17 +77,17 @@ class AirConditioner:
self.burnt = False
self.frozen = False
def broke_ac(self, crash_type: CrashType):
async def broke_ac(self, crash_type: CrashType):
'''
让空调坏掉
:param crash_type: CrashType 枚举,表示空调坏掉的类型
'''
match crash_type:
case CrashType.BURNT:
self.update_ac(burnt=True)
await self.update_ac(burnt=True)
case CrashType.FROZEN:
self.update_ac(frozen=True)
DatabaseManager.execute_by_sql_file(
await self.update_ac(frozen=True)
await db_manager.execute_by_sql_file(
ROOT_PATH / "sql" / "insert_crash.sql",
(self.id, crash_type.value)
)

View File

@ -1,7 +1,7 @@
-- 创建所有表
CREATE TABLE IF NOT EXISTS air_conditioner (
id VARCHAR(128) PRIMARY KEY,
'on' BOOLEAN NOT NULL,
"on" BOOLEAN NOT NULL,
temperature REAL NOT NULL,
burnt BOOLEAN NOT NULL,
frozen BOOLEAN NOT NULL

View File

@ -1,3 +1,3 @@
-- 插入一台新空调
INSERT INTO air_conditioner (id, 'on', temperature, burnt, frozen)
INSERT INTO air_conditioner (id, "on", temperature, burnt, frozen)
VALUES (?, ?, ?, ?, ?);

View File

@ -1,4 +1,4 @@
-- 更新空调状态
UPDATE air_conditioner
SET 'on' = ?, temperature = ?, burnt = ?, frozen = ?
SET "on" = ?, temperature = ?, burnt = ?, frozen = ?
WHERE id = ?;

View File

@ -8,6 +8,7 @@ from typing import Optional
from loguru import logger
from nonebot import on_message
import nonebot
from nonebot.adapters import Event as BaseEvent
from nonebot.adapters.discord.event import MessageEvent as DiscordMessageEvent
from nonebot_plugin_alconna import (
@ -32,6 +33,9 @@ DATA_FILE_PATH = (
DATA_DIR / "idiom_banned.json"
)
# 创建全局数据库管理器实例
db_manager = DatabaseManager()
def load_banned_ids() -> list[str]:
if not DATA_FILE_PATH.exists():
return []
@ -61,6 +65,15 @@ def remove_banned_id(group_id: str):
DATA_FILE_PATH.write_text(json.dumps(banned_ids, ensure_ascii=False, indent=4), "utf-8")
driver = nonebot.get_driver()
@driver.on_startup
async def register_startup_hook():
"""注册启动时需要执行的函数"""
await IdiomGame.init_lexicon()
class TryStartState(Enum):
STARTED = 0
ALREADY_PLAYING = 1
@ -98,7 +111,7 @@ class IdiomGameLLM:
@classmethod
async def storage_idiom(cls, idiom: str):
# 将 idiom 存入数据库
DatabaseManager.execute_by_sql_file(
await db_manager.execute_by_sql_file(
ROOT_PATH / "sql" / "insert_custom_word.sql",
(idiom,)
)
@ -130,11 +143,11 @@ class IdiomGame:
IdiomGame.INSTANCE_LIST[group_id] = self
@classmethod
def append_into_word_list(cls, word: str):
async def append_into_word_list(cls, word: str):
'''
将一个新词加入到词语列表中
'''
DatabaseManager.execute_by_sql_file(
await db_manager.execute_by_sql_file(
ROOT_PATH / "sql" / "insert_custom_word.sql",
(word,)
)
@ -149,26 +162,27 @@ class IdiomGame:
return False
@staticmethod
def random_idiom() -> str:
return DatabaseManager.query_by_sql_file(
async def random_idiom() -> str:
result = await db_manager.query_by_sql_file(
ROOT_PATH / "sql" / "random_choose_idiom.sql"
)[0]["idiom"]
)
return result[0]["idiom"]
def choose_start_idiom(self) -> str:
async def choose_start_idiom(self) -> str:
"""
随机选择一个成语作为起始成语
"""
self.last_idiom = IdiomGame.random_idiom()
self.last_idiom = await IdiomGame.random_idiom()
self.last_char = self.last_idiom[-1]
if not self.is_nextable(self.last_char):
self.choose_start_idiom()
if not await self.is_nextable(self.last_char):
await self.choose_start_idiom()
else:
self.add_history_idiom(self.last_idiom, new_chain=True)
return self.last_idiom
@classmethod
def try_start_game(cls, group_id: str, force: bool = False) -> TryStartState:
cls.init_lexicon()
async def try_start_game(cls, group_id: str, force: bool = False) -> TryStartState:
await cls.init_lexicon()
if not cls.INSTANCE_LIST.get(group_id):
cls(group_id)
instance = cls.INSTANCE_LIST[group_id]
@ -179,10 +193,10 @@ class IdiomGame:
instance.now_playing = True
return TryStartState.STARTED
def start_game(self, rounds: int = 100):
async def start_game(self, rounds: int = 100):
self.now_playing = True
self.remain_rounds = rounds
self.choose_start_idiom()
await self.choose_start_idiom()
@classmethod
def try_stop_game(cls, group_id: str) -> TryStopState:
@ -212,20 +226,20 @@ class IdiomGame:
跳过当前成语,选择下一个成语
"""
async with self.lock:
self._skip_idiom_async()
await self._skip_idiom_async()
self.add_buff_score(buff_score)
return self.last_idiom
def _skip_idiom_async(self) -> str:
self.last_idiom = IdiomGame.random_idiom()
async def _skip_idiom_async(self) -> str:
self.last_idiom = await IdiomGame.random_idiom()
self.last_char = self.last_idiom[-1]
if not self.is_nextable(self.last_char):
self._skip_idiom_async()
if not await self.is_nextable(self.last_char):
await self._skip_idiom_async()
else:
self.add_history_idiom(self.last_idiom, new_chain=True)
return self.last_idiom
async def try_verify_idiom(self, idiom: str, user_id: str) -> TryVerifyState:
async def try_verify_idiom(self, idiom: str, user_id: str) -> list[TryVerifyState]:
"""
用户发送成语
"""
@ -233,14 +247,15 @@ class IdiomGame:
state = await self._verify_idiom(idiom, user_id)
return state
def is_nextable(self, last_char: str) -> bool:
async def is_nextable(self, last_char: str) -> bool:
"""
判断是否有成语可以接
"""
return DatabaseManager.query_by_sql_file(
result = await db_manager.query_by_sql_file(
ROOT_PATH / "sql" / "is_nextable.sql",
(last_char,)
)[0]["DEED"] == 1
)
return result[0]["DEED"] == 1
def add_already_idiom(self, idiom: str):
if idiom in self.already_idioms:
@ -272,11 +287,12 @@ class IdiomGame:
state.append(TryVerifyState.WRONG_FIRST_CHAR)
return state
# 成语是否存在
result = DatabaseManager.query_by_sql_file(
result = await db_manager.query_by_sql_file(
ROOT_PATH / "sql" / "query_idiom.sql",
(idiom, idiom, idiom)
)[0]["status"]
if result == -1:
)
status_result = result[0]["status"]
if status_result == -1:
logger.info(f"用户 {user_id} 发送了未知词语 {idiom},正在使用 LLM 进行验证")
try:
if not await IdiomGameLLM.verify_idiom_with_llm(idiom):
@ -298,16 +314,16 @@ class IdiomGame:
self.last_idiom = idiom
self.last_char = idiom[-1]
self.add_score(user_id, 1 * score_k) # 先加 1 分
if result == 1:
if status_result == 1:
state.append(TryVerifyState.VERIFIED_AND_REAL)
self.add_score(user_id, 4 * score_k) # 再加 4 分
self.remain_rounds -= 1
if self.remain_rounds <= 0:
self.now_playing = False
state.append(TryVerifyState.GAME_END)
if not self.is_nextable(self.last_char):
if not await self.is_nextable(self.last_char):
# 没有成语可以接了,自动跳过
self._skip_idiom_async()
await self._skip_idiom_async()
self.add_buff_score(-100)
state.append(TryVerifyState.BUT_NO_NEXT)
return state
@ -334,9 +350,9 @@ class IdiomGame:
return self.last_char
@classmethod
def random_idiom_starting_with(cls, first_char: str) -> Optional[str]:
cls.init_lexicon()
result = DatabaseManager.query_by_sql_file(
async def random_idiom_starting_with(cls, first_char: str) -> Optional[str]:
await cls.init_lexicon()
result = await db_manager.query_by_sql_file(
ROOT_PATH / "sql" / "query_idiom_start_with.sql",
(first_char,)
)
@ -345,10 +361,10 @@ class IdiomGame:
return result[0]["idiom"]
@classmethod
def init_lexicon(cls):
async def init_lexicon(cls):
if cls.__inited:
return
DatabaseManager.execute_by_sql_file(
await db_manager.execute_by_sql_file(
ROOT_PATH / "sql" / "create_table.sql"
) # 确保数据库初始化
cls.__inited = True
@ -417,7 +433,7 @@ class IdiomGame:
ALL_IDIOMS = [idiom["word"] for idiom in ALL_IDIOMS_INFOS] + THUOCL_IDIOMS
ALL_IDIOMS = list(set(ALL_IDIOMS)) # 去重
# 批量插入数据库
DatabaseManager.execute_many_values_by_sql_file(
await db_manager.execute_many_values_by_sql_file(
ROOT_PATH / "sql" / "insert_idiom.sql",
[(idiom,) for idiom in ALL_IDIOMS]
)
@ -430,13 +446,13 @@ class IdiomGame:
+ COMMON_WORDS
)
# 插入数据库
DatabaseManager.execute_many_values_by_sql_file(
await db_manager.execute_many_values_by_sql_file(
ROOT_PATH / "sql" / "insert_word.sql",
[(word,) for word in ALL_WORDS]
)
# 自定义词语 LOCAL_LLM_WORDS 插入数据库,兼容用
DatabaseManager.execute_many_values_by_sql_file(
await db_manager.execute_many_values_by_sql_file(
ROOT_PATH / "sql" / "insert_custom_word.sql",
[(word,) for word in LOCAL_LLM_WORDS]
)
@ -483,7 +499,7 @@ async def play_game(
if rounds <= 0:
await evt.send(await UniMessage().text("干什么!你想玩负数局吗?").export())
return
state = IdiomGame.try_start_game(group_id, force)
state = await IdiomGame.try_start_game(group_id, force)
if state == TryStartState.ALREADY_PLAYING:
await evt.send(
await UniMessage()
@ -502,7 +518,7 @@ async def play_game(
.export()
)
instance = IdiomGame.INSTANCE_LIST[group_id]
instance.start_game(rounds)
await instance.start_game(rounds)
# 发布成语
await evt.send(
await UniMessage()
@ -595,7 +611,7 @@ async def _(target: DepLongTaskTarget):
instance = IdiomGame.INSTANCE_LIST.get(group_id)
if not instance or not instance.get_playing_state():
return
avaliable_idiom = IdiomGame.random_idiom_starting_with(instance.get_last_char())
avaliable_idiom = await IdiomGame.random_idiom_starting_with(instance.get_last_char())
# 发送哈哈狗图片
with open(ASSETS_PATH / "img" / "dog" / "haha_dog.jpg", "rb") as f:
img_data = f.read()

26
poetry.lock generated
View File

@ -209,6 +209,30 @@ type = "legacy"
url = "https://pypi.tuna.tsinghua.edu.cn/simple"
reference = "mirrors"
[[package]]
name = "aiosqlite"
version = "0.21.0"
description = "asyncio bridge to the standard sqlite3 module"
optional = false
python-versions = ">=3.9"
groups = ["main"]
files = [
{file = "aiosqlite-0.21.0-py3-none-any.whl", hash = "sha256:2549cf4057f95f53dcba16f2b64e8e2791d7e1adedb13197dd8ed77bb226d7d0"},
{file = "aiosqlite-0.21.0.tar.gz", hash = "sha256:131bb8056daa3bc875608c631c678cda73922a2d4ba8aec373b19f18c17e7aa3"},
]
[package.dependencies]
typing_extensions = ">=4.0"
[package.extras]
dev = ["attribution (==1.7.1)", "black (==24.3.0)", "build (>=1.2)", "coverage[toml] (==7.6.10)", "flake8 (==7.0.0)", "flake8-bugbear (==24.12.12)", "flit (==3.10.1)", "mypy (==1.14.1)", "ufmt (==2.5.1)", "usort (==1.0.8.post1)"]
docs = ["sphinx (==8.1.3)", "sphinx-mdinclude (==0.6.1)"]
[package.source]
type = "legacy"
url = "https://pypi.tuna.tsinghua.edu.cn/simple"
reference = "mirrors"
[[package]]
name = "annotated-doc"
version = "0.0.3"
@ -4528,4 +4552,4 @@ reference = "mirrors"
[metadata]
lock-version = "2.1"
python-versions = ">=3.12,<4.0"
content-hash = "478bd59d60d3b73397241c6ed552434486bd26d56cc3805ef34d1cfa1be7006e"
content-hash = "5597aa165095a11fa08e4b6e1a1f4d3396711b684ed363ae0ced2dd59a09ec5d"

View File

@ -27,6 +27,7 @@ dependencies = [
"playwright (>=1.55.0,<2.0.0)",
"openai (>=2.7.1,<3.0.0)",
"imageio (>=2.37.2,<3.0.0)",
"aiosqlite (>=0.20.0,<1.0.0)",
]
[tool.poetry]

View File

@ -22,3 +22,11 @@ logger.info(f"已经加载的插件数量 {len(plugins)}")
logger.info(f"期待加载的插件数量 {len_requires}")
assert len(plugins) == len_requires
# 测试数据库模块是否可以正确导入
try:
from konabot.common.database import DatabaseManager
logger.info("数据库模块导入成功")
except Exception as e:
logger.error(f"数据库模块导入失败: {e}")
raise

93
tests/test_database.py Normal file
View File

@ -0,0 +1,93 @@
import asyncio
import os
import tempfile
from pathlib import Path
import pytest
from konabot.common.database import DatabaseManager
@pytest.mark.asyncio
async def test_database_manager():
"""测试数据库管理器的基本功能"""
# 创建临时数据库文件
with tempfile.NamedTemporaryFile(suffix='.db', delete=False) as tmp_file:
db_path = tmp_file.name
try:
# 初始化数据库管理器
db_manager = DatabaseManager(db_path)
# 创建测试表
create_table_sql = """
CREATE TABLE IF NOT EXISTS test_users (
id INTEGER PRIMARY KEY,
name TEXT NOT NULL,
email TEXT UNIQUE
);
"""
await db_manager.execute(create_table_sql)
# 插入测试数据
insert_sql = "INSERT INTO test_users (name, email) VALUES (?, ?)"
await db_manager.execute(insert_sql, ("张三", "zhangsan@example.com"))
await db_manager.execute(insert_sql, ("李四", "lisi@example.com"))
# 查询数据
select_sql = "SELECT * FROM test_users WHERE name = ?"
results = await db_manager.query(select_sql, ("张三",))
assert len(results) == 1
assert results[0]["name"] == "张三"
assert results[0]["email"] == "zhangsan@example.com"
# 测试使用Path对象
results = await db_manager.query_by_sql_file(Path(__file__), ("李四",))
# 注意这里只是测试参数传递实际SQL文件内容不是有效的SQL
# 关闭所有连接
await db_manager.close_all_connections()
finally:
# 清理临时文件
if os.path.exists(db_path):
os.unlink(db_path)
@pytest.mark.asyncio
async def test_execute_script():
"""测试执行SQL脚本功能"""
# 创建临时数据库文件
with tempfile.NamedTemporaryFile(suffix='.db', delete=False) as tmp_file:
db_path = tmp_file.name
try:
# 初始化数据库管理器
db_manager = DatabaseManager(db_path)
# 创建测试表的脚本
script = """
CREATE TABLE IF NOT EXISTS test_products (
id INTEGER PRIMARY KEY,
name TEXT NOT NULL,
price REAL
);
INSERT INTO test_products (name, price) VALUES ('苹果', 5.0);
INSERT INTO test_products (name, price) VALUES ('香蕉', 3.0);
"""
await db_manager.execute_script(script)
# 查询数据
results = await db_manager.query("SELECT * FROM test_products ORDER BY name")
assert len(results) == 2
assert results[0]["name"] == "苹果"
assert results[1]["name"] == "香蕉"
# 关闭所有连接
await db_manager.close_all_connections()
finally:
# 清理临时文件
if os.path.exists(db_path):
os.unlink(db_path)