Merge pull request '修复偶发的数据库连接失效问题' (#72) from fix/database-lock into master
All checks were successful
continuous-integration/drone/push Build is passing

Reviewed-on: #72
This commit is contained in:
2026-04-28 00:12:44 +08:00
4 changed files with 57 additions and 78 deletions

View File

@ -1,6 +1,7 @@
from contextlib import asynccontextmanager
import os
import asyncio
from loguru import logger
import sqlparse
from pathlib import Path
from typing import List, Dict, Any, Optional, Union, TYPE_CHECKING
@ -10,10 +11,20 @@ import aiosqlite
if TYPE_CHECKING:
from . import DatabaseManager
# 全局数据库管理器实例
_global_db_manager: Optional["DatabaseManager"] = None
async def try_close_connection(conn: aiosqlite.Connection) -> bool:
try:
await conn.close()
return True
except Exception as e:
logger.error("有的连接关闭失败了")
logger.exception(e)
return False
def get_global_db_manager() -> "DatabaseManager":
"""获取全局数据库管理器实例"""
global _global_db_manager
@ -24,16 +35,10 @@ def get_global_db_manager() -> "DatabaseManager":
return _global_db_manager
def close_global_db_manager() -> None:
"""关闭全局数据库管理器实例"""
global _global_db_manager
if _global_db_manager is not None:
# 注意这个函数应该在async环境中调用close_all_connections
_global_db_manager = None
class DatabaseManager:
"""异步数据库管理器"""
"""
异步数据库管理器
"""
def __init__(self, db_path: Optional[Union[str, Path]] = None, pool_size: int = 5):
"""
@ -56,6 +61,7 @@ class DatabaseManager:
async def _get_connection(self) -> aiosqlite.Connection:
"""从连接池获取连接"""
async with self._lock:
# 尝试从池中获取现有连接
while self._connection_pool:
@ -67,10 +73,7 @@ class DatabaseManager:
return conn
except:
# 连接已失效,关闭它
try:
await conn.close()
except:
pass
await try_close_connection(conn)
# 如果连接池为空,创建新连接
conn = await aiosqlite.connect(self.db_path)
@ -86,16 +89,31 @@ class DatabaseManager:
self._connection_pool.append(conn)
else:
# 池已满,直接关闭连接
try:
await conn.close()
except:
pass
await try_close_connection(conn)
@asynccontextmanager
async def get_conn(self):
"""
从 db 中获取一个 Connection
"""
conn = await self._get_connection()
yield conn
await self._return_connection(conn)
try:
yield conn
# 只有当一切正常时才归还数据库连接
await self._return_connection(conn)
except Exception as e:
logger.error("有模块使用一个连接时出现了错误")
logger.exception(e)
try:
await conn.rollback()
await conn.close()
except Exception as e:
logger.error("在 Rollback 和关闭时也出现了问题")
logger.exception(e)
async def query(
self, query: str, params: Optional[tuple] = None
@ -190,42 +208,14 @@ class DatabaseManager:
else:
await self.execute_script(script)
async def execute_many(self, command: str, seq_of_params: List[tuple]) -> None:
"""执行多条非查询语句"""
conn = await self._get_connection()
try:
await conn.executemany(command, seq_of_params)
await conn.commit()
except Exception as e:
await conn.rollback()
raise Exception(f"数据库批量执行失败: {str(e)}") from e
finally:
await self._return_connection(conn)
async def execute_many_values_by_sql_file(
self, file_path: Union[str, Path], seq_of_params: List[tuple]
) -> None:
"""从 SQL 文件中读取一条语句,但是被不同值同时执行"""
path = str(file_path) if isinstance(file_path, Path) else file_path
with open(path, "r", encoding="utf-8") as f:
command = f.read()
await self.execute_many(command, seq_of_params)
async def close_all_connections(self) -> None:
"""关闭所有连接"""
async with self._lock:
# 关闭池中的连接
for conn in self._connection_pool:
try:
await conn.close()
except:
pass
await try_close_connection(conn)
self._connection_pool.clear()
# 关闭正在使用的连接
for conn in self._in_use.copy():
try:
await conn.close()
except:
pass
await try_close_connection(conn)
self._in_use.clear()

View File

@ -32,7 +32,7 @@ class PermManager:
def __init__(self, db: DatabaseManager) -> None:
self.db = db
async def check_has_permission_info(self, entities: _EntityLike, key: str):
async def get_permission_info(self, entities: _EntityLike, key: str):
entities = await _to_entity_chain(entities)
key = key.removesuffix("*").removesuffix(".")
key_split = key.split(".")
@ -52,7 +52,7 @@ class PermManager:
return None
async def check_has_permission(self, entities: _EntityLike, key: str) -> bool:
res = await self.check_has_permission_info(entities, key)
res = await self.get_permission_info(entities, key)
if res is None:
return False
return res[2]

View File

@ -43,15 +43,12 @@ class PermRepo:
Raises:
AssertionError: 如果创建后无法获取实体 ID。
"""
try:
await self.conn.execute(
s("create_entity.sql"),
(entity.platform, entity.entity_type, entity.external_id),
)
await self.conn.commit()
except Exception:
await self.conn.rollback()
raise
await self.conn.execute(
s("create_entity.sql"),
(entity.platform, entity.entity_type, entity.external_id),
)
await self.conn.commit()
eid = await self._get_entity_id_or_none(entity)
assert eid is not None
return eid
@ -119,12 +116,8 @@ class PermRepo:
value: 要设置的配置值True/False/None
"""
eid = await self.get_entity_id(entity)
try:
await self.conn.execute(s("update_perm_info.sql"), (eid, config_key, value))
await self.conn.commit()
except Exception:
await self.conn.rollback()
raise
await self.conn.execute(s("update_perm_info.sql"), (eid, config_key, value))
await self.conn.commit()
async def get_entity_id_batch(
self, entities: list[PermEntity]
@ -143,15 +136,11 @@ class PermRepo:
# s("create_entity.sql"),
# (entity.platform, entity.entity_type, entity.external_id),
# )
try:
await self.conn.executemany(
s("create_entity.sql"),
[(e.platform, e.entity_type, e.external_id) for e in entities],
)
await self.conn.commit()
except Exception:
await self.conn.rollback()
raise
await self.conn.executemany(
s("create_entity.sql"),
[(e.platform, e.entity_type, e.external_id) for e in entities],
)
await self.conn.commit()
val_placeholders = ", ".join(["(?, ?, ?)"] * len(entities))
params = []
for e in entities:

View File

@ -77,7 +77,7 @@ async def get_permission(
perm: str,
event: Event,
):
data = await pm.check_has_permission_info(ec, perm)
data = await pm.get_permission_info(ec, perm)
obj_s = f"{ec[0].platform}.{ec[0].entity_type}.{ec[0].external_id}"