修复坏枪从来没有运行过的单元测试,为项目引入单元测试框架(终于。。)
This commit is contained in:
@ -1,3 +1,4 @@
|
||||
from contextlib import asynccontextmanager
|
||||
import os
|
||||
import asyncio
|
||||
import sqlparse
|
||||
@ -10,16 +11,19 @@ if TYPE_CHECKING:
|
||||
from . import DatabaseManager
|
||||
|
||||
# 全局数据库管理器实例
|
||||
_global_db_manager: Optional['DatabaseManager'] = None
|
||||
_global_db_manager: Optional["DatabaseManager"] = None
|
||||
|
||||
def get_global_db_manager() -> 'DatabaseManager':
|
||||
|
||||
def get_global_db_manager() -> "DatabaseManager":
|
||||
"""获取全局数据库管理器实例"""
|
||||
global _global_db_manager
|
||||
if _global_db_manager is None:
|
||||
from . import DatabaseManager
|
||||
|
||||
_global_db_manager = DatabaseManager()
|
||||
return _global_db_manager
|
||||
|
||||
|
||||
def close_global_db_manager() -> None:
|
||||
"""关闭全局数据库管理器实例"""
|
||||
global _global_db_manager
|
||||
@ -87,6 +91,12 @@ class DatabaseManager:
|
||||
except:
|
||||
pass
|
||||
|
||||
@asynccontextmanager
|
||||
async def get_conn(self):
|
||||
conn = await self._get_connection()
|
||||
yield conn
|
||||
await self._return_connection(conn)
|
||||
|
||||
async def query(
|
||||
self, query: str, params: Optional[tuple] = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
@ -143,22 +153,24 @@ class DatabaseManager:
|
||||
# 使用sqlparse库更准确地分割SQL语句
|
||||
parsed = sqlparse.split(script)
|
||||
statements = []
|
||||
|
||||
|
||||
for statement in parsed:
|
||||
statement = statement.strip()
|
||||
if statement:
|
||||
statements.append(statement)
|
||||
|
||||
|
||||
return statements
|
||||
|
||||
async def execute_by_sql_file(
|
||||
self, file_path: Union[str, Path], params: Optional[Union[tuple, List[tuple]]] = None
|
||||
self,
|
||||
file_path: Union[str, Path],
|
||||
params: Optional[Union[tuple, List[tuple]]] = None,
|
||||
) -> None:
|
||||
"""从 SQL 文件中读取非查询语句并执行"""
|
||||
path = str(file_path) if isinstance(file_path, Path) else file_path
|
||||
with open(path, "r", encoding="utf-8") as f:
|
||||
script = f.read()
|
||||
|
||||
|
||||
# 如果有参数且是元组,使用execute执行整个脚本
|
||||
if params is not None and isinstance(params, tuple):
|
||||
await self.execute(script, params)
|
||||
@ -167,8 +179,10 @@ class DatabaseManager:
|
||||
# 使用sqlparse准确分割SQL语句
|
||||
statements = self._parse_sql_statements(script)
|
||||
if len(statements) != len(params):
|
||||
raise ValueError(f"语句数量({len(statements)})与参数组数量({len(params)})不匹配")
|
||||
|
||||
raise ValueError(
|
||||
f"语句数量({len(statements)})与参数组数量({len(params)})不匹配"
|
||||
)
|
||||
|
||||
for statement, stmt_params in zip(statements, params):
|
||||
if statement:
|
||||
await self.execute(statement, stmt_params)
|
||||
@ -215,4 +229,3 @@ class DatabaseManager:
|
||||
except:
|
||||
pass
|
||||
self._in_use.clear()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user