Compare commits

..

1 Commits

Author SHA1 Message Date
0cb236b561 feat: evolve textfx into a mini shell 2026-03-14 23:35:29 +08:00
12 changed files with 48 additions and 457 deletions

View File

@ -1,5 +1,3 @@
{ {
"python.REPL.enableREPLSmartSend": false, "python.REPL.enableREPLSmartSend": false
"python-envs.defaultEnvManager": "ms-python.python:poetry",
"python-envs.defaultPackageManager": "ms-python.python:poetry"
} }

View File

@ -1,3 +1,16 @@
FROM alpine:latest AS artifacts
RUN apk add --no-cache curl xz
WORKDIR /tmp
RUN mkdir -p /artifacts
RUN curl -L -o typst.tar.xz "https://github.com/typst/typst/releases/download/v0.14.2/typst-x86_64-unknown-linux-musl.tar.xz" \
&& tar -xJf typst.tar.xz \
&& mv typst-x86_64-unknown-linux-musl/typst /artifacts
RUN chmod -R +x /artifacts/
FROM python:3.13-slim AS base FROM python:3.13-slim AS base
ENV VIRTUAL_ENV=/app/.venv \ ENV VIRTUAL_ENV=/app/.venv \
@ -38,6 +51,7 @@ RUN uv sync --no-install-project
FROM base AS runtime FROM base AS runtime
COPY --from=builder ${VIRTUAL_ENV} ${VIRTUAL_ENV} COPY --from=builder ${VIRTUAL_ENV} ${VIRTUAL_ENV}
COPY --from=artifacts /artifacts/ /usr/local/bin/
WORKDIR /app WORKDIR /app

View File

