在我写的模块采用更安全的 asyncio 锁写法
All checks were successful
continuous-integration/drone/push Build is passing
continuous-integration/drone/tag Build is passing

This commit is contained in:
2025-10-19 20:27:18 +08:00
parent fd4c9302c2
commit 67382a0c0a
2 changed files with 41 additions and 31 deletions

View File

@ -254,11 +254,10 @@ def _save_longtask_data(data: LongTaskModuleData):
@asynccontextmanager @asynccontextmanager
async def longtask_data(): async def longtask_data():
await longtask_lock.acquire() async with longtask_lock:
data = _load_longtask_data() data = _load_longtask_data()
yield data yield data
_save_longtask_data(data) _save_longtask_data(data)
longtask_lock.release()
async def create_longtask( async def create_longtask(

View File

@ -15,10 +15,10 @@ from nonebot.adapters.console.event import MessageEvent as ConsoleMessageEvent
from nonebot.adapters.discord import Bot as DiscordBot from nonebot.adapters.discord import Bot as DiscordBot
from nonebot.adapters.discord.event import MessageEvent as DiscordMessageEvent from nonebot.adapters.discord.event import MessageEvent as DiscordMessageEvent
from nonebot.adapters.onebot.v11 import Bot as OnebotV11Bot from nonebot.adapters.onebot.v11 import Bot as OnebotV11Bot
from nonebot.adapters.onebot.v11.event import \ from nonebot.adapters.onebot.v11.event import (
GroupMessageEvent as OnebotV11GroupMessageEvent GroupMessageEvent as OnebotV11GroupMessageEvent,
from nonebot.adapters.onebot.v11.event import \ )
MessageEvent as OnebotV11MessageEvent from nonebot.adapters.onebot.v11.event import MessageEvent as OnebotV11MessageEvent
from nonebot_plugin_alconna import UniMessage, UniMsg from nonebot_plugin_alconna import UniMessage, UniMsg
from pydantic import BaseModel from pydantic import BaseModel
@ -68,14 +68,14 @@ def save_notify_config(config: NotifyConfigFile):
async def notify_now(notify: Notify): async def notify_now(notify: Notify):
if notify.platform == 'console': if notify.platform == "console":
bot = [b for b in nonebot.get_bots().values() if isinstance(b, ConsoleBot)] bot = [b for b in nonebot.get_bots().values() if isinstance(b, ConsoleBot)]
if len(bot) != 1: if len(bot) != 1:
logger.warning(f"提醒未成功发送出去:{nonebot.get_bots()} {notify}") logger.warning(f"提醒未成功发送出去:{nonebot.get_bots()} {notify}")
return False return False
bot = bot[0] bot = bot[0]
await bot.send_private_message(notify.target, f"代办通知:{notify.notify_msg}") await bot.send_private_message(notify.target, f"代办通知:{notify.notify_msg}")
elif notify.platform == 'discord': elif notify.platform == "discord":
bot = [b for b in nonebot.get_bots().values() if isinstance(b, DiscordBot)] bot = [b for b in nonebot.get_bots().values() if isinstance(b, DiscordBot)]
if len(bot) != 1: if len(bot) != 1:
logger.warning(f"提醒未成功发送出去:{nonebot.get_bots()} {notify}") logger.warning(f"提醒未成功发送出去:{nonebot.get_bots()} {notify}")
@ -83,7 +83,7 @@ async def notify_now(notify: Notify):
bot = bot[0] bot = bot[0]
channel = await bot.create_DM(recipient_id=int(notify.target)) channel = await bot.create_DM(recipient_id=int(notify.target))
await bot.send_to(channel.id, f"代办通知:{notify.notify_msg}") await bot.send_to(channel.id, f"代办通知:{notify.notify_msg}")
elif notify.platform == 'qq': elif notify.platform == "qq":
bot = [b for b in nonebot.get_bots().values() if isinstance(b, OnebotV11Bot)] bot = [b for b in nonebot.get_bots().values() if isinstance(b, OnebotV11Bot)]
if len(bot) != 1: if len(bot) != 1:
logger.warning(f"提醒未成功发送出去:{nonebot.get_bots()} {notify}") logger.warning(f"提醒未成功发送出去:{nonebot.get_bots()} {notify}")
@ -92,17 +92,22 @@ async def notify_now(notify: Notify):
if notify.target_env is None: if notify.target_env is None:
await bot.send_private_msg( await bot.send_private_msg(
user_id=int(notify.target), user_id=int(notify.target),
message=cast(Any, await UniMessage.text(f"代办通知:{notify.notify_msg}").export( message=cast(
Any,
await UniMessage.text(f"代办通知:{notify.notify_msg}").export(
bot=bot, bot=bot,
)), ),
),
) )
else: else:
await bot.send_group_msg( await bot.send_group_msg(
group_id=int(notify.target_env), group_id=int(notify.target_env),
message=cast(Any, message=cast(
await UniMessage().at( Any,
notify.target await UniMessage()
).text(f" 代办通知:{notify.notify_msg}").export(bot=bot) .at(notify.target)
.text(f" 代办通知:{notify.notify_msg}")
.export(bot=bot),
), ),
) )
else: else:
@ -127,15 +132,17 @@ def create_notify_task(notify: Notify, fail2remove: bool = True):
) )
res = await notify_now(notify) res = await notify_now(notify)
if fail2remove or res: if fail2remove or res:
await DATA_FILE_LOCK.acquire() async with DATA_FILE_LOCK:
cfg = load_notify_config() cfg = load_notify_config()
cfg.notifies = [n for n in cfg.notifies if n.get_str() != notify.get_str()] cfg.notifies = [
n for n in cfg.notifies if n.get_str() != notify.get_str()
]
if not res: if not res:
cfg.unsent.append(notify) cfg.unsent.append(notify)
save_notify_config(cfg) save_notify_config(cfg)
DATA_FILE_LOCK.release()
else: else:
pass pass
return asynkio.create_task(mission()) return asynkio.create_task(mission())
@ -201,8 +208,12 @@ async def _(msg: UniMsg, mEvt: Event):
save_notify_config(cfg) save_notify_config(cfg)
DATA_FILE_LOCK.release() DATA_FILE_LOCK.release()
await evt.send(await UniMessage().at(mEvt.get_user_id()).text( await evt.send(
f" 了解啦!将会在 {notify.notify_time} 提醒你哦~").export()) await UniMessage()
.at(mEvt.get_user_id())
.text(f" 了解啦!将会在 {notify.notify_time} 提醒你哦~")
.export()
)
logger.info(f"创建了一条于 {notify.notify_time} 的代办提醒") logger.info(f"创建了一条于 {notify.notify_time} 的代办提醒")
@ -253,8 +264,8 @@ async def _():
logger.info("所有的代办提醒 Task 都已经退出了") logger.info("所有的代办提醒 Task 都已经退出了")
for sig in (signal.SIGINT, signal.SIGTERM): for sig in (signal.SIGINT, signal.SIGTERM):
loop.add_signal_handler(sig, functools.partial( loop.add_signal_handler(
asynkio.create_task, shutdown(sig) sig, functools.partial(asynkio.create_task, shutdown(sig))
)) )
await asynkio.gather(*ASYNK_TASKS) await asynkio.gather(*ASYNK_TASKS)