228 lines
8.1 KiB
Python
228 lines
8.1 KiB
Python
from io import BytesIO
|
||
from pathlib import Path
|
||
from typing import Annotated
|
||
|
||
import httpx
|
||
import PIL.Image
|
||
from loguru import logger
|
||
import nonebot
|
||
from nonebot.matcher import Matcher
|
||
from nonebot.adapters import Bot, Event, Message
|
||
from nonebot.adapters.discord import Bot as DiscordBot
|
||
from nonebot.adapters.discord import MessageEvent as DiscordMessageEvent
|
||
from nonebot.adapters.discord.config import Config as DiscordConfig
|
||
from nonebot.adapters.onebot.v11 import Bot as OnebotV11Bot
|
||
from nonebot.adapters.onebot.v11 import Message as OnebotV11Message
|
||
from nonebot.adapters.onebot.v11 import MessageEvent as OnebotV11MessageEvent
|
||
import nonebot.params
|
||
from nonebot_plugin_alconna import Image, RefNode, Reply, UniMessage
|
||
from PIL import UnidentifiedImageError
|
||
from pydantic import BaseModel
|
||
from returns.result import Failure, Result, Success
|
||
|
||
|
||
discordConfig = nonebot.get_plugin_config(DiscordConfig)
|
||
|
||
|
||
class ExtractImageConfig(BaseModel):
|
||
module_extract_image_no_download: bool = False
|
||
"""
|
||
要不要算了,不下载了,直接爆炸算了,
|
||
适用于一些比较奇怪的网络环境,无法从协议端下载文件
|
||
"""
|
||
|
||
module_extract_image_target: str = './assets/img/other/boom.jpg'
|
||
"""
|
||
使用哪个图片呢
|
||
"""
|
||
|
||
|
||
module_config = nonebot.get_plugin_config(ExtractImageConfig)
|
||
|
||
|
||
async def download_image_bytes(url: str, proxy: str | None = None) -> Result[bytes, str]:
|
||
# if "/matcha/cache/" in url:
|
||
# url = url.replace('127.0.0.1', '10.126.126.101')
|
||
if module_config.module_extract_image_no_download:
|
||
return Success(Path(module_config.module_extract_image_target).read_bytes())
|
||
logger.debug(f"开始从 {url} 下载图片")
|
||
async with httpx.AsyncClient(proxy=proxy) as c:
|
||
try:
|
||
response = await c.get(url)
|
||
except (httpx.ConnectError, httpx.RemoteProtocolError) as e:
|
||
return Failure(f"HTTPX 模块下载图片时出错:{e}")
|
||
except httpx.ConnectTimeout:
|
||
return Failure("下载图片失败了,网络超时了qwq")
|
||
if response.status_code != 200:
|
||
return Failure("无法下载图片,可能存在网络问题需要排查")
|
||
return Success(response.content)
|
||
|
||
|
||
def bytes_to_pil(raw_data: bytes | BytesIO) -> Result[PIL.Image.Image, str]:
|
||
try:
|
||
if not isinstance(raw_data, BytesIO):
|
||
img_pil = PIL.Image.open(BytesIO(raw_data))
|
||
else:
|
||
img_pil = PIL.Image.open(raw_data)
|
||
img_pil.verify()
|
||
if not isinstance(raw_data, BytesIO):
|
||
img = PIL.Image.open(BytesIO(raw_data))
|
||
else:
|
||
raw_data.seek(0)
|
||
img = PIL.Image.open(raw_data)
|
||
return Success(img)
|
||
except UnidentifiedImageError:
|
||
return Failure("图像无法读取,可能是格式不支持orz")
|
||
except IOError:
|
||
return Failure("图像无法读取,可能是网络存在问题orz")
|
||
|
||
|
||
async def unimsg_img_to_bytes(image: Image) -> Result[bytes, str]:
|
||
if image.url is not None:
|
||
raw_result = await download_image_bytes(image.url)
|
||
elif image.raw is not None:
|
||
if isinstance(image.raw, bytes):
|
||
raw_result = Success(image.raw)
|
||
else:
|
||
raw_result = Success(image.raw.getvalue())
|
||
else:
|
||
return Failure("由于一些内部问题,下载图片失败了orz")
|
||
|
||
return raw_result
|
||
|
||
|
||
async def unimsg_img_to_pil(image: Image) -> Result[PIL.Image.Image, str]:
|
||
return (await unimsg_img_to_bytes(image)).bind(bytes_to_pil)
|
||
|
||
|
||
async def extract_image_from_qq_message(
|
||
msg: OnebotV11Message,
|
||
evt: OnebotV11MessageEvent,
|
||
bot: OnebotV11Bot,
|
||
allow_reply: bool = True,
|
||
) -> Result[bytes, str]:
|
||
if allow_reply and (reply := evt.reply) is not None:
|
||
return await extract_image_from_qq_message(
|
||
reply.message,
|
||
evt,
|
||
bot,
|
||
False,
|
||
)
|
||
for seg in msg:
|
||
if seg.type == "reply" and allow_reply:
|
||
msgid = seg.data.get("id")
|
||
if msgid is None:
|
||
return Failure("消息可能太久远,无法读取到消息原文")
|
||
try:
|
||
msg2 = await bot.get_msg(message_id=msgid)
|
||
except Exception as e:
|
||
logger.warning(f"获取消息内容时出错:{e}")
|
||
return Failure("消息可能太久远,无法读取到消息原文")
|
||
msg2_data = msg2.get("message")
|
||
if msg2_data is None:
|
||
return Failure("消息可能太久远,无法读取到消息原文")
|
||
logger.debug("发现消息引用,递归一层")
|
||
return await extract_image_from_qq_message(
|
||
msg=OnebotV11Message(msg2_data),
|
||
evt=evt,
|
||
bot=bot,
|
||
allow_reply=False,
|
||
)
|
||
if seg.type == "image":
|
||
url = seg.data.get("url")
|
||
if url is None:
|
||
return Failure("无法下载图片,可能有一些网络问题")
|
||
return await download_image_bytes(url)
|
||
|
||
return Failure("请在消息中包含图片,或者引用一个含有图片的消息")
|
||
|
||
|
||
async def extract_image_data_from_message(
|
||
msg: Message,
|
||
evt: Event,
|
||
bot: Bot,
|
||
allow_reply: bool = True,
|
||
) -> Result[bytes, str]:
|
||
if (
|
||
isinstance(bot, OnebotV11Bot)
|
||
and isinstance(msg, OnebotV11Message)
|
||
and isinstance(evt, OnebotV11MessageEvent)
|
||
):
|
||
# 看起来 UniMessage 在这方面能力似乎不足,因此用 QQ 的
|
||
logger.debug('获取图片的路径 Fallback 到 QQ 模块')
|
||
return await extract_image_from_qq_message(msg, evt, bot, allow_reply)
|
||
|
||
if isinstance(evt, DiscordMessageEvent):
|
||
logger.debug('获取图片的路径方式走 Discord')
|
||
for a in evt.attachments:
|
||
if "image/" not in a.content_type:
|
||
continue
|
||
url = a.proxy_url
|
||
return await download_image_bytes(url, discordConfig.discord_proxy)
|
||
|
||
for seg in UniMessage.of(msg, bot):
|
||
logger.info(seg)
|
||
if isinstance(seg, Image):
|
||
return await unimsg_img_to_bytes(seg)
|
||
elif isinstance(seg, Reply) and allow_reply:
|
||
msg2 = seg.msg
|
||
logger.debug(f"深入搜索引用的消息:{msg2}")
|
||
if msg2 is None or isinstance(msg2, str):
|
||
continue
|
||
return await extract_image_data_from_message(msg2, evt, bot, False)
|
||
elif isinstance(seg, RefNode) and allow_reply:
|
||
if isinstance(bot, DiscordBot):
|
||
return Failure("暂时不支持在 Discord 中通过引用的方式获取图片")
|
||
else:
|
||
return Failure("暂时不支持在这里中通过引用的方式获取图片")
|
||
return Failure("请在消息中包含图片,或者引用一个含有图片的消息")
|
||
|
||
|
||
async def _ext_img_data(
|
||
evt: Event,
|
||
bot: Bot,
|
||
matcher: Matcher,
|
||
) -> bytes | None:
|
||
match await extract_image_data_from_message(evt.get_message(), evt, bot):
|
||
case Success(img):
|
||
return img
|
||
case Failure(err):
|
||
# raise BotExceptionMessage(err)
|
||
await matcher.send(await UniMessage().text(err).export())
|
||
return None
|
||
assert False
|
||
|
||
|
||
async def _ext_img(
|
||
evt: Event,
|
||
bot: Bot,
|
||
matcher: Matcher,
|
||
) -> PIL.Image.Image | None:
|
||
r = await _ext_img_data(evt, bot, matcher)
|
||
if r:
|
||
match bytes_to_pil(r):
|
||
case Success(img):
|
||
return img
|
||
case Failure(msg):
|
||
await matcher.send(await UniMessage.text(msg).export())
|
||
return None
|
||
|
||
async def _try_ext_img(
|
||
evt: Event,
|
||
bot: Bot,
|
||
matcher: Matcher,
|
||
) -> bytes | None:
|
||
match await extract_image_data_from_message(evt.get_message(), evt, bot):
|
||
case Success(img):
|
||
return img
|
||
case Failure(err):
|
||
# raise BotExceptionMessage(err)
|
||
# await matcher.send(await UniMessage().text(err).export())
|
||
return None
|
||
assert False
|
||
|
||
DepImageBytes = Annotated[bytes, nonebot.params.Depends(_ext_img_data)]
|
||
DepPILImage = Annotated[PIL.Image.Image, nonebot.params.Depends(_ext_img)]
|
||
|
||
DepImageBytesOrNone = Annotated[bytes | None, nonebot.params.Depends(_try_ext_img)]
|