Files
konabot/konabot/plugins/simple_notify/__init__.py
Passthem f6fd25a41d
All checks were successful
continuous-integration/drone/push Build is passing
continuous-integration/drone/tag Build is passing
fix: 添加 Notify 任务创建限制
2025-09-30 19:08:30 +08:00

208 lines
6.5 KiB
Python

import asyncio
import datetime
from pathlib import Path
from typing import Any, Literal, cast
import nonebot
from loguru import logger
from nonebot import on_message
from nonebot.adapters import Event
from nonebot.adapters.console import Bot as ConsoleBot
from nonebot.adapters.console.event import MessageEvent as ConsoleMessageEvent
from nonebot.adapters.discord import Bot as DiscordBot
from nonebot.adapters.discord.event import MessageEvent as DiscordMessageEvent
from nonebot.adapters.onebot.v11 import Bot as OnebotV11Bot
from nonebot.adapters.onebot.v11.event import \
GroupMessageEvent as OnebotV11GroupMessageEvent
from nonebot.adapters.onebot.v11.event import \
MessageEvent as OnebotV11MessageEvent
from nonebot_plugin_alconna import UniMessage, UniMsg
from pydantic import BaseModel
from konabot.plugins.simple_notify.parse_time import get_target_time
evt = on_message()
(Path(__file__).parent.parent.parent.parent / "data").mkdir(exist_ok=True)
DATA_FILE_PATH = Path(__file__).parent.parent.parent.parent / "data" / "notify.json"
DATA_FILE_LOCK = asyncio.Lock()
class Notify(BaseModel):
platform: Literal["console", "qq", "discord"]
target: str
"需要接受通知的个体"
target_env: str | None
"在哪里进行通知,如果是 None 代表私聊通知"
notify_time: datetime.datetime
notify_msg: str
def get_str(self):
return f"{self.target}-{self.target_env}-{self.platform}-{self.notify_time}"
class NotifyConfigFile(BaseModel):
version: int = 1
notifies: list[Notify] = []
unsent: list[Notify] = []
def load_notify_config() -> NotifyConfigFile:
if not DATA_FILE_PATH.exists():
return NotifyConfigFile()
try:
return NotifyConfigFile.model_validate_json(DATA_FILE_PATH.read_text())
except Exception as e:
logger.warning(f"在解析 Notify 时遇到问题:{e}")
return NotifyConfigFile()
def save_notify_config(config: NotifyConfigFile):
DATA_FILE_PATH.write_text(config.model_dump_json(indent=4))
async def notify_now(notify: Notify):
if notify.platform == 'console':
bot = [b for b in nonebot.get_bots().values() if isinstance(b, ConsoleBot)]
if len(bot) != 1:
logger.warning(f"提醒未成功发送出去:{nonebot.get_bots()} {notify}")
return False
bot = bot[0]
await bot.send_private_message(notify.target, f"代办通知:{notify.notify_msg}")
elif notify.platform == 'discord':
bot = [b for b in nonebot.get_bots().values() if isinstance(b, DiscordBot)]
if len(bot) != 1:
logger.warning(f"提醒未成功发送出去:{nonebot.get_bots()} {notify}")
return False
bot = bot[0]
channel = await bot.create_DM(recipient_id=int(notify.target))
await bot.send_to(channel.id, f"代办通知:{notify.notify_msg}")
elif notify.platform == 'qq':
bot = [b for b in nonebot.get_bots().values() if isinstance(b, OnebotV11Bot)]
if len(bot) != 1:
logger.warning(f"提醒未成功发送出去:{nonebot.get_bots()} {notify}")
return False
bot = bot[0]
if notify.target_env is None:
await bot.send_private_msg(
user_id=int(notify.target),
message=f"代办通知:{notify.notify_msg}",
)
else:
await bot.send_group_msg(
group_id=int(notify.target_env),
message=cast(Any,
await UniMessage().at(notify.target).text(f" 代办通知:{notify.notify_msg}").export()
),
)
else:
logger.warning(f"提醒未成功发送出去:{notify}")
return False
return True
async def create_notify_task(notify: Notify, fail2remove: bool = True):
async def mission():
begin_time = datetime.datetime.now()
if begin_time < notify.notify_time:
await asyncio.sleep((notify.notify_time - begin_time).total_seconds())
res = await notify_now(notify)
if fail2remove or res:
await DATA_FILE_LOCK.acquire()
cfg = load_notify_config()
cfg.notifies = [n for n in cfg.notifies if n.get_str() != notify.get_str()]
if not res:
cfg.unsent.append(notify)
save_notify_config(cfg)
DATA_FILE_LOCK.release()
else:
pass
return asyncio.create_task(mission())
@evt.handle()
async def _(msg: UniMsg, mEvt: Event):
if mEvt.get_user_id() in nonebot.get_bots():
return
text = msg.extract_plain_text()
if "提醒我" not in text:
return
segments = text.split("提醒我", maxsplit=1)
if len(segments) != 2:
return
notify_time, notify_text = segments
target_time = get_target_time(notify_time)
if target_time is None:
logger.info(f"无法从 {notify_time} 中解析出时间")
return
if not notify_text:
return
await DATA_FILE_LOCK.acquire()
cfg = load_notify_config()
if isinstance(mEvt, ConsoleMessageEvent):
platform = "console"
target = mEvt.get_user_id()
target_env = None
elif isinstance(mEvt, OnebotV11MessageEvent):
platform = "qq"
target = mEvt.get_user_id()
if isinstance(mEvt, OnebotV11GroupMessageEvent):
target_env = str(mEvt.group_id)
else:
target_env = None
elif isinstance(mEvt, DiscordMessageEvent):
platform = "discord"
target = mEvt.get_user_id()
target_env = None
else:
logger.warning(f"Notify 遇到不支持的平台:{type(mEvt).__name__}")
return
notify = Notify(
platform=platform,
target=target,
target_env=target_env,
notify_time=target_time,
notify_msg=notify_text,
)
await create_notify_task(notify)
cfg.notifies.append(notify)
save_notify_config(cfg)
DATA_FILE_LOCK.release()
await evt.send(await UniMessage().at(mEvt.get_user_id()).text(
f" 了解啦!将会在 {notify.notify_time} 提醒你哦~").export())
driver = nonebot.get_driver()
NOTIFIED_FLAG = {
"task_added": False,
}
@driver.on_bot_connect
async def _():
if NOTIFIED_FLAG["task_added"]:
return
NOTIFIED_FLAG["task_added"] = True
await DATA_FILE_LOCK.acquire()
tasks = []
cfg = load_notify_config()
for notify in cfg.notifies:
tasks.append(create_notify_task(notify, fail2remove=False))
DATA_FILE_LOCK.release()
await asyncio.gather(*tasks)