修复坏枪从来没有运行过的单元测试,为项目引入单元测试框架(终于。。)

This commit is contained in:
2026-03-07 13:16:24 +08:00
parent a1c9f9bccb
commit 88861f4264
7 changed files with 1385 additions and 1309 deletions

6
.sqls.yml Normal file
View File

@ -0,0 +1,6 @@
lowercaseKeywords: false
connections:
- driver: sqlite
dataSourceName: "./data/database.db"
- driver: sqlite
dataSourceName: "./data/perm.sqlite3"

View File

@ -1,4 +1,5 @@
watch: watch:
poetry run watchfiles bot.main . --filter scripts.watch_filter.filter poetry run watchfiles bot.main . --filter scripts.watch_filter.filter
test:
poetry run pytest

View File

@ -1,3 +1,4 @@
from contextlib import asynccontextmanager
import os import os
import asyncio import asyncio
import sqlparse import sqlparse
@ -10,16 +11,19 @@ if TYPE_CHECKING:
from . import DatabaseManager from . import DatabaseManager
# 全局数据库管理器实例 # 全局数据库管理器实例
_global_db_manager: Optional['DatabaseManager'] = None _global_db_manager: Optional["DatabaseManager"] = None
def get_global_db_manager() -> 'DatabaseManager':
def get_global_db_manager() -> "DatabaseManager":
"""获取全局数据库管理器实例""" """获取全局数据库管理器实例"""
global _global_db_manager global _global_db_manager
if _global_db_manager is None: if _global_db_manager is None:
from . import DatabaseManager from . import DatabaseManager
_global_db_manager = DatabaseManager() _global_db_manager = DatabaseManager()
return _global_db_manager return _global_db_manager
def close_global_db_manager() -> None: def close_global_db_manager() -> None:
"""关闭全局数据库管理器实例""" """关闭全局数据库管理器实例"""
global _global_db_manager global _global_db_manager
@ -87,6 +91,12 @@ class DatabaseManager:
except: except:
pass pass
@asynccontextmanager
async def get_conn(self):
conn = await self._get_connection()
yield conn
await self._return_connection(conn)
async def query( async def query(
self, query: str, params: Optional[tuple] = None self, query: str, params: Optional[tuple] = None
) -> List[Dict[str, Any]]: ) -> List[Dict[str, Any]]:
@ -143,22 +153,24 @@ class DatabaseManager:
# 使用sqlparse库更准确地分割SQL语句 # 使用sqlparse库更准确地分割SQL语句
parsed = sqlparse.split(script) parsed = sqlparse.split(script)
statements = [] statements = []
for statement in parsed: for statement in parsed:
statement = statement.strip() statement = statement.strip()
if statement: if statement:
statements.append(statement) statements.append(statement)
return statements return statements
async def execute_by_sql_file( async def execute_by_sql_file(
self, file_path: Union[str, Path], params: Optional[Union[tuple, List[tuple]]] = None self,
file_path: Union[str, Path],
params: Optional[Union[tuple, List[tuple]]] = None,
) -> None: ) -> None:
"""从 SQL 文件中读取非查询语句并执行""" """从 SQL 文件中读取非查询语句并执行"""
path = str(file_path) if isinstance(file_path, Path) else file_path path = str(file_path) if isinstance(file_path, Path) else file_path
with open(path, "r", encoding="utf-8") as f: with open(path, "r", encoding="utf-8") as f:
script = f.read() script = f.read()
# 如果有参数且是元组使用execute执行整个脚本 # 如果有参数且是元组使用execute执行整个脚本
if params is not None and isinstance(params, tuple): if params is not None and isinstance(params, tuple):
await self.execute(script, params) await self.execute(script, params)
@ -167,8 +179,10 @@ class DatabaseManager:
# 使用sqlparse准确分割SQL语句 # 使用sqlparse准确分割SQL语句
statements = self._parse_sql_statements(script) statements = self._parse_sql_statements(script)
if len(statements) != len(params): if len(statements) != len(params):
raise ValueError(f"语句数量({len(statements)})与参数组数量({len(params)})不匹配") raise ValueError(
f"语句数量({len(statements)})与参数组数量({len(params)})不匹配"
)
for statement, stmt_params in zip(statements, params): for statement, stmt_params in zip(statements, params):
if statement: if statement:
await self.execute(statement, stmt_params) await self.execute(statement, stmt_params)
@ -215,4 +229,3 @@ class DatabaseManager:
except: except:
pass pass
self._in_use.clear() self._in_use.clear()

