Compare commits

..

9 Commits

7 changed files with 80 additions and 85 deletions

View File

@ -5,6 +5,8 @@ ENV VIRTUAL_ENV=/app/.venv \
PLAYWRIGHT_BROWSERS_PATH=/usr/lib/pw-browsers PLAYWRIGHT_BROWSERS_PATH=/usr/lib/pw-browsers
# 安装所有都需要的底层依赖 # 安装所有都需要的底层依赖
#
# xz-utils: 解压需要它
RUN apt-get update && \ RUN apt-get update && \
apt-get install -y --no-install-recommends \ apt-get install -y --no-install-recommends \
libfontconfig1 libgl1 libegl1 libglvnd0 mesa-vulkan-drivers at-spi2-common fontconfig \ libfontconfig1 libgl1 libegl1 libglvnd0 mesa-vulkan-drivers at-spi2-common fontconfig \
@ -16,6 +18,7 @@ RUN apt-get update && \
libatk-bridge2.0-0t64 libatspi2.0-0t64 libxcomposite1 libxdamage1 libxfixes3 \ libatk-bridge2.0-0t64 libatspi2.0-0t64 libxcomposite1 libxdamage1 libxfixes3 \
libxkbcommon0 libasound2t64 libnss3 fonts-noto-cjk fonts-noto-cjk-extra \ libxkbcommon0 libasound2t64 libnss3 fonts-noto-cjk fonts-noto-cjk-extra \
fonts-noto-color-emoji \ fonts-noto-color-emoji \
xz-utils \
&& rm -rf /var/lib/apt/lists/* && rm -rf /var/lib/apt/lists/*

View File

@ -50,7 +50,14 @@ class ArtifactDepends:
tasks = set() tasks = set()
for f in self.callbacks: for f in self.callbacks:
tasks.add(f(downloaded)) tasks.add(f(downloaded))
return await asyncio.gather(*tasks, return_exceptions=True) result = await asyncio.gather(*tasks, return_exceptions=True)
for r in result:
if isinstance(r, BaseException):
logger.warning("完成了二进制文件的下载,但是有未捕捉的错误")
logger.exception(r)
return result
class Config(BaseModel): class Config(BaseModel):
@ -73,12 +80,7 @@ async def _():
async def _task(artifact: ArtifactDepends): async def _task(artifact: ArtifactDepends):
async with semaphore: async with semaphore:
downloaded = await ensure_artifact(artifact) await ensure_artifact(artifact)
result = await artifact._finished(downloaded)
for r in result:
if isinstance(r, BaseException):
logger.warning("完成了二进制文件的下载,但是有未捕捉的错误")
logger.exception(r)
tasks: set[asyncio.Task] = set() tasks: set[asyncio.Task] = set()
for a in artifact_list: for a in artifact_list:
@ -116,9 +118,16 @@ async def download_artifact(artifact: ArtifactDepends):
f"下载到的二进制的 sha256 与需求不同 TARGET={artifact.target} REQUESTED={artifact.sha256} ACTUAL={m.hexdigest()}" f"下载到的二进制的 sha256 与需求不同 TARGET={artifact.target} REQUESTED={artifact.sha256} ACTUAL={m.hexdigest()}"
) )
await artifact._finished(True)
async def ensure_artifact(artifact: ArtifactDepends) -> bool: async def ensure_artifact(artifact: ArtifactDepends) -> bool:
"""
确保所需的二进制存在。返回是否下载了这个二进制文件。
"""
if not artifact.is_corresponding_platform(): if not artifact.is_corresponding_platform():
logger.debug(f"所需求的平台不是当前平台,跳过二进制下载 artifact={artifact}")
return False return False
if not artifact.target.exists(): if not artifact.target.exists():
@ -136,6 +145,7 @@ async def ensure_artifact(artifact: ArtifactDepends) -> bool:
artifact.target.unlink() artifact.target.unlink()
await download_artifact(artifact) await download_artifact(artifact)
return True return True
await artifact._finished(False)
return False return False

View File

@ -1,6 +1,7 @@
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
import os import os
import asyncio import asyncio
from loguru import logger
import sqlparse import sqlparse
from pathlib import Path from pathlib import Path
from typing import List, Dict, Any, Optional, Union, TYPE_CHECKING from typing import List, Dict, Any, Optional, Union, TYPE_CHECKING
@ -10,10 +11,20 @@ import aiosqlite
if TYPE_CHECKING: if TYPE_CHECKING:
from . import DatabaseManager from . import DatabaseManager
# 全局数据库管理器实例
_global_db_manager: Optional["DatabaseManager"] = None _global_db_manager: Optional["DatabaseManager"] = None
async def try_close_connection(conn: aiosqlite.Connection) -> bool:
try:
await conn.close()
return True
except Exception as e:
logger.error("有的连接关闭失败了")
logger.exception(e)
return False
def get_global_db_manager() -> "DatabaseManager": def get_global_db_manager() -> "DatabaseManager":
"""获取全局数据库管理器实例""" """获取全局数据库管理器实例"""
global _global_db_manager global _global_db_manager
@ -24,16 +35,10 @@ def get_global_db_manager() -> "DatabaseManager":
return _global_db_manager return _global_db_manager
def close_global_db_manager() -> None:
"""关闭全局数据库管理器实例"""
global _global_db_manager
if _global_db_manager is not None:
# 注意这个函数应该在async环境中调用close_all_connections
_global_db_manager = None
class DatabaseManager: class DatabaseManager:
"""异步数据库管理器""" """
异步数据库管理器
"""
def __init__(self, db_path: Optional[Union[str, Path]] = None, pool_size: int = 5): def __init__(self, db_path: Optional[Union[str, Path]] = None, pool_size: int = 5):
""" """
@ -56,6 +61,7 @@ class DatabaseManager:
async def _get_connection(self) -> aiosqlite.Connection: async def _get_connection(self) -> aiosqlite.Connection:
"""从连接池获取连接""" """从连接池获取连接"""
async with self._lock: async with self._lock:
# 尝试从池中获取现有连接 # 尝试从池中获取现有连接
while self._connection_pool: while self._connection_pool:
@ -67,10 +73,7 @@ class DatabaseManager:
return conn return conn
except: except:
# 连接已失效,关闭它 # 连接已失效,关闭它
try: await try_close_connection(conn)
await conn.close()
except:
pass
# 如果连接池为空,创建新连接 # 如果连接池为空,创建新连接
conn = await aiosqlite.connect(self.db_path) conn = await aiosqlite.connect(self.db_path)
@ -86,16 +89,31 @@ class DatabaseManager:
self._connection_pool.append(conn) self._connection_pool.append(conn)
else: else:
# 池已满,直接关闭连接 # 池已满,直接关闭连接
try: await try_close_connection(conn)
await conn.close()
except:
pass
@asynccontextmanager @asynccontextmanager
async def get_conn(self): async def get_conn(self):
"""
从 db 中获取一个 Connection
"""
conn = await self._get_connection() conn = await self._get_connection()
yield conn
await self._return_connection(conn) try:
yield conn
# 只有当一切正常时才归还数据库连接
await self._return_connection(conn)
except Exception as e:
logger.error("有模块使用一个连接时出现了错误")
logger.exception(e)
try:
await conn.rollback()
await conn.close()
except Exception as e:
logger.error("在 Rollback 和关闭时也出现了问题")
logger.exception(e)
async def query( async def query(
self, query: str, params: Optional[tuple] = None self, query: str, params: Optional[tuple] = None
@ -190,42 +208,14 @@ class DatabaseManager:
else: else:
await self.execute_script(script) await self.execute_script(script)
async def execute_many(self, command: str, seq_of_params: List[tuple]) -> None:
"""执行多条非查询语句"""
conn = await self._get_connection()
try:
await conn.executemany(command, seq_of_params)
await conn.commit()
except Exception as e:
await conn.rollback()
raise Exception(f"数据库批量执行失败: {str(e)}") from e
finally:
await self._return_connection(conn)
async def execute_many_values_by_sql_file(
self, file_path: Union[str, Path], seq_of_params: List[tuple]
) -> None:
"""从 SQL 文件中读取一条语句,但是被不同值同时执行"""
path = str(file_path) if isinstance(file_path, Path) else file_path
with open(path, "r", encoding="utf-8") as f:
command = f.read()
await self.execute_many(command, seq_of_params)
async def close_all_connections(self) -> None: async def close_all_connections(self) -> None:
"""关闭所有连接""" """关闭所有连接"""
async with self._lock: async with self._lock:
# 关闭池中的连接
for conn in self._connection_pool: for conn in self._connection_pool:
try: await try_close_connection(conn)
await conn.close()
except:
pass
self._connection_pool.clear() self._connection_pool.clear()
# 关闭正在使用的连接
for conn in self._in_use.copy(): for conn in self._in_use.copy():
try: await try_close_connection(conn)
await conn.close()
except:
pass
self._in_use.clear() self._in_use.clear()

View File

@ -32,7 +32,7 @@ class PermManager:
def __init__(self, db: DatabaseManager) -> None: def __init__(self, db: DatabaseManager) -> None:
self.db = db self.db = db
async def check_has_permission_info(self, entities: _EntityLike, key: str): async def get_permission_info(self, entities: _EntityLike, key: str):
entities = await _to_entity_chain(entities) entities = await _to_entity_chain(entities)
key = key.removesuffix("*").removesuffix(".") key = key.removesuffix("*").removesuffix(".")
key_split = key.split(".") key_split = key.split(".")
@ -52,7 +52,7 @@ class PermManager:
return None return None
async def check_has_permission(self, entities: _EntityLike, key: str) -> bool: async def check_has_permission(self, entities: _EntityLike, key: str) -> bool:
res = await self.check_has_permission_info(entities, key) res = await self.get_permission_info(entities, key)
if res is None: if res is None:
return False return False
return res[2] return res[2]

View File

@ -43,15 +43,12 @@ class PermRepo:
Raises: Raises:
AssertionError: 如果创建后无法获取实体 ID。 AssertionError: 如果创建后无法获取实体 ID。
""" """
try: await self.conn.execute(
await self.conn.execute( s("create_entity.sql"),
s("create_entity.sql"), (entity.platform, entity.entity_type, entity.external_id),
(entity.platform, entity.entity_type, entity.external_id), )
) await self.conn.commit()
await self.conn.commit()
except Exception:
await self.conn.rollback()
raise
eid = await self._get_entity_id_or_none(entity) eid = await self._get_entity_id_or_none(entity)
assert eid is not None assert eid is not None
return eid return eid
@ -119,12 +116,8 @@ class PermRepo:
value: 要设置的配置值True/False/None value: 要设置的配置值True/False/None
""" """
eid = await self.get_entity_id(entity) eid = await self.get_entity_id(entity)
try: await self.conn.execute(s("update_perm_info.sql"), (eid, config_key, value))
await self.conn.execute(s("update_perm_info.sql"), (eid, config_key, value)) await self.conn.commit()
await self.conn.commit()
except Exception:
await self.conn.rollback()
raise
async def get_entity_id_batch( async def get_entity_id_batch(
self, entities: list[PermEntity] self, entities: list[PermEntity]
@ -143,15 +136,11 @@ class PermRepo:
# s("create_entity.sql"), # s("create_entity.sql"),
# (entity.platform, entity.entity_type, entity.external_id), # (entity.platform, entity.entity_type, entity.external_id),
# ) # )
try: await self.conn.executemany(
await self.conn.executemany( s("create_entity.sql"),
s("create_entity.sql"), [(e.platform, e.entity_type, e.external_id) for e in entities],
[(e.platform, e.entity_type, e.external_id) for e in entities], )
) await self.conn.commit()
await self.conn.commit()
except Exception:
await self.conn.rollback()
raise
val_placeholders = ", ".join(["(?, ?, ?)"] * len(entities)) val_placeholders = ", ".join(["(?, ?, ?)"] * len(entities))
params = [] params = []
for e in entities: for e in entities:

View File

@ -77,7 +77,7 @@ async def get_permission(
perm: str, perm: str,
event: Event, event: Event,
): ):
data = await pm.check_has_permission_info(ec, perm) data = await pm.get_permission_info(ec, perm)
obj_s = f"{ec[0].platform}.{ec[0].entity_type}.{ec[0].external_id}" obj_s = f"{ec[0].platform}.{ec[0].entity_type}.{ec[0].external_id}"

View File

@ -41,6 +41,7 @@ bin_path: Path | None = None
@arti_typst_linux.on_finished @arti_typst_linux.on_finished
async def _(downloaded: bool): async def _(downloaded: bool):
logger.debug("安装好了 Linux 版本的 Typst")
global bin_path global bin_path
tar_path = arti_typst_linux.target tar_path = arti_typst_linux.target
@ -71,6 +72,7 @@ async def _(downloaded: bool):
@arti_typst_windows.on_finished @arti_typst_windows.on_finished
async def _(downloaded: bool): async def _(downloaded: bool):
logger.debug("安装好了 Windows 版本的 Typst")
global bin_path global bin_path
zip_path = arti_typst_windows.target zip_path = arti_typst_windows.target
bin_path = BINARY_PATH / "typst.exe" bin_path = BINARY_PATH / "typst.exe"
@ -160,6 +162,7 @@ async def _(
# 对于本地机器,一般不会在应用启动时自动下载,这里再保证存在 # 对于本地机器,一般不会在应用启动时自动下载,这里再保证存在
await ensure_artifact(arti_typst_linux) await ensure_artifact(arti_typst_linux)
await ensure_artifact(arti_typst_windows) await ensure_artifact(arti_typst_windows)
if bin_path is None or not bin_path.exists(): if bin_path is None or not bin_path.exists():
logger.warning("当前环境不存在 Typst但仍然调用了") logger.warning("当前环境不存在 Typst但仍然调用了")
return return