This commit is contained in:
116
konabot/common/subscribe/repo_local_data.py
Normal file
116
konabot/common/subscribe/repo_local_data.py
Normal file
@ -0,0 +1,116 @@
|
||||
import asyncio
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Annotated
|
||||
from nonebot.params import Depends
|
||||
from pydantic import BaseModel, ValidationError
|
||||
from konabot.common.longtask import LongTaskTarget
|
||||
from konabot.common.pager import PagerQuery, PagerResult
|
||||
from konabot.common.path import DATA_PATH
|
||||
|
||||
from .repository import IPosterRepo
|
||||
|
||||
|
||||
class ChannelData(BaseModel):
|
||||
targets: list[LongTaskTarget] = []
|
||||
|
||||
|
||||
class PosterData(BaseModel):
|
||||
channels: dict[str, ChannelData] = {}
|
||||
|
||||
|
||||
def is_the_same_target(target1: LongTaskTarget, target2: LongTaskTarget) -> bool:
|
||||
if target1.is_private_chat and not target2.is_private_chat:
|
||||
return False
|
||||
if target2.is_private_chat and not target1.is_private_chat:
|
||||
return False
|
||||
if target1.platform != target2.platform:
|
||||
return False
|
||||
|
||||
# 如果是群聊,则要求 channel_id 相同
|
||||
if not target1.is_private_chat:
|
||||
return target1.channel_id == target2.channel_id
|
||||
return target1.target_id == target2.target_id
|
||||
|
||||
|
||||
class LocalPosterRepo(IPosterRepo):
|
||||
def __init__(self, data: PosterData) -> None:
|
||||
self.data = data
|
||||
super().__init__()
|
||||
|
||||
async def get_channel_targets(self, channel: str) -> list[LongTaskTarget]:
|
||||
if channel not in self.data.channels:
|
||||
self.data.channels[channel] = ChannelData()
|
||||
return self.data.channels[channel].targets
|
||||
|
||||
async def add_channel_target(self, channel: str, target: LongTaskTarget) -> bool:
|
||||
targets = await self.get_channel_targets(channel)
|
||||
for t in targets:
|
||||
if is_the_same_target(t, target):
|
||||
return False
|
||||
targets.append(target)
|
||||
return True
|
||||
|
||||
async def remove_channel_target(self, channel: str, target: LongTaskTarget) -> bool:
|
||||
targets = await self.get_channel_targets(channel)
|
||||
len0 = len(targets)
|
||||
self.data.channels[channel].targets = [
|
||||
t for t in targets if not is_the_same_target(t, target)
|
||||
]
|
||||
len1 = len(self.data.channels[channel].targets)
|
||||
return len0 != len1
|
||||
|
||||
async def get_subscribed_channels(
|
||||
self, target: LongTaskTarget, pager: PagerQuery
|
||||
) -> PagerResult[str]:
|
||||
channels: list[str] = []
|
||||
for channel_id, channel in self.data.channels.items():
|
||||
for t in channel.targets:
|
||||
if is_the_same_target(target, t):
|
||||
channels.append(channel_id)
|
||||
break
|
||||
channels = sorted(channels)
|
||||
return pager.apply(channels)
|
||||
|
||||
async def merge_channel(self, from_channel: str, to_channel: str) -> None:
|
||||
channel_from = await self.get_channel_targets(from_channel)
|
||||
channel_to = await self.get_channel_targets(to_channel)
|
||||
|
||||
for t1 in channel_from:
|
||||
flag = True
|
||||
for t2 in channel_to:
|
||||
if is_the_same_target(t1, t2):
|
||||
flag = False
|
||||
break
|
||||
if flag:
|
||||
channel_to.append(t1)
|
||||
|
||||
del self.data.channels[from_channel]
|
||||
|
||||
|
||||
LOCAL_POSTER_DATA_LOCK = asyncio.Lock()
|
||||
LOCAL_POSTER_DATA_PATH = DATA_PATH / "module_poster_data.json"
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def local_poster_data():
|
||||
async with LOCAL_POSTER_DATA_LOCK:
|
||||
if not LOCAL_POSTER_DATA_PATH.exists():
|
||||
data = PosterData()
|
||||
else:
|
||||
try:
|
||||
data = PosterData.model_validate_json(
|
||||
LOCAL_POSTER_DATA_PATH.read_text()
|
||||
)
|
||||
except ValidationError:
|
||||
data = PosterData()
|
||||
yield data
|
||||
LOCAL_POSTER_DATA_PATH.write_text(data.model_dump_json())
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def local_poster():
|
||||
async with local_poster_data() as data:
|
||||
yield LocalPosterRepo(data)
|
||||
|
||||
|
||||
DepLocalPosterRepo = Annotated[LocalPosterRepo, Depends(local_poster)]
|
||||
Reference in New Issue
Block a user