Qwen 说让我再改点东西所以改了,强化了数据库相关的事情
This commit is contained in:
@ -1,20 +1,43 @@
|
||||
import os
|
||||
import asyncio
|
||||
import sqlparse
|
||||
from pathlib import Path
|
||||
from typing import List, Dict, Any, Optional, Union
|
||||
from typing import List, Dict, Any, Optional, Union, TYPE_CHECKING
|
||||
|
||||
import aiosqlite
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from . import DatabaseManager
|
||||
|
||||
# 全局数据库管理器实例
|
||||
_global_db_manager: Optional['DatabaseManager'] = None
|
||||
|
||||
def get_global_db_manager() -> 'DatabaseManager':
|
||||
"""获取全局数据库管理器实例"""
|
||||
global _global_db_manager
|
||||
if _global_db_manager is None:
|
||||
from . import DatabaseManager
|
||||
_global_db_manager = DatabaseManager()
|
||||
return _global_db_manager
|
||||
|
||||
def close_global_db_manager() -> None:
|
||||
"""关闭全局数据库管理器实例"""
|
||||
global _global_db_manager
|
||||
if _global_db_manager is not None:
|
||||
# 注意:这个函数应该在async环境中调用close_all_connections
|
||||
_global_db_manager = None
|
||||
|
||||
|
||||
class DatabaseManager:
|
||||
"""异步数据库管理器"""
|
||||
|
||||
def __init__(self, db_path: Optional[Union[str, Path]] = None):
|
||||
def __init__(self, db_path: Optional[Union[str, Path]] = None, pool_size: int = 5):
|
||||
"""
|
||||
初始化数据库管理器
|
||||
|
||||
Args:
|
||||
db_path: 数据库文件路径,支持str和Path类型
|
||||
pool_size: 连接池大小
|
||||
"""
|
||||
if db_path is None:
|
||||
self.db_path = os.environ.get("DATABASE_PATH", "./data/database.db")
|
||||
@ -23,27 +46,46 @@ class DatabaseManager:
|
||||
|
||||
# 连接池
|
||||
self._connection_pool = []
|
||||
self._pool_size = 5
|
||||
self._pool_size = pool_size
|
||||
self._lock = asyncio.Lock()
|
||||
self._in_use = set() # 跟踪正在使用的连接
|
||||
|
||||
async def _get_connection(self) -> aiosqlite.Connection:
|
||||
"""从连接池获取连接"""
|
||||
async with self._lock:
|
||||
if self._connection_pool:
|
||||
return self._connection_pool.pop()
|
||||
# 尝试从池中获取现有连接
|
||||
while self._connection_pool:
|
||||
conn = self._connection_pool.pop()
|
||||
# 检查连接是否仍然有效
|
||||
try:
|
||||
await conn.execute("SELECT 1")
|
||||
self._in_use.add(conn)
|
||||
return conn
|
||||
except:
|
||||
# 连接已失效,关闭它
|
||||
try:
|
||||
await conn.close()
|
||||
except:
|
||||
pass
|
||||
|
||||
# 如果连接池为空,创建新连接
|
||||
conn = await aiosqlite.connect(self.db_path)
|
||||
await conn.execute("PRAGMA foreign_keys = ON")
|
||||
return conn
|
||||
# 如果连接池为空,创建新连接
|
||||
conn = await aiosqlite.connect(self.db_path)
|
||||
await conn.execute("PRAGMA foreign_keys = ON")
|
||||
self._in_use.add(conn)
|
||||
return conn
|
||||
|
||||
async def _return_connection(self, conn: aiosqlite.Connection) -> None:
|
||||
"""将连接返回到连接池"""
|
||||
async with self._lock:
|
||||
self._in_use.discard(conn)
|
||||
if len(self._connection_pool) < self._pool_size:
|
||||
self._connection_pool.append(conn)
|
||||
else:
|
||||
await conn.close()
|
||||
# 池已满,直接关闭连接
|
||||
try:
|
||||
await conn.close()
|
||||
except:
|
||||
pass
|
||||
|
||||
async def query(
|
||||
self, query: str, params: Optional[tuple] = None
|
||||
@ -57,6 +99,9 @@ class DatabaseManager:
|
||||
results = [dict(zip(columns, row)) for row in rows]
|
||||
await cursor.close()
|
||||
return results
|
||||
except Exception as e:
|
||||
# 记录错误但重新抛出,让调用者处理
|
||||
raise Exception(f"数据库查询失败: {str(e)}") from e
|
||||
finally:
|
||||
await self._return_connection(conn)
|
||||
|
||||
@ -75,6 +120,9 @@ class DatabaseManager:
|
||||
try:
|
||||
await conn.execute(command, params or ())
|
||||
await conn.commit()
|
||||
except Exception as e:
|
||||
# 记录错误但重新抛出,让调用者处理
|
||||
raise Exception(f"数据库执行失败: {str(e)}") from e
|
||||
finally:
|
||||
await self._return_connection(conn)
|
||||
|
||||
@ -84,19 +132,47 @@ class DatabaseManager:
|
||||
try:
|
||||
await conn.executescript(script)
|
||||
await conn.commit()
|
||||
except Exception as e:
|
||||
# 记录错误但重新抛出,让调用者处理
|
||||
raise Exception(f"数据库脚本执行失败: {str(e)}") from e
|
||||
finally:
|
||||
await self._return_connection(conn)
|
||||
|
||||
def _parse_sql_statements(self, script: str) -> List[str]:
|
||||
"""解析SQL脚本,分割成独立的语句"""
|
||||
# 使用sqlparse库更准确地分割SQL语句
|
||||
parsed = sqlparse.split(script)
|
||||
statements = []
|
||||
|
||||
for statement in parsed:
|
||||
statement = statement.strip()
|
||||
if statement:
|
||||
statements.append(statement)
|
||||
|
||||
return statements
|
||||
|
||||
async def execute_by_sql_file(
|
||||
self, file_path: Union[str, Path], params: Optional[tuple] = None
|
||||
self, file_path: Union[str, Path], params: Optional[Union[tuple, List[tuple]]] = None
|
||||
) -> None:
|
||||
"""从 SQL 文件中读取非查询语句并执行"""
|
||||
path = str(file_path) if isinstance(file_path, Path) else file_path
|
||||
with open(path, "r", encoding="utf-8") as f:
|
||||
script = f.read()
|
||||
# 如果有参数,使用execute方法而不是execute_script
|
||||
if params:
|
||||
|
||||
# 如果有参数且是元组,使用execute执行整个脚本
|
||||
if params is not None and isinstance(params, tuple):
|
||||
await self.execute(script, params)
|
||||
# 如果有参数且是列表,分别执行每个语句
|
||||
elif params is not None and isinstance(params, list):
|
||||
# 使用sqlparse准确分割SQL语句
|
||||
statements = self._parse_sql_statements(script)
|
||||
if len(statements) != len(params):
|
||||
raise ValueError(f"语句数量({len(statements)})与参数组数量({len(params)})不匹配")
|
||||
|
||||
for statement, stmt_params in zip(statements, params):
|
||||
if statement:
|
||||
await self.execute(statement, stmt_params)
|
||||
# 如果无参数,使用executescript
|
||||
else:
|
||||
await self.execute_script(script)
|
||||
|
||||
@ -106,6 +182,9 @@ class DatabaseManager:
|
||||
try:
|
||||
await conn.executemany(command, seq_of_params)
|
||||
await conn.commit()
|
||||
except Exception as e:
|
||||
# 记录错误但重新抛出,让调用者处理
|
||||
raise Exception(f"数据库批量执行失败: {str(e)}") from e
|
||||
finally:
|
||||
await self._return_connection(conn)
|
||||
|
||||
@ -121,7 +200,19 @@ class DatabaseManager:
|
||||
async def close_all_connections(self) -> None:
|
||||
"""关闭所有连接"""
|
||||
async with self._lock:
|
||||
# 关闭池中的连接
|
||||
for conn in self._connection_pool:
|
||||
await conn.close()
|
||||
try:
|
||||
await conn.close()
|
||||
except:
|
||||
pass
|
||||
self._connection_pool.clear()
|
||||
|
||||
# 关闭正在使用的连接
|
||||
for conn in self._in_use.copy():
|
||||
try:
|
||||
await conn.close()
|
||||
except:
|
||||
pass
|
||||
self._in_use.clear()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user