117 lines
3.7 KiB
Python
117 lines
3.7 KiB
Python
import asyncio
|
|
from contextlib import asynccontextmanager
|
|
from typing import Annotated
|
|
from nonebot.params import Depends
|
|
from pydantic import BaseModel, ValidationError
|
|
from konabot.common.longtask import LongTaskTarget
|
|
from konabot.common.pager import PagerQuery, PagerResult
|
|
from konabot.common.path import DATA_PATH
|
|
|
|
from .repository import IPosterRepo
|
|
|
|
|
|
class ChannelData(BaseModel):
|
|
targets: list[LongTaskTarget] = []
|
|
|
|
|
|
class PosterData(BaseModel):
|
|
channels: dict[str, ChannelData] = {}
|
|
|
|
|
|
def is_the_same_target(target1: LongTaskTarget, target2: LongTaskTarget) -> bool:
|
|
if target1.is_private_chat and not target2.is_private_chat:
|
|
return False
|
|
if target2.is_private_chat and not target1.is_private_chat:
|
|
return False
|
|
if target1.platform != target2.platform:
|
|
return False
|
|
|
|
# 如果是群聊,则要求 channel_id 相同
|
|
if not target1.is_private_chat:
|
|
return target1.channel_id == target2.channel_id
|
|
return target1.target_id == target2.target_id
|
|
|
|
|
|
class LocalPosterRepo(IPosterRepo):
|
|
def __init__(self, data: PosterData) -> None:
|
|
self.data = data
|
|
super().__init__()
|
|
|
|
async def get_channel_targets(self, channel: str) -> list[LongTaskTarget]:
|
|
if channel not in self.data.channels:
|
|
self.data.channels[channel] = ChannelData()
|
|
return self.data.channels[channel].targets
|
|
|
|
async def add_channel_target(self, channel: str, target: LongTaskTarget) -> bool:
|
|
targets = await self.get_channel_targets(channel)
|
|
for t in targets:
|
|
if is_the_same_target(t, target):
|
|
return False
|
|
targets.append(target)
|
|
return True
|
|
|
|
async def remove_channel_target(self, channel: str, target: LongTaskTarget) -> bool:
|
|
targets = await self.get_channel_targets(channel)
|
|
len0 = len(targets)
|
|
self.data.channels[channel].targets = [
|
|
t for t in targets if not is_the_same_target(t, target)
|
|
]
|
|
len1 = len(self.data.channels[channel].targets)
|
|
return len0 != len1
|
|
|
|
async def get_subscribed_channels(
|
|
self, target: LongTaskTarget, pager: PagerQuery
|
|
) -> PagerResult[str]:
|
|
channels: list[str] = []
|
|
for channel_id, channel in self.data.channels.items():
|
|
for t in channel.targets:
|
|
if is_the_same_target(target, t):
|
|
channels.append(channel_id)
|
|
break
|
|
channels = sorted(channels)
|
|
return pager.apply(channels)
|
|
|
|
async def merge_channel(self, from_channel: str, to_channel: str) -> None:
|
|
channel_from = await self.get_channel_targets(from_channel)
|
|
channel_to = await self.get_channel_targets(to_channel)
|
|
|
|
for t1 in channel_from:
|
|
flag = True
|
|
for t2 in channel_to:
|
|
if is_the_same_target(t1, t2):
|
|
flag = False
|
|
break
|
|
if flag:
|
|
channel_to.append(t1)
|
|
|
|
del self.data.channels[from_channel]
|
|
|
|
|
|
LOCAL_POSTER_DATA_LOCK = asyncio.Lock()
|
|
LOCAL_POSTER_DATA_PATH = DATA_PATH / "module_poster_data.json"
|
|
|
|
|
|
@asynccontextmanager
|
|
async def local_poster_data():
|
|
async with LOCAL_POSTER_DATA_LOCK:
|
|
if not LOCAL_POSTER_DATA_PATH.exists():
|
|
data = PosterData()
|
|
else:
|
|
try:
|
|
data = PosterData.model_validate_json(
|
|
LOCAL_POSTER_DATA_PATH.read_text()
|
|
)
|
|
except ValidationError:
|
|
data = PosterData()
|
|
yield data
|
|
LOCAL_POSTER_DATA_PATH.write_text(data.model_dump_json())
|
|
|
|
|
|
@asynccontextmanager
|
|
async def local_poster():
|
|
async with local_poster_data() as data:
|
|
yield LocalPosterRepo(data)
|
|
|
|
|
|
DepLocalPosterRepo = Annotated[LocalPosterRepo, Depends(local_poster)]
|