@ -1,10 +1,9 @@
import asyncio import asyncio
from typing import Any, Awaitable, Callable
import aiohttp import aiohttp
import hashlib import hashlib
import platform import platform
from dataclasses import dataclass, field from dataclasses import dataclass
from pathlib import Path from pathlib import Path
import nonebot import nonebot
@ -15,8 +14,6 @@ from pydantic import BaseModel
@dataclass @dataclass
class ArtifactDepends: class ArtifactDepends:
_Callback = Callable[[bool], Awaitable[Any]]
url: str url: str
sha256: str sha256: str
target: Path target: Path
@ -30,9 +27,6 @@ class ArtifactDepends:
use_proxy: bool = True use_proxy: bool = True
"网络问题,赫赫;使用的是 Discord 模块配置的 proxy" "网络问题,赫赫;使用的是 Discord 模块配置的 proxy"
callbacks: list[_Callback] = field(default_factory=list)
"在任务完成以后,应该做的事情"
def is_corresponding_platform(self) -> bool: def is_corresponding_platform(self) -> bool:
if self.required_os is not None: if self.required_os is not None:
if self.required_os.lower() != platform.system().lower(): if self.required_os.lower() != platform.system().lower():
@ -42,43 +36,26 @@ class ArtifactDepends:
return False return False
return True return True
def on_finished(self, task: _Callback) -> _Callback:
self.callbacks.append(task)
return task
async def _finished(self, downloaded: bool) -> list[Any | BaseException]:
tasks = set()
for f in self.callbacks:
tasks.add(f(downloaded))
return await asyncio.gather(*tasks, return_exceptions=True)
class Config(BaseModel): class Config(BaseModel):
prefetch_artifact: bool = False prefetch_artifact: bool = False
"是否提前下载好二进制依赖" "是否提前下载好二进制依赖"
artifact_list: list[ArtifactDepends] = [] artifact_list = []
driver = nonebot.get_driver() driver = nonebot.get_driver()
config = nonebot.get_plugin_config(Config) config = nonebot.get_plugin_config(Config)
@driver.on_startup @driver.on_startup
async def _(): async def _():
if config.prefetch_artifact: if config.prefetch_artifact:
logger.info("启动检测中:正在检测需求的二进制是否下载") logger.info("启动检测中:正在检测需求的二进制是否下载")
semaphore = asyncio.Semaphore(10) semaphore = asyncio.Semaphore(10)
async def _task(artifact: ArtifactDepends): async def _task(artifact: ArtifactDepends):
async with semaphore: async with semaphore:
downloaded = await ensure_artifact(artifact) await ensure_artifact(artifact)
result = await artifact._finished(downloaded)
for r in result:
if isinstance(r, BaseException):
logger.warning("完成了二进制文件的下载,但是有未捕捉的错误")
logger.exception(r)
tasks: set[asyncio.Task] = set() tasks: set[asyncio.Task] = set()
for a in artifact_list: for a in artifact_list:
@ -101,43 +78,35 @@ async def download_artifact(artifact: ArtifactDepends):
async with aiohttp.ClientSession(proxy=proxy) as client: async with aiohttp.ClientSession(proxy=proxy) as client:
result = await client.get(artifact.url) result = await client.get(artifact.url)
if result.status != 200: if result.status != 200:
logger.warning( logger.warning(f"已经下载了二进制,但是注意服务器没有返回 200 URL={artifact.url} TARGET={artifact.target} CODE={result.status}")
f"已经下载了二进制,但是注意服务器没有返回 200 URL={artifact.url} TARGET={artifact.target} CODE={result.status}"
)
data = await result.read() data = await result.read()
artifact.target.write_bytes(data) artifact.target.write_bytes(data)
if not platform.system().lower() == "windows": if not platform.system().lower() == 'windows':
artifact.target.chmod(0o755) artifact.target.chmod(0o755)
logger.info(f"下载好了 TARGET={artifact.target} URL={artifact.url}") logger.info(f"下载好了 TARGET={artifact.target} URL={artifact.url}")
m = hashlib.sha256(artifact.target.read_bytes()) m = hashlib.sha256(artifact.target.read_bytes())
if m.hexdigest().lower() != artifact.sha256.lower(): if m.hexdigest().lower() != artifact.sha256.lower():
logger.warning( logger.warning(f"下载到的二进制的 sha256 与需求不同 TARGET={artifact.target} REQUESTED={artifact.sha256} ACTUAL={m.hexdigest()}")
f"下载到的二进制的 sha256 与需求不同 TARGET={artifact.target} REQUESTED={artifact.sha256} ACTUAL={m.hexdigest()}"
)
async def ensure_artifact(artifact: ArtifactDepends) -> bool: async def ensure_artifact(artifact: ArtifactDepends):
if not artifact.is_corresponding_platform(): if not artifact.is_corresponding_platform():
return False return
if not artifact.target.exists(): if not artifact.target.exists():
logger.info(f"二进制依赖 {artifact.target} 不存在") logger.info(f"二进制依赖 {artifact.target} 不存在")
if not artifact.target.parent.exists(): if not artifact.target.parent.exists():
artifact.target.parent.mkdir(parents=True, exist_ok=True) artifact.target.parent.mkdir(parents=True, exist_ok=True)
await download_artifact(artifact) await download_artifact(artifact)
return True
else: else:
m = hashlib.sha256(artifact.target.read_bytes()) m = hashlib.sha256(artifact.target.read_bytes())
if m.hexdigest().lower() != artifact.sha256.lower(): if m.hexdigest().lower() != artifact.sha256.lower():
logger.info( logger.info(f"二进制依赖 {artifact.target} 的哈希无法对应需求的哈希,准备重新下载")
f"二进制依赖 {artifact.target} 的哈希无法对应需求的哈希,准备重新下载"
)
artifact.target.unlink() artifact.target.unlink()
await download_artifact(artifact) await download_artifact(artifact)
return True
return False
def register_artifacts(*artifacts: ArtifactDepends): def register_artifacts(*artifacts: ArtifactDepends):
artifact_list.extend(artifacts) artifact_list.extend(artifacts)

View File

