Files
konabot/konabot/plugins/handle_text/base.py

588 lines
17 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import asyncio
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from enum import Enum
from typing import cast
from loguru import logger
from nonebot.adapters import Event
MAX_WHILE_ITERATIONS = 100
@dataclass
class TextHandlerEnvironment:
is_trusted: bool
event: Event | None = None
buffers: dict[str, str] = field(default_factory=dict)
@dataclass
class TextHandleResult:
code: int
ostream: str | None
attachment: bytes | None = None
class TextHandler(ABC):
name: str = ""
keywords: list[str] = []
@abstractmethod
async def handle(
self, env: TextHandlerEnvironment, istream: str | None, args: list[str]
) -> TextHandleResult: ...
def __repr__(self) -> str:
return f"<{self.__class__.__name__}: {self.name} [{''.join(self.keywords)}]>"
class TextHandlerSync(TextHandler):
@abstractmethod
def handle_sync(
self, env: TextHandlerEnvironment, istream: str | None, args: list[str]
) -> TextHandleResult: ...
async def handle(
self, env: TextHandlerEnvironment, istream: str | None, args: list[str]
) -> TextHandleResult:
def _hs():
return self.handle_sync(env, istream, args)
return await asyncio.to_thread(_hs)
@dataclass
class Redirect:
target: str
append: bool = False
@dataclass
class CommandNode:
name: str
handler: TextHandler
args: list[str]
redirects: list[Redirect] = field(default_factory=list)
@dataclass
class PipelineNode:
commands: list[CommandNode] = field(default_factory=list)
negate: bool = False
@dataclass
class ConditionalPipeline:
op: str | None
pipeline: PipelineNode
@dataclass
class CommandGroup:
chains: list[ConditionalPipeline] = field(default_factory=list)
@dataclass
class IfNode:
condition: CommandGroup
then_body: "Script"
else_body: "Script | None" = None
@dataclass
class WhileNode:
condition: CommandGroup
body: "Script"
@dataclass
class Script:
statements: list[CommandGroup | IfNode | WhileNode] = field(default_factory=list)
class TokenKind(Enum):
WORD = "word"
OP = "op"
@dataclass
class Token:
kind: TokenKind
value: str
class PipelineRunner:
handlers: list[TextHandler]
def __init__(self) -> None:
self.handlers = []
@staticmethod
def get_runner():
if "singleton" not in PipelineRunner.__annotations__:
PipelineRunner.__annotations__["singleton"] = PipelineRunner()
return cast(PipelineRunner, PipelineRunner.__annotations__.get("singleton"))
def register(self, handler: TextHandler):
self.handlers.append(handler)
def _resolve_handler(self, cmd_name: str) -> TextHandler | str:
matched = [
h for h in self.handlers if cmd_name == h.name or cmd_name in h.keywords
]
if not matched:
return f"不存在名为 {cmd_name} 的函数"
if len(matched) > 1:
logger.warning(
f"指令能对应超过一个文本处理器 CMD={cmd_name} handlers={self.handlers}"
)
return matched[0]
def tokenize(self, script: str) -> list[Token] | str:
tokens: list[Token] = []
buf = ""
quote: str | None = None
escape = False
i = 0
operators = {"|", ";", ">", "&&", "||", ">>", "!"}
escape_map = {
"n": "\n",
"r": "\r",
"t": "\t",
"0": "\0",
"a": "\a",
"b": "\b",
"f": "\f",
"v": "\v",
"\\": "\\",
'"': '"',
"'": "'",
}
def flush_word(force: bool = False):
nonlocal buf
if buf or force:
tokens.append(Token(TokenKind.WORD, buf))
buf = ""
while i < len(script):
c = script[i]
if quote is not None:
if escape:
buf += escape_map.get(c, c)
escape = False
elif c == "\\":
escape = True
elif c == quote:
quote = None
flush_word(force=True) # 引号闭合时强制 flush即使是空字符串
else:
buf += c
i += 1
continue
if c in "'\"":
quote = c
i += 1
continue
if c.isspace():
flush_word()
i += 1
continue
two = script[i : i + 2]
if two in operators:
flush_word()
tokens.append(Token(TokenKind.OP, two))
i += 2
continue
if c in {"|", ";", ">", "!"}:
flush_word()
tokens.append(Token(TokenKind.OP, c))
i += 1
continue
if c == "\\":
if i + 1 < len(script):
i += 1
buf += escape_map.get(script[i], script[i])
else:
buf += c
i += 1
continue
buf += c
i += 1
if quote is not None:
return "存在未闭合的引号"
if escape:
buf += "\\"
flush_word()
return tokens
def parse_pipeline(self, script: str) -> Script | str:
tokens = self.tokenize(script)
if isinstance(tokens, str):
return tokens
if not tokens:
return Script()
pos = 0
def peek(offset: int = 0) -> Token | None:
idx = pos + offset
return tokens[idx] if idx < len(tokens) else None
def consume() -> Token:
nonlocal pos
tok = tokens[pos]
pos += 1
return tok
def consume_if_op(value: str) -> bool:
tok = peek()
if tok is not None and tok.kind == TokenKind.OP and tok.value == value:
consume()
return True
return False
def consume_if_word(value: str) -> bool:
tok = peek()
if tok is not None and tok.kind == TokenKind.WORD and tok.value == value:
consume()
return True
return False
def expect_word(msg: str) -> Token | str:
tok = peek()
if tok is None or tok.kind != TokenKind.WORD:
return msg
return consume()
def parse_command() -> CommandNode | str:
first = expect_word("缺少指令名")
if isinstance(first, str):
return first
handler = self._resolve_handler(first.value)
if isinstance(handler, str):
return handler
args: list[str] = []
redirects: list[Redirect] = []
while True:
tok = peek()
if tok is None:
break
if tok.kind == TokenKind.OP and tok.value in {"|", ";", "&&", "||"}:
break
if tok.kind == TokenKind.OP and tok.value in {">", ">>"}:
op_tok = consume()
target = expect_word("重定向操作符后面需要缓存名")
if isinstance(target, str):
return target
redirects.append(
Redirect(target=target.value, append=op_tok.value == ">>")
)
continue
if tok.kind != TokenKind.WORD:
return f"无法解析的 token: {tok.value}"
args.append(consume().value)
return CommandNode(
name=first.value,
handler=handler,
args=args,
redirects=redirects,
)
def parse_pipe() -> PipelineNode | str:
negate = False
while consume_if_op("!"):
negate = not negate
pipeline = PipelineNode(negate=negate)
command = parse_command()
if isinstance(command, str):
return command
pipeline.commands.append(command)
while True:
tok = peek()
if tok is None or tok.kind != TokenKind.OP or tok.value != "|":
break
consume()
next_command = parse_command()
if isinstance(next_command, str):
return next_command
pipeline.commands.append(next_command)
return pipeline
def parse_chain() -> CommandGroup | str:
group = CommandGroup()
first_pipeline = parse_pipe()
if isinstance(first_pipeline, str):
return first_pipeline
group.chains.append(ConditionalPipeline(op=None, pipeline=first_pipeline))
while True:
tok = peek()
if tok is None or tok.kind != TokenKind.OP or tok.value not in {"&&", "||"}:
break
op = consume().value
next_pipeline = parse_pipe()
if isinstance(next_pipeline, str):
return next_pipeline
group.chains.append(ConditionalPipeline(op=op, pipeline=next_pipeline))
return group
def parse_if() -> IfNode | str:
if not consume_if_word("if"):
return "缺少 if"
condition = parse_chain()
if isinstance(condition, str):
return condition
consume_if_op(";")
if not consume_if_word("then"):
return "if 语句缺少 then"
then_body = parse_script(stop_words={"else", "fi"})
if isinstance(then_body, str):
return then_body
else_body: Script | None = None
if consume_if_word("else"):
else_body = parse_script(stop_words={"fi"})
if isinstance(else_body, str):
return else_body
if not consume_if_word("fi"):
return "if 语句缺少 fi"
return IfNode(condition=condition, then_body=then_body, else_body=else_body)
def parse_while() -> WhileNode | str:
if not consume_if_word("while"):
return "缺少 while"
condition = parse_chain()
if isinstance(condition, str):
return condition
consume_if_op(";")
if not consume_if_word("do"):
return "while 语句缺少 do"
body = parse_script(stop_words={"done"})
if isinstance(body, str):
return body
if not consume_if_word("done"):
return "while 语句缺少 done"
return WhileNode(condition=condition, body=body)
def parse_statement() -> CommandGroup | IfNode | WhileNode | str:
tok = peek()
if tok is not None and tok.kind == TokenKind.WORD:
if tok.value == "if":
return parse_if()
if tok.value == "while":
return parse_while()
return parse_chain()
def parse_script(stop_words: set[str] | None = None) -> Script | str:
parsed = Script()
nonlocal pos
while pos < len(tokens):
tok = peek()
if tok is None:
break
if stop_words and tok.kind == TokenKind.WORD and tok.value in stop_words:
break
if tok.kind == TokenKind.OP and tok.value == ";":
consume()
continue
statement = parse_statement()
if isinstance(statement, str):
return statement
parsed.statements.append(statement)
tok = peek()
if tok is not None and tok.kind == TokenKind.OP and tok.value == ";":
consume()
return parsed
parsed = parse_script()
if isinstance(parsed, str):
return parsed
if pos != len(tokens):
tok = tokens[pos]
return f"无法解析的 token: {tok.value}"
return parsed
async def _execute_command(
self,
command: CommandNode,
istream: str | None,
env: TextHandlerEnvironment,
) -> TextHandleResult:
logger.debug(
f"Executing: {command.name} args={command.args} redirects={command.redirects}"
)
result = await command.handler.handle(env, istream, command.args)
if result.code != 0:
return result
if command.redirects:
content = result.ostream or ""
for redirect in command.redirects:
if redirect.append:
old_content = env.buffers.get(redirect.target, "")
env.buffers[redirect.target] = old_content + content
else:
env.buffers[redirect.target] = content
return TextHandleResult(code=0, ostream=None, attachment=result.attachment)
return result
async def _execute_pipeline(
self,
pipeline: PipelineNode,
istream: str | None,
env: TextHandlerEnvironment,
) -> TextHandleResult:
current_stream = istream
last_result = TextHandleResult(code=0, ostream=None)
for command in pipeline.commands:
try:
last_result = await self._execute_command(command, current_stream, env)
except Exception as e:
logger.error(f"Pipeline execution failed at {command.name}")
logger.exception(e)
return TextHandleResult(code=-1, ostream="处理流水线时出现 python 错误")
if last_result.code != 0:
if pipeline.negate:
return TextHandleResult(code=0, ostream=None)
return last_result
current_stream = last_result.ostream
if pipeline.negate:
return TextHandleResult(code=1, ostream=None)
return last_result
async def _execute_group(
self,
group: CommandGroup,
istream: str | None,
env: TextHandlerEnvironment,
) -> TextHandleResult:
last_result = TextHandleResult(code=0, ostream=None)
for chain in group.chains:
should_run = True
if chain.op == "&&":
should_run = last_result.code == 0
elif chain.op == "||":
should_run = last_result.code != 0
if should_run:
last_result = await self._execute_pipeline(chain.pipeline, istream, env)
return last_result
async def _execute_if(
self,
if_node: IfNode,
istream: str | None,
env: TextHandlerEnvironment,
) -> TextHandleResult:
condition_result = await self._execute_group(if_node.condition, istream, env)
if condition_result.code == 0:
results = await self.run_pipeline(if_node.then_body, istream, env)
else:
results = (
await self.run_pipeline(if_node.else_body, istream, env)
if if_node.else_body is not None
else [TextHandleResult(code=0, ostream=None)]
)
return results[-1] if results else TextHandleResult(code=0, ostream=None)
async def _execute_while(
self,
while_node: WhileNode,
istream: str | None,
env: TextHandlerEnvironment,
) -> TextHandleResult:
last_result = TextHandleResult(code=0, ostream=None)
for _ in range(MAX_WHILE_ITERATIONS):
condition_result = await self._execute_group(while_node.condition, istream, env)
if condition_result.code != 0:
return last_result
body_results = await self.run_pipeline(while_node.body, istream, env)
if body_results:
last_result = body_results[-1]
if last_result.code != 0:
return last_result
return TextHandleResult(
code=2,
ostream=f"while 循环超过最大迭代次数限制({MAX_WHILE_ITERATIONS}",
)
async def run_pipeline(
self,
pipeline: Script,
istream: str | None,
env: TextHandlerEnvironment | None = None,
) -> list[TextHandleResult]:
if env is None:
env = TextHandlerEnvironment(is_trusted=False, event=None, buffers={})
results: list[TextHandleResult] = []
for statement in pipeline.statements:
try:
if isinstance(statement, IfNode):
results.append(await self._execute_if(statement, istream, env))
elif isinstance(statement, WhileNode):
results.append(await self._execute_while(statement, istream, env))
else:
results.append(await self._execute_group(statement, istream, env))
except Exception as e:
logger.error(f"Pipeline execution failed: {e}")
logger.exception(e)
results.append(
TextHandleResult(code=-1, ostream="处理流水线时出现 python 错误")
)
return results
return results
def register_text_handlers(*handlers: TextHandler):
for handler in handlers:
PipelineRunner.get_runner().register(handler)