from dataclasses import dataclass import math from pathlib import Path import aiosqlite from konabot.common.pager import PagerQuery, PagerResult from .entity import PermEntity def s(p: str): """读取 SQL 文件内容。 Args: p: SQL 文件名(相对于当前文件所在目录的 sql/ 子目录)。 Returns: SQL 文件的内容字符串。 """ return (Path(__file__).parent / "./sql/" / p).read_text() @dataclass class PermRepo: """权限实体存储库,负责与数据库交互管理权限实体。 Attributes: conn: aiosqlite 数据库连接对象。 """ conn: aiosqlite.Connection async def create_entity(self, entity: PermEntity) -> int: """创建新的权限实体并返回其 ID。 Args: entity: 要创建的权限实体对象。 Returns: 新创建实体的数据库 ID。 Raises: AssertionError: 如果创建后无法获取实体 ID。 """ try: await self.conn.execute( s("create_entity.sql"), (entity.platform, entity.entity_type, entity.external_id), ) await self.conn.commit() except Exception: await self.conn.rollback() raise eid = await self._get_entity_id_or_none(entity) assert eid is not None return eid async def _get_entity_id_or_none(self, entity: PermEntity) -> int | None: """查询实体 ID,如果不存在则返回 None。 Args: entity: 要查询的权限实体对象。 Returns: 实体 ID,如果不存在则返回 None。 """ res = await self.conn.execute( s("get_entity_id.sql"), (entity.platform, entity.entity_type, entity.external_id), ) row = await res.fetchone() if row is None: return None return row[0] async def get_entity_id(self, entity: PermEntity) -> int: """获取实体 ID,如果不存在则自动创建。 Args: entity: 权限实体对象。 Returns: 实体的数据库 ID。 """ eid = await self._get_entity_id_or_none(entity) if eid is None: return await self.create_entity(entity) return eid async def get_perm_info(self, entity: PermEntity, config_key: str) -> bool | None: """获取实体的权限配置信息。 Args: entity: 权限实体对象。 config_key: 配置项的键名。 Returns: 配置值(True/False),如果不存在则返回 None。 """ eid = await self.get_entity_id(entity) res = await self.conn.execute( s("get_perm_info.sql"), (eid, config_key), ) row = await res.fetchone() if row is None: return None return bool(row[0]) async def update_perm_info( self, entity: PermEntity, config_key: str, value: bool | None ): """更新实体的权限配置信息。 Args: entity: 权限实体对象。 config_key: 配置项的键名。 value: 要设置的配置值(True/False/None)。 """ eid = await self.get_entity_id(entity) try: await self.conn.execute(s("update_perm_info.sql"), (eid, config_key, value)) await self.conn.commit() except Exception: await self.conn.rollback() raise async def get_entity_id_batch( self, entities: list[PermEntity] ) -> dict[PermEntity, int]: """批量获取 Entity 的 entity_id Args: entities: PermEntity 列表 Returns: 字典,键为 PermEntity,值为对应的 ID """ # for entity in entities: # await self.conn.execute( # s("create_entity.sql"), # (entity.platform, entity.entity_type, entity.external_id), # ) try: await self.conn.executemany( s("create_entity.sql"), [(e.platform, e.entity_type, e.external_id) for e in entities], ) await self.conn.commit() except Exception: await self.conn.rollback() raise val_placeholders = ", ".join(["(?, ?, ?)"] * len(entities)) params = [] for e in entities: params.extend([e.platform, e.entity_type, e.external_id]) cursor = await self.conn.execute( f""" SELECT id, platform, entity_type, external_id FROM perm_entity WHERE (platform, entity_type, external_id) IN (VALUES {val_placeholders}); """, params, ) rows = await cursor.fetchall() return {PermEntity(row[1], row[2], row[3]): row[0] for row in rows} async def get_perm_info_batch( self, entities: list[PermEntity], config_keys: list[str] ) -> dict[tuple[PermEntity, str], bool]: """批量获取权限信息 Args: entities: PermEntity 列表 config_keys: 查询的键列表 Returns: 字典,键是 PermEntity 和 config_key 的元组,值是布尔,过滤掉所有空值 """ entity_ids = { v: k for k, v in (await self.get_entity_id_batch(entities)).items() } placeholders1 = ", ".join("?" * len(entity_ids)) placeholders2 = ", ".join("?" * len(config_keys)) sql = f""" SELECT entity_id, config_key, value FROM perm_info WHERE entity_id IN ({placeholders1}) AND config_key IN ({placeholders2}) AND value IS NOT NULL; """ params = tuple(entity_ids.keys()) + tuple(config_keys) cursor = await self.conn.execute(sql, params) rows = await cursor.fetchall() return {(entity_ids[row[0]], row[1]): bool(row[2]) for row in rows} async def list_perm_info_batch( self, entities: list[PermEntity], pager: PagerQuery ) -> PagerResult[tuple[PermEntity, str, bool]]: """批量获取某个实体的权限信息 Args: entities: PermEntity 列表 pager: PagerQuery 对象,即分页要求 Returns: 字典,键是 PermEntity,值是权限条目和布尔的元组,过滤掉所有空值 """ entity_to_id = await self.get_entity_id_batch(entities) id_to_entity = {v: k for k, v in entity_to_id.items()} ordered_ids = [entity_to_id[e] for e in entities if e in entity_to_id] placeholders = ", ".join("?" * len(ordered_ids)) order_by_cases = " ".join([f"WHEN ? THEN {i}" for i in range(len(ordered_ids))]) pagecount_sql = f"SELECT COUNT(*) FROM perm_info WHERE entity_id IN ({placeholders}) AND value IS NOT NULL;" count_cursor = await self.conn.execute(pagecount_sql, tuple(ordered_ids)) total_count = (await count_cursor.fetchone() or (0,))[0] sql = f""" SELECT entity_id, config_key, value FROM perm_info WHERE entity_id IN ({placeholders}) AND value IS NOT NULL ORDER BY (CASE entity_id {order_by_cases} END) ASC, config_key ASC LIMIT ? OFFSET ?; """ params = ( tuple(ordered_ids) + tuple(ordered_ids) + ( pager.page_size, (pager.page_index - 1) * pager.page_size, ) ) cursor = await self.conn.execute(sql, params) rows = await cursor.fetchall() # return {entity_ids[row[0]]: (row[1], bool(row[2])) for row in rows} return PagerResult( data=[(id_to_entity[row[0]], row[1], row[2]) for row in rows], success=True, message="", page_count=math.ceil(total_count / pager.page_size), query=pager, )