添加若干有用的小模块

This commit is contained in:
2025-10-19 04:45:15 +08:00
parent 9f3f79f51d
commit aaf0a75d65
3 changed files with 361 additions and 0 deletions

View File

@ -0,0 +1,36 @@
import asyncio
from contextlib import asynccontextmanager
from pathlib import Path
from typing import Generic, TypeVar
from pydantic import BaseModel, ValidationError
T = TypeVar("T", bound=BaseModel)
class DataManager(Generic[T]):
def __init__(self, cls: type[T], fp: Path) -> None:
self.cls = cls
self.fp = fp
self._aio_lock = asyncio.Lock()
self._data: T | None = None
def load(self) -> T:
if not self.fp.exists():
return self.cls()
try:
return self.cls.model_validate_json(self.fp.read_text())
except ValidationError:
return self.cls()
def save(self, data: T):
self.fp.write_text(data.model_dump_json())
@asynccontextmanager
async def get_data(self):
await self._aio_lock.acquire()
self._data = self.load()
yield self._data
self.save(self._data)
self._data = None
self._aio_lock.release()

275
konabot/common/longtask.py Normal file
View File

@ -0,0 +1,275 @@
from contextlib import asynccontextmanager
import datetime
import json
from typing import Annotated, Any, Callable, Coroutine, cast
import asyncio as asynkio
import uuid
from loguru import logger
import nonebot
from nonebot.params import Depends
from nonebot.adapters import Event as BaseEvent
from nonebot.adapters import Bot as BaseBot
from nonebot.adapters.onebot.v11 import Bot as OBBot
from nonebot.adapters.onebot.v11 import GroupMessageEvent as OBGroupMessageEvent
from nonebot.adapters.onebot.v11 import PrivateMessageEvent as OBPrivateMessageEvent
from nonebot.adapters.console import Bot as ConsoleBot
from nonebot.adapters.console import MessageEvent as ConsoleMessageEvent
from nonebot.adapters.discord import MessageEvent as DCMessageEvent
from nonebot.adapters.discord import Bot as DCBot
from nonebot_plugin_alconna import UniMessage
from pydantic import BaseModel, ValidationError
from .path import DATA_PATH
LONGTASK_DATA_DIR = DATA_PATH / "longtasks.json"
class LongTaskTarget(BaseModel):
"""
用于定义长期任务的目标沟通对象,一般通过 DepLongTaskTarget 依赖注入获取:
```python
@cmd.handle()
async def _(target: DepLongTaskTarget):
...
```
"""
platform: str
"沟通对象所在的平台"
self_id: str
"进行沟通的对象自己的 ID"
channel_id: str
"沟通对象所在的群或者 Discord Channel。若为空则代表是私聊"
target_id: str
"沟通对象的 ID"
async def send_message(self, msg: UniMessage, at: bool = True) -> bool:
try:
bot = nonebot.get_bot(self.self_id)
except KeyError:
logger.warning(f"试图访问了不存在的 Bot。ID={self.self_id}")
return False
if self.platform == "qq":
if not isinstance(bot, OBBot):
logger.warning(
f"编号对应的平台并非期望的平台 ID={self.self_id} PLATFORM={
self.platform
} BOT_CLASS={bot.__class__.__name__}"
)
return False
if self.channel_id == "":
# 私聊模式
await bot.send_private_msg(
user_id=int(self.target_id),
message=cast(Any, await msg.export(bot)),
auto_escape=False,
)
return True
else:
if at:
msg = UniMessage().at(self.target_id).text(" ") + msg
await bot.send_group_msg(
group_id=int(self.channel_id),
message=cast(Any, await msg.export(bot)),
auto_escape=False,
)
return True
if self.platform == "console":
if not isinstance(bot, ConsoleBot):
logger.warning(
f"编号对应的平台并非期望的平台 ID={self.self_id} PLATFORM={
self.platform
} BOT_CLASS={bot.__class__.__name__}"
)
return False
await bot.send_message(self.channel_id, cast(Any, await msg.export()))
if self.platform == "discord":
if not isinstance(bot, DCBot):
logger.warning(
f"编号对应的平台并非期望的平台 ID={self.self_id} PLATFORM={
self.platform
} BOT_CLASS={bot.__class__.__name__}"
)
return False
await bot.send_to(
channel_id=int(self.channel_id),
message=cast(
Any, await (UniMessage().at(self.target_id) + msg).export()
),
tts=False,
)
logger.warning(f"没有一个平台是期望的平台 PLATFORM={self.platform}")
return False
class LongTask(BaseModel):
uuid: str
data_json: str
target: "LongTaskTarget"
callback: str
deadline: datetime.datetime
canceled: bool = False
_aio_task: asynkio.Task | None = None
async def run(self):
now = datetime.datetime.now()
if self.deadline < now and not self.canceled:
await self._run_task()
return
await asynkio.sleep((self.deadline - now).total_seconds())
if self.canceled:
return
await self._run_task()
async def _run_task(self):
hdl = registered_long_task_handler.get(self.callback, None)
if hdl is None:
logger.warning(
f"Callback {self.callback} 未曾被注册,但是被期待调用,已忽略"
)
async with longtask_data() as datafile:
datafile.to_handle[self.callback] = [
t
for t in datafile.to_handle.get(self.callback, [])
if t.uuid != self.uuid
]
datafile.unhandled.setdefault(self.callback, []).append(self)
return
await hdl(self)
async with longtask_data() as datafile:
datafile.to_handle[self.callback] = [
t for t in datafile.to_handle[self.callback] if t.uuid != self.uuid
]
def clean(self):
self._aio_task = None
@property
def data(self):
return json.loads(self.data_json)
async def start(self):
self._aio_task = asynkio.Task(self.run())
self._aio_task.add_done_callback(lambda _: self.clean())
class LongTaskModuleData(BaseModel):
to_handle: dict[str, list[LongTask]]
unhandled: dict[str, list[LongTask]]
async def get_long_task_target(event: BaseEvent, bot: BaseBot) -> LongTaskTarget | None:
if isinstance(event, OBGroupMessageEvent):
return LongTaskTarget(
platform="qq",
self_id=str(event.self_id),
channel_id=str(event.group_id),
target_id=str(event.user_id),
)
if isinstance(event, OBPrivateMessageEvent):
return LongTaskTarget(
platform="qq",
self_id=str(event.self_id),
channel_id="",
target_id=str(event.user_id),
)
if isinstance(event, ConsoleMessageEvent):
return LongTaskTarget(
platform="console",
self_id=str(event.self_id),
channel_id=str(event.channel.id),
target_id=str(event.user.id),
)
if isinstance(event, DCMessageEvent):
self_id = ""
if isinstance(bot, DCBot):
self_id = str(bot.self_id)
return LongTaskTarget(
platform="discord",
self_id=self_id,
channel_id=str(event.channel_id),
target_id=str(event.user_id),
)
_TaskHandler = Callable[[LongTask], Coroutine[Any, Any, Any]]
registered_long_task_handler: dict[str, _TaskHandler] = {}
longtask_lock = asynkio.Lock()
def handle_long_task(callback_id: str):
def _decorator(func: _TaskHandler):
assert callback_id not in registered_long_task_handler, (
"有长任务的 ID 出现冲突,请换个名字!"
)
registered_long_task_handler[callback_id] = func
return func
return _decorator
def _load_longtask_data() -> LongTaskModuleData:
try:
txt = LONGTASK_DATA_DIR.read_text()
return LongTaskModuleData.model_validate_json(txt)
except (FileNotFoundError, ValidationError) as e:
logger.info(f"取得 LongTask 数据时出现问题:{e}")
return LongTaskModuleData(
to_handle={},
unhandled={},
)
def _save_longtask_data(data: LongTaskModuleData):
LONGTASK_DATA_DIR.write_text(data.model_dump_json())
@asynccontextmanager
async def longtask_data():
await longtask_lock.acquire()
data = _load_longtask_data()
yield data
_save_longtask_data(data)
longtask_lock.release()
async def create_longtask(
handler: str,
data: dict[str, Any],
target: LongTaskTarget,
deadline: datetime.datetime,
):
task = LongTask(
uuid=str(uuid.uuid4()),
data_json=json.dumps(data),
target=target,
callback=handler,
deadline=deadline,
)
await task.start()
async with longtask_data() as d:
d.to_handle.setdefault(handler, []).append(task)
return task
async def init_longtask():
async with longtask_data() as data:
for v in data.to_handle.values():
for t in v:
await t.start()
DepLongTaskTarget = Annotated[LongTaskTarget, Depends(get_long_task_target)]

