添加一个可供管理的订阅制模块,并且接入 KonaPH
This commit is contained in:
76
konabot/common/pager.py
Normal file
76
konabot/common/pager.py
Normal file
@ -0,0 +1,76 @@
|
||||
from dataclasses import dataclass
|
||||
from math import ceil
|
||||
from typing import Any, Callable
|
||||
|
||||
from nonebot_plugin_alconna import UniMessage
|
||||
|
||||
|
||||
@dataclass
|
||||
class PagerQuery:
|
||||
page_index: int
|
||||
page_size: int
|
||||
|
||||
def apply[T](self, ls: list[T]) -> "PagerResult[T]":
|
||||
if self.page_size <= 0:
|
||||
return PagerResult(
|
||||
success=False,
|
||||
message="每页元素数量应该大于 0",
|
||||
data=[],
|
||||
page_count=-1,
|
||||
query=self,
|
||||
)
|
||||
page_count = ceil(len(ls) / self.page_size)
|
||||
if self.page_index <= 0 or self.page_size <= 0:
|
||||
return PagerResult(
|
||||
success=False,
|
||||
message="页数必须大于 0",
|
||||
data=[],
|
||||
page_count=page_count,
|
||||
query=self,
|
||||
)
|
||||
data = ls[(self.page_index - 1) * self.page_size: self.page_index * self.page_size]
|
||||
if len(data) > 0:
|
||||
return PagerResult(
|
||||
success=True,
|
||||
message="",
|
||||
data=data,
|
||||
page_count=page_count,
|
||||
query=self,
|
||||
)
|
||||
return PagerResult(
|
||||
success=False,
|
||||
message="指定的页数超过最大页数",
|
||||
data=data,
|
||||
page_count=page_count,
|
||||
query=self,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PagerResult[T]:
|
||||
data: list[T]
|
||||
success: bool
|
||||
message: str
|
||||
page_count: int
|
||||
query: PagerQuery
|
||||
|
||||
def to_unimessage(
|
||||
self,
|
||||
formatter: Callable[[T], str | UniMessage[Any]] = str,
|
||||
title: str = '查询结果',
|
||||
list_indicator: str = '- ',
|
||||
) -> UniMessage[Any]:
|
||||
msg = UniMessage.text(f'===== {title} =====\n\n')
|
||||
|
||||
if not self.success:
|
||||
msg = msg.text(f'⚠️ {self.message}\n')
|
||||
else:
|
||||
for obj in self.data:
|
||||
msg = msg.text(list_indicator)
|
||||
msg += formatter(obj)
|
||||
msg += '\n'
|
||||
|
||||
msg = msg.text(f'\n===== 第 {self.query.page_index} 页,共 {self.page_count} 页 =====')
|
||||
return msg
|
||||
|
||||
|
||||
@ -4,12 +4,12 @@ from math import ceil
|
||||
|
||||
from loguru import logger
|
||||
from nonebot import on_message
|
||||
import nonebot
|
||||
from nonebot_plugin_alconna import (Alconna, Args, UniMessage, UniMsg,
|
||||
on_alconna)
|
||||
from nonebot_plugin_apscheduler import scheduler
|
||||
|
||||
from konabot.common.longtask import DepLongTaskTarget
|
||||
from konabot.common.nb.qq_broadcast import qq_broadcast
|
||||
from konabot.plugins.kona_ph.core.message import (get_daily_report,
|
||||
get_daily_report_v2,
|
||||
get_puzzle_description,
|
||||
@ -18,8 +18,14 @@ from konabot.plugins.kona_ph.core.storage import get_today_date
|
||||
from konabot.plugins.kona_ph.manager import (PUZZLE_PAGE_SIZE, config,
|
||||
create_admin_commands,
|
||||
puzzle_manager)
|
||||
from konabot.plugins.poster.poster_info import PosterInfo, register_poster_info
|
||||
from konabot.plugins.poster.service import broadcast
|
||||
|
||||
create_admin_commands()
|
||||
register_poster_info("每日谜题", info=PosterInfo(
|
||||
aliases={"konaph", "kona_ph", "KonaPH", "此方谜题", "KONAPH"},
|
||||
description="此方 BOT 每日谜题推送",
|
||||
))
|
||||
|
||||
|
||||
async def is_play_group(target: DepLongTaskTarget):
|
||||
@ -125,11 +131,15 @@ async def _():
|
||||
yesterday = get_today_date() - datetime.timedelta(days=1)
|
||||
msg2 = get_daily_report(manager, yesterday)
|
||||
if msg2 is not None:
|
||||
await qq_broadcast(config.plugin_puzzle_playgroup, msg2)
|
||||
await broadcast("每日谜题", msg2)
|
||||
|
||||
puzzle = manager.get_today_puzzle()
|
||||
if puzzle is not None:
|
||||
logger.info(f"找到了题目 {puzzle.raw_id},发送")
|
||||
await qq_broadcast(config.plugin_puzzle_playgroup, get_puzzle_description(puzzle))
|
||||
await broadcast("每日谜题", get_puzzle_description(puzzle))
|
||||
else:
|
||||
logger.info("自动任务:没有找到题目,跳过")
|
||||
|
||||
|
||||
driver = nonebot.get_driver()
|
||||
|
||||
|
||||
70
konabot/plugins/poster/__init__.py
Normal file
70
konabot/plugins/poster/__init__.py
Normal file
@ -0,0 +1,70 @@
|
||||
import nonebot
|
||||
from nonebot_plugin_alconna import Alconna, Args, on_alconna
|
||||
|
||||
from konabot.common.longtask import DepLongTaskTarget
|
||||
from konabot.common.pager import PagerQuery
|
||||
from konabot.plugins.poster.poster_info import POSTER_INFO_DATA
|
||||
from konabot.plugins.poster.service import dep_poster_service
|
||||
|
||||
|
||||
cmd_subscribe = on_alconna(Alconna(
|
||||
"订阅",
|
||||
Args["channel", str],
|
||||
))
|
||||
|
||||
|
||||
@cmd_subscribe.handle()
|
||||
async def _(target: DepLongTaskTarget, channel: str):
|
||||
async with dep_poster_service() as service:
|
||||
result = await service.subscribe(channel, target)
|
||||
if result:
|
||||
await target.send_message(f"已订阅「{channel}」")
|
||||
else:
|
||||
await target.send_message(f"已经订阅过「{channel}」了")
|
||||
|
||||
|
||||
cmd_list = on_alconna(Alconna(
|
||||
"re:(?:查询|我的|获取)订阅(列表)?",
|
||||
Args["page?", int],
|
||||
))
|
||||
|
||||
|
||||
def better_channel_message(channel_id: str) -> str:
|
||||
if channel_id not in POSTER_INFO_DATA:
|
||||
return channel_id
|
||||
data = POSTER_INFO_DATA[channel_id]
|
||||
return f"{channel_id}:{data.description}"
|
||||
|
||||
|
||||
@cmd_list.handle()
|
||||
async def _(target: DepLongTaskTarget, page: int = 1):
|
||||
async with dep_poster_service() as service:
|
||||
result = await service.get_channels(target, PagerQuery(
|
||||
page_index=page,
|
||||
page_size=10,
|
||||
))
|
||||
await target.send_message(result.to_unimessage(title="订阅列表"))
|
||||
|
||||
|
||||
cmd_unsubscribe = on_alconna(Alconna(
|
||||
"取消订阅",
|
||||
Args["channel", str],
|
||||
))
|
||||
|
||||
|
||||
@cmd_unsubscribe.handle()
|
||||
async def _(target: DepLongTaskTarget, channel: str):
|
||||
async with dep_poster_service() as service:
|
||||
result = await service.subscribe(channel, target)
|
||||
if result:
|
||||
await target.send_message(f"已取消订阅「{channel}」")
|
||||
else:
|
||||
await target.send_message(f"这里没有订阅过「{channel}」")
|
||||
|
||||
|
||||
driver = nonebot.get_driver()
|
||||
|
||||
@driver.on_startup
|
||||
async def _():
|
||||
async with dep_poster_service() as service:
|
||||
await service.fix_data()
|
||||
15
konabot/plugins/poster/poster_info.py
Normal file
15
konabot/plugins/poster/poster_info.py
Normal file
@ -0,0 +1,15 @@
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
|
||||
@dataclass
|
||||
class PosterInfo:
|
||||
aliases: set[str] = field(default_factory=set)
|
||||
description: str = field(default='')
|
||||
|
||||
|
||||
POSTER_INFO_DATA: dict[str, PosterInfo] = {}
|
||||
|
||||
|
||||
def register_poster_info(channel: str, info: PosterInfo):
|
||||
POSTER_INFO_DATA[channel] = info
|
||||
|
||||
112
konabot/plugins/poster/repo_local_data.py
Normal file
112
konabot/plugins/poster/repo_local_data.py
Normal file
@ -0,0 +1,112 @@
|
||||
import asyncio
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Annotated
|
||||
from nonebot.params import Depends
|
||||
from pydantic import BaseModel, ValidationError
|
||||
from konabot.common.longtask import LongTaskTarget
|
||||
from konabot.common.pager import PagerQuery, PagerResult
|
||||
from konabot.common.path import DATA_PATH
|
||||
from konabot.plugins.poster.repository import IPosterRepo
|
||||
|
||||
|
||||
class ChannelData(BaseModel):
|
||||
targets: list[LongTaskTarget] = []
|
||||
|
||||
|
||||
class PosterData(BaseModel):
|
||||
channels: dict[str, ChannelData] = {}
|
||||
|
||||
|
||||
def is_the_same_target(target1: LongTaskTarget, target2: LongTaskTarget) -> bool:
|
||||
if (target1.is_private_chat and not target2.is_private_chat):
|
||||
return False
|
||||
if (target2.is_private_chat and not target1.is_private_chat):
|
||||
return False
|
||||
if target1.platform != target2.platform:
|
||||
return False
|
||||
|
||||
# 如果是群聊,则要求 channel_id 相同
|
||||
if not target1.is_private_chat:
|
||||
return target1.channel_id == target2.channel_id
|
||||
return target1.target_id == target2.target_id
|
||||
|
||||
|
||||
class LocalPosterRepo(IPosterRepo):
|
||||
def __init__(self, data: PosterData) -> None:
|
||||
self.data = data
|
||||
super().__init__()
|
||||
|
||||
async def get_channel_targets(self, channel: str) -> list[LongTaskTarget]:
|
||||
if channel not in self.data.channels:
|
||||
self.data.channels[channel] = ChannelData()
|
||||
return self.data.channels[channel].targets
|
||||
|
||||
async def add_channel_target(self, channel: str, target: LongTaskTarget) -> bool:
|
||||
targets = await self.get_channel_targets(channel)
|
||||
for t in targets:
|
||||
if is_the_same_target(t, target):
|
||||
return False
|
||||
targets.append(target)
|
||||
return True
|
||||
|
||||
async def remove_channel_target(self, channel: str, target: LongTaskTarget) -> bool:
|
||||
targets = await self.get_channel_targets(channel)
|
||||
len0 = len(targets)
|
||||
self.data.channels[channel].targets = [
|
||||
t for t in targets if not is_the_same_target(t, target)
|
||||
]
|
||||
len1 = len(self.data.channels[channel].targets)
|
||||
return len0 != len1
|
||||
|
||||
async def get_subscribed_channels(self, target: LongTaskTarget, pager: PagerQuery) -> PagerResult[str]:
|
||||
channels: list[str] = []
|
||||
for channel_id, channel in self.data.channels.items():
|
||||
for t in channel.targets:
|
||||
if is_the_same_target(target, t):
|
||||
channels.append(channel_id)
|
||||
break
|
||||
channels = sorted(channels)
|
||||
return pager.apply(channels)
|
||||
|
||||
async def merge_channel(self, from_channel: str, to_channel: str) -> None:
|
||||
channel_from = await self.get_channel_targets(from_channel)
|
||||
channel_to = await self.get_channel_targets(to_channel)
|
||||
|
||||
for t1 in channel_from:
|
||||
flag = True
|
||||
for t2 in channel_to:
|
||||
if is_the_same_target(t1, t2):
|
||||
flag = False
|
||||
break
|
||||
if flag:
|
||||
channel_to.append(t1)
|
||||
|
||||
del self.data.channels[from_channel]
|
||||
|
||||
|
||||
LOCAL_POSTER_DATA_LOCK = asyncio.Lock()
|
||||
LOCAL_POSTER_DATA_PATH = DATA_PATH / "module_poster_data.json"
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def local_poster_data():
|
||||
async with LOCAL_POSTER_DATA_LOCK:
|
||||
if not LOCAL_POSTER_DATA_PATH.exists():
|
||||
data = PosterData()
|
||||
else:
|
||||
try:
|
||||
data = PosterData.model_validate_json(LOCAL_POSTER_DATA_PATH.read_text())
|
||||
except ValidationError:
|
||||
data = PosterData()
|
||||
yield data
|
||||
LOCAL_POSTER_DATA_PATH.write_text(data.model_dump_json())
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def local_poster():
|
||||
async with local_poster_data() as data:
|
||||
yield LocalPosterRepo(data)
|
||||
|
||||
|
||||
DepLocalPosterRepo = Annotated[LocalPosterRepo, Depends(local_poster)]
|
||||
|
||||
37
konabot/plugins/poster/repository.py
Normal file
37
konabot/plugins/poster/repository.py
Normal file
@ -0,0 +1,37 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from konabot.common.longtask import LongTaskTarget
|
||||
from konabot.common.pager import PagerQuery, PagerResult
|
||||
|
||||
|
||||
class IPosterRepo(ABC):
|
||||
@abstractmethod
|
||||
async def get_channel_targets(self, channel: str) -> list[LongTaskTarget]:
|
||||
"""
|
||||
获取广播通道的所有广播对象
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def add_channel_target(self, channel: str, target: LongTaskTarget) -> bool:
|
||||
"""
|
||||
向广播通道添加一个广播目标。若目标已存在,则返回 False
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def remove_channel_target(self, channel: str, target: LongTaskTarget) -> bool:
|
||||
"""
|
||||
移除一个广播通道的目标。若目标不存在,则返回 False
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def get_subscribed_channels(self, target: LongTaskTarget, pager: PagerQuery) -> PagerResult[str]:
|
||||
"""
|
||||
获得一个目标已经订阅了的广播通道
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def merge_channel(self, from_channel: str, to_channel: str) -> None:
|
||||
"""
|
||||
合并两个 Channel 为一个,并移除另一个
|
||||
"""
|
||||
|
||||
59
konabot/plugins/poster/service.py
Normal file
59
konabot/plugins/poster/service.py
Normal file
@ -0,0 +1,59 @@
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Annotated, Any
|
||||
from nonebot.params import Depends
|
||||
from nonebot_plugin_alconna import UniMessage
|
||||
from konabot.common.longtask import LongTaskTarget
|
||||
from konabot.common.pager import PagerQuery, PagerResult
|
||||
from konabot.plugins.poster.poster_info import POSTER_INFO_DATA
|
||||
from konabot.plugins.poster.repo_local_data import local_poster
|
||||
from konabot.plugins.poster.repository import IPosterRepo
|
||||
|
||||
|
||||
class PosterService:
|
||||
def __init__(self, repo: IPosterRepo) -> None:
|
||||
self.repo = repo
|
||||
|
||||
def parse_channel_id(self, channel: str):
|
||||
for cid, cinfo in POSTER_INFO_DATA.items():
|
||||
if channel in cinfo.aliases:
|
||||
return cid
|
||||
return channel
|
||||
|
||||
async def subscribe(self, channel: str, target: LongTaskTarget) -> bool:
|
||||
channel = self.parse_channel_id(channel)
|
||||
return await self.repo.add_channel_target(channel, target)
|
||||
|
||||
async def unsubscribe(self, channel: str, target: LongTaskTarget) -> bool:
|
||||
channel = self.parse_channel_id(channel)
|
||||
return await self.repo.remove_channel_target(channel, target)
|
||||
|
||||
async def broadcast(self, channel: str, message: UniMessage[Any] | str) -> list[LongTaskTarget]:
|
||||
channel = self.parse_channel_id(channel)
|
||||
targets = await self.repo.get_channel_targets(channel)
|
||||
for target in targets:
|
||||
# 因为是订阅消息,就不要 At 对方了
|
||||
await target.send_message(message, at=False)
|
||||
return targets
|
||||
|
||||
async def get_channels(self, target: LongTaskTarget, pager: PagerQuery) -> PagerResult[str]:
|
||||
return await self.repo.get_subscribed_channels(target, pager)
|
||||
|
||||
async def fix_data(self):
|
||||
for cid, cinfo in POSTER_INFO_DATA.items():
|
||||
for alias in cinfo.aliases:
|
||||
await self.repo.merge_channel(alias, cid)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def dep_poster_service():
|
||||
async with local_poster() as repo:
|
||||
yield PosterService(repo)
|
||||
|
||||
|
||||
async def broadcast(channel: str, message: UniMessage[Any] | str):
|
||||
async with dep_poster_service() as service:
|
||||
return await service.broadcast(channel, message)
|
||||
|
||||
|
||||
DepPosterService = Annotated[PosterService, Depends(dep_poster_service)]
|
||||
|
||||
@ -8,6 +8,8 @@ base = Path(__file__).parent.parent.absolute()
|
||||
def filter(change: Change, path: str) -> bool:
|
||||
if "__pycache__" in path:
|
||||
return False
|
||||
if Path(path).absolute().is_relative_to(base / "data"):
|
||||
if Path(path).absolute().is_relative_to((base / "data").absolute()):
|
||||
return False
|
||||
if Path(path).absolute().is_relative_to((base / ".git").absolute()):
|
||||
return False
|
||||
return True
|
||||
|
||||
Reference in New Issue
Block a user