diff --git a/konabot/common/path.py b/konabot/common/path.py index 6bea695..e81760d 100644 --- a/konabot/common/path.py +++ b/konabot/common/path.py @@ -12,3 +12,10 @@ DOCS_PATH_MAN1 = DOCS_PATH / "user" DOCS_PATH_MAN3 = DOCS_PATH / "lib" DOCS_PATH_MAN7 = DOCS_PATH / "concepts" DOCS_PATH_MAN8 = DOCS_PATH / "sys" + +if not DATA_PATH.exists(): + DATA_PATH.mkdir() + +if not LOG_PATH.exists(): + LOG_PATH.mkdir() + diff --git a/konabot/plugins/poll/__init__.py b/konabot/plugins/poll/__init__.py index 266a1c6..8885d2a 100644 --- a/konabot/plugins/poll/__init__.py +++ b/konabot/plugins/poll/__init__.py @@ -1,16 +1,19 @@ import json, time -from nonebot_plugin_alconna import (Alconna, Args, Field, MultiVar, UniMessage, - on_alconna) -from nonebot_plugin_alconna.uniseg import UniMsg, At, Reply +from nonebot_plugin_alconna import Alconna, Args, Field, MultiVar, on_alconna from nonebot.adapters.onebot.v11 import Event -poll_json_path = "assets/json/poll.json" +from konabot.common.path import ASSETS_PATH, DATA_PATH -poll_file = open(poll_json_path,"r",encoding="utf-8") -poll_list_raw = poll_file.read() -poll_file.close() -poll_list = json.loads(poll_list_raw)['poll'] + +POLL_TEMPLATE_FILE = ASSETS_PATH / "json" / "poll.json" +POLL_DATA_FILE = DATA_PATH / "poll.json" + +if not POLL_DATA_FILE.exists(): + POLL_DATA_FILE.write_bytes(POLL_TEMPLATE_FILE.read_bytes()) + + +poll_list = json.loads(POLL_DATA_FILE.read_text())['poll'] async def createpoll(title,qqid,options): polllength = len(poll_list) @@ -44,8 +47,11 @@ def getpolldata(pollid_or_title): return [thepoll,polnum] def writeback(): - file = open(poll_json_path,"w",encoding="utf-8") - json.dump({'poll':poll_list},file,ensure_ascii=False,sort_keys=True) + # file = open(poll_json_path,"w",encoding="utf-8") + # json.dump({'poll':poll_list},file,ensure_ascii=False,sort_keys=True) + POLL_DATA_FILE.write_text(json.dumps({ + 'poll': poll_list, + }, ensure_ascii=False, sort_keys=True)) async def pollvote(polnum,optionnum,qqnum): optiond = poll_list[polnum]["polldata"] @@ -157,4 +163,4 @@ async def _(saying: list, event: Event): # 写入项目 else: await pollvote(polnum,optionnum,event.get_user_id()) - await viewpoll.send("投票成功!你投给了 "+saying[1]) \ No newline at end of file + await viewpoll.send("投票成功!你投给了 "+saying[1]) diff --git a/konabot/plugins/simple_notify/__init__.py b/konabot/plugins/simple_notify/__init__.py index 3f0d964..8f0b5d3 100644 --- a/konabot/plugins/simple_notify/__init__.py +++ b/konabot/plugins/simple_notify/__init__.py @@ -1,8 +1,10 @@ -import asyncio +import asyncio as asynkio import datetime +import functools from pathlib import Path from typing import Any, Literal, cast +import signal import nonebot import ptimeparse from loguru import logger @@ -24,7 +26,9 @@ evt = on_message() (Path(__file__).parent.parent.parent.parent / "data").mkdir(exist_ok=True) DATA_FILE_PATH = Path(__file__).parent.parent.parent.parent / "data" / "notify.json" -DATA_FILE_LOCK = asyncio.Lock() +DATA_FILE_LOCK = asynkio.Lock() + +ASYNK_TASKS: set[asynkio.Task[Any]] = set() class Notify(BaseModel): @@ -111,7 +115,11 @@ def create_notify_task(notify: Notify, fail2remove: bool = True): async def mission(): begin_time = datetime.datetime.now() if begin_time < notify.notify_time: - await asyncio.sleep((notify.notify_time - begin_time).total_seconds()) + try: + await asynkio.sleep((notify.notify_time - begin_time).total_seconds()) + except asynkio.CancelledError: + logger.debug("代办提醒被信号中止,任务退出") + return else: logger.warning( f"期望在 {notify.notify_time} 在平台 {notify.platform} {notify.target_env}" @@ -128,7 +136,7 @@ def create_notify_task(notify: Notify, fail2remove: bool = True): DATA_FILE_LOCK.release() else: pass - return asyncio.create_task(mission()) + return asynkio.create_task(mission()) @evt.handle() @@ -214,11 +222,11 @@ async def _(): DELTA = 2 logger.info(f"第一次探测到 Bot 连接,等待 {DELTA} 秒后开始通知") - await asyncio.sleep(DELTA) + await asynkio.sleep(DELTA) await DATA_FILE_LOCK.acquire() - tasks: set[asyncio.Task[Any]] = set() + # tasks: set[asynkio.Task[Any]] = set() cfg = load_notify_config() if cfg.version == 1: logger.info("将配置文件的版本升级为 2") @@ -227,11 +235,26 @@ async def _(): counter = 0 for notify in [*cfg.notifies]: task = create_notify_task(notify, fail2remove=False) - tasks.add(task) - task.add_done_callback(lambda self: tasks.remove(self)) + ASYNK_TASKS.add(task) + task.add_done_callback(lambda self: ASYNK_TASKS.remove(self)) counter += 1 logger.info(f"成功创建了 {counter} 条代办事项") save_notify_config(cfg) DATA_FILE_LOCK.release() - await asyncio.gather(*tasks) + loop = asynkio.get_running_loop() + + # 解决 asynk task 没有被 cancel 的问题 + async def shutdown(sig: signal.Signals): + logger.info(f"收到 {sig.name} 指令,正在关闭所有的东西") + for task in ASYNK_TASKS: + task.cancel() + await asynkio.gather(*ASYNK_TASKS, return_exceptions=True) + logger.info("所有的代办提醒 Task 都已经退出了") + + for sig in (signal.SIGINT, signal.SIGTERM): + loop.add_signal_handler(sig, functools.partial( + asynkio.create_task, shutdown(sig) + )) + + await asynkio.gather(*ASYNK_TASKS)