from contextlib import asynccontextmanager import os import asyncio import sqlparse from pathlib import Path from typing import List, Dict, Any, Optional, Union, TYPE_CHECKING import aiosqlite if TYPE_CHECKING: from . import DatabaseManager # 全局数据库管理器实例 _global_db_manager: Optional["DatabaseManager"] = None def get_global_db_manager() -> "DatabaseManager": """获取全局数据库管理器实例""" global _global_db_manager if _global_db_manager is None: from . import DatabaseManager _global_db_manager = DatabaseManager() return _global_db_manager def close_global_db_manager() -> None: """关闭全局数据库管理器实例""" global _global_db_manager if _global_db_manager is not None: # 注意:这个函数应该在async环境中调用close_all_connections _global_db_manager = None class DatabaseManager: """异步数据库管理器""" def __init__(self, db_path: Optional[Union[str, Path]] = None, pool_size: int = 5): """ 初始化数据库管理器 Args: db_path: 数据库文件路径,支持str和Path类型 pool_size: 连接池大小 """ if db_path is None: self.db_path = os.environ.get("DATABASE_PATH", "./data/database.db") else: self.db_path = str(db_path) if isinstance(db_path, Path) else db_path # 连接池 self._connection_pool = [] self._pool_size = pool_size self._lock = asyncio.Lock() self._in_use = set() # 跟踪正在使用的连接 async def _get_connection(self) -> aiosqlite.Connection: """从连接池获取连接""" async with self._lock: # 尝试从池中获取现有连接 while self._connection_pool: conn = self._connection_pool.pop() # 检查连接是否仍然有效 try: await conn.execute("SELECT 1") self._in_use.add(conn) return conn except: # 连接已失效,关闭它 try: await conn.close() except: pass # 如果连接池为空,创建新连接 conn = await aiosqlite.connect(self.db_path) await conn.execute("PRAGMA foreign_keys = ON") self._in_use.add(conn) return conn async def _return_connection(self, conn: aiosqlite.Connection) -> None: """将连接返回到连接池""" async with self._lock: self._in_use.discard(conn) if len(self._connection_pool) < self._pool_size: self._connection_pool.append(conn) else: # 池已满,直接关闭连接 try: await conn.close() except: pass @asynccontextmanager async def get_conn(self): conn = await self._get_connection() yield conn await self._return_connection(conn) async def query( self, query: str, params: Optional[tuple] = None ) -> List[Dict[str, Any]]: """执行查询语句并返回结果""" conn = await self._get_connection() try: cursor = await conn.execute(query, params or ()) columns = [description[0] for description in cursor.description] rows = await cursor.fetchall() results = [dict(zip(columns, row)) for row in rows] await cursor.close() return results except Exception as e: # 记录错误但重新抛出,让调用者处理 raise Exception(f"数据库查询失败: {str(e)}") from e finally: await self._return_connection(conn) async def query_by_sql_file( self, file_path: Union[str, Path], params: Optional[tuple] = None ) -> List[Dict[str, Any]]: """从 SQL 文件中读取查询语句并执行""" path = str(file_path) if isinstance(file_path, Path) else file_path with open(path, "r", encoding="utf-8") as f: query = f.read() return await self.query(query, params) async def execute(self, command: str, params: Optional[tuple] = None) -> None: """执行非查询语句""" conn = await self._get_connection() try: await conn.execute(command, params or ()) await conn.commit() except Exception as e: await conn.rollback() raise Exception(f"数据库执行失败: {str(e)}") from e finally: await self._return_connection(conn) async def execute_script(self, script: str) -> None: """执行SQL脚本""" conn = await self._get_connection() try: await conn.executescript(script) await conn.commit() except Exception as e: await conn.rollback() raise Exception(f"数据库脚本执行失败: {str(e)}") from e finally: await self._return_connection(conn) def _parse_sql_statements(self, script: str) -> List[str]: """解析SQL脚本,分割成独立的语句""" # 使用sqlparse库更准确地分割SQL语句 parsed = sqlparse.split(script) statements = [] for statement in parsed: statement = statement.strip() if statement: statements.append(statement) return statements async def execute_by_sql_file( self, file_path: Union[str, Path], params: Optional[Union[tuple, List[tuple]]] = None, ) -> None: """从 SQL 文件中读取非查询语句并执行""" path = str(file_path) if isinstance(file_path, Path) else file_path with open(path, "r", encoding="utf-8") as f: script = f.read() # 如果有参数且是元组,使用execute执行整个脚本 if params is not None and isinstance(params, tuple): await self.execute(script, params) # 如果有参数且是列表,分别执行每个语句 elif params is not None and isinstance(params, list): # 使用sqlparse准确分割SQL语句 statements = self._parse_sql_statements(script) if len(statements) != len(params): raise ValueError( f"语句数量({len(statements)})与参数组数量({len(params)})不匹配" ) for statement, stmt_params in zip(statements, params): if statement: await self.execute(statement, stmt_params) # 如果无参数,使用executescript else: await self.execute_script(script) async def execute_many(self, command: str, seq_of_params: List[tuple]) -> None: """执行多条非查询语句""" conn = await self._get_connection() try: await conn.executemany(command, seq_of_params) await conn.commit() except Exception as e: await conn.rollback() raise Exception(f"数据库批量执行失败: {str(e)}") from e finally: await self._return_connection(conn) async def execute_many_values_by_sql_file( self, file_path: Union[str, Path], seq_of_params: List[tuple] ) -> None: """从 SQL 文件中读取一条语句,但是被不同值同时执行""" path = str(file_path) if isinstance(file_path, Path) else file_path with open(path, "r", encoding="utf-8") as f: command = f.read() await self.execute_many(command, seq_of_params) async def close_all_connections(self) -> None: """关闭所有连接""" async with self._lock: # 关闭池中的连接 for conn in self._connection_pool: try: await conn.close() except: pass self._connection_pool.clear() # 关闭正在使用的连接 for conn in self._in_use.copy(): try: await conn.close() except: pass self._in_use.clear()