From aaf0a75d659b867c0ef446e0909611375864c8ef Mon Sep 17 00:00:00 2001 From: passthem Date: Sun, 19 Oct 2025 04:45:15 +0800 Subject: [PATCH] =?UTF-8?q?=E6=B7=BB=E5=8A=A0=E8=8B=A5=E5=B9=B2=E6=9C=89?= =?UTF-8?q?=E7=94=A8=E7=9A=84=E5=B0=8F=E6=A8=A1=E5=9D=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- konabot/common/data_man.py | 36 ++++ konabot/common/longtask.py | 275 +++++++++++++++++++++++++++++++ konabot/plugins/longtask_core.py | 50 ++++++ 3 files changed, 361 insertions(+) create mode 100644 konabot/common/data_man.py create mode 100644 konabot/common/longtask.py create mode 100644 konabot/plugins/longtask_core.py diff --git a/konabot/common/data_man.py b/konabot/common/data_man.py new file mode 100644 index 0000000..71c7fe0 --- /dev/null +++ b/konabot/common/data_man.py @@ -0,0 +1,36 @@ +import asyncio +from contextlib import asynccontextmanager +from pathlib import Path +from typing import Generic, TypeVar + +from pydantic import BaseModel, ValidationError + +T = TypeVar("T", bound=BaseModel) + + +class DataManager(Generic[T]): + def __init__(self, cls: type[T], fp: Path) -> None: + self.cls = cls + self.fp = fp + self._aio_lock = asyncio.Lock() + self._data: T | None = None + + def load(self) -> T: + if not self.fp.exists(): + return self.cls() + try: + return self.cls.model_validate_json(self.fp.read_text()) + except ValidationError: + return self.cls() + + def save(self, data: T): + self.fp.write_text(data.model_dump_json()) + + @asynccontextmanager + async def get_data(self): + await self._aio_lock.acquire() + self._data = self.load() + yield self._data + self.save(self._data) + self._data = None + self._aio_lock.release() diff --git a/konabot/common/longtask.py b/konabot/common/longtask.py new file mode 100644 index 0000000..1a69b2a --- /dev/null +++ b/konabot/common/longtask.py @@ -0,0 +1,275 @@ +from contextlib import asynccontextmanager +import datetime +import json +from typing import Annotated, Any, Callable, Coroutine, cast +import asyncio as asynkio +import uuid + +from loguru import logger +import nonebot +from nonebot.params import Depends +from nonebot.adapters import Event as BaseEvent +from nonebot.adapters import Bot as BaseBot +from nonebot.adapters.onebot.v11 import Bot as OBBot +from nonebot.adapters.onebot.v11 import GroupMessageEvent as OBGroupMessageEvent +from nonebot.adapters.onebot.v11 import PrivateMessageEvent as OBPrivateMessageEvent +from nonebot.adapters.console import Bot as ConsoleBot +from nonebot.adapters.console import MessageEvent as ConsoleMessageEvent +from nonebot.adapters.discord import MessageEvent as DCMessageEvent +from nonebot.adapters.discord import Bot as DCBot +from nonebot_plugin_alconna import UniMessage +from pydantic import BaseModel, ValidationError + +from .path import DATA_PATH + +LONGTASK_DATA_DIR = DATA_PATH / "longtasks.json" + + +class LongTaskTarget(BaseModel): + """ + 用于定义长期任务的目标沟通对象,一般通过 DepLongTaskTarget 依赖注入获取: + + ```python + @cmd.handle() + async def _(target: DepLongTaskTarget): + ... + ``` + """ + + platform: str + "沟通对象所在的平台" + + self_id: str + "进行沟通的对象自己的 ID" + + channel_id: str + "沟通对象所在的群或者 Discord Channel。若为空则代表是私聊" + + target_id: str + "沟通对象的 ID" + + async def send_message(self, msg: UniMessage, at: bool = True) -> bool: + try: + bot = nonebot.get_bot(self.self_id) + except KeyError: + logger.warning(f"试图访问了不存在的 Bot。ID={self.self_id}") + return False + + if self.platform == "qq": + if not isinstance(bot, OBBot): + logger.warning( + f"编号对应的平台并非期望的平台 ID={self.self_id} PLATFORM={ + self.platform + } BOT_CLASS={bot.__class__.__name__}" + ) + return False + if self.channel_id == "": + # 私聊模式 + await bot.send_private_msg( + user_id=int(self.target_id), + message=cast(Any, await msg.export(bot)), + auto_escape=False, + ) + return True + else: + if at: + msg = UniMessage().at(self.target_id).text(" ") + msg + await bot.send_group_msg( + group_id=int(self.channel_id), + message=cast(Any, await msg.export(bot)), + auto_escape=False, + ) + return True + if self.platform == "console": + if not isinstance(bot, ConsoleBot): + logger.warning( + f"编号对应的平台并非期望的平台 ID={self.self_id} PLATFORM={ + self.platform + } BOT_CLASS={bot.__class__.__name__}" + ) + return False + await bot.send_message(self.channel_id, cast(Any, await msg.export())) + if self.platform == "discord": + if not isinstance(bot, DCBot): + logger.warning( + f"编号对应的平台并非期望的平台 ID={self.self_id} PLATFORM={ + self.platform + } BOT_CLASS={bot.__class__.__name__}" + ) + return False + await bot.send_to( + channel_id=int(self.channel_id), + message=cast( + Any, await (UniMessage().at(self.target_id) + msg).export() + ), + tts=False, + ) + logger.warning(f"没有一个平台是期望的平台 PLATFORM={self.platform}") + return False + + +class LongTask(BaseModel): + uuid: str + data_json: str + target: "LongTaskTarget" + callback: str + deadline: datetime.datetime + canceled: bool = False + + _aio_task: asynkio.Task | None = None + + async def run(self): + now = datetime.datetime.now() + if self.deadline < now and not self.canceled: + await self._run_task() + return + await asynkio.sleep((self.deadline - now).total_seconds()) + if self.canceled: + return + await self._run_task() + + async def _run_task(self): + hdl = registered_long_task_handler.get(self.callback, None) + if hdl is None: + logger.warning( + f"Callback {self.callback} 未曾被注册,但是被期待调用,已忽略" + ) + async with longtask_data() as datafile: + datafile.to_handle[self.callback] = [ + t + for t in datafile.to_handle.get(self.callback, []) + if t.uuid != self.uuid + ] + datafile.unhandled.setdefault(self.callback, []).append(self) + + return + await hdl(self) + async with longtask_data() as datafile: + datafile.to_handle[self.callback] = [ + t for t in datafile.to_handle[self.callback] if t.uuid != self.uuid + ] + + def clean(self): + self._aio_task = None + + @property + def data(self): + return json.loads(self.data_json) + + async def start(self): + self._aio_task = asynkio.Task(self.run()) + self._aio_task.add_done_callback(lambda _: self.clean()) + + +class LongTaskModuleData(BaseModel): + to_handle: dict[str, list[LongTask]] + unhandled: dict[str, list[LongTask]] + + +async def get_long_task_target(event: BaseEvent, bot: BaseBot) -> LongTaskTarget | None: + if isinstance(event, OBGroupMessageEvent): + return LongTaskTarget( + platform="qq", + self_id=str(event.self_id), + channel_id=str(event.group_id), + target_id=str(event.user_id), + ) + if isinstance(event, OBPrivateMessageEvent): + return LongTaskTarget( + platform="qq", + self_id=str(event.self_id), + channel_id="", + target_id=str(event.user_id), + ) + if isinstance(event, ConsoleMessageEvent): + return LongTaskTarget( + platform="console", + self_id=str(event.self_id), + channel_id=str(event.channel.id), + target_id=str(event.user.id), + ) + if isinstance(event, DCMessageEvent): + self_id = "" + if isinstance(bot, DCBot): + self_id = str(bot.self_id) + return LongTaskTarget( + platform="discord", + self_id=self_id, + channel_id=str(event.channel_id), + target_id=str(event.user_id), + ) + + +_TaskHandler = Callable[[LongTask], Coroutine[Any, Any, Any]] + + +registered_long_task_handler: dict[str, _TaskHandler] = {} +longtask_lock = asynkio.Lock() + + +def handle_long_task(callback_id: str): + def _decorator(func: _TaskHandler): + assert callback_id not in registered_long_task_handler, ( + "有长任务的 ID 出现冲突,请换个名字!" + ) + registered_long_task_handler[callback_id] = func + return func + + return _decorator + + +def _load_longtask_data() -> LongTaskModuleData: + try: + txt = LONGTASK_DATA_DIR.read_text() + return LongTaskModuleData.model_validate_json(txt) + except (FileNotFoundError, ValidationError) as e: + logger.info(f"取得 LongTask 数据时出现问题:{e}") + return LongTaskModuleData( + to_handle={}, + unhandled={}, + ) + + +def _save_longtask_data(data: LongTaskModuleData): + LONGTASK_DATA_DIR.write_text(data.model_dump_json()) + + +@asynccontextmanager +async def longtask_data(): + await longtask_lock.acquire() + data = _load_longtask_data() + yield data + _save_longtask_data(data) + longtask_lock.release() + + +async def create_longtask( + handler: str, + data: dict[str, Any], + target: LongTaskTarget, + deadline: datetime.datetime, +): + task = LongTask( + uuid=str(uuid.uuid4()), + data_json=json.dumps(data), + target=target, + callback=handler, + deadline=deadline, + ) + + await task.start() + + async with longtask_data() as d: + d.to_handle.setdefault(handler, []).append(task) + + return task + + +async def init_longtask(): + async with longtask_data() as data: + for v in data.to_handle.values(): + for t in v: + await t.start() + + +DepLongTaskTarget = Annotated[LongTaskTarget, Depends(get_long_task_target)] diff --git a/konabot/plugins/longtask_core.py b/konabot/plugins/longtask_core.py new file mode 100644 index 0000000..68469d2 --- /dev/null +++ b/konabot/plugins/longtask_core.py @@ -0,0 +1,50 @@ +import asyncio + +# import datetime +from loguru import logger +import nonebot + +# from nonebot.adapters import Bot, Event +# from nonebot_plugin_alconna import UniMessage +from konabot.common.longtask import ( + # DepLongTaskTarget, + # LongTask, + # create_longtask, + # get_long_task_target, + # handle_long_task, + init_longtask, +) + + +driver = nonebot.get_driver() +INIT_FLAG = {"flag": False} + + +@driver.on_bot_connect +async def _(): + if INIT_FLAG["flag"]: + return + INIT_FLAG["flag"] = True + logger.info("有 Bot 连接,等待 5 秒后初始化 LongTask 模块") + await asyncio.sleep(5) + await init_longtask() + logger.info("LongTask 初始化完成") + + +# cmd1 = nonebot.on_command("test114") +# +# +# @handle_long_task("test_callback_001") +# async def _(lt: LongTask): +# await lt.target.send_message(UniMessage().text("Hello, world!"), at=True) +# +# +# @cmd1.handle() +# async def _(target: DepLongTaskTarget): +# await create_longtask( +# handler="test_callback_001", +# data={}, +# target=target, +# deadline=datetime.datetime.now() + datetime.timedelta(seconds=2), +# ) +# await target.send_message(UniMessage().text("Hello, world!"), at=True)