113 lines
3.6 KiB
Python
113 lines
3.6 KiB
Python
import asyncio
|
||
import aiohttp
|
||
import hashlib
|
||
import platform
|
||
|
||
from dataclasses import dataclass
|
||
from pathlib import Path
|
||
|
||
import nonebot
|
||
from loguru import logger
|
||
from nonebot.adapters.discord.config import Config as DiscordConfig
|
||
from pydantic import BaseModel
|
||
|
||
|
||
@dataclass
|
||
class ArtifactDepends:
|
||
url: str
|
||
sha256: str
|
||
target: Path
|
||
|
||
required_os: str | None = None
|
||
"示例值:Windows, Linux, Darwin"
|
||
|
||
required_arch: str | None = None
|
||
"示例值:AMD64, x86_64, arm64"
|
||
|
||
use_proxy: bool = True
|
||
"网络问题,赫赫;使用的是 Discord 模块配置的 proxy"
|
||
|
||
def is_corresponding_platform(self) -> bool:
|
||
if self.required_os is not None:
|
||
if self.required_os.lower() != platform.system().lower():
|
||
return False
|
||
if self.required_arch is not None:
|
||
if self.required_arch.lower() != platform.machine().lower():
|
||
return False
|
||
return True
|
||
|
||
|
||
class Config(BaseModel):
|
||
prefetch_artifact: bool = False
|
||
"是否提前下载好二进制依赖"
|
||
|
||
|
||
artifact_list = []
|
||
|
||
|
||
driver = nonebot.get_driver()
|
||
config = nonebot.get_plugin_config(Config)
|
||
|
||
@driver.on_startup
|
||
async def _():
|
||
if config.prefetch_artifact:
|
||
logger.info("启动检测中:正在检测需求的二进制是否下载")
|
||
semaphore = asyncio.Semaphore(10)
|
||
async def _task(artifact: ArtifactDepends):
|
||
async with semaphore:
|
||
await ensure_artifact(artifact)
|
||
|
||
tasks: set[asyncio.Task] = set()
|
||
for a in artifact_list:
|
||
tasks.add(asyncio.Task(_task(a)))
|
||
await asyncio.gather(*tasks, return_exceptions=False)
|
||
logger.info("检测好了")
|
||
|
||
|
||
async def download_artifact(artifact: ArtifactDepends):
|
||
proxy = None
|
||
if artifact.use_proxy:
|
||
discord_config = nonebot.get_plugin_config(DiscordConfig)
|
||
proxy = discord_config.discord_proxy
|
||
|
||
if proxy is not None:
|
||
logger.info(f"正在使用 Proxy 下载 TARGET={artifact.target} PROXY={proxy}")
|
||
else:
|
||
logger.info(f"正在下载 TARGET={artifact.target}")
|
||
|
||
async with aiohttp.ClientSession(proxy=proxy) as client:
|
||
result = await client.get(artifact.url)
|
||
if result.status != 200:
|
||
logger.warning(f"已经下载了二进制,但是注意服务器没有返回 200! URL={artifact.url} TARGET={artifact.target} CODE={result.status}")
|
||
data = await result.read()
|
||
artifact.target.write_bytes(data)
|
||
if not platform.system().lower() == 'windows':
|
||
artifact.target.chmod(0o755)
|
||
|
||
logger.info(f"下载好了 TARGET={artifact.target} URL={artifact.url}")
|
||
m = hashlib.sha256(artifact.target.read_bytes())
|
||
if m.hexdigest().lower() != artifact.sha256.lower():
|
||
logger.warning(f"下载到的二进制的 sha256 与需求不同 TARGET={artifact.target} REQUESTED={artifact.sha256} ACTUAL={m.hexdigest()}")
|
||
|
||
|
||
async def ensure_artifact(artifact: ArtifactDepends):
|
||
if not artifact.is_corresponding_platform():
|
||
return
|
||
|
||
if not artifact.target.exists():
|
||
logger.info(f"二进制依赖 {artifact.target} 不存在")
|
||
if not artifact.target.parent.exists():
|
||
artifact.target.parent.mkdir(parents=True, exist_ok=True)
|
||
await download_artifact(artifact)
|
||
else:
|
||
m = hashlib.sha256(artifact.target.read_bytes())
|
||
if m.hexdigest().lower() != artifact.sha256.lower():
|
||
logger.info(f"二进制依赖 {artifact.target} 的哈希无法对应需求的哈希,准备重新下载")
|
||
artifact.target.unlink()
|
||
await download_artifact(artifact)
|
||
|
||
|
||
def register_artifacts(*artifacts: ArtifactDepends):
|
||
artifact_list.extend(artifacts)
|
||
|