Files
konabot/konabot/common/subscribe/repo_local_data.py
passthem 392c699b33
All checks were successful
continuous-integration/drone/push Build is passing
移动 poster 模块到 common
2026-03-09 14:40:27 +08:00

117 lines
3.7 KiB
Python

import asyncio
from contextlib import asynccontextmanager
from typing import Annotated
from nonebot.params import Depends
from pydantic import BaseModel, ValidationError
from konabot.common.longtask import LongTaskTarget
from konabot.common.pager import PagerQuery, PagerResult
from konabot.common.path import DATA_PATH
from .repository import IPosterRepo
class ChannelData(BaseModel):
targets: list[LongTaskTarget] = []
class PosterData(BaseModel):
channels: dict[str, ChannelData] = {}
def is_the_same_target(target1: LongTaskTarget, target2: LongTaskTarget) -> bool:
if target1.is_private_chat and not target2.is_private_chat:
return False
if target2.is_private_chat and not target1.is_private_chat:
return False
if target1.platform != target2.platform:
return False
# 如果是群聊,则要求 channel_id 相同
if not target1.is_private_chat:
return target1.channel_id == target2.channel_id
return target1.target_id == target2.target_id
class LocalPosterRepo(IPosterRepo):
def __init__(self, data: PosterData) -> None:
self.data = data
super().__init__()
async def get_channel_targets(self, channel: str) -> list[LongTaskTarget]:
if channel not in self.data.channels:
self.data.channels[channel] = ChannelData()
return self.data.channels[channel].targets
async def add_channel_target(self, channel: str, target: LongTaskTarget) -> bool:
targets = await self.get_channel_targets(channel)
for t in targets:
if is_the_same_target(t, target):
return False
targets.append(target)
return True
async def remove_channel_target(self, channel: str, target: LongTaskTarget) -> bool:
targets = await self.get_channel_targets(channel)
len0 = len(targets)
self.data.channels[channel].targets = [
t for t in targets if not is_the_same_target(t, target)
]
len1 = len(self.data.channels[channel].targets)
return len0 != len1
async def get_subscribed_channels(
self, target: LongTaskTarget, pager: PagerQuery
) -> PagerResult[str]:
channels: list[str] = []
for channel_id, channel in self.data.channels.items():
for t in channel.targets:
if is_the_same_target(target, t):
channels.append(channel_id)
break
channels = sorted(channels)
return pager.apply(channels)
async def merge_channel(self, from_channel: str, to_channel: str) -> None:
channel_from = await self.get_channel_targets(from_channel)
channel_to = await self.get_channel_targets(to_channel)
for t1 in channel_from:
flag = True
for t2 in channel_to:
if is_the_same_target(t1, t2):
flag = False
break
if flag:
channel_to.append(t1)
del self.data.channels[from_channel]
LOCAL_POSTER_DATA_LOCK = asyncio.Lock()
LOCAL_POSTER_DATA_PATH = DATA_PATH / "module_poster_data.json"
@asynccontextmanager
async def local_poster_data():
async with LOCAL_POSTER_DATA_LOCK:
if not LOCAL_POSTER_DATA_PATH.exists():
data = PosterData()
else:
try:
data = PosterData.model_validate_json(
LOCAL_POSTER_DATA_PATH.read_text()
)
except ValidationError:
data = PosterData()
yield data
LOCAL_POSTER_DATA_PATH.write_text(data.model_dump_json())
@asynccontextmanager
async def local_poster():
async with local_poster_data() as data:
yield LocalPosterRepo(data)
DepLocalPosterRepo = Annotated[LocalPosterRepo, Depends(local_poster)]