Files
konabot/konabot/common/permsys/repo.py
passthem 0ba51bc9b2
All checks were successful
continuous-integration/drone/push Build is passing
修复 rollback 失效问题
2026-04-18 10:53:53 +08:00

255 lines
8.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from dataclasses import dataclass
import math
from pathlib import Path
import aiosqlite
from konabot.common.pager import PagerQuery, PagerResult
from .entity import PermEntity
def s(p: str):
"""读取 SQL 文件内容。
Args:
p: SQL 文件名(相对于当前文件所在目录的 sql/ 子目录)。
Returns:
SQL 文件的内容字符串。
"""
return (Path(__file__).parent / "./sql/" / p).read_text()
@dataclass
class PermRepo:
"""权限实体存储库,负责与数据库交互管理权限实体。
Attributes:
conn: aiosqlite 数据库连接对象。
"""
conn: aiosqlite.Connection
async def create_entity(self, entity: PermEntity) -> int:
"""创建新的权限实体并返回其 ID。
Args:
entity: 要创建的权限实体对象。
Returns:
新创建实体的数据库 ID。
Raises:
AssertionError: 如果创建后无法获取实体 ID。
"""
try:
await self.conn.execute(
s("create_entity.sql"),
(entity.platform, entity.entity_type, entity.external_id),
)
await self.conn.commit()
except Exception:
await self.conn.rollback()
raise
eid = await self._get_entity_id_or_none(entity)
assert eid is not None
return eid
async def _get_entity_id_or_none(self, entity: PermEntity) -> int | None:
"""查询实体 ID如果不存在则返回 None。
Args:
entity: 要查询的权限实体对象。
Returns:
实体 ID如果不存在则返回 None。
"""
res = await self.conn.execute(
s("get_entity_id.sql"),
(entity.platform, entity.entity_type, entity.external_id),
)
row = await res.fetchone()
if row is None:
return None
return row[0]
async def get_entity_id(self, entity: PermEntity) -> int:
"""获取实体 ID如果不存在则自动创建。
Args:
entity: 权限实体对象。
Returns:
实体的数据库 ID。
"""
eid = await self._get_entity_id_or_none(entity)
if eid is None:
return await self.create_entity(entity)
return eid
async def get_perm_info(self, entity: PermEntity, config_key: str) -> bool | None:
"""获取实体的权限配置信息。
Args:
entity: 权限实体对象。
config_key: 配置项的键名。
Returns:
配置值True/False如果不存在则返回 None。
"""
eid = await self.get_entity_id(entity)
res = await self.conn.execute(
s("get_perm_info.sql"),
(eid, config_key),
)
row = await res.fetchone()
if row is None:
return None
return bool(row[0])
async def update_perm_info(
self, entity: PermEntity, config_key: str, value: bool | None
):
"""更新实体的权限配置信息。
Args:
entity: 权限实体对象。
config_key: 配置项的键名。
value: 要设置的配置值True/False/None
"""
eid = await self.get_entity_id(entity)
try:
await self.conn.execute(s("update_perm_info.sql"), (eid, config_key, value))
await self.conn.commit()
except Exception:
await self.conn.rollback()
raise
async def get_entity_id_batch(
self, entities: list[PermEntity]
) -> dict[PermEntity, int]:
"""批量获取 Entity 的 entity_id
Args:
entities: PermEntity 列表
Returns:
字典,键为 PermEntity值为对应的 ID
"""
# for entity in entities:
# await self.conn.execute(
# s("create_entity.sql"),
# (entity.platform, entity.entity_type, entity.external_id),
# )
try:
await self.conn.executemany(
s("create_entity.sql"),
[(e.platform, e.entity_type, e.external_id) for e in entities],
)
await self.conn.commit()
except Exception:
await self.conn.rollback()
raise
val_placeholders = ", ".join(["(?, ?, ?)"] * len(entities))
params = []
for e in entities:
params.extend([e.platform, e.entity_type, e.external_id])
cursor = await self.conn.execute(
f"""
SELECT id, platform, entity_type, external_id
FROM perm_entity
WHERE (platform, entity_type, external_id) IN (VALUES {val_placeholders});
""",
params,
)
rows = await cursor.fetchall()
return {PermEntity(row[1], row[2], row[3]): row[0] for row in rows}
async def get_perm_info_batch(
self, entities: list[PermEntity], config_keys: list[str]
) -> dict[tuple[PermEntity, str], bool]:
"""批量获取权限信息
Args:
entities: PermEntity 列表
config_keys: 查询的键列表
Returns:
字典,键是 PermEntity 和 config_key 的元组,值是布尔,过滤掉所有空值
"""
entity_ids = {
v: k for k, v in (await self.get_entity_id_batch(entities)).items()
}
placeholders1 = ", ".join("?" * len(entity_ids))
placeholders2 = ", ".join("?" * len(config_keys))
sql = f"""
SELECT entity_id, config_key, value
FROM perm_info
WHERE entity_id IN ({placeholders1})
AND config_key IN ({placeholders2})
AND value IS NOT NULL;
"""
params = tuple(entity_ids.keys()) + tuple(config_keys)
cursor = await self.conn.execute(sql, params)
rows = await cursor.fetchall()
return {(entity_ids[row[0]], row[1]): bool(row[2]) for row in rows}
async def list_perm_info_batch(
self, entities: list[PermEntity], pager: PagerQuery
) -> PagerResult[tuple[PermEntity, str, bool]]:
"""批量获取某个实体的权限信息
Args:
entities: PermEntity 列表
pager: PagerQuery 对象,即分页要求
Returns:
字典,键是 PermEntity值是权限条目和布尔的元组过滤掉所有空值
"""
entity_to_id = await self.get_entity_id_batch(entities)
id_to_entity = {v: k for k, v in entity_to_id.items()}
ordered_ids = [entity_to_id[e] for e in entities if e in entity_to_id]
placeholders = ", ".join("?" * len(ordered_ids))
order_by_cases = " ".join([f"WHEN ? THEN {i}" for i in range(len(ordered_ids))])
pagecount_sql = f"SELECT COUNT(*) FROM perm_info WHERE entity_id IN ({placeholders}) AND value IS NOT NULL;"
count_cursor = await self.conn.execute(pagecount_sql, tuple(ordered_ids))
total_count = (await count_cursor.fetchone() or (0,))[0]
sql = f"""
SELECT entity_id, config_key, value
FROM perm_info
WHERE entity_id IN ({placeholders})
AND value IS NOT NULL
ORDER BY
(CASE entity_id {order_by_cases} END) ASC,
config_key ASC
LIMIT ?
OFFSET ?;
"""
params = (
tuple(ordered_ids)
+ tuple(ordered_ids)
+ (
pager.page_size,
(pager.page_index - 1) * pager.page_size,
)
)
cursor = await self.conn.execute(sql, params)
rows = await cursor.fetchall()
# return {entity_ids[row[0]]: (row[1], bool(row[2])) for row in rows}
return PagerResult(
data=[(id_to_entity[row[0]], row[1], row[2]) for row in rows],
success=True,
message="",
page_count=math.ceil(total_count / pager.page_size),
query=pager,
)