@ -161,9 +161,9 @@ class PipelineRunner:
"'": "'", "'": "'",
} }
def flush_word(force: bool = False): def flush_word():
nonlocal buf nonlocal buf
if buf or force: if buf:
tokens.append(Token(TokenKind.WORD, buf)) tokens.append(Token(TokenKind.WORD, buf))
buf = "" buf = ""
@ -178,7 +178,6 @@ class PipelineRunner:
escape = True escape = True
elif c == quote: elif c == quote:
quote = None quote = None
flush_word(force=True) # 引号闭合时强制 flush即使是空字符串
else: else:
buf += c buf += c
i += 1 i += 1
@ -189,7 +188,7 @@ class PipelineRunner:
i += 1 i += 1
continue continue
if c.isspace(): if c.isspace() or c in "":
flush_word() flush_word()
i += 1 i += 1
continue continue
@ -564,20 +563,12 @@ class PipelineRunner:
results: list[TextHandleResult] = [] results: list[TextHandleResult] = []
for statement in pipeline.statements: for statement in pipeline.statements:
try: if isinstance(statement, IfNode):
if isinstance(statement, IfNode): results.append(await self._execute_if(statement, istream, env))
results.append(await self._execute_if(statement, istream, env)) elif isinstance(statement, WhileNode):
elif isinstance(statement, WhileNode): results.append(await self._execute_while(statement, istream, env))
results.append(await self._execute_while(statement, istream, env)) else:
else: results.append(await self._execute_group(statement, istream, env))
results.append(await self._execute_group(statement, istream, env))
except Exception as e:
logger.error(f"Pipeline execution failed: {e}")
logger.exception(e)
results.append(
TextHandleResult(code=-1, ostream="处理流水线时出现 python 错误")
)
return results
return results return results

View File

@ -15,12 +15,10 @@ class THQwen(TextHandler):
self, env: TextHandlerEnvironment, istream: str | None, args: list[str] self, env: TextHandlerEnvironment, istream: str | None, args: list[str]
) -> TextHandleResult: ) -> TextHandleResult:
pm = perm_manager() pm = perm_manager()
if env.event is None or not await pm.check_has_permission( if env.event is None or not pm.check_has_permission(env.event, "textfx.qwen"):
env.event, "textfx.qwen"
):
return TextHandleResult( return TextHandleResult(
code=1, code=1,
ostream="你或当前环境没有使用 qwen 的权限。如有疑问请联系管理员", ostream="这里暂未开启 AI 功能",
) )
llm = get_llm() llm = get_llm()

View File

@ -13,8 +13,10 @@ class THEcho(TextHandler):
async def handle( async def handle(
self, env: TextHandlerEnvironment, istream: str | None, args: list[str] self, env: TextHandlerEnvironment, istream: str | None, args: list[str]
) -> TextHandleResult: ) -> TextHandleResult:
# echo 不读 stdin只输出参数Unix 语义) if len(args) == 0 and istream is None:
# 无参数时输出空行(与 Unix echo 行为一致) return TextHandleResult(1, "请在 echo 后面添加需要输出的文本")
if istream is not None:
return TextHandleResult(0, "\n".join([istream] + args))
return TextHandleResult(0, "\n".join(args)) return TextHandleResult(0, "\n".join(args))
@ -116,8 +118,9 @@ class THTest(TextHandler):
self, env: TextHandlerEnvironment, istream: str | None, args: list[str] self, env: TextHandlerEnvironment, istream: str | None, args: list[str]
) -> TextHandleResult: ) -> TextHandleResult:
expr = list(args) expr = list(args)
if self.name == "[":
pass
# 支持方括号语法:[ expr ] 会自动移除末尾的 ]
if expr and expr[-1] == "]": if expr and expr[-1] == "]":
expr = expr[:-1] expr = expr[:-1]

View File

