339 lines
12 KiB
Python
339 lines
12 KiB
Python
import asyncio
|
||
import datetime
|
||
import re
|
||
from pathlib import Path
|
||
from typing import Any, Literal, cast
|
||
|
||
import nonebot
|
||
from loguru import logger
|
||
from nonebot import on_message
|
||
from nonebot.adapters import Event
|
||
from nonebot.adapters.console import Bot as ConsoleBot
|
||
from nonebot.adapters.console.event import MessageEvent as ConsoleMessageEvent
|
||
from nonebot.adapters.discord import Bot as DiscordBot
|
||
from nonebot.adapters.discord.event import MessageEvent as DiscordMessageEvent
|
||
from nonebot.adapters.onebot.v11 import Bot as OnebotV11Bot
|
||
from nonebot.adapters.onebot.v11.event import \
|
||
GroupMessageEvent as OnebotV11GroupMessageEvent
|
||
from nonebot.adapters.onebot.v11.event import \
|
||
MessageEvent as OnebotV11MessageEvent
|
||
from nonebot_plugin_alconna import UniMessage, UniMsg
|
||
from pydantic import BaseModel
|
||
|
||
PATTERN_DELTA_HMS = re.compile(r"^((\d+|[零一两二三四五六七八九十]+) ?天)?((\d+|[零一两二三四五六七八九十]+) ?个?小?时)?((\d+|[零一两二三四五六七八九十]+) ?分钟?)?((\d+|[零一两二三四五六七八九十]+) ?秒钟?)? ?后 ?$")
|
||
|
||
PATTERN_DATE_SPECIFY = re.compile(r"(\d{1,2}|[零一二三四五六七八九十]+) ?[日号]")
|
||
PATTERN_MONTH_SPECIFY = re.compile(r"(\d{1,2}|[零一二三四五六七八九十]+) ?月")
|
||
PATTERN_YEAR_SPECIFY = re.compile(r"(\d|[零一二三四五六七八九十]+) ?年")
|
||
PATTERN_HOUR_SPECIFY = re.compile(r"(\d|[零一二三四五六七八九十]+) ?[点时](半?)钟?")
|
||
PATTERN_MINUTE_SPECIFY = re.compile(r"(\d|[零一二三四五六七八九十]+) ?分(钟)?")
|
||
PATTERN_SECOND_SPECIFY = re.compile(r"(\d|[零一二三四五六七八九十]+) ?秒(钟)?")
|
||
PATTERN_HMS_SPECIFY = re.compile(r"\d\d[::]\d\d([::]\d\d)?")
|
||
PATTERN_PM_SPECIFY = re.compile(r"(下午|PM|晚上)")
|
||
|
||
|
||
def parse_chinese_or_digit(s: str) -> int:
|
||
if set(s) <= set("0123456789"):
|
||
return int(s)
|
||
|
||
s = s.replace("两", "二")
|
||
|
||
chinese_to_arabic = {
|
||
'零': 0, '一': 1, '二': 2, '三': 3, '四': 4,
|
||
'五': 5, '六': 6, '七': 7, '八': 8, '九': 9,
|
||
'十': 10
|
||
}
|
||
|
||
if s in chinese_to_arabic:
|
||
return chinese_to_arabic[s]
|
||
|
||
if len(s) == 2 and s[0] == '十':
|
||
if s[1] not in chinese_to_arabic:
|
||
return -1
|
||
return 10 + chinese_to_arabic.get(s[1], 0)
|
||
elif len(s) == 2 and s[1] == '十':
|
||
if s[0] not in chinese_to_arabic:
|
||
return -1
|
||
return 10 * chinese_to_arabic.get(s[0], 0)
|
||
elif len(s) == 3 and s[1] == '十':
|
||
if s[0] not in chinese_to_arabic or s[2] not in chinese_to_arabic:
|
||
return -1
|
||
return 10 * chinese_to_arabic.get(s[0], 0) + chinese_to_arabic.get(s[2], 0)
|
||
|
||
try:
|
||
return int(s)
|
||
except ValueError:
|
||
return -1
|
||
|
||
|
||
def get_target_time(content: str) -> datetime.datetime | None:
|
||
if match := re.match(PATTERN_DELTA_HMS, content.strip()):
|
||
days = parse_chinese_or_digit(match.group(2) or "0")
|
||
hours = parse_chinese_or_digit(match.group(4) or "0")
|
||
minutes = parse_chinese_or_digit(match.group(6) or "0")
|
||
seconds = parse_chinese_or_digit(match.group(8) or "0")
|
||
return datetime.datetime.now() + datetime.timedelta(days=days, hours=hours, minutes=minutes, seconds=seconds)
|
||
|
||
t = datetime.datetime.now()
|
||
content_to_match = content
|
||
if "明天" in content_to_match:
|
||
content_to_match = "".join(content_to_match.split("明天"))
|
||
t += datetime.timedelta(days=1)
|
||
elif "后天" in content_to_match:
|
||
content_to_match = "".join(content_to_match.split("后天"))
|
||
t += datetime.timedelta(days=2)
|
||
elif "今天" in content_to_match:
|
||
content_to_match = "".join(content_to_match.split("今天"))
|
||
|
||
if match1 := re.match(PATTERN_DATE_SPECIFY, content_to_match):
|
||
content_to_match = "".join(content_to_match.split(match1.group(0)))
|
||
day = parse_chinese_or_digit(match1.group(1))
|
||
if day <= 0 or day > 31:
|
||
return
|
||
if day < t.day:
|
||
if t.month == 12:
|
||
t = t.replace(year=t.year + 1, month=1, day=day)
|
||
else:
|
||
t = t.replace(month=t.month + 1, day=day)
|
||
else:
|
||
t = t.replace(day=day)
|
||
if match2 := re.match(PATTERN_MONTH_SPECIFY, content_to_match):
|
||
content_to_match = "".join(content_to_match.split(match2.group(0)))
|
||
month = parse_chinese_or_digit(match2.group(1))
|
||
if month <= 0 or month > 12:
|
||
return
|
||
if month < t.month:
|
||
t = t.replace(year=t.year + 1, month=month)
|
||
else:
|
||
t = t.replace(month=month)
|
||
if match3 := re.match(PATTERN_YEAR_SPECIFY, content_to_match):
|
||
content_to_match = "".join(content_to_match.split(match3.group(0)))
|
||
year = parse_chinese_or_digit(match3.group(1))
|
||
if year < 100:
|
||
year += 2000
|
||
if year < t.year:
|
||
return
|
||
t = t.replace(year=year)
|
||
if match4 := re.match(PATTERN_HOUR_SPECIFY, content_to_match):
|
||
content_to_match = "".join(content_to_match.split(match4.group(0)))
|
||
hour = parse_chinese_or_digit(match4.group(1))
|
||
if hour < 0 or hour > 23:
|
||
return
|
||
t = t.replace(hour=hour, minute=0, second=0)
|
||
if match4.group(2) != None:
|
||
t = t.replace(minute=30)
|
||
if match5 := re.match(PATTERN_MINUTE_SPECIFY, content_to_match):
|
||
content_to_match = "".join(content_to_match.split(match5.group(0)))
|
||
minute = parse_chinese_or_digit(match5.group(1))
|
||
if minute < 0 or minute > 59:
|
||
return
|
||
t = t.replace(minute=minute, second=0)
|
||
if match6 := re.match(PATTERN_SECOND_SPECIFY, content_to_match):
|
||
content_to_match = "".join(content_to_match.split(match6.group(0)))
|
||
second = parse_chinese_or_digit(match6.group(1))
|
||
if second < 0 or second > 59:
|
||
return
|
||
t = t.replace(second=second)
|
||
if match7 := re.match(PATTERN_HMS_SPECIFY, content_to_match):
|
||
content_to_match = "".join(content_to_match.split(match7.group(0)))
|
||
hms = match7.group(0).replace(":", ":").split(":")
|
||
if len(hms) >= 2:
|
||
hour = int(hms[0])
|
||
minute = int(hms[1])
|
||
if hour < 0 or hour > 23 or minute < 0 or minute > 59:
|
||
return
|
||
t = t.replace(hour=hour, minute=minute)
|
||
if len(hms) == 3:
|
||
second = int(hms[2])
|
||
if second < 0 or second > 59:
|
||
return
|
||
t = t.replace(second=second)
|
||
|
||
content_to_match = content_to_match.replace("上午", "").replace("AM", "").replace("凌晨", "")
|
||
if match8 := re.match(PATTERN_PM_SPECIFY, content_to_match):
|
||
content_to_match = "".join(content_to_match.split(match8.group(0)))
|
||
if t.hour < 12:
|
||
t = t.replace(hour=t.hour + 12)
|
||
if t.hour == 12:
|
||
t += datetime.timedelta(hours=12)
|
||
|
||
if len(content_to_match.strip()) != 0:
|
||
return
|
||
if t < datetime.datetime.now():
|
||
t += datetime.timedelta(days=1)
|
||
return t
|
||
|
||
|
||
evt = on_message()
|
||
|
||
(Path(__file__).parent.parent.parent.parent / "data").mkdir(exist_ok=True)
|
||
DATA_FILE_PATH = Path(__file__).parent.parent.parent.parent / "data" / "notify.json"
|
||
DATA_FILE_LOCK = asyncio.Lock()
|
||
|
||
|
||
class Notify(BaseModel):
|
||
platform: Literal["console", "qq", "discord"]
|
||
|
||
target: str
|
||
"需要接受通知的个体"
|
||
|
||
target_env: str | None
|
||
"在哪里进行通知,如果是 None 代表私聊通知"
|
||
|
||
notify_time: datetime.datetime
|
||
notify_msg: str
|
||
|
||
def get_str(self):
|
||
return f"{self.target}-{self.target_env}-{self.platform}-{self.notify_time}"
|
||
|
||
|
||
class NotifyConfigFile(BaseModel):
|
||
version: int = 1
|
||
notifies: list[Notify] = []
|
||
unsent: list[Notify] = []
|
||
|
||
|
||
def load_notify_config() -> NotifyConfigFile:
|
||
if not DATA_FILE_PATH.exists():
|
||
return NotifyConfigFile()
|
||
try:
|
||
return NotifyConfigFile.model_validate_json(DATA_FILE_PATH.read_text())
|
||
except Exception as e:
|
||
logger.warning(f"在解析 Notify 时遇到问题:{e}")
|
||
return NotifyConfigFile()
|
||
|
||
|
||
def save_notify_config(config: NotifyConfigFile):
|
||
DATA_FILE_PATH.write_text(config.model_dump_json(indent=4))
|
||
|
||
|
||
async def notify_now(notify: Notify):
|
||
if notify.platform == 'console':
|
||
bot = [b for b in nonebot.get_bots().values() if isinstance(b, ConsoleBot)]
|
||
if len(bot) != 1:
|
||
logger.warning(f"提醒未成功发送出去:{nonebot.get_bots()} {notify}")
|
||
return False
|
||
bot = bot[0]
|
||
await bot.send_private_message(notify.target, f"代办通知:{notify.notify_msg}")
|
||
elif notify.platform == 'discord':
|
||
bot = [b for b in nonebot.get_bots().values() if isinstance(b, DiscordBot)]
|
||
if len(bot) != 1:
|
||
logger.warning(f"提醒未成功发送出去:{nonebot.get_bots()} {notify}")
|
||
return False
|
||
bot = bot[0]
|
||
channel = await bot.create_DM(recipient_id=int(notify.target))
|
||
await bot.send_to(channel.id, f"代办通知:{notify.notify_msg}")
|
||
elif notify.platform == 'qq':
|
||
bot = [b for b in nonebot.get_bots().values() if isinstance(b, OnebotV11Bot)]
|
||
if len(bot) != 1:
|
||
logger.warning(f"提醒未成功发送出去:{nonebot.get_bots()} {notify}")
|
||
return False
|
||
bot = bot[0]
|
||
if notify.target_env is None:
|
||
await bot.send_private_msg(
|
||
user_id=int(notify.target),
|
||
message=f"代办通知:{notify.notify_msg}",
|
||
)
|
||
else:
|
||
await bot.send_group_msg(
|
||
group_id=int(notify.target_env),
|
||
message=cast(Any,
|
||
await UniMessage().at(notify.target).text(f" 代办通知:{notify.notify_msg}").export()
|
||
),
|
||
)
|
||
else:
|
||
logger.warning(f"提醒未成功发送出去:{notify}")
|
||
return False
|
||
return True
|
||
|
||
|
||
async def create_notify_task(notify: Notify, fail2remove: bool = True):
|
||
async def mission():
|
||
begin_time = datetime.datetime.now()
|
||
if begin_time < notify.notify_time:
|
||
await asyncio.sleep((notify.notify_time - begin_time).total_seconds())
|
||
res = await notify_now(notify)
|
||
if fail2remove or res:
|
||
await DATA_FILE_LOCK.acquire()
|
||
cfg = load_notify_config()
|
||
cfg.notifies = [n for n in cfg.notifies if n.get_str() != notify.get_str()]
|
||
if not res:
|
||
cfg.unsent.append(notify)
|
||
save_notify_config(cfg)
|
||
DATA_FILE_LOCK.release()
|
||
else:
|
||
pass
|
||
return asyncio.create_task(mission())
|
||
|
||
|
||
@evt.handle()
|
||
async def _(msg: UniMsg, mEvt: Event):
|
||
text = msg.extract_plain_text()
|
||
if "提醒我" not in text:
|
||
return
|
||
|
||
segments = text.split("提醒我")
|
||
if len(segments) != 2:
|
||
return
|
||
|
||
notify_time, notify_text = segments
|
||
target_time = get_target_time(notify_time)
|
||
if target_time is None:
|
||
logger.info(f"无法从 {notify_time} 中解析出时间")
|
||
return
|
||
if not notify_text:
|
||
return
|
||
|
||
await DATA_FILE_LOCK.acquire()
|
||
cfg = load_notify_config()
|
||
|
||
if isinstance(mEvt, ConsoleMessageEvent):
|
||
platform = "console"
|
||
target = mEvt.get_user_id()
|
||
target_env = None
|
||
elif isinstance(mEvt, OnebotV11MessageEvent):
|
||
platform = "qq"
|
||
target = mEvt.get_user_id()
|
||
if isinstance(mEvt, OnebotV11GroupMessageEvent):
|
||
target_env = str(mEvt.group_id)
|
||
else:
|
||
target_env = None
|
||
elif isinstance(mEvt, DiscordMessageEvent):
|
||
platform = "discord"
|
||
target = mEvt.get_user_id()
|
||
target_env = None
|
||
else:
|
||
logger.warning(f"Notify 遇到不支持的平台:{type(mEvt).__name__}")
|
||
return
|
||
|
||
notify = Notify(
|
||
platform=platform,
|
||
target=target,
|
||
target_env=target_env,
|
||
notify_time=target_time,
|
||
notify_msg=notify_text,
|
||
)
|
||
await create_notify_task(notify)
|
||
|
||
cfg.notifies.append(notify)
|
||
save_notify_config(cfg)
|
||
DATA_FILE_LOCK.release()
|
||
|
||
await evt.send(await UniMessage().at(mEvt.get_user_id()).text(
|
||
f" 了解啦!将会在 {notify.notify_time} 提醒你哦~").export())
|
||
|
||
|
||
driver = nonebot.get_driver()
|
||
|
||
|
||
@driver.on_bot_connect
|
||
async def _():
|
||
await DATA_FILE_LOCK.acquire()
|
||
tasks = []
|
||
cfg = load_notify_config()
|
||
for notify in cfg.notifies:
|
||
tasks.append(create_notify_task(notify, fail2remove=False))
|
||
DATA_FILE_LOCK.release()
|
||
|
||
await asyncio.gather(*tasks)
|