forked from mttu-developers/konabot
我拿 AI 改坏枪代码!
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@ -3,4 +3,3 @@
|
|||||||
|
|
||||||
__pycache__
|
__pycache__
|
||||||
|
|
||||||
*.db
|
|
||||||
@ -71,6 +71,10 @@ code .
|
|||||||
|
|
||||||
详见[konabot-web 配置文档](/docs/konabot-web.md)
|
详见[konabot-web 配置文档](/docs/konabot-web.md)
|
||||||
|
|
||||||
|
#### 数据库配置
|
||||||
|
|
||||||
|
本项目使用SQLite作为数据库,默认数据库文件位于`./data/database.db`。可以通过设置`DATABASE_PATH`环境变量来指定其他位置。
|
||||||
|
|
||||||
### 运行
|
### 运行
|
||||||
|
|
||||||
使用命令行手动启动 Bot:
|
使用命令行手动启动 Bot:
|
||||||
@ -91,3 +95,7 @@ poetry run python bot.py
|
|||||||
- [事件响应器](https://nonebot.dev/docs/tutorial/matcher)
|
- [事件响应器](https://nonebot.dev/docs/tutorial/matcher)
|
||||||
- [事件处理](https://nonebot.dev/docs/tutorial/handler)
|
- [事件处理](https://nonebot.dev/docs/tutorial/handler)
|
||||||
- [Alconna 插件](https://nonebot.dev/docs/best-practice/alconna/)
|
- [Alconna 插件](https://nonebot.dev/docs/best-practice/alconna/)
|
||||||
|
|
||||||
|
## 数据库模块
|
||||||
|
|
||||||
|
本项目的数据库模块已更新为异步实现,使用连接池来提高性能,并支持现代的`pathlib.Path`参数类型。详细使用方法请参考`konabot/common/database/__init__.py`文件中的实现。
|
||||||
|
|||||||
5
bot.py
5
bot.py
@ -10,7 +10,7 @@ from nonebot.adapters.onebot.v11 import Adapter as OnebotAdapter
|
|||||||
from konabot.common.log import init_logger
|
from konabot.common.log import init_logger
|
||||||
from konabot.common.nb.exc import BotExceptionMessage
|
from konabot.common.nb.exc import BotExceptionMessage
|
||||||
from konabot.common.path import LOG_PATH
|
from konabot.common.path import LOG_PATH
|
||||||
from konabot.core.preinit import preinit
|
|
||||||
|
|
||||||
dotenv.load_dotenv()
|
dotenv.load_dotenv()
|
||||||
env = os.environ.get("ENVIRONMENT", "prod")
|
env = os.environ.get("ENVIRONMENT", "prod")
|
||||||
@ -49,9 +49,6 @@ def main():
|
|||||||
nonebot.load_plugins("konabot/plugins")
|
nonebot.load_plugins("konabot/plugins")
|
||||||
nonebot.load_plugin("nonebot_plugin_analysis_bilibili")
|
nonebot.load_plugin("nonebot_plugin_analysis_bilibili")
|
||||||
|
|
||||||
# 预加载
|
|
||||||
preinit("konabot/plugins")
|
|
||||||
|
|
||||||
nonebot.run()
|
nonebot.run()
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@ -1,64 +1,127 @@
|
|||||||
import os
|
import os
|
||||||
import sqlite3
|
import asyncio
|
||||||
from typing import List, Dict, Any, Optional
|
from pathlib import Path
|
||||||
|
from typing import List, Dict, Any, Optional, Union
|
||||||
|
|
||||||
|
import aiosqlite
|
||||||
|
|
||||||
|
|
||||||
class DatabaseManager:
|
class DatabaseManager:
|
||||||
"""超级无敌神奇的数据库!"""
|
"""异步数据库管理器"""
|
||||||
|
|
||||||
@classmethod
|
def __init__(self, db_path: Optional[Union[str, Path]] = None):
|
||||||
def query(cls, query: str, params: Optional[tuple] = None) -> List[Dict[str, Any]]:
|
"""
|
||||||
|
初始化数据库管理器
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db_path: 数据库文件路径,支持str和Path类型
|
||||||
|
"""
|
||||||
|
if db_path is None:
|
||||||
|
self.db_path = os.environ.get("DATABASE_PATH", "./data/database.db")
|
||||||
|
else:
|
||||||
|
self.db_path = str(db_path) if isinstance(db_path, Path) else db_path
|
||||||
|
|
||||||
|
# 连接池
|
||||||
|
self._connection_pool = []
|
||||||
|
self._pool_size = 5
|
||||||
|
self._lock = asyncio.Lock()
|
||||||
|
|
||||||
|
async def _get_connection(self) -> aiosqlite.Connection:
|
||||||
|
"""从连接池获取连接"""
|
||||||
|
async with self._lock:
|
||||||
|
if self._connection_pool:
|
||||||
|
return self._connection_pool.pop()
|
||||||
|
|
||||||
|
# 如果连接池为空,创建新连接
|
||||||
|
conn = await aiosqlite.connect(self.db_path)
|
||||||
|
await conn.execute("PRAGMA foreign_keys = ON")
|
||||||
|
return conn
|
||||||
|
|
||||||
|
async def _return_connection(self, conn: aiosqlite.Connection) -> None:
|
||||||
|
"""将连接返回到连接池"""
|
||||||
|
async with self._lock:
|
||||||
|
if len(self._connection_pool) < self._pool_size:
|
||||||
|
self._connection_pool.append(conn)
|
||||||
|
else:
|
||||||
|
await conn.close()
|
||||||
|
|
||||||
|
async def query(
|
||||||
|
self, query: str, params: Optional[tuple] = None
|
||||||
|
) -> List[Dict[str, Any]]:
|
||||||
"""执行查询语句并返回结果"""
|
"""执行查询语句并返回结果"""
|
||||||
conn = sqlite3.connect(os.environ.get('DATABASE_PATH', './data/database.db'))
|
conn = await self._get_connection()
|
||||||
cursor = conn.cursor()
|
try:
|
||||||
cursor.execute(query, params or ())
|
cursor = await conn.execute(query, params or ())
|
||||||
columns = [description[0] for description in cursor.description]
|
columns = [description[0] for description in cursor.description]
|
||||||
results = [dict(zip(columns, row)) for row in cursor.fetchall()]
|
rows = await cursor.fetchall()
|
||||||
cursor.close()
|
results = [dict(zip(columns, row)) for row in rows]
|
||||||
conn.close()
|
await cursor.close()
|
||||||
return results
|
return results
|
||||||
|
finally:
|
||||||
@classmethod
|
await self._return_connection(conn)
|
||||||
def query_by_sql_file(cls, file_path: str, params: Optional[tuple] = None) -> List[Dict[str, Any]]:
|
|
||||||
|
async def query_by_sql_file(
|
||||||
|
self, file_path: Union[str, Path], params: Optional[tuple] = None
|
||||||
|
) -> List[Dict[str, Any]]:
|
||||||
"""从 SQL 文件中读取查询语句并执行"""
|
"""从 SQL 文件中读取查询语句并执行"""
|
||||||
with open(file_path, 'r', encoding='utf-8') as f:
|
path = str(file_path) if isinstance(file_path, Path) else file_path
|
||||||
|
with open(path, "r", encoding="utf-8") as f:
|
||||||
query = f.read()
|
query = f.read()
|
||||||
return cls.query(query, params)
|
return await self.query(query, params)
|
||||||
|
|
||||||
@classmethod
|
async def execute(self, command: str, params: Optional[tuple] = None) -> None:
|
||||||
def execute(cls, command: str, params: Optional[tuple] = None) -> None:
|
|
||||||
"""执行非查询语句"""
|
"""执行非查询语句"""
|
||||||
conn = sqlite3.connect(os.environ.get('DATABASE_PATH', './data/database.db'))
|
conn = await self._get_connection()
|
||||||
cursor = conn.cursor()
|
try:
|
||||||
cursor.execute(command, params or ())
|
await conn.execute(command, params or ())
|
||||||
conn.commit()
|
await conn.commit()
|
||||||
cursor.close()
|
finally:
|
||||||
conn.close()
|
await self._return_connection(conn)
|
||||||
|
|
||||||
@classmethod
|
async def execute_script(self, script: str) -> None:
|
||||||
def execute_by_sql_file(cls, file_path: str, params: Optional[tuple] = None) -> None:
|
"""执行SQL脚本"""
|
||||||
|
conn = await self._get_connection()
|
||||||
|
try:
|
||||||
|
await conn.executescript(script)
|
||||||
|
await conn.commit()
|
||||||
|
finally:
|
||||||
|
await self._return_connection(conn)
|
||||||
|
|
||||||
|
async def execute_by_sql_file(
|
||||||
|
self, file_path: Union[str, Path], params: Optional[tuple] = None
|
||||||
|
) -> None:
|
||||||
"""从 SQL 文件中读取非查询语句并执行"""
|
"""从 SQL 文件中读取非查询语句并执行"""
|
||||||
with open(file_path, 'r', encoding='utf-8') as f:
|
path = str(file_path) if isinstance(file_path, Path) else file_path
|
||||||
command = f.read()
|
with open(path, "r", encoding="utf-8") as f:
|
||||||
# 按照需要执行多条语句
|
script = f.read()
|
||||||
commands = command.split(';')
|
# 如果有参数,使用execute方法而不是execute_script
|
||||||
for cmd in commands:
|
if params:
|
||||||
cmd = cmd.strip()
|
await self.execute(script, params)
|
||||||
if cmd:
|
else:
|
||||||
cls.execute(cmd, params)
|
await self.execute_script(script)
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def execute_many(cls, command: str, seq_of_params: List[tuple]) -> None:
|
|
||||||
"""执行多条非查询语句"""
|
|
||||||
conn = sqlite3.connect(os.environ.get('DATABASE_PATH', './data/database.db'))
|
|
||||||
cursor = conn.cursor()
|
|
||||||
cursor.executemany(command, seq_of_params)
|
|
||||||
conn.commit()
|
|
||||||
cursor.close()
|
|
||||||
conn.close()
|
|
||||||
|
|
||||||
@classmethod
|
async def execute_many(self, command: str, seq_of_params: List[tuple]) -> None:
|
||||||
def execute_many_values_by_sql_file(cls, file_path: str, seq_of_params: List[tuple]) -> None:
|
"""执行多条非查询语句"""
|
||||||
|
conn = await self._get_connection()
|
||||||
|
try:
|
||||||
|
await conn.executemany(command, seq_of_params)
|
||||||
|
await conn.commit()
|
||||||
|
finally:
|
||||||
|
await self._return_connection(conn)
|
||||||
|
|
||||||
|
async def execute_many_values_by_sql_file(
|
||||||
|
self, file_path: Union[str, Path], seq_of_params: List[tuple]
|
||||||
|
) -> None:
|
||||||
"""从 SQL 文件中读取一条语句,但是被不同值同时执行"""
|
"""从 SQL 文件中读取一条语句,但是被不同值同时执行"""
|
||||||
with open(file_path, 'r', encoding='utf-8') as f:
|
path = str(file_path) if isinstance(file_path, Path) else file_path
|
||||||
|
with open(path, "r", encoding="utf-8") as f:
|
||||||
command = f.read()
|
command = f.read()
|
||||||
cls.execute_many(command, seq_of_params)
|
await self.execute_many(command, seq_of_params)
|
||||||
|
|
||||||
|
async def close_all_connections(self) -> None:
|
||||||
|
"""关闭所有连接"""
|
||||||
|
async with self._lock:
|
||||||
|
for conn in self._connection_pool:
|
||||||
|
await conn.close()
|
||||||
|
self._connection_pool.clear()
|
||||||
|
|
||||||
|
|||||||
@ -1,15 +0,0 @@
|
|||||||
from pathlib import Path
|
|
||||||
|
|
||||||
from nonebot import logger
|
|
||||||
|
|
||||||
def preinit(path: str):
|
|
||||||
# 执行预初始化,递归找到位于对应路径内文件名为 __preinit__.py 的所有文件都会被执行
|
|
||||||
dir_path = Path(path)
|
|
||||||
for item in dir_path.iterdir():
|
|
||||||
if item.is_dir():
|
|
||||||
preinit(item)
|
|
||||||
elif item.is_file() and item.name == "__preinit__.py":
|
|
||||||
# 动态导入该文件以执行预初始化代码
|
|
||||||
module_path = str(item.with_suffix("")).replace("/", ".").replace("\\", ".")
|
|
||||||
__import__(module_path)
|
|
||||||
logger.info(f"Preinitialized module: {module_path}")
|
|
||||||
@ -1,6 +1,7 @@
|
|||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from typing import Optional, Union
|
from typing import Optional, Union
|
||||||
import cv2
|
import cv2
|
||||||
|
import nonebot
|
||||||
from nonebot.adapters import Event as BaseEvent
|
from nonebot.adapters import Event as BaseEvent
|
||||||
from nonebot.adapters.console.event import MessageEvent as ConsoleMessageEvent
|
from nonebot.adapters.console.event import MessageEvent as ConsoleMessageEvent
|
||||||
from nonebot.adapters.discord.event import MessageEvent as DiscordMessageEvent
|
from nonebot.adapters.discord.event import MessageEvent as DiscordMessageEvent
|
||||||
@ -18,8 +19,11 @@ import math
|
|||||||
|
|
||||||
ROOT_PATH = Path(__file__).resolve().parent
|
ROOT_PATH = Path(__file__).resolve().parent
|
||||||
|
|
||||||
def get_ac(id: str) -> AirConditioner:
|
# 创建全局数据库管理器实例
|
||||||
ac = AirConditioner.get_ac(id)
|
db_manager = DatabaseManager()
|
||||||
|
|
||||||
|
async def get_ac(id: str) -> AirConditioner:
|
||||||
|
ac = await AirConditioner.get_ac(id)
|
||||||
if ac is None:
|
if ac is None:
|
||||||
ac = AirConditioner(id)
|
ac = AirConditioner(id)
|
||||||
return ac
|
return ac
|
||||||
@ -46,14 +50,26 @@ async def send_ac_image(event: type[AlconnaMatcher], ac: AirConditioner):
|
|||||||
ac_image = await generate_ac_image(ac)
|
ac_image = await generate_ac_image(ac)
|
||||||
await event.send(await UniMessage().image(raw=ac_image).export())
|
await event.send(await UniMessage().image(raw=ac_image).export())
|
||||||
|
|
||||||
|
|
||||||
|
driver = nonebot.get_driver()
|
||||||
|
|
||||||
|
|
||||||
|
@driver.on_startup
|
||||||
|
async def register_startup_hook():
|
||||||
|
"""注册启动时需要执行的函数"""
|
||||||
|
# 初始化数据库表
|
||||||
|
await db_manager.execute_by_sql_file(
|
||||||
|
Path(__file__).resolve().parent / "sql" / "create_table.sql"
|
||||||
|
)
|
||||||
|
|
||||||
evt = on_alconna(Alconna(
|
evt = on_alconna(Alconna(
|
||||||
"群空调"
|
"群空调"
|
||||||
), use_cmd_start=True, use_cmd_sep=False, skip_for_unmatch=True)
|
), use_cmd_start=True, use_cmd_sep=False, skip_for_unmatch=True)
|
||||||
|
|
||||||
@evt.handle()
|
@evt.handle()
|
||||||
async def _(event: BaseEvent, target: DepLongTaskTarget):
|
async def _(target: DepLongTaskTarget):
|
||||||
id = target.channel_id
|
id = target.channel_id
|
||||||
ac = get_ac(id)
|
ac = await get_ac(id)
|
||||||
await send_ac_image(evt, ac)
|
await send_ac_image(evt, ac)
|
||||||
|
|
||||||
evt = on_alconna(Alconna(
|
evt = on_alconna(Alconna(
|
||||||
@ -61,10 +77,10 @@ evt = on_alconna(Alconna(
|
|||||||
), use_cmd_start=True, use_cmd_sep=False, skip_for_unmatch=True)
|
), use_cmd_start=True, use_cmd_sep=False, skip_for_unmatch=True)
|
||||||
|
|
||||||
@evt.handle()
|
@evt.handle()
|
||||||
async def _(event: BaseEvent, target: DepLongTaskTarget):
|
async def _(target: DepLongTaskTarget):
|
||||||
id = target.channel_id
|
id = target.channel_id
|
||||||
ac = get_ac(id)
|
ac = await get_ac(id)
|
||||||
ac.update_ac(state=True)
|
await ac.update_ac(state=True)
|
||||||
await send_ac_image(evt, ac)
|
await send_ac_image(evt, ac)
|
||||||
|
|
||||||
evt = on_alconna(Alconna(
|
evt = on_alconna(Alconna(
|
||||||
@ -72,10 +88,10 @@ evt = on_alconna(Alconna(
|
|||||||
), use_cmd_start=True, use_cmd_sep=False, skip_for_unmatch=True)
|
), use_cmd_start=True, use_cmd_sep=False, skip_for_unmatch=True)
|
||||||
|
|
||||||
@evt.handle()
|
@evt.handle()
|
||||||
async def _(event: BaseEvent, target: DepLongTaskTarget):
|
async def _(target: DepLongTaskTarget):
|
||||||
id = target.channel_id
|
id = target.channel_id
|
||||||
ac = get_ac(id)
|
ac = await get_ac(id)
|
||||||
ac.update_ac(state=False)
|
await ac.update_ac(state=False)
|
||||||
await send_ac_image(evt, ac)
|
await send_ac_image(evt, ac)
|
||||||
|
|
||||||
evt = on_alconna(Alconna(
|
evt = on_alconna(Alconna(
|
||||||
@ -84,17 +100,17 @@ evt = on_alconna(Alconna(
|
|||||||
), use_cmd_start=True, use_cmd_sep=False, skip_for_unmatch=True)
|
), use_cmd_start=True, use_cmd_sep=False, skip_for_unmatch=True)
|
||||||
|
|
||||||
@evt.handle()
|
@evt.handle()
|
||||||
async def _(event: BaseEvent, target: DepLongTaskTarget, temp: Optional[Union[int, float]] = 1):
|
async def _(target: DepLongTaskTarget, temp: Optional[Union[int, float]] = 1):
|
||||||
if temp is None:
|
if temp is None:
|
||||||
temp = 1
|
temp = 1
|
||||||
if temp <= 0:
|
if temp <= 0:
|
||||||
return
|
return
|
||||||
id = target.channel_id
|
id = target.channel_id
|
||||||
ac = get_ac(id)
|
ac = await get_ac(id)
|
||||||
if not ac.on or ac.burnt == True or ac.frozen == True:
|
if not ac.on or ac.burnt == True or ac.frozen == True:
|
||||||
await send_ac_image(evt, ac)
|
await send_ac_image(evt, ac)
|
||||||
return
|
return
|
||||||
ac.update_ac(temperature_delta=temp)
|
await ac.update_ac(temperature_delta=temp)
|
||||||
if ac.temperature > 40:
|
if ac.temperature > 40:
|
||||||
# 根据温度随机出是否爆炸,40度开始,呈指数增长
|
# 根据温度随机出是否爆炸,40度开始,呈指数增长
|
||||||
possibility = -math.e ** ((40-ac.temperature) / 50) + 1
|
possibility = -math.e ** ((40-ac.temperature) / 50) + 1
|
||||||
@ -108,7 +124,7 @@ async def _(event: BaseEvent, target: DepLongTaskTarget, temp: Optional[Union[in
|
|||||||
pil_frames[0].save(output, format="GIF", save_all=True, append_images=pil_frames[1:], loop=0, duration=35, disposal=2)
|
pil_frames[0].save(output, format="GIF", save_all=True, append_images=pil_frames[1:], loop=0, duration=35, disposal=2)
|
||||||
output.seek(0)
|
output.seek(0)
|
||||||
await evt.send(await UniMessage().image(raw=output).export())
|
await evt.send(await UniMessage().image(raw=output).export())
|
||||||
ac.broke_ac(CrashType.BURNT)
|
await ac.broke_ac(CrashType.BURNT)
|
||||||
await evt.send("太热啦,空调炸了!")
|
await evt.send("太热啦,空调炸了!")
|
||||||
return
|
return
|
||||||
await send_ac_image(evt, ac)
|
await send_ac_image(evt, ac)
|
||||||
@ -125,16 +141,16 @@ async def _(event: BaseEvent, target: DepLongTaskTarget, temp: Optional[Union[in
|
|||||||
if temp <= 0:
|
if temp <= 0:
|
||||||
return
|
return
|
||||||
id = target.channel_id
|
id = target.channel_id
|
||||||
ac = get_ac(id)
|
ac = await get_ac(id)
|
||||||
if not ac.on or ac.burnt == True or ac.frozen == True:
|
if not ac.on or ac.burnt == True or ac.frozen == True:
|
||||||
await send_ac_image(evt, ac)
|
await send_ac_image(evt, ac)
|
||||||
return
|
return
|
||||||
ac.update_ac(temperature_delta=-temp)
|
await ac.update_ac(temperature_delta=-temp)
|
||||||
if ac.temperature < 0:
|
if ac.temperature < 0:
|
||||||
# 根据温度随机出是否冻结,0度开始,呈指数增长
|
# 根据温度随机出是否冻结,0度开始,呈指数增长
|
||||||
possibility = -math.e ** (ac.temperature / 50) + 1
|
possibility = -math.e ** (ac.temperature / 50) + 1
|
||||||
if random.random() < possibility:
|
if random.random() < possibility:
|
||||||
ac.broke_ac(CrashType.FROZEN)
|
await ac.broke_ac(CrashType.FROZEN)
|
||||||
await send_ac_image(evt, ac)
|
await send_ac_image(evt, ac)
|
||||||
|
|
||||||
evt = on_alconna(Alconna(
|
evt = on_alconna(Alconna(
|
||||||
@ -144,19 +160,21 @@ evt = on_alconna(Alconna(
|
|||||||
@evt.handle()
|
@evt.handle()
|
||||||
async def _(event: BaseEvent, target: DepLongTaskTarget):
|
async def _(event: BaseEvent, target: DepLongTaskTarget):
|
||||||
id = target.channel_id
|
id = target.channel_id
|
||||||
ac = get_ac(id)
|
ac = await get_ac(id)
|
||||||
ac.change_ac()
|
await ac.change_ac()
|
||||||
await send_ac_image(evt, ac)
|
await send_ac_image(evt, ac)
|
||||||
|
|
||||||
def query_number_ranking(id: str) -> tuple[int, int]:
|
async def query_number_ranking(id: str) -> tuple[int, int]:
|
||||||
result = DatabaseManager.query_by_sql_file(
|
result = await db_manager.query_by_sql_file(
|
||||||
ROOT_PATH / "sql" / "query_crash_and_rank.sql",
|
ROOT_PATH / "sql" / "query_crash_and_rank.sql",
|
||||||
(id,id)
|
(id,id)
|
||||||
)
|
)
|
||||||
if len(result) == 0:
|
if len(result) == 0:
|
||||||
return 0, 0
|
return 0, 0
|
||||||
else:
|
else:
|
||||||
return result[0].values()
|
# 将字典转换为值的元组
|
||||||
|
values = list(result[0].values())
|
||||||
|
return values[0], values[1]
|
||||||
|
|
||||||
evt = on_alconna(Alconna(
|
evt = on_alconna(Alconna(
|
||||||
"空调炸炸排行榜",
|
"空调炸炸排行榜",
|
||||||
@ -167,7 +185,7 @@ async def _(event: BaseEvent, target: DepLongTaskTarget):
|
|||||||
id = target.channel_id
|
id = target.channel_id
|
||||||
# ac = get_ac(id)
|
# ac = get_ac(id)
|
||||||
# number, ranking = ac.get_crashes_and_ranking()
|
# number, ranking = ac.get_crashes_and_ranking()
|
||||||
number, ranking = query_number_ranking(id)
|
number, ranking = await query_number_ranking(id)
|
||||||
params = {
|
params = {
|
||||||
"number": number,
|
"number": number,
|
||||||
"ranking": ranking
|
"ranking": ranking
|
||||||
@ -177,4 +195,4 @@ async def _(event: BaseEvent, target: DepLongTaskTarget):
|
|||||||
target=".box",
|
target=".box",
|
||||||
params=params
|
params=params
|
||||||
)
|
)
|
||||||
await evt.send(await UniMessage().image(raw=image).export())
|
await evt.send(await UniMessage().image(raw=image).export())
|
||||||
|
|||||||
@ -1,9 +0,0 @@
|
|||||||
# 预初始化,只要是导入本插件包就会执行这里的代码
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
from konabot.common.database import DatabaseManager
|
|
||||||
|
|
||||||
# 初始化数据库表
|
|
||||||
DatabaseManager.execute_by_sql_file(
|
|
||||||
Path(__file__).resolve().parent / "sql" / "create_table.sql"
|
|
||||||
)
|
|
||||||
@ -13,16 +13,19 @@ import json
|
|||||||
|
|
||||||
ROOT_PATH = Path(__file__).resolve().parent
|
ROOT_PATH = Path(__file__).resolve().parent
|
||||||
|
|
||||||
|
# 创建全局数据库管理器实例
|
||||||
|
db_manager = DatabaseManager()
|
||||||
|
|
||||||
class CrashType(Enum):
|
class CrashType(Enum):
|
||||||
BURNT = 0
|
BURNT = 0
|
||||||
FROZEN = 1
|
FROZEN = 1
|
||||||
|
|
||||||
class AirConditioner:
|
class AirConditioner:
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_ac(cls, id: str) -> 'AirConditioner':
|
async def get_ac(cls, id: str) -> 'AirConditioner':
|
||||||
result = DatabaseManager.query_by_sql_file(ROOT_PATH / "sql" / "query_ac.sql", (id,))
|
result = await db_manager.query_by_sql_file(ROOT_PATH / "sql" / "query_ac.sql", (id,))
|
||||||
if len(result) == 0:
|
if len(result) == 0:
|
||||||
ac = cls.create_ac(id)
|
ac = await cls.create_ac(id)
|
||||||
return ac
|
return ac
|
||||||
ac_data = result[0]
|
ac_data = result[0]
|
||||||
ac = AirConditioner(id)
|
ac = AirConditioner(id)
|
||||||
@ -33,15 +36,15 @@ class AirConditioner:
|
|||||||
return ac
|
return ac
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def create_ac(cls, id: str) -> 'AirConditioner':
|
async def create_ac(cls, id: str) -> 'AirConditioner':
|
||||||
ac = AirConditioner(id)
|
ac = AirConditioner(id)
|
||||||
DatabaseManager.execute_by_sql_file(
|
await db_manager.execute_by_sql_file(
|
||||||
ROOT_PATH / "sql" / "insert_ac.sql",
|
ROOT_PATH / "sql" / "insert_ac.sql",
|
||||||
(id, ac.on, ac.temperature, ac.burnt, ac.frozen)
|
(id, ac.on, ac.temperature, ac.burnt, ac.frozen)
|
||||||
)
|
)
|
||||||
return ac
|
return ac
|
||||||
|
|
||||||
def update_ac(self, state: bool = None, temperature_delta: float = None, burnt: bool = None, frozen: bool = None) -> 'AirConditioner':
|
async def update_ac(self, state: bool = None, temperature_delta: float = None, burnt: bool = None, frozen: bool = None) -> 'AirConditioner':
|
||||||
if state is not None:
|
if state is not None:
|
||||||
self.on = state
|
self.on = state
|
||||||
if temperature_delta is not None:
|
if temperature_delta is not None:
|
||||||
@ -50,18 +53,18 @@ class AirConditioner:
|
|||||||
self.burnt = burnt
|
self.burnt = burnt
|
||||||
if frozen is not None:
|
if frozen is not None:
|
||||||
self.frozen = frozen
|
self.frozen = frozen
|
||||||
DatabaseManager.execute_by_sql_file(
|
await db_manager.execute_by_sql_file(
|
||||||
ROOT_PATH / "sql" / "update_ac.sql",
|
ROOT_PATH / "sql" / "update_ac.sql",
|
||||||
(self.on, self.temperature, self.burnt, self.frozen, self.id)
|
(self.on, self.temperature, self.burnt, self.frozen, self.id)
|
||||||
)
|
)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def change_ac(self) -> 'AirConditioner':
|
async def change_ac(self) -> 'AirConditioner':
|
||||||
self.on = False
|
self.on = False
|
||||||
self.temperature = 24
|
self.temperature = 24
|
||||||
self.burnt = False
|
self.burnt = False
|
||||||
self.frozen = False
|
self.frozen = False
|
||||||
DatabaseManager.execute_by_sql_file(
|
await db_manager.execute_by_sql_file(
|
||||||
ROOT_PATH / "sql" / "update_ac.sql",
|
ROOT_PATH / "sql" / "update_ac.sql",
|
||||||
(self.on, self.temperature, self.burnt, self.frozen, self.id)
|
(self.on, self.temperature, self.burnt, self.frozen, self.id)
|
||||||
)
|
)
|
||||||
@ -74,17 +77,17 @@ class AirConditioner:
|
|||||||
self.burnt = False
|
self.burnt = False
|
||||||
self.frozen = False
|
self.frozen = False
|
||||||
|
|
||||||
def broke_ac(self, crash_type: CrashType):
|
async def broke_ac(self, crash_type: CrashType):
|
||||||
'''
|
'''
|
||||||
让空调坏掉
|
让空调坏掉
|
||||||
:param crash_type: CrashType 枚举,表示空调坏掉的类型
|
:param crash_type: CrashType 枚举,表示空调坏掉的类型
|
||||||
'''
|
'''
|
||||||
match crash_type:
|
match crash_type:
|
||||||
case CrashType.BURNT:
|
case CrashType.BURNT:
|
||||||
self.update_ac(burnt=True)
|
await self.update_ac(burnt=True)
|
||||||
case CrashType.FROZEN:
|
case CrashType.FROZEN:
|
||||||
self.update_ac(frozen=True)
|
await self.update_ac(frozen=True)
|
||||||
DatabaseManager.execute_by_sql_file(
|
await db_manager.execute_by_sql_file(
|
||||||
ROOT_PATH / "sql" / "insert_crash.sql",
|
ROOT_PATH / "sql" / "insert_crash.sql",
|
||||||
(self.id, crash_type.value)
|
(self.id, crash_type.value)
|
||||||
)
|
)
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
-- 创建所有表
|
-- 创建所有表
|
||||||
CREATE TABLE IF NOT EXISTS air_conditioner (
|
CREATE TABLE IF NOT EXISTS air_conditioner (
|
||||||
id VARCHAR(128) PRIMARY KEY,
|
id VARCHAR(128) PRIMARY KEY,
|
||||||
'on' BOOLEAN NOT NULL,
|
"on" BOOLEAN NOT NULL,
|
||||||
temperature REAL NOT NULL,
|
temperature REAL NOT NULL,
|
||||||
burnt BOOLEAN NOT NULL,
|
burnt BOOLEAN NOT NULL,
|
||||||
frozen BOOLEAN NOT NULL
|
frozen BOOLEAN NOT NULL
|
||||||
|
|||||||
@ -1,3 +1,3 @@
|
|||||||
-- 插入一台新空调
|
-- 插入一台新空调
|
||||||
INSERT INTO air_conditioner (id, 'on', temperature, burnt, frozen)
|
INSERT INTO air_conditioner (id, "on", temperature, burnt, frozen)
|
||||||
VALUES (?, ?, ?, ?, ?);
|
VALUES (?, ?, ?, ?, ?);
|
||||||
@ -1,4 +1,4 @@
|
|||||||
-- 更新空调状态
|
-- 更新空调状态
|
||||||
UPDATE air_conditioner
|
UPDATE air_conditioner
|
||||||
SET 'on' = ?, temperature = ?, burnt = ?, frozen = ?
|
SET "on" = ?, temperature = ?, burnt = ?, frozen = ?
|
||||||
WHERE id = ?;
|
WHERE id = ?;
|
||||||
@ -8,6 +8,7 @@ from typing import Optional
|
|||||||
|
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from nonebot import on_message
|
from nonebot import on_message
|
||||||
|
import nonebot
|
||||||
from nonebot.adapters import Event as BaseEvent
|
from nonebot.adapters import Event as BaseEvent
|
||||||
from nonebot.adapters.discord.event import MessageEvent as DiscordMessageEvent
|
from nonebot.adapters.discord.event import MessageEvent as DiscordMessageEvent
|
||||||
from nonebot_plugin_alconna import (
|
from nonebot_plugin_alconna import (
|
||||||
@ -32,6 +33,9 @@ DATA_FILE_PATH = (
|
|||||||
DATA_DIR / "idiom_banned.json"
|
DATA_DIR / "idiom_banned.json"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# 创建全局数据库管理器实例
|
||||||
|
db_manager = DatabaseManager()
|
||||||
|
|
||||||
def load_banned_ids() -> list[str]:
|
def load_banned_ids() -> list[str]:
|
||||||
if not DATA_FILE_PATH.exists():
|
if not DATA_FILE_PATH.exists():
|
||||||
return []
|
return []
|
||||||
@ -61,6 +65,15 @@ def remove_banned_id(group_id: str):
|
|||||||
DATA_FILE_PATH.write_text(json.dumps(banned_ids, ensure_ascii=False, indent=4), "utf-8")
|
DATA_FILE_PATH.write_text(json.dumps(banned_ids, ensure_ascii=False, indent=4), "utf-8")
|
||||||
|
|
||||||
|
|
||||||
|
driver = nonebot.get_driver()
|
||||||
|
|
||||||
|
|
||||||
|
@driver.on_startup
|
||||||
|
async def register_startup_hook():
|
||||||
|
"""注册启动时需要执行的函数"""
|
||||||
|
await IdiomGame.init_lexicon()
|
||||||
|
|
||||||
|
|
||||||
class TryStartState(Enum):
|
class TryStartState(Enum):
|
||||||
STARTED = 0
|
STARTED = 0
|
||||||
ALREADY_PLAYING = 1
|
ALREADY_PLAYING = 1
|
||||||
@ -98,7 +111,7 @@ class IdiomGameLLM:
|
|||||||
@classmethod
|
@classmethod
|
||||||
async def storage_idiom(cls, idiom: str):
|
async def storage_idiom(cls, idiom: str):
|
||||||
# 将 idiom 存入数据库
|
# 将 idiom 存入数据库
|
||||||
DatabaseManager.execute_by_sql_file(
|
await db_manager.execute_by_sql_file(
|
||||||
ROOT_PATH / "sql" / "insert_custom_word.sql",
|
ROOT_PATH / "sql" / "insert_custom_word.sql",
|
||||||
(idiom,)
|
(idiom,)
|
||||||
)
|
)
|
||||||
@ -130,11 +143,11 @@ class IdiomGame:
|
|||||||
IdiomGame.INSTANCE_LIST[group_id] = self
|
IdiomGame.INSTANCE_LIST[group_id] = self
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def append_into_word_list(cls, word: str):
|
async def append_into_word_list(cls, word: str):
|
||||||
'''
|
'''
|
||||||
将一个新词加入到词语列表中
|
将一个新词加入到词语列表中
|
||||||
'''
|
'''
|
||||||
DatabaseManager.execute_by_sql_file(
|
await db_manager.execute_by_sql_file(
|
||||||
ROOT_PATH / "sql" / "insert_custom_word.sql",
|
ROOT_PATH / "sql" / "insert_custom_word.sql",
|
||||||
(word,)
|
(word,)
|
||||||
)
|
)
|
||||||
@ -149,26 +162,27 @@ class IdiomGame:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def random_idiom() -> str:
|
async def random_idiom() -> str:
|
||||||
return DatabaseManager.query_by_sql_file(
|
result = await db_manager.query_by_sql_file(
|
||||||
ROOT_PATH / "sql" / "random_choose_idiom.sql"
|
ROOT_PATH / "sql" / "random_choose_idiom.sql"
|
||||||
)[0]["idiom"]
|
)
|
||||||
|
return result[0]["idiom"]
|
||||||
|
|
||||||
def choose_start_idiom(self) -> str:
|
async def choose_start_idiom(self) -> str:
|
||||||
"""
|
"""
|
||||||
随机选择一个成语作为起始成语
|
随机选择一个成语作为起始成语
|
||||||
"""
|
"""
|
||||||
self.last_idiom = IdiomGame.random_idiom()
|
self.last_idiom = await IdiomGame.random_idiom()
|
||||||
self.last_char = self.last_idiom[-1]
|
self.last_char = self.last_idiom[-1]
|
||||||
if not self.is_nextable(self.last_char):
|
if not await self.is_nextable(self.last_char):
|
||||||
self.choose_start_idiom()
|
await self.choose_start_idiom()
|
||||||
else:
|
else:
|
||||||
self.add_history_idiom(self.last_idiom, new_chain=True)
|
self.add_history_idiom(self.last_idiom, new_chain=True)
|
||||||
return self.last_idiom
|
return self.last_idiom
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def try_start_game(cls, group_id: str, force: bool = False) -> TryStartState:
|
async def try_start_game(cls, group_id: str, force: bool = False) -> TryStartState:
|
||||||
cls.init_lexicon()
|
await cls.init_lexicon()
|
||||||
if not cls.INSTANCE_LIST.get(group_id):
|
if not cls.INSTANCE_LIST.get(group_id):
|
||||||
cls(group_id)
|
cls(group_id)
|
||||||
instance = cls.INSTANCE_LIST[group_id]
|
instance = cls.INSTANCE_LIST[group_id]
|
||||||
@ -179,10 +193,10 @@ class IdiomGame:
|
|||||||
instance.now_playing = True
|
instance.now_playing = True
|
||||||
return TryStartState.STARTED
|
return TryStartState.STARTED
|
||||||
|
|
||||||
def start_game(self, rounds: int = 100):
|
async def start_game(self, rounds: int = 100):
|
||||||
self.now_playing = True
|
self.now_playing = True
|
||||||
self.remain_rounds = rounds
|
self.remain_rounds = rounds
|
||||||
self.choose_start_idiom()
|
await self.choose_start_idiom()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def try_stop_game(cls, group_id: str) -> TryStopState:
|
def try_stop_game(cls, group_id: str) -> TryStopState:
|
||||||
@ -212,20 +226,20 @@ class IdiomGame:
|
|||||||
跳过当前成语,选择下一个成语
|
跳过当前成语,选择下一个成语
|
||||||
"""
|
"""
|
||||||
async with self.lock:
|
async with self.lock:
|
||||||
self._skip_idiom_async()
|
await self._skip_idiom_async()
|
||||||
self.add_buff_score(buff_score)
|
self.add_buff_score(buff_score)
|
||||||
return self.last_idiom
|
return self.last_idiom
|
||||||
|
|
||||||
def _skip_idiom_async(self) -> str:
|
async def _skip_idiom_async(self) -> str:
|
||||||
self.last_idiom = IdiomGame.random_idiom()
|
self.last_idiom = await IdiomGame.random_idiom()
|
||||||
self.last_char = self.last_idiom[-1]
|
self.last_char = self.last_idiom[-1]
|
||||||
if not self.is_nextable(self.last_char):
|
if not await self.is_nextable(self.last_char):
|
||||||
self._skip_idiom_async()
|
await self._skip_idiom_async()
|
||||||
else:
|
else:
|
||||||
self.add_history_idiom(self.last_idiom, new_chain=True)
|
self.add_history_idiom(self.last_idiom, new_chain=True)
|
||||||
return self.last_idiom
|
return self.last_idiom
|
||||||
|
|
||||||
async def try_verify_idiom(self, idiom: str, user_id: str) -> TryVerifyState:
|
async def try_verify_idiom(self, idiom: str, user_id: str) -> list[TryVerifyState]:
|
||||||
"""
|
"""
|
||||||
用户发送成语
|
用户发送成语
|
||||||
"""
|
"""
|
||||||
@ -233,14 +247,15 @@ class IdiomGame:
|
|||||||
state = await self._verify_idiom(idiom, user_id)
|
state = await self._verify_idiom(idiom, user_id)
|
||||||
return state
|
return state
|
||||||
|
|
||||||
def is_nextable(self, last_char: str) -> bool:
|
async def is_nextable(self, last_char: str) -> bool:
|
||||||
"""
|
"""
|
||||||
判断是否有成语可以接
|
判断是否有成语可以接
|
||||||
"""
|
"""
|
||||||
return DatabaseManager.query_by_sql_file(
|
result = await db_manager.query_by_sql_file(
|
||||||
ROOT_PATH / "sql" / "is_nextable.sql",
|
ROOT_PATH / "sql" / "is_nextable.sql",
|
||||||
(last_char,)
|
(last_char,)
|
||||||
)[0]["DEED"] == 1
|
)
|
||||||
|
return result[0]["DEED"] == 1
|
||||||
|
|
||||||
def add_already_idiom(self, idiom: str):
|
def add_already_idiom(self, idiom: str):
|
||||||
if idiom in self.already_idioms:
|
if idiom in self.already_idioms:
|
||||||
@ -272,11 +287,12 @@ class IdiomGame:
|
|||||||
state.append(TryVerifyState.WRONG_FIRST_CHAR)
|
state.append(TryVerifyState.WRONG_FIRST_CHAR)
|
||||||
return state
|
return state
|
||||||
# 成语是否存在
|
# 成语是否存在
|
||||||
result = DatabaseManager.query_by_sql_file(
|
result = await db_manager.query_by_sql_file(
|
||||||
ROOT_PATH / "sql" / "query_idiom.sql",
|
ROOT_PATH / "sql" / "query_idiom.sql",
|
||||||
(idiom, idiom, idiom)
|
(idiom, idiom, idiom)
|
||||||
)[0]["status"]
|
)
|
||||||
if result == -1:
|
status_result = result[0]["status"]
|
||||||
|
if status_result == -1:
|
||||||
logger.info(f"用户 {user_id} 发送了未知词语 {idiom},正在使用 LLM 进行验证")
|
logger.info(f"用户 {user_id} 发送了未知词语 {idiom},正在使用 LLM 进行验证")
|
||||||
try:
|
try:
|
||||||
if not await IdiomGameLLM.verify_idiom_with_llm(idiom):
|
if not await IdiomGameLLM.verify_idiom_with_llm(idiom):
|
||||||
@ -298,16 +314,16 @@ class IdiomGame:
|
|||||||
self.last_idiom = idiom
|
self.last_idiom = idiom
|
||||||
self.last_char = idiom[-1]
|
self.last_char = idiom[-1]
|
||||||
self.add_score(user_id, 1 * score_k) # 先加 1 分
|
self.add_score(user_id, 1 * score_k) # 先加 1 分
|
||||||
if result == 1:
|
if status_result == 1:
|
||||||
state.append(TryVerifyState.VERIFIED_AND_REAL)
|
state.append(TryVerifyState.VERIFIED_AND_REAL)
|
||||||
self.add_score(user_id, 4 * score_k) # 再加 4 分
|
self.add_score(user_id, 4 * score_k) # 再加 4 分
|
||||||
self.remain_rounds -= 1
|
self.remain_rounds -= 1
|
||||||
if self.remain_rounds <= 0:
|
if self.remain_rounds <= 0:
|
||||||
self.now_playing = False
|
self.now_playing = False
|
||||||
state.append(TryVerifyState.GAME_END)
|
state.append(TryVerifyState.GAME_END)
|
||||||
if not self.is_nextable(self.last_char):
|
if not await self.is_nextable(self.last_char):
|
||||||
# 没有成语可以接了,自动跳过
|
# 没有成语可以接了,自动跳过
|
||||||
self._skip_idiom_async()
|
await self._skip_idiom_async()
|
||||||
self.add_buff_score(-100)
|
self.add_buff_score(-100)
|
||||||
state.append(TryVerifyState.BUT_NO_NEXT)
|
state.append(TryVerifyState.BUT_NO_NEXT)
|
||||||
return state
|
return state
|
||||||
@ -334,9 +350,9 @@ class IdiomGame:
|
|||||||
return self.last_char
|
return self.last_char
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def random_idiom_starting_with(cls, first_char: str) -> Optional[str]:
|
async def random_idiom_starting_with(cls, first_char: str) -> Optional[str]:
|
||||||
cls.init_lexicon()
|
await cls.init_lexicon()
|
||||||
result = DatabaseManager.query_by_sql_file(
|
result = await db_manager.query_by_sql_file(
|
||||||
ROOT_PATH / "sql" / "query_idiom_start_with.sql",
|
ROOT_PATH / "sql" / "query_idiom_start_with.sql",
|
||||||
(first_char,)
|
(first_char,)
|
||||||
)
|
)
|
||||||
@ -345,10 +361,10 @@ class IdiomGame:
|
|||||||
return result[0]["idiom"]
|
return result[0]["idiom"]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def init_lexicon(cls):
|
async def init_lexicon(cls):
|
||||||
if cls.__inited:
|
if cls.__inited:
|
||||||
return
|
return
|
||||||
DatabaseManager.execute_by_sql_file(
|
await db_manager.execute_by_sql_file(
|
||||||
ROOT_PATH / "sql" / "create_table.sql"
|
ROOT_PATH / "sql" / "create_table.sql"
|
||||||
) # 确保数据库初始化
|
) # 确保数据库初始化
|
||||||
cls.__inited = True
|
cls.__inited = True
|
||||||
@ -417,7 +433,7 @@ class IdiomGame:
|
|||||||
ALL_IDIOMS = [idiom["word"] for idiom in ALL_IDIOMS_INFOS] + THUOCL_IDIOMS
|
ALL_IDIOMS = [idiom["word"] for idiom in ALL_IDIOMS_INFOS] + THUOCL_IDIOMS
|
||||||
ALL_IDIOMS = list(set(ALL_IDIOMS)) # 去重
|
ALL_IDIOMS = list(set(ALL_IDIOMS)) # 去重
|
||||||
# 批量插入数据库
|
# 批量插入数据库
|
||||||
DatabaseManager.execute_many_values_by_sql_file(
|
await db_manager.execute_many_values_by_sql_file(
|
||||||
ROOT_PATH / "sql" / "insert_idiom.sql",
|
ROOT_PATH / "sql" / "insert_idiom.sql",
|
||||||
[(idiom,) for idiom in ALL_IDIOMS]
|
[(idiom,) for idiom in ALL_IDIOMS]
|
||||||
)
|
)
|
||||||
@ -430,13 +446,13 @@ class IdiomGame:
|
|||||||
+ COMMON_WORDS
|
+ COMMON_WORDS
|
||||||
)
|
)
|
||||||
# 插入数据库
|
# 插入数据库
|
||||||
DatabaseManager.execute_many_values_by_sql_file(
|
await db_manager.execute_many_values_by_sql_file(
|
||||||
ROOT_PATH / "sql" / "insert_word.sql",
|
ROOT_PATH / "sql" / "insert_word.sql",
|
||||||
[(word,) for word in ALL_WORDS]
|
[(word,) for word in ALL_WORDS]
|
||||||
)
|
)
|
||||||
|
|
||||||
# 自定义词语 LOCAL_LLM_WORDS 插入数据库,兼容用
|
# 自定义词语 LOCAL_LLM_WORDS 插入数据库,兼容用
|
||||||
DatabaseManager.execute_many_values_by_sql_file(
|
await db_manager.execute_many_values_by_sql_file(
|
||||||
ROOT_PATH / "sql" / "insert_custom_word.sql",
|
ROOT_PATH / "sql" / "insert_custom_word.sql",
|
||||||
[(word,) for word in LOCAL_LLM_WORDS]
|
[(word,) for word in LOCAL_LLM_WORDS]
|
||||||
)
|
)
|
||||||
@ -483,7 +499,7 @@ async def play_game(
|
|||||||
if rounds <= 0:
|
if rounds <= 0:
|
||||||
await evt.send(await UniMessage().text("干什么!你想玩负数局吗?").export())
|
await evt.send(await UniMessage().text("干什么!你想玩负数局吗?").export())
|
||||||
return
|
return
|
||||||
state = IdiomGame.try_start_game(group_id, force)
|
state = await IdiomGame.try_start_game(group_id, force)
|
||||||
if state == TryStartState.ALREADY_PLAYING:
|
if state == TryStartState.ALREADY_PLAYING:
|
||||||
await evt.send(
|
await evt.send(
|
||||||
await UniMessage()
|
await UniMessage()
|
||||||
@ -502,7 +518,7 @@ async def play_game(
|
|||||||
.export()
|
.export()
|
||||||
)
|
)
|
||||||
instance = IdiomGame.INSTANCE_LIST[group_id]
|
instance = IdiomGame.INSTANCE_LIST[group_id]
|
||||||
instance.start_game(rounds)
|
await instance.start_game(rounds)
|
||||||
# 发布成语
|
# 发布成语
|
||||||
await evt.send(
|
await evt.send(
|
||||||
await UniMessage()
|
await UniMessage()
|
||||||
@ -595,7 +611,7 @@ async def _(target: DepLongTaskTarget):
|
|||||||
instance = IdiomGame.INSTANCE_LIST.get(group_id)
|
instance = IdiomGame.INSTANCE_LIST.get(group_id)
|
||||||
if not instance or not instance.get_playing_state():
|
if not instance or not instance.get_playing_state():
|
||||||
return
|
return
|
||||||
avaliable_idiom = IdiomGame.random_idiom_starting_with(instance.get_last_char())
|
avaliable_idiom = await IdiomGame.random_idiom_starting_with(instance.get_last_char())
|
||||||
# 发送哈哈狗图片
|
# 发送哈哈狗图片
|
||||||
with open(ASSETS_PATH / "img" / "dog" / "haha_dog.jpg", "rb") as f:
|
with open(ASSETS_PATH / "img" / "dog" / "haha_dog.jpg", "rb") as f:
|
||||||
img_data = f.read()
|
img_data = f.read()
|
||||||
|
|||||||
26
poetry.lock
generated
26
poetry.lock
generated
@ -209,6 +209,30 @@ type = "legacy"
|
|||||||
url = "https://pypi.tuna.tsinghua.edu.cn/simple"
|
url = "https://pypi.tuna.tsinghua.edu.cn/simple"
|
||||||
reference = "mirrors"
|
reference = "mirrors"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "aiosqlite"
|
||||||
|
version = "0.21.0"
|
||||||
|
description = "asyncio bridge to the standard sqlite3 module"
|
||||||
|
optional = false
|
||||||
|
python-versions = ">=3.9"
|
||||||
|
groups = ["main"]
|
||||||
|
files = [
|
||||||
|
{file = "aiosqlite-0.21.0-py3-none-any.whl", hash = "sha256:2549cf4057f95f53dcba16f2b64e8e2791d7e1adedb13197dd8ed77bb226d7d0"},
|
||||||
|
{file = "aiosqlite-0.21.0.tar.gz", hash = "sha256:131bb8056daa3bc875608c631c678cda73922a2d4ba8aec373b19f18c17e7aa3"},
|
||||||
|
]
|
||||||
|
|
||||||
|
[package.dependencies]
|
||||||
|
typing_extensions = ">=4.0"
|
||||||
|
|
||||||
|
[package.extras]
|
||||||
|
dev = ["attribution (==1.7.1)", "black (==24.3.0)", "build (>=1.2)", "coverage[toml] (==7.6.10)", "flake8 (==7.0.0)", "flake8-bugbear (==24.12.12)", "flit (==3.10.1)", "mypy (==1.14.1)", "ufmt (==2.5.1)", "usort (==1.0.8.post1)"]
|
||||||
|
docs = ["sphinx (==8.1.3)", "sphinx-mdinclude (==0.6.1)"]
|
||||||
|
|
||||||
|
[package.source]
|
||||||
|
type = "legacy"
|
||||||
|
url = "https://pypi.tuna.tsinghua.edu.cn/simple"
|
||||||
|
reference = "mirrors"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "annotated-doc"
|
name = "annotated-doc"
|
||||||
version = "0.0.3"
|
version = "0.0.3"
|
||||||
@ -4528,4 +4552,4 @@ reference = "mirrors"
|
|||||||
[metadata]
|
[metadata]
|
||||||
lock-version = "2.1"
|
lock-version = "2.1"
|
||||||
python-versions = ">=3.12,<4.0"
|
python-versions = ">=3.12,<4.0"
|
||||||
content-hash = "478bd59d60d3b73397241c6ed552434486bd26d56cc3805ef34d1cfa1be7006e"
|
content-hash = "5597aa165095a11fa08e4b6e1a1f4d3396711b684ed363ae0ced2dd59a09ec5d"
|
||||||
|
|||||||
@ -27,6 +27,7 @@ dependencies = [
|
|||||||
"playwright (>=1.55.0,<2.0.0)",
|
"playwright (>=1.55.0,<2.0.0)",
|
||||||
"openai (>=2.7.1,<3.0.0)",
|
"openai (>=2.7.1,<3.0.0)",
|
||||||
"imageio (>=2.37.2,<3.0.0)",
|
"imageio (>=2.37.2,<3.0.0)",
|
||||||
|
"aiosqlite (>=0.20.0,<1.0.0)",
|
||||||
]
|
]
|
||||||
|
|
||||||
[tool.poetry]
|
[tool.poetry]
|
||||||
|
|||||||
@ -22,3 +22,11 @@ logger.info(f"已经加载的插件数量 {len(plugins)}")
|
|||||||
logger.info(f"期待加载的插件数量 {len_requires}")
|
logger.info(f"期待加载的插件数量 {len_requires}")
|
||||||
|
|
||||||
assert len(plugins) == len_requires
|
assert len(plugins) == len_requires
|
||||||
|
|
||||||
|
# 测试数据库模块是否可以正确导入
|
||||||
|
try:
|
||||||
|
from konabot.common.database import DatabaseManager
|
||||||
|
logger.info("数据库模块导入成功")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"数据库模块导入失败: {e}")
|
||||||
|
raise
|
||||||
|
|||||||
93
tests/test_database.py
Normal file
93
tests/test_database.py
Normal file
@ -0,0 +1,93 @@
|
|||||||
|
import asyncio
|
||||||
|
import os
|
||||||
|
import tempfile
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from konabot.common.database import DatabaseManager
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_database_manager():
|
||||||
|
"""测试数据库管理器的基本功能"""
|
||||||
|
# 创建临时数据库文件
|
||||||
|
with tempfile.NamedTemporaryFile(suffix='.db', delete=False) as tmp_file:
|
||||||
|
db_path = tmp_file.name
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 初始化数据库管理器
|
||||||
|
db_manager = DatabaseManager(db_path)
|
||||||
|
|
||||||
|
# 创建测试表
|
||||||
|
create_table_sql = """
|
||||||
|
CREATE TABLE IF NOT EXISTS test_users (
|
||||||
|
id INTEGER PRIMARY KEY,
|
||||||
|
name TEXT NOT NULL,
|
||||||
|
email TEXT UNIQUE
|
||||||
|
);
|
||||||
|
"""
|
||||||
|
await db_manager.execute(create_table_sql)
|
||||||
|
|
||||||
|
# 插入测试数据
|
||||||
|
insert_sql = "INSERT INTO test_users (name, email) VALUES (?, ?)"
|
||||||
|
await db_manager.execute(insert_sql, ("张三", "zhangsan@example.com"))
|
||||||
|
await db_manager.execute(insert_sql, ("李四", "lisi@example.com"))
|
||||||
|
|
||||||
|
# 查询数据
|
||||||
|
select_sql = "SELECT * FROM test_users WHERE name = ?"
|
||||||
|
results = await db_manager.query(select_sql, ("张三",))
|
||||||
|
assert len(results) == 1
|
||||||
|
assert results[0]["name"] == "张三"
|
||||||
|
assert results[0]["email"] == "zhangsan@example.com"
|
||||||
|
|
||||||
|
# 测试使用Path对象
|
||||||
|
results = await db_manager.query_by_sql_file(Path(__file__), ("李四",))
|
||||||
|
# 注意:这里只是测试参数传递,实际SQL文件内容不是有效的SQL
|
||||||
|
|
||||||
|
# 关闭所有连接
|
||||||
|
await db_manager.close_all_connections()
|
||||||
|
|
||||||
|
finally:
|
||||||
|
# 清理临时文件
|
||||||
|
if os.path.exists(db_path):
|
||||||
|
os.unlink(db_path)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_execute_script():
|
||||||
|
"""测试执行SQL脚本功能"""
|
||||||
|
# 创建临时数据库文件
|
||||||
|
with tempfile.NamedTemporaryFile(suffix='.db', delete=False) as tmp_file:
|
||||||
|
db_path = tmp_file.name
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 初始化数据库管理器
|
||||||
|
db_manager = DatabaseManager(db_path)
|
||||||
|
|
||||||
|
# 创建测试表的脚本
|
||||||
|
script = """
|
||||||
|
CREATE TABLE IF NOT EXISTS test_products (
|
||||||
|
id INTEGER PRIMARY KEY,
|
||||||
|
name TEXT NOT NULL,
|
||||||
|
price REAL
|
||||||
|
);
|
||||||
|
INSERT INTO test_products (name, price) VALUES ('苹果', 5.0);
|
||||||
|
INSERT INTO test_products (name, price) VALUES ('香蕉', 3.0);
|
||||||
|
"""
|
||||||
|
|
||||||
|
await db_manager.execute_script(script)
|
||||||
|
|
||||||
|
# 查询数据
|
||||||
|
results = await db_manager.query("SELECT * FROM test_products ORDER BY name")
|
||||||
|
assert len(results) == 2
|
||||||
|
assert results[0]["name"] == "苹果"
|
||||||
|
assert results[1]["name"] == "香蕉"
|
||||||
|
|
||||||
|
# 关闭所有连接
|
||||||
|
await db_manager.close_all_connections()
|
||||||
|
|
||||||
|
finally:
|
||||||
|
# 清理临时文件
|
||||||
|
if os.path.exists(db_path):
|
||||||
|
os.unlink(db_path)
|
||||||
Reference in New Issue
Block a user