Files
konabot/konabot/common/database/__init__.py

232 lines
8.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from contextlib import asynccontextmanager
import os
import asyncio
import sqlparse
from pathlib import Path
from typing import List, Dict, Any, Optional, Union, TYPE_CHECKING
import aiosqlite
if TYPE_CHECKING:
from . import DatabaseManager
# 全局数据库管理器实例
_global_db_manager: Optional["DatabaseManager"] = None
def get_global_db_manager() -> "DatabaseManager":
"""获取全局数据库管理器实例"""
global _global_db_manager
if _global_db_manager is None:
from . import DatabaseManager
_global_db_manager = DatabaseManager()
return _global_db_manager
def close_global_db_manager() -> None:
"""关闭全局数据库管理器实例"""
global _global_db_manager
if _global_db_manager is not None:
# 注意这个函数应该在async环境中调用close_all_connections
_global_db_manager = None
class DatabaseManager:
"""异步数据库管理器"""
def __init__(self, db_path: Optional[Union[str, Path]] = None, pool_size: int = 5):
"""
初始化数据库管理器
Args:
db_path: 数据库文件路径支持str和Path类型
pool_size: 连接池大小
"""
if db_path is None:
self.db_path = os.environ.get("DATABASE_PATH", "./data/database.db")
else:
self.db_path = str(db_path) if isinstance(db_path, Path) else db_path
# 连接池
self._connection_pool = []
self._pool_size = pool_size
self._lock = asyncio.Lock()
self._in_use = set() # 跟踪正在使用的连接
async def _get_connection(self) -> aiosqlite.Connection:
"""从连接池获取连接"""
async with self._lock:
# 尝试从池中获取现有连接
while self._connection_pool:
conn = self._connection_pool.pop()
# 检查连接是否仍然有效
try:
await conn.execute("SELECT 1")
self._in_use.add(conn)
return conn
except:
# 连接已失效,关闭它
try:
await conn.close()
except:
pass
# 如果连接池为空,创建新连接
conn = await aiosqlite.connect(self.db_path)
await conn.execute("PRAGMA foreign_keys = ON")
self._in_use.add(conn)
return conn
async def _return_connection(self, conn: aiosqlite.Connection) -> None:
"""将连接返回到连接池"""
async with self._lock:
self._in_use.discard(conn)
if len(self._connection_pool) < self._pool_size:
self._connection_pool.append(conn)
else:
# 池已满,直接关闭连接
try:
await conn.close()
except:
pass
@asynccontextmanager
async def get_conn(self):
conn = await self._get_connection()
yield conn
await self._return_connection(conn)
async def query(
self, query: str, params: Optional[tuple] = None
) -> List[Dict[str, Any]]:
"""执行查询语句并返回结果"""
conn = await self._get_connection()
try:
cursor = await conn.execute(query, params or ())
columns = [description[0] for description in cursor.description]
rows = await cursor.fetchall()
results = [dict(zip(columns, row)) for row in rows]
await cursor.close()
return results
except Exception as e:
# 记录错误但重新抛出,让调用者处理
raise Exception(f"数据库查询失败: {str(e)}") from e
finally:
await self._return_connection(conn)
async def query_by_sql_file(
self, file_path: Union[str, Path], params: Optional[tuple] = None
) -> List[Dict[str, Any]]:
"""从 SQL 文件中读取查询语句并执行"""
path = str(file_path) if isinstance(file_path, Path) else file_path
with open(path, "r", encoding="utf-8") as f:
query = f.read()
return await self.query(query, params)
async def execute(self, command: str, params: Optional[tuple] = None) -> None:
"""执行非查询语句"""
conn = await self._get_connection()
try:
await conn.execute(command, params or ())
await conn.commit()
except Exception as e:
# 记录错误但重新抛出,让调用者处理
raise Exception(f"数据库执行失败: {str(e)}") from e
finally:
await self._return_connection(conn)
async def execute_script(self, script: str) -> None:
"""执行SQL脚本"""
conn = await self._get_connection()
try:
await conn.executescript(script)
await conn.commit()
except Exception as e:
# 记录错误但重新抛出,让调用者处理
raise Exception(f"数据库脚本执行失败: {str(e)}") from e
finally:
await self._return_connection(conn)
def _parse_sql_statements(self, script: str) -> List[str]:
"""解析SQL脚本分割成独立的语句"""
# 使用sqlparse库更准确地分割SQL语句
parsed = sqlparse.split(script)
statements = []
for statement in parsed:
statement = statement.strip()
if statement:
statements.append(statement)
return statements
async def execute_by_sql_file(
self,
file_path: Union[str, Path],
params: Optional[Union[tuple, List[tuple]]] = None,
) -> None:
"""从 SQL 文件中读取非查询语句并执行"""
path = str(file_path) if isinstance(file_path, Path) else file_path
with open(path, "r", encoding="utf-8") as f:
script = f.read()
# 如果有参数且是元组使用execute执行整个脚本
if params is not None and isinstance(params, tuple):
await self.execute(script, params)
# 如果有参数且是列表,分别执行每个语句
elif params is not None and isinstance(params, list):
# 使用sqlparse准确分割SQL语句
statements = self._parse_sql_statements(script)
if len(statements) != len(params):
raise ValueError(
f"语句数量({len(statements)})与参数组数量({len(params)})不匹配"
)
for statement, stmt_params in zip(statements, params):
if statement:
await self.execute(statement, stmt_params)
# 如果无参数使用executescript
else:
await self.execute_script(script)
async def execute_many(self, command: str, seq_of_params: List[tuple]) -> None:
"""执行多条非查询语句"""
conn = await self._get_connection()
try:
await conn.executemany(command, seq_of_params)
await conn.commit()
except Exception as e:
# 记录错误但重新抛出,让调用者处理
raise Exception(f"数据库批量执行失败: {str(e)}") from e
finally:
await self._return_connection(conn)
async def execute_many_values_by_sql_file(
self, file_path: Union[str, Path], seq_of_params: List[tuple]
) -> None:
"""从 SQL 文件中读取一条语句,但是被不同值同时执行"""
path = str(file_path) if isinstance(file_path, Path) else file_path
with open(path, "r", encoding="utf-8") as f:
command = f.read()
await self.execute_many(command, seq_of_params)
async def close_all_connections(self) -> None:
"""关闭所有连接"""
async with self._lock:
# 关闭池中的连接
for conn in self._connection_pool:
try:
await conn.close()
except:
pass
self._connection_pool.clear()
# 关闭正在使用的连接
for conn in self._in_use.copy():
try:
await conn.close()
except:
pass
self._in_use.clear()