255 lines
8.0 KiB
Python
255 lines
8.0 KiB
Python
from dataclasses import dataclass
|
||
import math
|
||
from pathlib import Path
|
||
|
||
import aiosqlite
|
||
|
||
from konabot.common.pager import PagerQuery, PagerResult
|
||
|
||
from .entity import PermEntity
|
||
|
||
|
||
def s(p: str):
|
||
"""读取 SQL 文件内容。
|
||
|
||
Args:
|
||
p: SQL 文件名(相对于当前文件所在目录的 sql/ 子目录)。
|
||
|
||
Returns:
|
||
SQL 文件的内容字符串。
|
||
"""
|
||
return (Path(__file__).parent / "./sql/" / p).read_text()
|
||
|
||
|
||
@dataclass
|
||
class PermRepo:
|
||
"""权限实体存储库,负责与数据库交互管理权限实体。
|
||
|
||
Attributes:
|
||
conn: aiosqlite 数据库连接对象。
|
||
"""
|
||
|
||
conn: aiosqlite.Connection
|
||
|
||
async def create_entity(self, entity: PermEntity) -> int:
|
||
"""创建新的权限实体并返回其 ID。
|
||
|
||
Args:
|
||
entity: 要创建的权限实体对象。
|
||
|
||
Returns:
|
||
新创建实体的数据库 ID。
|
||
|
||
Raises:
|
||
AssertionError: 如果创建后无法获取实体 ID。
|
||
"""
|
||
try:
|
||
await self.conn.execute(
|
||
s("create_entity.sql"),
|
||
(entity.platform, entity.entity_type, entity.external_id),
|
||
)
|
||
await self.conn.commit()
|
||
except Exception:
|
||
await self.conn.rollback()
|
||
raise
|
||
eid = await self._get_entity_id_or_none(entity)
|
||
assert eid is not None
|
||
return eid
|
||
|
||
async def _get_entity_id_or_none(self, entity: PermEntity) -> int | None:
|
||
"""查询实体 ID,如果不存在则返回 None。
|
||
|
||
Args:
|
||
entity: 要查询的权限实体对象。
|
||
|
||
Returns:
|
||
实体 ID,如果不存在则返回 None。
|
||
"""
|
||
res = await self.conn.execute(
|
||
s("get_entity_id.sql"),
|
||
(entity.platform, entity.entity_type, entity.external_id),
|
||
)
|
||
row = await res.fetchone()
|
||
if row is None:
|
||
return None
|
||
return row[0]
|
||
|
||
async def get_entity_id(self, entity: PermEntity) -> int:
|
||
"""获取实体 ID,如果不存在则自动创建。
|
||
|
||
Args:
|
||
entity: 权限实体对象。
|
||
|
||
Returns:
|
||
实体的数据库 ID。
|
||
"""
|
||
eid = await self._get_entity_id_or_none(entity)
|
||
if eid is None:
|
||
return await self.create_entity(entity)
|
||
return eid
|
||
|
||
async def get_perm_info(self, entity: PermEntity, config_key: str) -> bool | None:
|
||
"""获取实体的权限配置信息。
|
||
|
||
Args:
|
||
entity: 权限实体对象。
|
||
config_key: 配置项的键名。
|
||
|
||
Returns:
|
||
配置值(True/False),如果不存在则返回 None。
|
||
"""
|
||
eid = await self.get_entity_id(entity)
|
||
res = await self.conn.execute(
|
||
s("get_perm_info.sql"),
|
||
(eid, config_key),
|
||
)
|
||
row = await res.fetchone()
|
||
if row is None:
|
||
return None
|
||
return bool(row[0])
|
||
|
||
async def update_perm_info(
|
||
self, entity: PermEntity, config_key: str, value: bool | None
|
||
):
|
||
"""更新实体的权限配置信息。
|
||
|
||
Args:
|
||
entity: 权限实体对象。
|
||
config_key: 配置项的键名。
|
||
value: 要设置的配置值(True/False/None)。
|
||
"""
|
||
eid = await self.get_entity_id(entity)
|
||
try:
|
||
await self.conn.execute(s("update_perm_info.sql"), (eid, config_key, value))
|
||
await self.conn.commit()
|
||
except Exception:
|
||
await self.conn.rollback()
|
||
raise
|
||
|
||
async def get_entity_id_batch(
|
||
self, entities: list[PermEntity]
|
||
) -> dict[PermEntity, int]:
|
||
"""批量获取 Entity 的 entity_id
|
||
|
||
Args:
|
||
entities: PermEntity 列表
|
||
|
||
Returns:
|
||
字典,键为 PermEntity,值为对应的 ID
|
||
"""
|
||
|
||
# for entity in entities:
|
||
# await self.conn.execute(
|
||
# s("create_entity.sql"),
|
||
# (entity.platform, entity.entity_type, entity.external_id),
|
||
# )
|
||
try:
|
||
await self.conn.executemany(
|
||
s("create_entity.sql"),
|
||
[(e.platform, e.entity_type, e.external_id) for e in entities],
|
||
)
|
||
await self.conn.commit()
|
||
except Exception:
|
||
await self.conn.rollback()
|
||
raise
|
||
val_placeholders = ", ".join(["(?, ?, ?)"] * len(entities))
|
||
params = []
|
||
for e in entities:
|
||
params.extend([e.platform, e.entity_type, e.external_id])
|
||
cursor = await self.conn.execute(
|
||
f"""
|
||
SELECT id, platform, entity_type, external_id
|
||
FROM perm_entity
|
||
WHERE (platform, entity_type, external_id) IN (VALUES {val_placeholders});
|
||
""",
|
||
params,
|
||
)
|
||
rows = await cursor.fetchall()
|
||
return {PermEntity(row[1], row[2], row[3]): row[0] for row in rows}
|
||
|
||
async def get_perm_info_batch(
|
||
self, entities: list[PermEntity], config_keys: list[str]
|
||
) -> dict[tuple[PermEntity, str], bool]:
|
||
"""批量获取权限信息
|
||
|
||
Args:
|
||
entities: PermEntity 列表
|
||
config_keys: 查询的键列表
|
||
|
||
Returns:
|
||
字典,键是 PermEntity 和 config_key 的元组,值是布尔,过滤掉所有空值
|
||
"""
|
||
entity_ids = {
|
||
v: k for k, v in (await self.get_entity_id_batch(entities)).items()
|
||
}
|
||
placeholders1 = ", ".join("?" * len(entity_ids))
|
||
placeholders2 = ", ".join("?" * len(config_keys))
|
||
sql = f"""
|
||
SELECT entity_id, config_key, value
|
||
FROM perm_info
|
||
WHERE entity_id IN ({placeholders1})
|
||
AND config_key IN ({placeholders2})
|
||
AND value IS NOT NULL;
|
||
"""
|
||
|
||
params = tuple(entity_ids.keys()) + tuple(config_keys)
|
||
cursor = await self.conn.execute(sql, params)
|
||
rows = await cursor.fetchall()
|
||
|
||
return {(entity_ids[row[0]], row[1]): bool(row[2]) for row in rows}
|
||
|
||
async def list_perm_info_batch(
|
||
self, entities: list[PermEntity], pager: PagerQuery
|
||
) -> PagerResult[tuple[PermEntity, str, bool]]:
|
||
"""批量获取某个实体的权限信息
|
||
|
||
Args:
|
||
entities: PermEntity 列表
|
||
pager: PagerQuery 对象,即分页要求
|
||
|
||
Returns:
|
||
字典,键是 PermEntity,值是权限条目和布尔的元组,过滤掉所有空值
|
||
"""
|
||
entity_to_id = await self.get_entity_id_batch(entities)
|
||
id_to_entity = {v: k for k, v in entity_to_id.items()}
|
||
ordered_ids = [entity_to_id[e] for e in entities if e in entity_to_id]
|
||
|
||
placeholders = ", ".join("?" * len(ordered_ids))
|
||
order_by_cases = " ".join([f"WHEN ? THEN {i}" for i in range(len(ordered_ids))])
|
||
|
||
pagecount_sql = f"SELECT COUNT(*) FROM perm_info WHERE entity_id IN ({placeholders}) AND value IS NOT NULL;"
|
||
count_cursor = await self.conn.execute(pagecount_sql, tuple(ordered_ids))
|
||
total_count = (await count_cursor.fetchone() or (0,))[0]
|
||
|
||
sql = f"""
|
||
SELECT entity_id, config_key, value
|
||
FROM perm_info
|
||
WHERE entity_id IN ({placeholders})
|
||
AND value IS NOT NULL
|
||
ORDER BY
|
||
(CASE entity_id {order_by_cases} END) ASC,
|
||
config_key ASC
|
||
LIMIT ?
|
||
OFFSET ?;
|
||
"""
|
||
|
||
params = (
|
||
tuple(ordered_ids)
|
||
+ tuple(ordered_ids)
|
||
+ (
|
||
pager.page_size,
|
||
(pager.page_index - 1) * pager.page_size,
|
||
)
|
||
)
|
||
cursor = await self.conn.execute(sql, params)
|
||
rows = await cursor.fetchall()
|
||
|
||
# return {entity_ids[row[0]]: (row[1], bool(row[2])) for row in rows}
|
||
return PagerResult(
|
||
data=[(id_to_entity[row[0]], row[1], row[2]) for row in rows],
|
||
success=True,
|
||
message="",
|
||
page_count=math.ceil(total_count / pager.page_size),
|
||
query=pager,
|
||
)
|