View File

@ -0,0 +1,50 @@
import asyncio
# import datetime
from loguru import logger
import nonebot
# from nonebot.adapters import Bot, Event
# from nonebot_plugin_alconna import UniMessage
from konabot.common.longtask import (
# DepLongTaskTarget,
# LongTask,
# create_longtask,
# get_long_task_target,
# handle_long_task,
init_longtask,
)
driver = nonebot.get_driver()
INIT_FLAG = {"flag": False}
@driver.on_bot_connect
async def _():
if INIT_FLAG["flag"]:
return
INIT_FLAG["flag"] = True
logger.info("有 Bot 连接,等待 5 秒后初始化 LongTask 模块")
await asyncio.sleep(5)
await init_longtask()
logger.info("LongTask 初始化完成")
# cmd1 = nonebot.on_command("test114")
#
#
# @handle_long_task("test_callback_001")
# async def _(lt: LongTask):
# await lt.target.send_message(UniMessage().text("Hello, world!"), at=True)
#
#
# @cmd1.handle()
# async def _(target: DepLongTaskTarget):
# await create_longtask(
# handler="test_callback_001",
# data={},
# target=target,
# deadline=datetime.datetime.now() + datetime.timedelta(seconds=2),
# )
# await target.send_message(UniMessage().text("Hello, world!"), at=True)