@ -1,210 +0,0 @@
import copy
import re
from pathlib import Path
import nonebot
from nonebot import on_command
from nonebot.adapters import Bot, Event, Message
from nonebot.log import logger
from nonebot.message import handle_event
from nonebot.params import CommandArg
from konabot.common.database import DatabaseManager
from konabot.common.longtask import DepLongTaskTarget
ROOT_PATH = Path(__file__).resolve().parent
cmd = on_command(cmd="语法糖", aliases={"", "sugar"}, block=True)
db_manager = DatabaseManager()
driver = nonebot.get_driver()
@driver.on_startup
async def register_startup_hook():
await init_db()
@driver.on_shutdown
async def register_shutdown_hook():
await db_manager.close_all_connections()
async def init_db():
await db_manager.execute_by_sql_file(ROOT_PATH / "sql" / "create_table.sql")
table_info = await db_manager.query("PRAGMA table_info(syntactic_sugar)")
columns = {str(row.get("name")) for row in table_info}
if "channel_id" not in columns:
await db_manager.execute(
"ALTER TABLE syntactic_sugar ADD COLUMN channel_id VARCHAR(255) NOT NULL DEFAULT ''"
)
await db_manager.execute("DROP INDEX IF EXISTS idx_syntactic_sugar_name_belong_to")
await db_manager.execute(
"CREATE UNIQUE INDEX IF NOT EXISTS idx_syntactic_sugar_name_channel_target "
"ON syntactic_sugar(name, channel_id, belong_to)"
)
def _extract_reply_plain_text(evt: Event) -> str:
reply = getattr(evt, "reply", None)
if reply is None:
return ""
reply_message = getattr(reply, "message", None)
if reply_message is None:
return ""
extract_plain_text = getattr(reply_message, "extract_plain_text", None)
if callable(extract_plain_text):
return extract_plain_text().strip()
return str(reply_message).strip()
def _split_variables(tokens: list[str]) -> tuple[list[str], dict[str, str]]:
positional: list[str] = []
named: dict[str, str] = {}
for token in tokens:
if "=" in token:
key, value = token.split("=", 1)
key = key.strip()
if key:
named[key] = value
continue
positional.append(token)
return positional, named
def _render_template(content: str, positional: list[str], named: dict[str, str]) -> str:
def replace(match: re.Match[str]) -> str:
key = match.group(1).strip()
if key.isdigit():
idx = int(key) - 1
if 0 <= idx < len(positional):
return positional[idx]
return match.group(0)
return named.get(key, match.group(0))
return re.sub(r"\{([^{}]+)\}", replace, content)
async def _store_sugar(name: str, content: str, belong_to: str, channel_id: str):
await db_manager.execute_by_sql_file(
ROOT_PATH / "sql" / "insert_sugar.sql",
(name, content, belong_to, channel_id),
)
async def _delete_sugar(name: str, belong_to: str, channel_id: str):
await db_manager.execute(
"DELETE FROM syntactic_sugar WHERE name = ? AND belong_to = ? AND channel_id = ?",
(name, belong_to, channel_id),
)
async def _find_sugar(name: str, belong_to: str, channel_id: str) -> str | None:
rows = await db_manager.query(
(
"SELECT content FROM syntactic_sugar "
"WHERE name = ? AND channel_id = ? "
"ORDER BY CASE WHEN belong_to = ? THEN 0 ELSE 1 END, id ASC "
"LIMIT 1"
),
(name, channel_id, belong_to),
)
if not rows:
return None
return rows[0].get("content")
async def _reinject_command(bot: Bot, evt: Event, command_text: str) -> bool:
depth = int(getattr(evt, "_syntactic_sugar_depth", 0))
if depth >= 3:
return False
try:
cloned_evt = copy.deepcopy(evt)
except Exception:
logger.exception("语法糖克隆事件失败")
return False
message = getattr(cloned_evt, "message", None)
if message is None:
return False
try:
msg_obj = type(message)(command_text)
except Exception:
msg_obj = command_text
setattr(cloned_evt, "message", msg_obj)
if hasattr(cloned_evt, "original_message"):
setattr(cloned_evt, "original_message", msg_obj)
if hasattr(cloned_evt, "raw_message"):
setattr(cloned_evt, "raw_message", command_text)
setattr(cloned_evt, "_syntactic_sugar_depth", depth + 1)
try:
await handle_event(bot, cloned_evt)
except Exception:
logger.exception("语法糖回注事件失败")
return False
return True
@cmd.handle()
async def _(bot: Bot, evt: Event, target: DepLongTaskTarget, args: Message = CommandArg()):
raw = args.extract_plain_text().strip()
if not raw:
return
tokens = raw.split()
action = tokens[0]
target_id = target.target_id
channel_id = target.channel_id
if action == "存入":
if len(tokens) < 2:
await cmd.finish("请提供要存入的名称")
name = tokens[1].strip()
content = " ".join(tokens[2:]).strip()
if not content:
content = _extract_reply_plain_text(evt)
if not content:
await cmd.finish("请提供要存入的内容")
await _store_sugar(name, content, target_id, channel_id)
await cmd.finish(f"糖已存入:「{name}」!")
if action == "删除":
if len(tokens) < 2:
await cmd.finish("请提供要删除的名称")
name = tokens[1].strip()
await _delete_sugar(name, target_id, channel_id)
await cmd.finish(f"已删除糖:「{name}」!")
if action == "查看":
if len(tokens) < 2:
await cmd.finish("请提供要查看的名称")
name = tokens[1].strip()
content = await _find_sugar(name, target_id, channel_id)
if content is None:
await cmd.finish(f"没有糖:「{name}")
await cmd.finish(f"糖的内容:「{content}")
name = action
content = await _find_sugar(name, target_id, channel_id)
if content is None:
await cmd.finish(f"没有糖:「{name}")
positional, named = _split_variables(tokens[1:])
rendered = _render_template(content, positional, named)
ok = await _reinject_command(bot, evt, rendered)
if not ok:
await cmd.finish(f"糖的展开结果:「{rendered}")

