修复坏枪从来没有运行过的单元测试,为项目引入单元测试框架(终于。。)
This commit is contained in:
6
.sqls.yml
Normal file
6
.sqls.yml
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
lowercaseKeywords: false
|
||||||
|
connections:
|
||||||
|
- driver: sqlite
|
||||||
|
dataSourceName: "./data/database.db"
|
||||||
|
- driver: sqlite
|
||||||
|
dataSourceName: "./data/perm.sqlite3"
|
||||||
3
justfile
3
justfile
@ -1,4 +1,5 @@
|
|||||||
watch:
|
watch:
|
||||||
poetry run watchfiles bot.main . --filter scripts.watch_filter.filter
|
poetry run watchfiles bot.main . --filter scripts.watch_filter.filter
|
||||||
|
|
||||||
|
test:
|
||||||
|
poetry run pytest
|
||||||
|
|||||||
@ -1,3 +1,4 @@
|
|||||||
|
from contextlib import asynccontextmanager
|
||||||
import os
|
import os
|
||||||
import asyncio
|
import asyncio
|
||||||
import sqlparse
|
import sqlparse
|
||||||
@ -10,16 +11,19 @@ if TYPE_CHECKING:
|
|||||||
from . import DatabaseManager
|
from . import DatabaseManager
|
||||||
|
|
||||||
# 全局数据库管理器实例
|
# 全局数据库管理器实例
|
||||||
_global_db_manager: Optional['DatabaseManager'] = None
|
_global_db_manager: Optional["DatabaseManager"] = None
|
||||||
|
|
||||||
def get_global_db_manager() -> 'DatabaseManager':
|
|
||||||
|
def get_global_db_manager() -> "DatabaseManager":
|
||||||
"""获取全局数据库管理器实例"""
|
"""获取全局数据库管理器实例"""
|
||||||
global _global_db_manager
|
global _global_db_manager
|
||||||
if _global_db_manager is None:
|
if _global_db_manager is None:
|
||||||
from . import DatabaseManager
|
from . import DatabaseManager
|
||||||
|
|
||||||
_global_db_manager = DatabaseManager()
|
_global_db_manager = DatabaseManager()
|
||||||
return _global_db_manager
|
return _global_db_manager
|
||||||
|
|
||||||
|
|
||||||
def close_global_db_manager() -> None:
|
def close_global_db_manager() -> None:
|
||||||
"""关闭全局数据库管理器实例"""
|
"""关闭全局数据库管理器实例"""
|
||||||
global _global_db_manager
|
global _global_db_manager
|
||||||
@ -87,6 +91,12 @@ class DatabaseManager:
|
|||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def get_conn(self):
|
||||||
|
conn = await self._get_connection()
|
||||||
|
yield conn
|
||||||
|
await self._return_connection(conn)
|
||||||
|
|
||||||
async def query(
|
async def query(
|
||||||
self, query: str, params: Optional[tuple] = None
|
self, query: str, params: Optional[tuple] = None
|
||||||
) -> List[Dict[str, Any]]:
|
) -> List[Dict[str, Any]]:
|
||||||
@ -143,22 +153,24 @@ class DatabaseManager:
|
|||||||
# 使用sqlparse库更准确地分割SQL语句
|
# 使用sqlparse库更准确地分割SQL语句
|
||||||
parsed = sqlparse.split(script)
|
parsed = sqlparse.split(script)
|
||||||
statements = []
|
statements = []
|
||||||
|
|
||||||
for statement in parsed:
|
for statement in parsed:
|
||||||
statement = statement.strip()
|
statement = statement.strip()
|
||||||
if statement:
|
if statement:
|
||||||
statements.append(statement)
|
statements.append(statement)
|
||||||
|
|
||||||
return statements
|
return statements
|
||||||
|
|
||||||
async def execute_by_sql_file(
|
async def execute_by_sql_file(
|
||||||
self, file_path: Union[str, Path], params: Optional[Union[tuple, List[tuple]]] = None
|
self,
|
||||||
|
file_path: Union[str, Path],
|
||||||
|
params: Optional[Union[tuple, List[tuple]]] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""从 SQL 文件中读取非查询语句并执行"""
|
"""从 SQL 文件中读取非查询语句并执行"""
|
||||||
path = str(file_path) if isinstance(file_path, Path) else file_path
|
path = str(file_path) if isinstance(file_path, Path) else file_path
|
||||||
with open(path, "r", encoding="utf-8") as f:
|
with open(path, "r", encoding="utf-8") as f:
|
||||||
script = f.read()
|
script = f.read()
|
||||||
|
|
||||||
# 如果有参数且是元组,使用execute执行整个脚本
|
# 如果有参数且是元组,使用execute执行整个脚本
|
||||||
if params is not None and isinstance(params, tuple):
|
if params is not None and isinstance(params, tuple):
|
||||||
await self.execute(script, params)
|
await self.execute(script, params)
|
||||||
@ -167,8 +179,10 @@ class DatabaseManager:
|
|||||||
# 使用sqlparse准确分割SQL语句
|
# 使用sqlparse准确分割SQL语句
|
||||||
statements = self._parse_sql_statements(script)
|
statements = self._parse_sql_statements(script)
|
||||||
if len(statements) != len(params):
|
if len(statements) != len(params):
|
||||||
raise ValueError(f"语句数量({len(statements)})与参数组数量({len(params)})不匹配")
|
raise ValueError(
|
||||||
|
f"语句数量({len(statements)})与参数组数量({len(params)})不匹配"
|
||||||
|
)
|
||||||
|
|
||||||
for statement, stmt_params in zip(statements, params):
|
for statement, stmt_params in zip(statements, params):
|
||||||
if statement:
|
if statement:
|
||||||
await self.execute(statement, stmt_params)
|
await self.execute(statement, stmt_params)
|
||||||
@ -215,4 +229,3 @@ class DatabaseManager:
|
|||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
self._in_use.clear()
|
self._in_use.clear()
|
||||||
|
|
||||||
|
|||||||
2572
poetry.lock
generated
2572
poetry.lock
generated
File diff suppressed because it is too large
Load Diff
@ -34,6 +34,8 @@ dependencies = [
|
|||||||
"shapely (>=2.1.2,<3.0.0)",
|
"shapely (>=2.1.2,<3.0.0)",
|
||||||
"mcstatus (>=12.2.1,<13.0.0)",
|
"mcstatus (>=12.2.1,<13.0.0)",
|
||||||
"borax (>=4.1.3,<5.0.0)",
|
"borax (>=4.1.3,<5.0.0)",
|
||||||
|
"pytest (>=8.0.0,<9.0.0)",
|
||||||
|
"nonebug (>=0.4.3,<0.5.0)",
|
||||||
]
|
]
|
||||||
|
|
||||||
[tool.poetry]
|
[tool.poetry]
|
||||||
@ -52,8 +54,14 @@ priority = "primary"
|
|||||||
|
|
||||||
|
|
||||||
[dependency-groups]
|
[dependency-groups]
|
||||||
dev = [
|
dev = ["rust-just (>=1.43.0,<2.0.0)", "pytest-asyncio (>=1.3.0,<2.0.0)"]
|
||||||
"rust-just (>=1.43.0,<2.0.0)",
|
|
||||||
"pytest (>=9.0.1,<10.0.0)",
|
[tool.pytest.ini_options]
|
||||||
"pytest-asyncio (>=1.3.0,<2.0.0)"
|
testpaths = "tests"
|
||||||
]
|
python_files = "test_*.py"
|
||||||
|
asyncio_mode = "auto"
|
||||||
|
asyncio_default_fixture_loop_scope = "session"
|
||||||
|
|
||||||
|
[tool.nonebot]
|
||||||
|
# plugin_dirs = ["konabot/plugins/"]
|
||||||
|
plugin_dirs = []
|
||||||
|
|||||||
28
tests/conftest.py
Normal file
28
tests/conftest.py
Normal file
@ -0,0 +1,28 @@
|
|||||||
|
# 文件内容来源:
|
||||||
|
# https://nonebot.dev/docs/best-practice/testing/
|
||||||
|
# 保证 nonebug 测试框架正常运作
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import nonebot
|
||||||
|
from pytest_asyncio import is_async_test
|
||||||
|
from nonebot.adapters.console import Adapter as ConsoleAdapter
|
||||||
|
from nonebug import NONEBOT_START_LIFESPAN
|
||||||
|
|
||||||
|
|
||||||
|
def pytest_collection_modifyitems(items: list[pytest.Item]):
|
||||||
|
pytest_asyncio_tests = (item for item in items if is_async_test(item))
|
||||||
|
session_scope_marker = pytest.mark.asyncio(loop_scope="session")
|
||||||
|
for async_test in pytest_asyncio_tests:
|
||||||
|
async_test.add_marker(session_scope_marker, append=False)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session", autouse=True)
|
||||||
|
async def after_nonebot_init(after_nonebot_init: None):
|
||||||
|
driver = nonebot.get_driver()
|
||||||
|
driver.register_adapter(ConsoleAdapter)
|
||||||
|
|
||||||
|
nonebot.load_from_toml("pyproject.toml")
|
||||||
|
|
||||||
|
|
||||||
|
def pytest_configure(config: pytest.Config):
|
||||||
|
config.stash[NONEBOT_START_LIFESPAN] = False
|
||||||
@ -1,4 +1,3 @@
|
|||||||
import asyncio
|
|
||||||
import os
|
import os
|
||||||
import tempfile
|
import tempfile
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@ -12,13 +11,13 @@ from konabot.common.database import DatabaseManager
|
|||||||
async def test_database_manager():
|
async def test_database_manager():
|
||||||
"""测试数据库管理器的基本功能"""
|
"""测试数据库管理器的基本功能"""
|
||||||
# 创建临时数据库文件
|
# 创建临时数据库文件
|
||||||
with tempfile.NamedTemporaryFile(suffix='.db', delete=False) as tmp_file:
|
with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as tmp_file:
|
||||||
db_path = tmp_file.name
|
db_path = tmp_file.name
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 初始化数据库管理器
|
# 初始化数据库管理器
|
||||||
db_manager = DatabaseManager(db_path)
|
db_manager = DatabaseManager(db_path)
|
||||||
|
|
||||||
# 创建测试表
|
# 创建测试表
|
||||||
create_table_sql = """
|
create_table_sql = """
|
||||||
CREATE TABLE IF NOT EXISTS test_users (
|
CREATE TABLE IF NOT EXISTS test_users (
|
||||||
@ -28,26 +27,27 @@ async def test_database_manager():
|
|||||||
);
|
);
|
||||||
"""
|
"""
|
||||||
await db_manager.execute(create_table_sql)
|
await db_manager.execute(create_table_sql)
|
||||||
|
|
||||||
# 插入测试数据
|
# 插入测试数据
|
||||||
insert_sql = "INSERT INTO test_users (name, email) VALUES (?, ?)"
|
insert_sql = "INSERT INTO test_users (name, email) VALUES (?, ?)"
|
||||||
await db_manager.execute(insert_sql, ("张三", "zhangsan@example.com"))
|
await db_manager.execute(insert_sql, ("张三", "zhangsan@example.com"))
|
||||||
await db_manager.execute(insert_sql, ("李四", "lisi@example.com"))
|
await db_manager.execute(insert_sql, ("李四", "lisi@example.com"))
|
||||||
|
|
||||||
# 查询数据
|
# 查询数据
|
||||||
select_sql = "SELECT * FROM test_users WHERE name = ?"
|
select_sql = "SELECT * FROM test_users WHERE name = ?"
|
||||||
results = await db_manager.query(select_sql, ("张三",))
|
results = await db_manager.query(select_sql, ("张三",))
|
||||||
assert len(results) == 1
|
assert len(results) == 1
|
||||||
assert results[0]["name"] == "张三"
|
assert results[0]["name"] == "张三"
|
||||||
assert results[0]["email"] == "zhangsan@example.com"
|
assert results[0]["email"] == "zhangsan@example.com"
|
||||||
|
|
||||||
# 测试使用Path对象
|
# 测试使用Path对象
|
||||||
results = await db_manager.query_by_sql_file(Path(__file__), ("李四",))
|
# results = await db_manager.query_by_sql_file(Path(__file__), ("李四",))
|
||||||
# 注意:这里只是测试参数传递,实际SQL文件内容不是有效的SQL
|
# 注意:这里只是测试参数传递,实际SQL文件内容不是有效的SQL
|
||||||
|
## ^^^ 卧了个槽的坏枪,你让 AI 写单元测试不检查一下吗
|
||||||
|
|
||||||
# 关闭所有连接
|
# 关闭所有连接
|
||||||
await db_manager.close_all_connections()
|
await db_manager.close_all_connections()
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
# 清理临时文件
|
# 清理临时文件
|
||||||
if os.path.exists(db_path):
|
if os.path.exists(db_path):
|
||||||
@ -58,13 +58,13 @@ async def test_database_manager():
|
|||||||
async def test_execute_script():
|
async def test_execute_script():
|
||||||
"""测试执行SQL脚本功能"""
|
"""测试执行SQL脚本功能"""
|
||||||
# 创建临时数据库文件
|
# 创建临时数据库文件
|
||||||
with tempfile.NamedTemporaryFile(suffix='.db', delete=False) as tmp_file:
|
with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as tmp_file:
|
||||||
db_path = tmp_file.name
|
db_path = tmp_file.name
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 初始化数据库管理器
|
# 初始化数据库管理器
|
||||||
db_manager = DatabaseManager(db_path)
|
db_manager = DatabaseManager(db_path)
|
||||||
|
|
||||||
# 创建测试表的脚本
|
# 创建测试表的脚本
|
||||||
script = """
|
script = """
|
||||||
CREATE TABLE IF NOT EXISTS test_products (
|
CREATE TABLE IF NOT EXISTS test_products (
|
||||||
@ -75,19 +75,19 @@ async def test_execute_script():
|
|||||||
INSERT INTO test_products (name, price) VALUES ('苹果', 5.0);
|
INSERT INTO test_products (name, price) VALUES ('苹果', 5.0);
|
||||||
INSERT INTO test_products (name, price) VALUES ('香蕉', 3.0);
|
INSERT INTO test_products (name, price) VALUES ('香蕉', 3.0);
|
||||||
"""
|
"""
|
||||||
|
|
||||||
await db_manager.execute_script(script)
|
await db_manager.execute_script(script)
|
||||||
|
|
||||||
# 查询数据
|
# 查询数据
|
||||||
results = await db_manager.query("SELECT * FROM test_products ORDER BY name")
|
results = await db_manager.query("SELECT * FROM test_products ORDER BY name")
|
||||||
assert len(results) == 2
|
assert len(results) == 2
|
||||||
assert results[0]["name"] == "苹果"
|
assert results[0]["name"] == "苹果"
|
||||||
assert results[1]["name"] == "香蕉"
|
assert results[1]["name"] == "香蕉"
|
||||||
|
|
||||||
# 关闭所有连接
|
# 关闭所有连接
|
||||||
await db_manager.close_all_connections()
|
await db_manager.close_all_connections()
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
# 清理临时文件
|
# 清理临时文件
|
||||||
if os.path.exists(db_path):
|
if os.path.exists(db_path):
|
||||||
os.unlink(db_path)
|
os.unlink(db_path)
|
||||||
|
|||||||
Reference in New Issue
Block a user