diff --git a/konabot/common/pager.py b/konabot/common/pager.py new file mode 100644 index 0000000..792e62e --- /dev/null +++ b/konabot/common/pager.py @@ -0,0 +1,76 @@ +from dataclasses import dataclass +from math import ceil +from typing import Any, Callable + +from nonebot_plugin_alconna import UniMessage + + +@dataclass +class PagerQuery: + page_index: int + page_size: int + + def apply[T](self, ls: list[T]) -> "PagerResult[T]": + if self.page_size <= 0: + return PagerResult( + success=False, + message="每页元素数量应该大于 0", + data=[], + page_count=-1, + query=self, + ) + page_count = ceil(len(ls) / self.page_size) + if self.page_index <= 0 or self.page_size <= 0: + return PagerResult( + success=False, + message="页数必须大于 0", + data=[], + page_count=page_count, + query=self, + ) + data = ls[(self.page_index - 1) * self.page_size: self.page_index * self.page_size] + if len(data) > 0: + return PagerResult( + success=True, + message="", + data=data, + page_count=page_count, + query=self, + ) + return PagerResult( + success=False, + message="指定的页数超过最大页数", + data=data, + page_count=page_count, + query=self, + ) + + +@dataclass +class PagerResult[T]: + data: list[T] + success: bool + message: str + page_count: int + query: PagerQuery + + def to_unimessage( + self, + formatter: Callable[[T], str | UniMessage[Any]] = str, + title: str = '查询结果', + list_indicator: str = '- ', + ) -> UniMessage[Any]: + msg = UniMessage.text(f'===== {title} =====\n\n') + + if not self.success: + msg = msg.text(f'⚠️ {self.message}\n') + else: + for obj in self.data: + msg = msg.text(list_indicator) + msg += formatter(obj) + msg += '\n' + + msg = msg.text(f'\n===== 第 {self.query.page_index} 页,共 {self.page_count} 页 =====') + return msg + + diff --git a/konabot/plugins/kona_ph/__init__.py b/konabot/plugins/kona_ph/__init__.py index 2c88186..7e02316 100644 --- a/konabot/plugins/kona_ph/__init__.py +++ b/konabot/plugins/kona_ph/__init__.py @@ -4,12 +4,12 @@ from math import ceil from loguru import logger from nonebot import on_message +import nonebot from nonebot_plugin_alconna import (Alconna, Args, UniMessage, UniMsg, on_alconna) from nonebot_plugin_apscheduler import scheduler from konabot.common.longtask import DepLongTaskTarget -from konabot.common.nb.qq_broadcast import qq_broadcast from konabot.plugins.kona_ph.core.message import (get_daily_report, get_daily_report_v2, get_puzzle_description, @@ -18,8 +18,14 @@ from konabot.plugins.kona_ph.core.storage import get_today_date from konabot.plugins.kona_ph.manager import (PUZZLE_PAGE_SIZE, config, create_admin_commands, puzzle_manager) +from konabot.plugins.poster.poster_info import PosterInfo, register_poster_info +from konabot.plugins.poster.service import broadcast create_admin_commands() +register_poster_info("每日谜题", info=PosterInfo( + aliases={"konaph", "kona_ph", "KonaPH", "此方谜题", "KONAPH"}, + description="此方 BOT 每日谜题推送", +)) async def is_play_group(target: DepLongTaskTarget): @@ -125,11 +131,15 @@ async def _(): yesterday = get_today_date() - datetime.timedelta(days=1) msg2 = get_daily_report(manager, yesterday) if msg2 is not None: - await qq_broadcast(config.plugin_puzzle_playgroup, msg2) + await broadcast("每日谜题", msg2) puzzle = manager.get_today_puzzle() if puzzle is not None: logger.info(f"找到了题目 {puzzle.raw_id},发送") - await qq_broadcast(config.plugin_puzzle_playgroup, get_puzzle_description(puzzle)) + await broadcast("每日谜题", get_puzzle_description(puzzle)) else: logger.info("自动任务:没有找到题目,跳过") + + +driver = nonebot.get_driver() + diff --git a/konabot/plugins/poster/__init__.py b/konabot/plugins/poster/__init__.py new file mode 100644 index 0000000..285d41b --- /dev/null +++ b/konabot/plugins/poster/__init__.py @@ -0,0 +1,70 @@ +import nonebot +from nonebot_plugin_alconna import Alconna, Args, on_alconna + +from konabot.common.longtask import DepLongTaskTarget +from konabot.common.pager import PagerQuery +from konabot.plugins.poster.poster_info import POSTER_INFO_DATA +from konabot.plugins.poster.service import dep_poster_service + + +cmd_subscribe = on_alconna(Alconna( + "订阅", + Args["channel", str], +)) + + +@cmd_subscribe.handle() +async def _(target: DepLongTaskTarget, channel: str): + async with dep_poster_service() as service: + result = await service.subscribe(channel, target) + if result: + await target.send_message(f"已订阅「{channel}」") + else: + await target.send_message(f"已经订阅过「{channel}」了") + + +cmd_list = on_alconna(Alconna( + "re:(?:查询|我的|获取)订阅(列表)?", + Args["page?", int], +)) + + +def better_channel_message(channel_id: str) -> str: + if channel_id not in POSTER_INFO_DATA: + return channel_id + data = POSTER_INFO_DATA[channel_id] + return f"{channel_id}:{data.description}" + + +@cmd_list.handle() +async def _(target: DepLongTaskTarget, page: int = 1): + async with dep_poster_service() as service: + result = await service.get_channels(target, PagerQuery( + page_index=page, + page_size=10, + )) + await target.send_message(result.to_unimessage(title="订阅列表")) + + +cmd_unsubscribe = on_alconna(Alconna( + "取消订阅", + Args["channel", str], +)) + + +@cmd_unsubscribe.handle() +async def _(target: DepLongTaskTarget, channel: str): + async with dep_poster_service() as service: + result = await service.subscribe(channel, target) + if result: + await target.send_message(f"已取消订阅「{channel}」") + else: + await target.send_message(f"这里没有订阅过「{channel}」") + + +driver = nonebot.get_driver() + +@driver.on_startup +async def _(): + async with dep_poster_service() as service: + await service.fix_data() diff --git a/konabot/plugins/poster/poster_info.py b/konabot/plugins/poster/poster_info.py new file mode 100644 index 0000000..5ccb419 --- /dev/null +++ b/konabot/plugins/poster/poster_info.py @@ -0,0 +1,15 @@ +from dataclasses import dataclass, field + + +@dataclass +class PosterInfo: + aliases: set[str] = field(default_factory=set) + description: str = field(default='') + + +POSTER_INFO_DATA: dict[str, PosterInfo] = {} + + +def register_poster_info(channel: str, info: PosterInfo): + POSTER_INFO_DATA[channel] = info + diff --git a/konabot/plugins/poster/repo_local_data.py b/konabot/plugins/poster/repo_local_data.py new file mode 100644 index 0000000..5b686ba --- /dev/null +++ b/konabot/plugins/poster/repo_local_data.py @@ -0,0 +1,112 @@ +import asyncio +from contextlib import asynccontextmanager +from typing import Annotated +from nonebot.params import Depends +from pydantic import BaseModel, ValidationError +from konabot.common.longtask import LongTaskTarget +from konabot.common.pager import PagerQuery, PagerResult +from konabot.common.path import DATA_PATH +from konabot.plugins.poster.repository import IPosterRepo + + +class ChannelData(BaseModel): + targets: list[LongTaskTarget] = [] + + +class PosterData(BaseModel): + channels: dict[str, ChannelData] = {} + + +def is_the_same_target(target1: LongTaskTarget, target2: LongTaskTarget) -> bool: + if (target1.is_private_chat and not target2.is_private_chat): + return False + if (target2.is_private_chat and not target1.is_private_chat): + return False + if target1.platform != target2.platform: + return False + + # 如果是群聊,则要求 channel_id 相同 + if not target1.is_private_chat: + return target1.channel_id == target2.channel_id + return target1.target_id == target2.target_id + + +class LocalPosterRepo(IPosterRepo): + def __init__(self, data: PosterData) -> None: + self.data = data + super().__init__() + + async def get_channel_targets(self, channel: str) -> list[LongTaskTarget]: + if channel not in self.data.channels: + self.data.channels[channel] = ChannelData() + return self.data.channels[channel].targets + + async def add_channel_target(self, channel: str, target: LongTaskTarget) -> bool: + targets = await self.get_channel_targets(channel) + for t in targets: + if is_the_same_target(t, target): + return False + targets.append(target) + return True + + async def remove_channel_target(self, channel: str, target: LongTaskTarget) -> bool: + targets = await self.get_channel_targets(channel) + len0 = len(targets) + self.data.channels[channel].targets = [ + t for t in targets if not is_the_same_target(t, target) + ] + len1 = len(self.data.channels[channel].targets) + return len0 != len1 + + async def get_subscribed_channels(self, target: LongTaskTarget, pager: PagerQuery) -> PagerResult[str]: + channels: list[str] = [] + for channel_id, channel in self.data.channels.items(): + for t in channel.targets: + if is_the_same_target(target, t): + channels.append(channel_id) + break + channels = sorted(channels) + return pager.apply(channels) + + async def merge_channel(self, from_channel: str, to_channel: str) -> None: + channel_from = await self.get_channel_targets(from_channel) + channel_to = await self.get_channel_targets(to_channel) + + for t1 in channel_from: + flag = True + for t2 in channel_to: + if is_the_same_target(t1, t2): + flag = False + break + if flag: + channel_to.append(t1) + + del self.data.channels[from_channel] + + +LOCAL_POSTER_DATA_LOCK = asyncio.Lock() +LOCAL_POSTER_DATA_PATH = DATA_PATH / "module_poster_data.json" + + +@asynccontextmanager +async def local_poster_data(): + async with LOCAL_POSTER_DATA_LOCK: + if not LOCAL_POSTER_DATA_PATH.exists(): + data = PosterData() + else: + try: + data = PosterData.model_validate_json(LOCAL_POSTER_DATA_PATH.read_text()) + except ValidationError: + data = PosterData() + yield data + LOCAL_POSTER_DATA_PATH.write_text(data.model_dump_json()) + + +@asynccontextmanager +async def local_poster(): + async with local_poster_data() as data: + yield LocalPosterRepo(data) + + +DepLocalPosterRepo = Annotated[LocalPosterRepo, Depends(local_poster)] + diff --git a/konabot/plugins/poster/repository.py b/konabot/plugins/poster/repository.py new file mode 100644 index 0000000..2508108 --- /dev/null +++ b/konabot/plugins/poster/repository.py @@ -0,0 +1,37 @@ +from abc import ABC, abstractmethod + +from konabot.common.longtask import LongTaskTarget +from konabot.common.pager import PagerQuery, PagerResult + + +class IPosterRepo(ABC): + @abstractmethod + async def get_channel_targets(self, channel: str) -> list[LongTaskTarget]: + """ + 获取广播通道的所有广播对象 + """ + + @abstractmethod + async def add_channel_target(self, channel: str, target: LongTaskTarget) -> bool: + """ + 向广播通道添加一个广播目标。若目标已存在,则返回 False + """ + + @abstractmethod + async def remove_channel_target(self, channel: str, target: LongTaskTarget) -> bool: + """ + 移除一个广播通道的目标。若目标不存在,则返回 False + """ + + @abstractmethod + async def get_subscribed_channels(self, target: LongTaskTarget, pager: PagerQuery) -> PagerResult[str]: + """ + 获得一个目标已经订阅了的广播通道 + """ + + @abstractmethod + async def merge_channel(self, from_channel: str, to_channel: str) -> None: + """ + 合并两个 Channel 为一个,并移除另一个 + """ + diff --git a/konabot/plugins/poster/service.py b/konabot/plugins/poster/service.py new file mode 100644 index 0000000..770146a --- /dev/null +++ b/konabot/plugins/poster/service.py @@ -0,0 +1,59 @@ +from contextlib import asynccontextmanager +from typing import Annotated, Any +from nonebot.params import Depends +from nonebot_plugin_alconna import UniMessage +from konabot.common.longtask import LongTaskTarget +from konabot.common.pager import PagerQuery, PagerResult +from konabot.plugins.poster.poster_info import POSTER_INFO_DATA +from konabot.plugins.poster.repo_local_data import local_poster +from konabot.plugins.poster.repository import IPosterRepo + + +class PosterService: + def __init__(self, repo: IPosterRepo) -> None: + self.repo = repo + + def parse_channel_id(self, channel: str): + for cid, cinfo in POSTER_INFO_DATA.items(): + if channel in cinfo.aliases: + return cid + return channel + + async def subscribe(self, channel: str, target: LongTaskTarget) -> bool: + channel = self.parse_channel_id(channel) + return await self.repo.add_channel_target(channel, target) + + async def unsubscribe(self, channel: str, target: LongTaskTarget) -> bool: + channel = self.parse_channel_id(channel) + return await self.repo.remove_channel_target(channel, target) + + async def broadcast(self, channel: str, message: UniMessage[Any] | str) -> list[LongTaskTarget]: + channel = self.parse_channel_id(channel) + targets = await self.repo.get_channel_targets(channel) + for target in targets: + # 因为是订阅消息,就不要 At 对方了 + await target.send_message(message, at=False) + return targets + + async def get_channels(self, target: LongTaskTarget, pager: PagerQuery) -> PagerResult[str]: + return await self.repo.get_subscribed_channels(target, pager) + + async def fix_data(self): + for cid, cinfo in POSTER_INFO_DATA.items(): + for alias in cinfo.aliases: + await self.repo.merge_channel(alias, cid) + + +@asynccontextmanager +async def dep_poster_service(): + async with local_poster() as repo: + yield PosterService(repo) + + +async def broadcast(channel: str, message: UniMessage[Any] | str): + async with dep_poster_service() as service: + return await service.broadcast(channel, message) + + +DepPosterService = Annotated[PosterService, Depends(dep_poster_service)] + diff --git a/scripts/watch_filter.py b/scripts/watch_filter.py index a47c8f3..dff78eb 100644 --- a/scripts/watch_filter.py +++ b/scripts/watch_filter.py @@ -8,6 +8,8 @@ base = Path(__file__).parent.parent.absolute() def filter(change: Change, path: str) -> bool: if "__pycache__" in path: return False - if Path(path).absolute().is_relative_to(base / "data"): + if Path(path).absolute().is_relative_to((base / "data").absolute()): + return False + if Path(path).absolute().is_relative_to((base / ".git").absolute()): return False return True