View File

@ -1,12 +0,0 @@
-- 创建语法糖表
CREATE TABLE IF NOT EXISTS syntactic_sugar (
id INTEGER PRIMARY KEY AUTOINCREMENT,
name VARCHAR(255) NOT NULL,
content TEXT NOT NULL,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
belong_to VARCHAR(255) NOT NULL,
channel_id VARCHAR(255) NOT NULL DEFAULT ''
);
CREATE UNIQUE INDEX IF NOT EXISTS idx_syntactic_sugar_name_channel_target
ON syntactic_sugar(name, channel_id, belong_to);

View File

@ -1,5 +0,0 @@
-- 插入语法糖,如果同一用户下名称已存在则更新内容
INSERT INTO syntactic_sugar (name, content, belong_to, channel_id)
VALUES (?, ?, ?, ?)
ON CONFLICT(name, channel_id, belong_to) DO UPDATE SET
content = excluded.content;

View File

@ -1,11 +1,9 @@
import asyncio import asyncio
import os
import subprocess import subprocess
from pathlib import Path from pathlib import Path
from tempfile import TemporaryDirectory from tempfile import TemporaryDirectory
from typing import cast from typing import cast
import zipfile
from loguru import logger from loguru import logger
from nonebot import on_command from nonebot import on_command
@ -15,99 +13,22 @@ from nonebot.adapters.onebot.v11.event import MessageEvent as OB11MessageEvent
from nonebot.adapters.onebot.v11.bot import Bot as OB11Bot from nonebot.adapters.onebot.v11.bot import Bot as OB11Bot
from nonebot.adapters.onebot.v11.message import Message as OB11Message from nonebot.adapters.onebot.v11.message import Message as OB11Message
from konabot.common.artifact import ArtifactDepends, ensure_artifact, register_artifacts
from konabot.common.longtask import DepLongTaskTarget from konabot.common.longtask import DepLongTaskTarget
from konabot.common.path import BINARY_PATH, TMP_PATH from konabot.common.path import TMP_PATH
arti_typst_linux = ArtifactDepends(
url="https://github.com/typst/typst/releases/download/v0.14.2/typst-x86_64-unknown-linux-musl.tar.xz",
sha256="a6044cbad2a954deb921167e257e120ac0a16b20339ec01121194ff9d394996d",
target=BINARY_PATH / "typst.tar.xz",
required_os="Linux",
required_arch="x86_64",
)
arti_typst_windows = ArtifactDepends(
url="https://github.com/typst/typst/releases/download/v0.14.2/typst-x86_64-pc-windows-msvc.zip",
sha256="51353994ac83218c3497052e89b2c432c53b9d4439cdc1b361e2ea4798ebfc13",
target=BINARY_PATH / "typst.zip",
required_os="Windows",
required_arch="AMD64",
)
bin_path: Path | None = None
@arti_typst_linux.on_finished
async def _(downloaded: bool):
global bin_path
tar_path = arti_typst_linux.target
bin_path = BINARY_PATH / "typst"
if downloaded or not bin_path.exists():
bin_path.unlink(missing_ok=True)
process = await asyncio.create_subprocess_exec(
"tar",
"-xvf",
tar_path,
"--strip-components=1",
"-C",
BINARY_PATH,
"typst-x86_64-unknown-linux-musl/typst",
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
stdout, stderr = await process.communicate()
if process.returncode != 0 or not bin_path.exists():
logger.warning(
"似乎没有成功解压 Typst 二进制文件,检查一下吧! "
f"stdout={stdout} stderr={stderr}"
)
else:
os.chmod(bin_path, 0o755)
@arti_typst_windows.on_finished
async def _(downloaded: bool):
global bin_path
zip_path = arti_typst_windows.target
bin_path = BINARY_PATH / "typst.exe"
if downloaded or not bin_path.exists():
bin_path.unlink(missing_ok=True)
with zipfile.ZipFile(zip_path, "r") as zf:
target_name = "typst-x86_64-pc-windows-msvc/typst.exe"
if target_name not in zf.namelist():
logger.warning("在 Zip 压缩包里面没有找到目标文件")
return
zf.extract(target_name, BINARY_PATH)
(BINARY_PATH / target_name).rename(bin_path)
(BINARY_PATH / "typst-x86_64-pc-windows-msvc").rmdir()
register_artifacts(arti_typst_linux)
register_artifacts(arti_typst_windows)
TEMPLATE_PATH = Path(__file__).parent / "template.typ" TEMPLATE_PATH = Path(__file__).parent / "template.typ"
TEMPLATE = TEMPLATE_PATH.read_text() TEMPLATE = TEMPLATE_PATH.read_text()
def render_sync(code: str) -> bytes | None: def render_sync(code: str) -> bytes:
global bin_path
if bin_path is None:
return
with TemporaryDirectory(dir=TMP_PATH) as tmpdirname: with TemporaryDirectory(dir=TMP_PATH) as tmpdirname:
temp_dir = Path(tmpdirname).resolve() temp_dir = Path(tmpdirname).resolve()
temp_typ = temp_dir / "page.typ" temp_typ = temp_dir / "page.typ"
temp_typ.write_text(TEMPLATE + "\n\n" + code) temp_typ.write_text(TEMPLATE + "\n\n" + code)
cmd = [ cmd = [
bin_path, "typst",
"compile", "compile",
temp_typ.name, temp_typ.name,
"--format", "--format",
@ -140,7 +61,7 @@ def render_sync(code: str) -> bytes | None:
return result_png.read_bytes() return result_png.read_bytes()
async def render(code: str) -> bytes | None: async def render(code: str) -> bytes:
task = asyncio.to_thread(lambda: render_sync(code)) task = asyncio.to_thread(lambda: render_sync(code))
return await task return await task
@ -149,21 +70,7 @@ cmd = on_command("typst")
@cmd.handle() @cmd.handle()
async def _( async def _(evt: Event, bot: Bot, msg: UniMsg, target: DepLongTaskTarget):
evt: Event,
bot: Bot,
msg: UniMsg,
target: DepLongTaskTarget,
):
global bin_path
# 对于本地机器,一般不会在应用启动时自动下载,这里再保证存在
await ensure_artifact(arti_typst_linux)
await ensure_artifact(arti_typst_windows)
if bin_path is None or not bin_path.exists():
logger.warning("当前环境不存在 Typst但仍然调用了")
return
typst_code = "" typst_code = ""
if isinstance(evt, OB11MessageEvent): if isinstance(evt, OB11MessageEvent):
if evt.reply is not None: if evt.reply is not None:
@ -185,8 +92,6 @@ async def _(
try: try:
res = await render(typst_code) res = await render(typst_code)
if res is None:
raise FileNotFoundError("没有渲染出来内容")
except FileNotFoundError as e: except FileNotFoundError as e:
await target.send_message("渲染出错:内部错误") await target.send_message("渲染出错:内部错误")
raise e from e raise e from e

View File

@ -2,14 +2,7 @@ import nonebot
nonebot.init() nonebot.init()
import asyncio from konabot.plugins.handle_text.__init__ import _get_textfx_user_key
import pytest
from konabot.plugins.handle_text.__init__ import (
_get_textfx_user_key,
_textfx_running_users,
TEXTFX_MAX_RUNTIME_SECONDS,
)
from konabot.plugins.handle_text.base import PipelineRunner
class DummyEvent: class DummyEvent:
@ -38,38 +31,3 @@ def test_textfx_user_key_private():
def test_textfx_user_key_session_fallback(): def test_textfx_user_key_session_fallback():
evt = DummyEvent(session_id='console:alice') evt = DummyEvent(session_id='console:alice')
assert _get_textfx_user_key(evt) == 'session:console:alice' assert _get_textfx_user_key(evt) == 'session:console:alice'
@pytest.mark.asyncio
async def test_textfx_timeout_limit():
"""测试脚本执行超时限制"""
runner = PipelineRunner.get_runner()
# 创建一个会超时的脚本while true 会触发迭代限制,但我们用 sleep 模拟长时间运行)
# 由于实际超时是 60 秒,我们不能真的等那么久,所以这个测试验证超时机制存在
script = "echo start"
parsed = runner.parse_pipeline(script)
assert not isinstance(parsed, str), "脚本解析应该成功"
# 验证 TEXTFX_MAX_RUNTIME_SECONDS 常量存在且合理
assert TEXTFX_MAX_RUNTIME_SECONDS == 60
@pytest.mark.asyncio
async def test_textfx_concurrent_limit():
"""测试同一用户并发执行限制"""
user_key = "test:group:user123"
# 清理可能的残留状态
_textfx_running_users.discard(user_key)
# 模拟第一个脚本正在运行
assert user_key not in _textfx_running_users
_textfx_running_users.add(user_key)
# 验证用户已被标记为运行中
assert user_key in _textfx_running_users
# 清理
_textfx_running_users.discard(user_key)
assert user_key not in _textfx_running_users

View File

@ -205,21 +205,3 @@ async def test_while_body_can_use_if(runner: PipelineRunner):
assert not isinstance(parsed, str) assert not isinstance(parsed, str)
results = await runner.run_pipeline(parsed, None, TextHandlerEnvironment(False)) results = await runner.run_pipeline(parsed, None, TextHandlerEnvironment(False))
assert results[0].code == 1 assert results[0].code == 1
@pytest.mark.asyncio
async def test_echo_empty_string(runner: PipelineRunner):
"""测试 echo 空字符串"""
# 双引号空字符串
parsed = runner.parse_pipeline('echo ""')
assert not isinstance(parsed, str)
results = await runner.run_pipeline(parsed, None, TextHandlerEnvironment(False))
assert results[0].code == 0
assert results[0].ostream == ''
# 单引号空字符串
parsed2 = runner.parse_pipeline("echo ''")
assert not isinstance(parsed2, str)
results2 = await runner.run_pipeline(parsed2, None, TextHandlerEnvironment(False))
assert results2[0].code == 0
assert results2[0].ostream == ''