2572
poetry.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -34,6 +34,8 @@ dependencies = [
"shapely (>=2.1.2,<3.0.0)", "shapely (>=2.1.2,<3.0.0)",
"mcstatus (>=12.2.1,<13.0.0)", "mcstatus (>=12.2.1,<13.0.0)",
"borax (>=4.1.3,<5.0.0)", "borax (>=4.1.3,<5.0.0)",
"pytest (>=8.0.0,<9.0.0)",
"nonebug (>=0.4.3,<0.5.0)",
] ]
[tool.poetry] [tool.poetry]
@ -52,8 +54,14 @@ priority = "primary"
[dependency-groups] [dependency-groups]
dev = [ dev = ["rust-just (>=1.43.0,<2.0.0)", "pytest-asyncio (>=1.3.0,<2.0.0)"]
"rust-just (>=1.43.0,<2.0.0)",
"pytest (>=9.0.1,<10.0.0)", [tool.pytest.ini_options]
"pytest-asyncio (>=1.3.0,<2.0.0)" testpaths = "tests"
] python_files = "test_*.py"
asyncio_mode = "auto"
asyncio_default_fixture_loop_scope = "session"
[tool.nonebot]
# plugin_dirs = ["konabot/plugins/"]
plugin_dirs = []

28
tests/conftest.py Normal file
View File

@ -0,0 +1,28 @@
# 文件内容来源:
# https://nonebot.dev/docs/best-practice/testing/
# 保证 nonebug 测试框架正常运作
import pytest
import nonebot
from pytest_asyncio import is_async_test
from nonebot.adapters.console import Adapter as ConsoleAdapter
from nonebug import NONEBOT_START_LIFESPAN
def pytest_collection_modifyitems(items: list[pytest.Item]):
pytest_asyncio_tests = (item for item in items if is_async_test(item))
session_scope_marker = pytest.mark.asyncio(loop_scope="session")
for async_test in pytest_asyncio_tests:
async_test.add_marker(session_scope_marker, append=False)
@pytest.fixture(scope="session", autouse=True)
async def after_nonebot_init(after_nonebot_init: None):
driver = nonebot.get_driver()
driver.register_adapter(ConsoleAdapter)
nonebot.load_from_toml("pyproject.toml")
def pytest_configure(config: pytest.Config):
config.stash[NONEBOT_START_LIFESPAN] = False

View File

