303 lines
9.5 KiB
Python
303 lines
9.5 KiB
Python
from __future__ import annotations
|
|
|
|
from contextlib import asynccontextmanager
|
|
import datetime
|
|
import json
|
|
from typing import Annotated, Any, Callable, Coroutine, cast
|
|
import asyncio as asynkio
|
|
import uuid
|
|
|
|
from loguru import logger
|
|
import nonebot
|
|
from nonebot.params import Depends
|
|
from nonebot.adapters import Event as BaseEvent
|
|
from nonebot.adapters import Bot as BaseBot
|
|
from nonebot.adapters.onebot.v11 import Bot as OBBot
|
|
from nonebot.adapters.onebot.v11 import GroupMessageEvent as OBGroupMessageEvent
|
|
from nonebot.adapters.onebot.v11 import PrivateMessageEvent as OBPrivateMessageEvent
|
|
from nonebot.adapters.console import Bot as ConsoleBot
|
|
from nonebot.adapters.console import MessageEvent as ConsoleMessageEvent
|
|
from nonebot.adapters.discord import MessageEvent as DCMessageEvent
|
|
from nonebot.adapters.discord import Bot as DCBot
|
|
from nonebot_plugin_alconna import UniMessage
|
|
from pydantic import BaseModel, ValidationError
|
|
|
|
from .path import DATA_PATH
|
|
|
|
LONGTASK_DATA_DIR = DATA_PATH / "longtasks.json"
|
|
QQ_PRIVATE_CHAT_CHANNEL_PREFIX = "_CHANNEL_QQ_PRIVATE_"
|
|
|
|
|
|
class LongTaskTarget(BaseModel):
|
|
"""
|
|
用于定义长期任务的目标沟通对象,一般通过 DepLongTaskTarget 依赖注入获取:
|
|
|
|
```python
|
|
@cmd.handle()
|
|
async def _(target: DepLongTaskTarget):
|
|
...
|
|
```
|
|
"""
|
|
|
|
platform: str
|
|
"沟通对象所在的平台"
|
|
|
|
self_id: str
|
|
"进行沟通的对象自己的 ID"
|
|
|
|
channel_id: str
|
|
"沟通对象所在的群或者 Discord Channel。若为空则代表是私聊"
|
|
|
|
target_id: str
|
|
"沟通对象的 ID"
|
|
|
|
@property
|
|
def is_private_chat(self):
|
|
return self.channel_id.startswith(QQ_PRIVATE_CHAT_CHANNEL_PREFIX)
|
|
|
|
async def send_message(self, msg: UniMessage | str, at: bool = True) -> bool:
|
|
try:
|
|
bot = nonebot.get_bot(self.self_id)
|
|
except KeyError:
|
|
logger.warning(f"试图访问了不存在的 Bot。ID={self.self_id}")
|
|
return False
|
|
|
|
if isinstance(msg, str):
|
|
msg = UniMessage.text(msg)
|
|
|
|
if self.platform == "qq":
|
|
if not isinstance(bot, OBBot):
|
|
logger.warning(
|
|
f"编号对应的平台并非期望的平台 ID={self.self_id} PLATFORM={
|
|
self.platform
|
|
} BOT_CLASS={bot.__class__.__name__}"
|
|
)
|
|
return False
|
|
if self.channel_id.startswith(QQ_PRIVATE_CHAT_CHANNEL_PREFIX) or not self.channel_id.strip():
|
|
# 私聊模式
|
|
await bot.send_private_msg(
|
|
user_id=int(self.target_id),
|
|
message=cast(Any, await msg.export(bot)),
|
|
auto_escape=False,
|
|
)
|
|
return True
|
|
else:
|
|
if at:
|
|
msg = UniMessage().at(self.target_id).text(" ") + msg
|
|
await bot.send_group_msg(
|
|
group_id=int(self.channel_id),
|
|
message=cast(Any, await msg.export(bot)),
|
|
auto_escape=False,
|
|
)
|
|
return True
|
|
if self.platform == "console":
|
|
if not isinstance(bot, ConsoleBot):
|
|
logger.warning(
|
|
f"编号对应的平台并非期望的平台 ID={self.self_id} PLATFORM={
|
|
self.platform
|
|
} BOT_CLASS={bot.__class__.__name__}"
|
|
)
|
|
return False
|
|
await bot.send_message(self.channel_id, cast(Any, await msg.export()))
|
|
return True
|
|
if self.platform == "discord":
|
|
if not isinstance(bot, DCBot):
|
|
logger.warning(
|
|
f"编号对应的平台并非期望的平台 ID={self.self_id} PLATFORM={
|
|
self.platform
|
|
} BOT_CLASS={bot.__class__.__name__}"
|
|
)
|
|
return False
|
|
await bot.send_to(
|
|
channel_id=int(self.channel_id),
|
|
message=cast(
|
|
Any, await (UniMessage().at(self.target_id) + msg).export()
|
|
),
|
|
tts=False,
|
|
)
|
|
return True
|
|
logger.warning(f"没有一个平台是期望的平台 PLATFORM={self.platform}")
|
|
return False
|
|
|
|
|
|
class LongTask(BaseModel):
|
|
uuid: str
|
|
data_json: str
|
|
target: LongTaskTarget
|
|
callback: str
|
|
deadline: datetime.datetime
|
|
|
|
_aio_task: asynkio.Task | None = None
|
|
|
|
async def run(self):
|
|
now = datetime.datetime.now()
|
|
if self.deadline < now:
|
|
await self._run_task()
|
|
return
|
|
await asynkio.sleep((self.deadline - now).total_seconds())
|
|
async with longtask_data() as data:
|
|
if self.uuid not in data.to_handle[self.callback]:
|
|
return
|
|
await self._run_task()
|
|
|
|
async def _run_task(self):
|
|
hdl = registered_long_task_handler.get(self.callback, None)
|
|
if hdl is None:
|
|
logger.warning(
|
|
f"Callback {self.callback} 未曾被注册,但是被期待调用,已忽略"
|
|
)
|
|
async with longtask_data() as datafile:
|
|
del datafile.to_handle[self.callback][self.uuid]
|
|
datafile.unhandled.setdefault(self.callback, []).append(self)
|
|
|
|
return
|
|
success = False
|
|
try:
|
|
await hdl(self)
|
|
success = True
|
|
except Exception as e:
|
|
logger.exception(e)
|
|
async with longtask_data() as datafile:
|
|
del datafile.to_handle[self.callback][self.uuid]
|
|
if not success:
|
|
datafile.unhandled.setdefault(self.callback, []).append(self)
|
|
logger.info(
|
|
f"LongTask 执行失败 UUID={self.uuid} callback={self.callback}"
|
|
)
|
|
else:
|
|
logger.info(
|
|
f"LongTask 工作完成 UUID={self.uuid} callback={self.callback}"
|
|
)
|
|
|
|
def clean(self):
|
|
self._aio_task = None
|
|
|
|
@property
|
|
def data(self):
|
|
return json.loads(self.data_json)
|
|
|
|
async def start(self):
|
|
self._aio_task = asynkio.Task(self.run())
|
|
self._aio_task.add_done_callback(lambda _: self.clean())
|
|
|
|
|
|
class LongTaskModuleData(BaseModel):
|
|
to_handle: dict[str, dict[str, LongTask]]
|
|
unhandled: dict[str, list[LongTask]]
|
|
|
|
|
|
async def get_long_task_target(event: BaseEvent, bot: BaseBot) -> LongTaskTarget | None:
|
|
if isinstance(event, OBGroupMessageEvent):
|
|
return LongTaskTarget(
|
|
platform="qq",
|
|
self_id=str(event.self_id),
|
|
channel_id=str(event.group_id),
|
|
target_id=str(event.user_id),
|
|
)
|
|
if isinstance(event, OBPrivateMessageEvent):
|
|
return LongTaskTarget(
|
|
platform="qq",
|
|
self_id=str(event.self_id),
|
|
channel_id=f"{QQ_PRIVATE_CHAT_CHANNEL_PREFIX}{event.self_id}",
|
|
target_id=str(event.user_id),
|
|
)
|
|
if isinstance(event, ConsoleMessageEvent):
|
|
return LongTaskTarget(
|
|
platform="console",
|
|
self_id=str(event.self_id),
|
|
channel_id=str(event.channel.id),
|
|
target_id=str(event.user.id),
|
|
)
|
|
if isinstance(event, DCMessageEvent):
|
|
self_id = ""
|
|
if isinstance(bot, DCBot):
|
|
self_id = str(bot.self_id)
|
|
return LongTaskTarget(
|
|
platform="discord",
|
|
self_id=self_id,
|
|
channel_id=str(event.channel_id),
|
|
target_id=str(event.user_id),
|
|
)
|
|
|
|
|
|
_TaskHandler = Callable[[LongTask], Coroutine[Any, Any, Any]]
|
|
|
|
|
|
registered_long_task_handler: dict[str, _TaskHandler] = {}
|
|
longtask_lock = asynkio.Lock()
|
|
|
|
|
|
def handle_long_task(callback_id: str):
|
|
def _decorator(func: _TaskHandler):
|
|
assert callback_id not in registered_long_task_handler, (
|
|
"有长任务的 ID 出现冲突,请换个名字!"
|
|
)
|
|
registered_long_task_handler[callback_id] = func
|
|
return func
|
|
|
|
return _decorator
|
|
|
|
|
|
def _load_longtask_data() -> LongTaskModuleData:
|
|
try:
|
|
txt = LONGTASK_DATA_DIR.read_text("utf-8")
|
|
return LongTaskModuleData.model_validate_json(txt)
|
|
except (FileNotFoundError, ValidationError) as e:
|
|
logger.info(f"取得 LongTask 数据时出现问题:{e}")
|
|
return LongTaskModuleData(
|
|
to_handle={},
|
|
unhandled={},
|
|
)
|
|
|
|
|
|
def _save_longtask_data(data: LongTaskModuleData):
|
|
LONGTASK_DATA_DIR.write_text(data.model_dump_json(), "utf-8")
|
|
|
|
|
|
@asynccontextmanager
|
|
async def longtask_data():
|
|
async with longtask_lock:
|
|
data = _load_longtask_data()
|
|
yield data
|
|
_save_longtask_data(data)
|
|
|
|
|
|
async def create_longtask(
|
|
handler: str,
|
|
data: dict[str, Any],
|
|
target: LongTaskTarget,
|
|
deadline: datetime.datetime,
|
|
):
|
|
task = LongTask(
|
|
uuid=str(uuid.uuid4()),
|
|
data_json=json.dumps(data),
|
|
target=target,
|
|
callback=handler,
|
|
deadline=deadline,
|
|
)
|
|
|
|
logger.info(f"创建了新的 LongTask UUID={task.uuid} CALLBACK={task.callback}")
|
|
await task.start()
|
|
|
|
async with longtask_data() as d:
|
|
d.to_handle.setdefault(handler, {})[task.uuid] = task
|
|
|
|
return task
|
|
|
|
|
|
async def init_longtask():
|
|
counter = 0
|
|
req: set[str] = set()
|
|
|
|
async with longtask_data() as data:
|
|
for v in data.to_handle.values():
|
|
for t in v.values():
|
|
await t.start()
|
|
counter += 1
|
|
req.add(t.callback)
|
|
|
|
logger.info(f"LongTask 启动了任务 数量={counter} 期望的门类=[{','.join(req)}]")
|
|
|
|
|
|
DepLongTaskTarget = Annotated[LongTaskTarget, Depends(get_long_task_target)]
|