forked from mttu-developers/konabot
Merge branch 'master' into marchtoy_gl
This commit is contained in:
@ -5,6 +5,8 @@ ENV VIRTUAL_ENV=/app/.venv \
|
||||
PLAYWRIGHT_BROWSERS_PATH=/usr/lib/pw-browsers
|
||||
|
||||
# 安装所有都需要的底层依赖
|
||||
#
|
||||
# xz-utils: 解压需要它
|
||||
RUN apt-get update && \
|
||||
apt-get install -y --no-install-recommends \
|
||||
libfontconfig1 libgl1 libegl1 libglvnd0 mesa-vulkan-drivers at-spi2-common fontconfig \
|
||||
@ -16,6 +18,7 @@ RUN apt-get update && \
|
||||
libatk-bridge2.0-0t64 libatspi2.0-0t64 libxcomposite1 libxdamage1 libxfixes3 \
|
||||
libxkbcommon0 libasound2t64 libnss3 fonts-noto-cjk fonts-noto-cjk-extra \
|
||||
fonts-noto-color-emoji \
|
||||
xz-utils \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
|
||||
|
||||
@ -50,7 +50,14 @@ class ArtifactDepends:
|
||||
tasks = set()
|
||||
for f in self.callbacks:
|
||||
tasks.add(f(downloaded))
|
||||
return await asyncio.gather(*tasks, return_exceptions=True)
|
||||
result = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
for r in result:
|
||||
if isinstance(r, BaseException):
|
||||
logger.warning("完成了二进制文件的下载,但是有未捕捉的错误")
|
||||
logger.exception(r)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
class Config(BaseModel):
|
||||
@ -73,12 +80,7 @@ async def _():
|
||||
|
||||
async def _task(artifact: ArtifactDepends):
|
||||
async with semaphore:
|
||||
downloaded = await ensure_artifact(artifact)
|
||||
result = await artifact._finished(downloaded)
|
||||
for r in result:
|
||||
if isinstance(r, BaseException):
|
||||
logger.warning("完成了二进制文件的下载,但是有未捕捉的错误")
|
||||
logger.exception(r)
|
||||
await ensure_artifact(artifact)
|
||||
|
||||
tasks: set[asyncio.Task] = set()
|
||||
for a in artifact_list:
|
||||
@ -116,9 +118,16 @@ async def download_artifact(artifact: ArtifactDepends):
|
||||
f"下载到的二进制的 sha256 与需求不同 TARGET={artifact.target} REQUESTED={artifact.sha256} ACTUAL={m.hexdigest()}"
|
||||
)
|
||||
|
||||
await artifact._finished(True)
|
||||
|
||||
|
||||
async def ensure_artifact(artifact: ArtifactDepends) -> bool:
|
||||
"""
|
||||
确保所需的二进制存在。返回是否下载了这个二进制文件。
|
||||
"""
|
||||
|
||||
if not artifact.is_corresponding_platform():
|
||||
logger.debug(f"所需求的平台不是当前平台,跳过二进制下载 artifact={artifact}")
|
||||
return False
|
||||
|
||||
if not artifact.target.exists():
|
||||
@ -136,6 +145,7 @@ async def ensure_artifact(artifact: ArtifactDepends) -> bool:
|
||||
artifact.target.unlink()
|
||||
await download_artifact(artifact)
|
||||
return True
|
||||
await artifact._finished(False)
|
||||
return False
|
||||
|
||||
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
from contextlib import asynccontextmanager
|
||||
import os
|
||||
import asyncio
|
||||
from loguru import logger
|
||||
import sqlparse
|
||||
from pathlib import Path
|
||||
from typing import List, Dict, Any, Optional, Union, TYPE_CHECKING
|
||||
@ -10,10 +11,20 @@ import aiosqlite
|
||||
if TYPE_CHECKING:
|
||||
from . import DatabaseManager
|
||||
|
||||
# 全局数据库管理器实例
|
||||
_global_db_manager: Optional["DatabaseManager"] = None
|
||||
|
||||
|
||||
async def try_close_connection(conn: aiosqlite.Connection) -> bool:
|
||||
try:
|
||||
await conn.close()
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error("有的连接关闭失败了")
|
||||
logger.exception(e)
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def get_global_db_manager() -> "DatabaseManager":
|
||||
"""获取全局数据库管理器实例"""
|
||||
global _global_db_manager
|
||||
@ -24,16 +35,10 @@ def get_global_db_manager() -> "DatabaseManager":
|
||||
return _global_db_manager
|
||||
|
||||
|
||||
def close_global_db_manager() -> None:
|
||||
"""关闭全局数据库管理器实例"""
|
||||
global _global_db_manager
|
||||
if _global_db_manager is not None:
|
||||
# 注意:这个函数应该在async环境中调用close_all_connections
|
||||
_global_db_manager = None
|
||||
|
||||
|
||||
class DatabaseManager:
|
||||
"""异步数据库管理器"""
|
||||
"""
|
||||
异步数据库管理器
|
||||
"""
|
||||
|
||||
def __init__(self, db_path: Optional[Union[str, Path]] = None, pool_size: int = 5):
|
||||
"""
|
||||
@ -56,6 +61,7 @@ class DatabaseManager:
|
||||
|
||||
async def _get_connection(self) -> aiosqlite.Connection:
|
||||
"""从连接池获取连接"""
|
||||
|
||||
async with self._lock:
|
||||
# 尝试从池中获取现有连接
|
||||
while self._connection_pool:
|
||||
@ -67,10 +73,7 @@ class DatabaseManager:
|
||||
return conn
|
||||
except:
|
||||
# 连接已失效,关闭它
|
||||
try:
|
||||
await conn.close()
|
||||
except:
|
||||
pass
|
||||
await try_close_connection(conn)
|
||||
|
||||
# 如果连接池为空,创建新连接
|
||||
conn = await aiosqlite.connect(self.db_path)
|
||||
@ -86,16 +89,31 @@ class DatabaseManager:
|
||||
self._connection_pool.append(conn)
|
||||
else:
|
||||
# 池已满,直接关闭连接
|
||||
try:
|
||||
await conn.close()
|
||||
except:
|
||||
pass
|
||||
await try_close_connection(conn)
|
||||
|
||||
@asynccontextmanager
|
||||
async def get_conn(self):
|
||||
"""
|
||||
从 db 中获取一个 Connection
|
||||
"""
|
||||
|
||||
conn = await self._get_connection()
|
||||
yield conn
|
||||
await self._return_connection(conn)
|
||||
|
||||
try:
|
||||
yield conn
|
||||
|
||||
# 只有当一切正常时才归还数据库连接
|
||||
await self._return_connection(conn)
|
||||
except Exception as e:
|
||||
logger.error("有模块使用一个连接时出现了错误")
|
||||
logger.exception(e)
|
||||
|
||||
try:
|
||||
await conn.rollback()
|
||||
await conn.close()
|
||||
except Exception as e:
|
||||
logger.error("在 Rollback 和关闭时也出现了问题")
|
||||
logger.exception(e)
|
||||
|
||||
async def query(
|
||||
self, query: str, params: Optional[tuple] = None
|
||||
@ -190,42 +208,14 @@ class DatabaseManager:
|
||||
else:
|
||||
await self.execute_script(script)
|
||||
|
||||
async def execute_many(self, command: str, seq_of_params: List[tuple]) -> None:
|
||||
"""执行多条非查询语句"""
|
||||
conn = await self._get_connection()
|
||||
try:
|
||||
await conn.executemany(command, seq_of_params)
|
||||
await conn.commit()
|
||||
except Exception as e:
|
||||
await conn.rollback()
|
||||
raise Exception(f"数据库批量执行失败: {str(e)}") from e
|
||||
finally:
|
||||
await self._return_connection(conn)
|
||||
|
||||
async def execute_many_values_by_sql_file(
|
||||
self, file_path: Union[str, Path], seq_of_params: List[tuple]
|
||||
) -> None:
|
||||
"""从 SQL 文件中读取一条语句,但是被不同值同时执行"""
|
||||
path = str(file_path) if isinstance(file_path, Path) else file_path
|
||||
with open(path, "r", encoding="utf-8") as f:
|
||||
command = f.read()
|
||||
await self.execute_many(command, seq_of_params)
|
||||
|
||||
async def close_all_connections(self) -> None:
|
||||
"""关闭所有连接"""
|
||||
|
||||
async with self._lock:
|
||||
# 关闭池中的连接
|
||||
for conn in self._connection_pool:
|
||||
try:
|
||||
await conn.close()
|
||||
except:
|
||||
pass
|
||||
await try_close_connection(conn)
|
||||
self._connection_pool.clear()
|
||||
|
||||
# 关闭正在使用的连接
|
||||
for conn in self._in_use.copy():
|
||||
try:
|
||||
await conn.close()
|
||||
except:
|
||||
pass
|
||||
await try_close_connection(conn)
|
||||
self._in_use.clear()
|
||||
|
||||
@ -32,7 +32,7 @@ class PermManager:
|
||||
def __init__(self, db: DatabaseManager) -> None:
|
||||
self.db = db
|
||||
|
||||
async def check_has_permission_info(self, entities: _EntityLike, key: str):
|
||||
async def get_permission_info(self, entities: _EntityLike, key: str):
|
||||
entities = await _to_entity_chain(entities)
|
||||
key = key.removesuffix("*").removesuffix(".")
|
||||
key_split = key.split(".")
|
||||
@ -52,7 +52,7 @@ class PermManager:
|
||||
return None
|
||||
|
||||
async def check_has_permission(self, entities: _EntityLike, key: str) -> bool:
|
||||
res = await self.check_has_permission_info(entities, key)
|
||||
res = await self.get_permission_info(entities, key)
|
||||
if res is None:
|
||||
return False
|
||||
return res[2]
|
||||
|
||||
@ -43,15 +43,12 @@ class PermRepo:
|
||||
Raises:
|
||||
AssertionError: 如果创建后无法获取实体 ID。
|
||||
"""
|
||||
try:
|
||||
await self.conn.execute(
|
||||
s("create_entity.sql"),
|
||||
(entity.platform, entity.entity_type, entity.external_id),
|
||||
)
|
||||
await self.conn.commit()
|
||||
except Exception:
|
||||
await self.conn.rollback()
|
||||
raise
|
||||
await self.conn.execute(
|
||||
s("create_entity.sql"),
|
||||
(entity.platform, entity.entity_type, entity.external_id),
|
||||
)
|
||||
await self.conn.commit()
|
||||
|
||||
eid = await self._get_entity_id_or_none(entity)
|
||||
assert eid is not None
|
||||
return eid
|
||||
@ -119,12 +116,8 @@ class PermRepo:
|
||||
value: 要设置的配置值(True/False/None)。
|
||||
"""
|
||||
eid = await self.get_entity_id(entity)
|
||||
try:
|
||||
await self.conn.execute(s("update_perm_info.sql"), (eid, config_key, value))
|
||||
await self.conn.commit()
|
||||
except Exception:
|
||||
await self.conn.rollback()
|
||||
raise
|
||||
await self.conn.execute(s("update_perm_info.sql"), (eid, config_key, value))
|
||||
await self.conn.commit()
|
||||
|
||||
async def get_entity_id_batch(
|
||||
self, entities: list[PermEntity]
|
||||
@ -143,15 +136,11 @@ class PermRepo:
|
||||
# s("create_entity.sql"),
|
||||
# (entity.platform, entity.entity_type, entity.external_id),
|
||||
# )
|
||||
try:
|
||||
await self.conn.executemany(
|
||||
s("create_entity.sql"),
|
||||
[(e.platform, e.entity_type, e.external_id) for e in entities],
|
||||
)
|
||||
await self.conn.commit()
|
||||
except Exception:
|
||||
await self.conn.rollback()
|
||||
raise
|
||||
await self.conn.executemany(
|
||||
s("create_entity.sql"),
|
||||
[(e.platform, e.entity_type, e.external_id) for e in entities],
|
||||
)
|
||||
await self.conn.commit()
|
||||
val_placeholders = ", ".join(["(?, ?, ?)"] * len(entities))
|
||||
params = []
|
||||
for e in entities:
|
||||
|
||||
@ -77,7 +77,7 @@ async def get_permission(
|
||||
perm: str,
|
||||
event: Event,
|
||||
):
|
||||
data = await pm.check_has_permission_info(ec, perm)
|
||||
data = await pm.get_permission_info(ec, perm)
|
||||
|
||||
obj_s = f"{ec[0].platform}.{ec[0].entity_type}.{ec[0].external_id}"
|
||||
|
||||
|
||||
@ -41,6 +41,7 @@ bin_path: Path | None = None
|
||||
|
||||
@arti_typst_linux.on_finished
|
||||
async def _(downloaded: bool):
|
||||
logger.debug("安装好了 Linux 版本的 Typst")
|
||||
global bin_path
|
||||
|
||||
tar_path = arti_typst_linux.target
|
||||
@ -71,6 +72,7 @@ async def _(downloaded: bool):
|
||||
|
||||
@arti_typst_windows.on_finished
|
||||
async def _(downloaded: bool):
|
||||
logger.debug("安装好了 Windows 版本的 Typst")
|
||||
global bin_path
|
||||
zip_path = arti_typst_windows.target
|
||||
bin_path = BINARY_PATH / "typst.exe"
|
||||
@ -160,6 +162,7 @@ async def _(
|
||||
# 对于本地机器,一般不会在应用启动时自动下载,这里再保证存在
|
||||
await ensure_artifact(arti_typst_linux)
|
||||
await ensure_artifact(arti_typst_windows)
|
||||
|
||||
if bin_path is None or not bin_path.exists():
|
||||
logger.warning("当前环境不存在 Typst,但仍然调用了")
|
||||
return
|
||||
|
||||
Reference in New Issue
Block a user