@ -1,4 +1,3 @@
import asyncio
import os import os
import tempfile import tempfile
from pathlib import Path from pathlib import Path
@ -12,13 +11,13 @@ from konabot.common.database import DatabaseManager
async def test_database_manager(): async def test_database_manager():
"""测试数据库管理器的基本功能""" """测试数据库管理器的基本功能"""
# 创建临时数据库文件 # 创建临时数据库文件
with tempfile.NamedTemporaryFile(suffix='.db', delete=False) as tmp_file: with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as tmp_file:
db_path = tmp_file.name db_path = tmp_file.name
try: try:
# 初始化数据库管理器 # 初始化数据库管理器
db_manager = DatabaseManager(db_path) db_manager = DatabaseManager(db_path)
# 创建测试表 # 创建测试表
create_table_sql = """ create_table_sql = """
CREATE TABLE IF NOT EXISTS test_users ( CREATE TABLE IF NOT EXISTS test_users (
@ -28,26 +27,27 @@ async def test_database_manager():
); );
""" """
await db_manager.execute(create_table_sql) await db_manager.execute(create_table_sql)
# 插入测试数据 # 插入测试数据
insert_sql = "INSERT INTO test_users (name, email) VALUES (?, ?)" insert_sql = "INSERT INTO test_users (name, email) VALUES (?, ?)"
await db_manager.execute(insert_sql, ("张三", "zhangsan@example.com")) await db_manager.execute(insert_sql, ("张三", "zhangsan@example.com"))
await db_manager.execute(insert_sql, ("李四", "lisi@example.com")) await db_manager.execute(insert_sql, ("李四", "lisi@example.com"))
# 查询数据 # 查询数据
select_sql = "SELECT * FROM test_users WHERE name = ?" select_sql = "SELECT * FROM test_users WHERE name = ?"
results = await db_manager.query(select_sql, ("张三",)) results = await db_manager.query(select_sql, ("张三",))
assert len(results) == 1 assert len(results) == 1
assert results[0]["name"] == "张三" assert results[0]["name"] == "张三"
assert results[0]["email"] == "zhangsan@example.com" assert results[0]["email"] == "zhangsan@example.com"
# 测试使用Path对象 # 测试使用Path对象
results = await db_manager.query_by_sql_file(Path(__file__), ("李四",)) # results = await db_manager.query_by_sql_file(Path(__file__), ("李四",))
# 注意这里只是测试参数传递实际SQL文件内容不是有效的SQL # 注意这里只是测试参数传递实际SQL文件内容不是有效的SQL
## ^^^ 卧了个槽的坏枪,你让 AI 写单元测试不检查一下吗
# 关闭所有连接 # 关闭所有连接
await db_manager.close_all_connections() await db_manager.close_all_connections()
finally: finally:
# 清理临时文件 # 清理临时文件
if os.path.exists(db_path): if os.path.exists(db_path):
@ -58,13 +58,13 @@ async def test_database_manager():
async def test_execute_script(): async def test_execute_script():
"""测试执行SQL脚本功能""" """测试执行SQL脚本功能"""
# 创建临时数据库文件 # 创建临时数据库文件
with tempfile.NamedTemporaryFile(suffix='.db', delete=False) as tmp_file: with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as tmp_file:
db_path = tmp_file.name db_path = tmp_file.name
try: try:
# 初始化数据库管理器 # 初始化数据库管理器
db_manager = DatabaseManager(db_path) db_manager = DatabaseManager(db_path)
# 创建测试表的脚本 # 创建测试表的脚本
script = """ script = """
CREATE TABLE IF NOT EXISTS test_products ( CREATE TABLE IF NOT EXISTS test_products (
@ -75,19 +75,19 @@ async def test_execute_script():
INSERT INTO test_products (name, price) VALUES ('苹果', 5.0); INSERT INTO test_products (name, price) VALUES ('苹果', 5.0);
INSERT INTO test_products (name, price) VALUES ('香蕉', 3.0); INSERT INTO test_products (name, price) VALUES ('香蕉', 3.0);
""" """
await db_manager.execute_script(script) await db_manager.execute_script(script)
# 查询数据 # 查询数据
results = await db_manager.query("SELECT * FROM test_products ORDER BY name") results = await db_manager.query("SELECT * FROM test_products ORDER BY name")
assert len(results) == 2 assert len(results) == 2
assert results[0]["name"] == "苹果" assert results[0]["name"] == "苹果"
assert results[1]["name"] == "香蕉" assert results[1]["name"] == "香蕉"
# 关闭所有连接 # 关闭所有连接
await db_manager.close_all_connections() await db_manager.close_all_connections()
finally: finally:
# 清理临时文件 # 清理临时文件
if os.path.exists(db_path): if os.path.exists(db_path):
os.unlink(db_path) os.unlink(db_path)