修复坏枪从来没有运行过的单元测试,为项目引入单元测试框架(终于。。)

This commit is contained in:
2026-03-07 13:16:24 +08:00
parent a1c9f9bccb
commit 88861f4264
7 changed files with 1385 additions and 1309 deletions

View File

@ -1,3 +1,4 @@
from contextlib import asynccontextmanager
import os
import asyncio
import sqlparse
@ -10,16 +11,19 @@ if TYPE_CHECKING:
from . import DatabaseManager
# 全局数据库管理器实例
_global_db_manager: Optional['DatabaseManager'] = None
_global_db_manager: Optional["DatabaseManager"] = None
def get_global_db_manager() -> 'DatabaseManager':
def get_global_db_manager() -> "DatabaseManager":
"""获取全局数据库管理器实例"""
global _global_db_manager
if _global_db_manager is None:
from . import DatabaseManager
_global_db_manager = DatabaseManager()
return _global_db_manager
def close_global_db_manager() -> None:
"""关闭全局数据库管理器实例"""
global _global_db_manager
@ -87,6 +91,12 @@ class DatabaseManager:
except:
pass
@asynccontextmanager
async def get_conn(self):
conn = await self._get_connection()
yield conn
await self._return_connection(conn)
async def query(
self, query: str, params: Optional[tuple] = None
) -> List[Dict[str, Any]]:
@ -143,22 +153,24 @@ class DatabaseManager:
# 使用sqlparse库更准确地分割SQL语句
parsed = sqlparse.split(script)
statements = []
for statement in parsed:
statement = statement.strip()
if statement:
statements.append(statement)
return statements
async def execute_by_sql_file(
self, file_path: Union[str, Path], params: Optional[Union[tuple, List[tuple]]] = None
self,
file_path: Union[str, Path],
params: Optional[Union[tuple, List[tuple]]] = None,
) -> None:
"""从 SQL 文件中读取非查询语句并执行"""
path = str(file_path) if isinstance(file_path, Path) else file_path
with open(path, "r", encoding="utf-8") as f:
script = f.read()
# 如果有参数且是元组使用execute执行整个脚本
if params is not None and isinstance(params, tuple):
await self.execute(script, params)
@ -167,8 +179,10 @@ class DatabaseManager:
# 使用sqlparse准确分割SQL语句
statements = self._parse_sql_statements(script)
if len(statements) != len(params):
raise ValueError(f"语句数量({len(statements)})与参数组数量({len(params)})不匹配")
raise ValueError(
f"语句数量({len(statements)})与参数组数量({len(params)})不匹配"
)
for statement, stmt_params in zip(statements, params):
if statement:
await self.execute(statement, stmt_params)
@ -215,4 +229,3 @@ class DatabaseManager:
except:
pass
self._in_use.clear()