import asyncio from contextlib import asynccontextmanager from typing import Annotated from nonebot.params import Depends from pydantic import BaseModel, ValidationError from konabot.common.longtask import LongTaskTarget from konabot.common.pager import PagerQuery, PagerResult from konabot.common.path import DATA_PATH from .repository import IPosterRepo class ChannelData(BaseModel): targets: list[LongTaskTarget] = [] class PosterData(BaseModel): channels: dict[str, ChannelData] = {} def is_the_same_target(target1: LongTaskTarget, target2: LongTaskTarget) -> bool: if target1.is_private_chat and not target2.is_private_chat: return False if target2.is_private_chat and not target1.is_private_chat: return False if target1.platform != target2.platform: return False # 如果是群聊,则要求 channel_id 相同 if not target1.is_private_chat: return target1.channel_id == target2.channel_id return target1.target_id == target2.target_id class LocalPosterRepo(IPosterRepo): def __init__(self, data: PosterData) -> None: self.data = data super().__init__() async def get_channel_targets(self, channel: str) -> list[LongTaskTarget]: if channel not in self.data.channels: self.data.channels[channel] = ChannelData() return self.data.channels[channel].targets async def add_channel_target(self, channel: str, target: LongTaskTarget) -> bool: targets = await self.get_channel_targets(channel) for t in targets: if is_the_same_target(t, target): return False targets.append(target) return True async def remove_channel_target(self, channel: str, target: LongTaskTarget) -> bool: targets = await self.get_channel_targets(channel) len0 = len(targets) self.data.channels[channel].targets = [ t for t in targets if not is_the_same_target(t, target) ] len1 = len(self.data.channels[channel].targets) return len0 != len1 async def get_subscribed_channels( self, target: LongTaskTarget, pager: PagerQuery ) -> PagerResult[str]: channels: list[str] = [] for channel_id, channel in self.data.channels.items(): for t in channel.targets: if is_the_same_target(target, t): channels.append(channel_id) break channels = sorted(channels) return pager.apply(channels) async def merge_channel(self, from_channel: str, to_channel: str) -> None: channel_from = await self.get_channel_targets(from_channel) channel_to = await self.get_channel_targets(to_channel) for t1 in channel_from: flag = True for t2 in channel_to: if is_the_same_target(t1, t2): flag = False break if flag: channel_to.append(t1) del self.data.channels[from_channel] LOCAL_POSTER_DATA_LOCK = asyncio.Lock() LOCAL_POSTER_DATA_PATH = DATA_PATH / "module_poster_data.json" @asynccontextmanager async def local_poster_data(): async with LOCAL_POSTER_DATA_LOCK: if not LOCAL_POSTER_DATA_PATH.exists(): data = PosterData() else: try: data = PosterData.model_validate_json( LOCAL_POSTER_DATA_PATH.read_text() ) except ValidationError: data = PosterData() yield data LOCAL_POSTER_DATA_PATH.write_text(data.model_dump_json()) @asynccontextmanager async def local_poster(): async with local_poster_data() as data: yield LocalPosterRepo(data) DepLocalPosterRepo = Annotated[LocalPosterRepo, Depends(local_poster)]