在一个更统一的地方管理 connection 的 rollback 和丢弃

This commit is contained in:
2026-04-27 23:33:53 +08:00
parent 6b152235cf
commit 4d4bbc86dc
4 changed files with 57 additions and 78 deletions

View File

@ -1,6 +1,7 @@
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
import os import os
import asyncio import asyncio
from loguru import logger
import sqlparse import sqlparse
from pathlib import Path from pathlib import Path
from typing import List, Dict, Any, Optional, Union, TYPE_CHECKING from typing import List, Dict, Any, Optional, Union, TYPE_CHECKING
@ -10,10 +11,20 @@ import aiosqlite
if TYPE_CHECKING: if TYPE_CHECKING:
from . import DatabaseManager from . import DatabaseManager
# 全局数据库管理器实例
_global_db_manager: Optional["DatabaseManager"] = None _global_db_manager: Optional["DatabaseManager"] = None
async def try_close_connection(conn: aiosqlite.Connection) -> bool:
try:
await conn.close()
return True
except Exception as e:
logger.error("有的连接关闭失败了")
logger.exception(e)
return False
def get_global_db_manager() -> "DatabaseManager": def get_global_db_manager() -> "DatabaseManager":
"""获取全局数据库管理器实例""" """获取全局数据库管理器实例"""
global _global_db_manager global _global_db_manager
@ -24,16 +35,10 @@ def get_global_db_manager() -> "DatabaseManager":
return _global_db_manager return _global_db_manager
def close_global_db_manager() -> None:
"""关闭全局数据库管理器实例"""
global _global_db_manager
if _global_db_manager is not None:
# 注意这个函数应该在async环境中调用close_all_connections
_global_db_manager = None
class DatabaseManager: class DatabaseManager:
"""异步数据库管理器""" """
异步数据库管理器
"""
def __init__(self, db_path: Optional[Union[str, Path]] = None, pool_size: int = 5): def __init__(self, db_path: Optional[Union[str, Path]] = None, pool_size: int = 5):
""" """
@ -56,6 +61,7 @@ class DatabaseManager:
async def _get_connection(self) -> aiosqlite.Connection: async def _get_connection(self) -> aiosqlite.Connection:
"""从连接池获取连接""" """从连接池获取连接"""
async with self._lock: async with self._lock:
# 尝试从池中获取现有连接 # 尝试从池中获取现有连接
while self._connection_pool: while self._connection_pool:
@ -67,10 +73,7 @@ class DatabaseManager:
return conn return conn
except: except:
# 连接已失效,关闭它 # 连接已失效,关闭它
try: await try_close_connection(conn)
await conn.close()
except:
pass
# 如果连接池为空,创建新连接 # 如果连接池为空,创建新连接
conn = await aiosqlite.connect(self.db_path) conn = await aiosqlite.connect(self.db_path)
@ -86,16 +89,31 @@ class DatabaseManager:
self._connection_pool.append(conn) self._connection_pool.append(conn)
else: else:
# 池已满,直接关闭连接 # 池已满,直接关闭连接
try: await try_close_connection(conn)
await conn.close()
except:
pass
@asynccontextmanager @asynccontextmanager
async def get_conn(self): async def get_conn(self):
"""
从 db 中获取一个 Connection
"""
conn = await self._get_connection() conn = await self._get_connection()
yield conn
await self._return_connection(conn) try:
yield conn
# 只有当一切正常时才归还数据库连接
await self._return_connection(conn)
except Exception as e:
logger.error("有模块使用一个连接时出现了错误")
logger.exception(e)
try:
await conn.rollback()
await conn.close()
except Exception as e:
logger.error("在 Rollback 和关闭时也出现了问题")
logger.exception(e)
async def query( async def query(
self, query: str, params: Optional[tuple] = None self, query: str, params: Optional[tuple] = None
@ -190,42 +208,14 @@ class DatabaseManager:
else: else:
await self.execute_script(script) await self.execute_script(script)
async def execute_many(self, command: str, seq_of_params: List[tuple]) -> None:
"""执行多条非查询语句"""
conn = await self._get_connection()
try:
await conn.executemany(command, seq_of_params)
await conn.commit()
except Exception as e:
await conn.rollback()
raise Exception(f"数据库批量执行失败: {str(e)}") from e
finally:
await self._return_connection(conn)
async def execute_many_values_by_sql_file(
self, file_path: Union[str, Path], seq_of_params: List[tuple]
) -> None:
"""从 SQL 文件中读取一条语句,但是被不同值同时执行"""
path = str(file_path) if isinstance(file_path, Path) else file_path
with open(path, "r", encoding="utf-8") as f:
command = f.read()
await self.execute_many(command, seq_of_params)
async def close_all_connections(self) -> None: async def close_all_connections(self) -> None:
"""关闭所有连接""" """关闭所有连接"""
async with self._lock: async with self._lock:
# 关闭池中的连接
for conn in self._connection_pool: for conn in self._connection_pool:
try: await try_close_connection(conn)
await conn.close()
except:
pass
self._connection_pool.clear() self._connection_pool.clear()
# 关闭正在使用的连接
for conn in self._in_use.copy(): for conn in self._in_use.copy():
try: await try_close_connection(conn)
await conn.close()
except:
pass
self._in_use.clear() self._in_use.clear()

View File

@ -32,7 +32,7 @@ class PermManager:
def __init__(self, db: DatabaseManager) -> None: def __init__(self, db: DatabaseManager) -> None:
self.db = db self.db = db
async def check_has_permission_info(self, entities: _EntityLike, key: str): async def get_permission_info(self, entities: _EntityLike, key: str):
entities = await _to_entity_chain(entities) entities = await _to_entity_chain(entities)
key = key.removesuffix("*").removesuffix(".") key = key.removesuffix("*").removesuffix(".")
key_split = key.split(".") key_split = key.split(".")
@ -52,7 +52,7 @@ class PermManager:
return None return None
async def check_has_permission(self, entities: _EntityLike, key: str) -> bool: async def check_has_permission(self, entities: _EntityLike, key: str) -> bool:
res = await self.check_has_permission_info(entities, key) res = await self.get_permission_info(entities, key)
if res is None: if res is None:
return False return False
return res[2] return res[2]

View File

@ -43,15 +43,12 @@ class PermRepo:
Raises: Raises:
AssertionError: 如果创建后无法获取实体 ID。 AssertionError: 如果创建后无法获取实体 ID。
""" """
try: await self.conn.execute(
await self.conn.execute( s("create_entity.sql"),
s("create_entity.sql"), (entity.platform, entity.entity_type, entity.external_id),
(entity.platform, entity.entity_type, entity.external_id), )
) await self.conn.commit()
await self.conn.commit()
except Exception:
await self.conn.rollback()
raise
eid = await self._get_entity_id_or_none(entity) eid = await self._get_entity_id_or_none(entity)
assert eid is not None assert eid is not None
return eid return eid
@ -119,12 +116,8 @@ class PermRepo:
value: 要设置的配置值True/False/None value: 要设置的配置值True/False/None
""" """
eid = await self.get_entity_id(entity) eid = await self.get_entity_id(entity)
try: await self.conn.execute(s("update_perm_info.sql"), (eid, config_key, value))
await self.conn.execute(s("update_perm_info.sql"), (eid, config_key, value)) await self.conn.commit()
await self.conn.commit()
except Exception:
await self.conn.rollback()
raise
async def get_entity_id_batch( async def get_entity_id_batch(
self, entities: list[PermEntity] self, entities: list[PermEntity]
@ -143,15 +136,11 @@ class PermRepo:
# s("create_entity.sql"), # s("create_entity.sql"),
# (entity.platform, entity.entity_type, entity.external_id), # (entity.platform, entity.entity_type, entity.external_id),
# ) # )
try: await self.conn.executemany(
await self.conn.executemany( s("create_entity.sql"),
s("create_entity.sql"), [(e.platform, e.entity_type, e.external_id) for e in entities],
[(e.platform, e.entity_type, e.external_id) for e in entities], )
) await self.conn.commit()
await self.conn.commit()
except Exception:
await self.conn.rollback()
raise
val_placeholders = ", ".join(["(?, ?, ?)"] * len(entities)) val_placeholders = ", ".join(["(?, ?, ?)"] * len(entities))
params = [] params = []
for e in entities: for e in entities:

View File

@ -77,7 +77,7 @@ async def get_permission(
perm: str, perm: str,
event: Event, event: Event,
): ):
data = await pm.check_has_permission_info(ec, perm) data = await pm.get_permission_info(ec, perm)
obj_s = f"{ec[0].platform}.{ec[0].entity_type}.{ec[0].external_id}" obj_s = f"{ec[0].platform}.{ec[0].entity_type}.{ec[0].external_id}"