232 lines
8.3 KiB
Python
232 lines
8.3 KiB
Python
from contextlib import asynccontextmanager
|
||
import os
|
||
import asyncio
|
||
import sqlparse
|
||
from pathlib import Path
|
||
from typing import List, Dict, Any, Optional, Union, TYPE_CHECKING
|
||
|
||
import aiosqlite
|
||
|
||
if TYPE_CHECKING:
|
||
from . import DatabaseManager
|
||
|
||
# 全局数据库管理器实例
|
||
_global_db_manager: Optional["DatabaseManager"] = None
|
||
|
||
|
||
def get_global_db_manager() -> "DatabaseManager":
|
||
"""获取全局数据库管理器实例"""
|
||
global _global_db_manager
|
||
if _global_db_manager is None:
|
||
from . import DatabaseManager
|
||
|
||
_global_db_manager = DatabaseManager()
|
||
return _global_db_manager
|
||
|
||
|
||
def close_global_db_manager() -> None:
|
||
"""关闭全局数据库管理器实例"""
|
||
global _global_db_manager
|
||
if _global_db_manager is not None:
|
||
# 注意:这个函数应该在async环境中调用close_all_connections
|
||
_global_db_manager = None
|
||
|
||
|
||
class DatabaseManager:
|
||
"""异步数据库管理器"""
|
||
|
||
def __init__(self, db_path: Optional[Union[str, Path]] = None, pool_size: int = 5):
|
||
"""
|
||
初始化数据库管理器
|
||
|
||
Args:
|
||
db_path: 数据库文件路径,支持str和Path类型
|
||
pool_size: 连接池大小
|
||
"""
|
||
if db_path is None:
|
||
self.db_path = os.environ.get("DATABASE_PATH", "./data/database.db")
|
||
else:
|
||
self.db_path = str(db_path) if isinstance(db_path, Path) else db_path
|
||
|
||
# 连接池
|
||
self._connection_pool = []
|
||
self._pool_size = pool_size
|
||
self._lock = asyncio.Lock()
|
||
self._in_use = set() # 跟踪正在使用的连接
|
||
|
||
async def _get_connection(self) -> aiosqlite.Connection:
|
||
"""从连接池获取连接"""
|
||
async with self._lock:
|
||
# 尝试从池中获取现有连接
|
||
while self._connection_pool:
|
||
conn = self._connection_pool.pop()
|
||
# 检查连接是否仍然有效
|
||
try:
|
||
await conn.execute("SELECT 1")
|
||
self._in_use.add(conn)
|
||
return conn
|
||
except:
|
||
# 连接已失效,关闭它
|
||
try:
|
||
await conn.close()
|
||
except:
|
||
pass
|
||
|
||
# 如果连接池为空,创建新连接
|
||
conn = await aiosqlite.connect(self.db_path)
|
||
await conn.execute("PRAGMA foreign_keys = ON")
|
||
self._in_use.add(conn)
|
||
return conn
|
||
|
||
async def _return_connection(self, conn: aiosqlite.Connection) -> None:
|
||
"""将连接返回到连接池"""
|
||
async with self._lock:
|
||
self._in_use.discard(conn)
|
||
if len(self._connection_pool) < self._pool_size:
|
||
self._connection_pool.append(conn)
|
||
else:
|
||
# 池已满,直接关闭连接
|
||
try:
|
||
await conn.close()
|
||
except:
|
||
pass
|
||
|
||
@asynccontextmanager
|
||
async def get_conn(self):
|
||
conn = await self._get_connection()
|
||
yield conn
|
||
await self._return_connection(conn)
|
||
|
||
async def query(
|
||
self, query: str, params: Optional[tuple] = None
|
||
) -> List[Dict[str, Any]]:
|
||
"""执行查询语句并返回结果"""
|
||
conn = await self._get_connection()
|
||
try:
|
||
cursor = await conn.execute(query, params or ())
|
||
columns = [description[0] for description in cursor.description]
|
||
rows = await cursor.fetchall()
|
||
results = [dict(zip(columns, row)) for row in rows]
|
||
await cursor.close()
|
||
return results
|
||
except Exception as e:
|
||
# 记录错误但重新抛出,让调用者处理
|
||
raise Exception(f"数据库查询失败: {str(e)}") from e
|
||
finally:
|
||
await self._return_connection(conn)
|
||
|
||
async def query_by_sql_file(
|
||
self, file_path: Union[str, Path], params: Optional[tuple] = None
|
||
) -> List[Dict[str, Any]]:
|
||
"""从 SQL 文件中读取查询语句并执行"""
|
||
path = str(file_path) if isinstance(file_path, Path) else file_path
|
||
with open(path, "r", encoding="utf-8") as f:
|
||
query = f.read()
|
||
return await self.query(query, params)
|
||
|
||
async def execute(self, command: str, params: Optional[tuple] = None) -> None:
|
||
"""执行非查询语句"""
|
||
conn = await self._get_connection()
|
||
try:
|
||
await conn.execute(command, params or ())
|
||
await conn.commit()
|
||
except Exception as e:
|
||
# 记录错误但重新抛出,让调用者处理
|
||
raise Exception(f"数据库执行失败: {str(e)}") from e
|
||
finally:
|
||
await self._return_connection(conn)
|
||
|
||
async def execute_script(self, script: str) -> None:
|
||
"""执行SQL脚本"""
|
||
conn = await self._get_connection()
|
||
try:
|
||
await conn.executescript(script)
|
||
await conn.commit()
|
||
except Exception as e:
|
||
# 记录错误但重新抛出,让调用者处理
|
||
raise Exception(f"数据库脚本执行失败: {str(e)}") from e
|
||
finally:
|
||
await self._return_connection(conn)
|
||
|
||
def _parse_sql_statements(self, script: str) -> List[str]:
|
||
"""解析SQL脚本,分割成独立的语句"""
|
||
# 使用sqlparse库更准确地分割SQL语句
|
||
parsed = sqlparse.split(script)
|
||
statements = []
|
||
|
||
for statement in parsed:
|
||
statement = statement.strip()
|
||
if statement:
|
||
statements.append(statement)
|
||
|
||
return statements
|
||
|
||
async def execute_by_sql_file(
|
||
self,
|
||
file_path: Union[str, Path],
|
||
params: Optional[Union[tuple, List[tuple]]] = None,
|
||
) -> None:
|
||
"""从 SQL 文件中读取非查询语句并执行"""
|
||
path = str(file_path) if isinstance(file_path, Path) else file_path
|
||
with open(path, "r", encoding="utf-8") as f:
|
||
script = f.read()
|
||
|
||
# 如果有参数且是元组,使用execute执行整个脚本
|
||
if params is not None and isinstance(params, tuple):
|
||
await self.execute(script, params)
|
||
# 如果有参数且是列表,分别执行每个语句
|
||
elif params is not None and isinstance(params, list):
|
||
# 使用sqlparse准确分割SQL语句
|
||
statements = self._parse_sql_statements(script)
|
||
if len(statements) != len(params):
|
||
raise ValueError(
|
||
f"语句数量({len(statements)})与参数组数量({len(params)})不匹配"
|
||
)
|
||
|
||
for statement, stmt_params in zip(statements, params):
|
||
if statement:
|
||
await self.execute(statement, stmt_params)
|
||
# 如果无参数,使用executescript
|
||
else:
|
||
await self.execute_script(script)
|
||
|
||
async def execute_many(self, command: str, seq_of_params: List[tuple]) -> None:
|
||
"""执行多条非查询语句"""
|
||
conn = await self._get_connection()
|
||
try:
|
||
await conn.executemany(command, seq_of_params)
|
||
await conn.commit()
|
||
except Exception as e:
|
||
# 记录错误但重新抛出,让调用者处理
|
||
raise Exception(f"数据库批量执行失败: {str(e)}") from e
|
||
finally:
|
||
await self._return_connection(conn)
|
||
|
||
async def execute_many_values_by_sql_file(
|
||
self, file_path: Union[str, Path], seq_of_params: List[tuple]
|
||
) -> None:
|
||
"""从 SQL 文件中读取一条语句,但是被不同值同时执行"""
|
||
path = str(file_path) if isinstance(file_path, Path) else file_path
|
||
with open(path, "r", encoding="utf-8") as f:
|
||
command = f.read()
|
||
await self.execute_many(command, seq_of_params)
|
||
|
||
async def close_all_connections(self) -> None:
|
||
"""关闭所有连接"""
|
||
async with self._lock:
|
||
# 关闭池中的连接
|
||
for conn in self._connection_pool:
|
||
try:
|
||
await conn.close()
|
||
except:
|
||
pass
|
||
self._connection_pool.clear()
|
||
|
||
# 关闭正在使用的连接
|
||
for conn in self._in_use.copy():
|
||
try:
|
||
await conn.close()
|
||
except:
|
||
pass
|
||
self._in_use.clear()
|