From 61a435965ff2d0fec5180d3bad1a80243eac4556 Mon Sep 17 00:00:00 2001 From: Passthem Date: Wed, 8 Apr 2026 13:55:17 +0800 Subject: [PATCH] =?UTF-8?q?=E5=88=9D=E5=A7=8B=E5=8C=96=E6=A8=A1=E6=9D=BF?= =?UTF-8?q?=E4=BB=93=E5=BA=93?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .gitignore | 19 ++ PRD_SCRIPT.md | 66 +++++++ justfile | 8 + scripts/common.py | 70 ++++++++ scripts/gen_index.py | 156 +++++++++++++++++ scripts/img2typ.prompt.txt | 27 +++ scripts/img2typ.py | 343 +++++++++++++++++++++++++++++++++++++ scripts/solve.prompt.txt | 20 +++ scripts/solve.py | 310 +++++++++++++++++++++++++++++++++ 9 files changed, 1019 insertions(+) create mode 100644 .gitignore create mode 100644 PRD_SCRIPT.md create mode 100644 justfile create mode 100644 scripts/common.py create mode 100644 scripts/gen_index.py create mode 100644 scripts/img2typ.prompt.txt create mode 100644 scripts/img2typ.py create mode 100644 scripts/solve.prompt.txt create mode 100644 scripts/solve.py diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..ec823da --- /dev/null +++ b/.gitignore @@ -0,0 +1,19 @@ +# Environment +.env +.env.example + +# Python +__pycache__/ +*.py[cod] +*$py.class +.ruff_cache/ + +# Data (homework content - use as template) +data/ + +# Generated outputs +index.pdf +index.typ + +# Secrets +.secret diff --git a/PRD_SCRIPT.md b/PRD_SCRIPT.md new file mode 100644 index 0000000..c232797 --- /dev/null +++ b/PRD_SCRIPT.md @@ -0,0 +1,66 @@ +请在 `./scripts/` 文件夹下写一个 `img2typ.py` 脚本,达成下面的要求。 + +## 流程 + +遍历当前所在目录下的文件(不深入到子目录)。 + +找到所有的 `\S\s?[\d\.]+` 的文件名的图片。且最开头的 `\S` 不能是 `答` 或者 `A` 或者 `a` + +- 这里只是大概的描述,你可能需要调整正则,或者不使用正则 +- 例如,`问 13.png` `Q3.1.jpg` `R1.1.5.PNG` 等等的文件名都是合法的 + +检测有没有对应的 `.typ` 文件。即,看有没有后缀名改成 `typ` 的文件。 + +如果没有,则用 `./scripts/img2typ.prompt.txt` 为提示词,调用一个支持图片的 OpenAI 兼容 API。这个提示词文件在未来会改动,请在程序执行时动态读取它。API 的 API 端点和 API Key 应该使用一个 `.env` 文件定义,这个文件将会置放在 `./scripts/` 文件夹下。 + +对输出结果,如果有 Markdown 代码块包裹,则去除(可能需要你写正则或者其他任何机制) + +接着,保存为对应的 `.typ` 文件。 + +最后,整理所有符合条件的 `.typ` 文件(包括生成失败的),整理成一个列表后,在当前根目录写入(可覆盖)`questions.json`,是一个 JSON 列表,列表的每一个项目都是一个 JSON Object。形如: + +```json +[ + { + "question": "Q3.1", + "format": "typst", + "target": "Q3.1.typ" + }, + { + "question": "R1.1.5", + "format": "typst", + "target": "R1.1.5.typ" + } +] +``` + +## 细节 + +调用 API 时,如果失败,则重试。最多重试 3 次(或者可定义),真的失败了不能 panic,只是在 stderr 中汇报。 + +## 规范 + +代码应该是人类可维护的。你可以(而且最好)使用 python 的比较常用的现代语法,例如直接的类型注解(使用 `list` 而不是 `typing.List`),以及尽量使用 `pathlib.Path` 而不是字符串。 + +这个应用是面向过程的。你应该首先对流程做拆解,然后以子函数的形式声明整个函数。每个函数都应该有 docstring。对于 AI 等一些比较重要的东西,你再使用面向对象的方式去应对。 + +你可以给我要求,让我依赖一些更多的外置库。当前环境有 `requests`、`dotenv`、`rich`、`tqdm` 可用。 + +程序往 `stderr` 的输出应该是可审计的。应该汇报: + +- 将要处理的文件清单。 +- 调用了什么 API,调用情况如何,输入输出多少 tokens。 +- 写入了什么文件。 + +程序往 `stderr` 的输出应该是良好可视化的,就是说,有颜色区分,但是不要加 emoji。或者说,你应该使用自带的日志库 + 一定的格式化。 + +## 额外功能添补 + +这个脚本应该可以作为 cli 调用,支持以下参数: + +- `--file` 或 `-f` 后接文件名,可重复这个参数。当存在这个参数,则解析对应的图片,而不是扫描当前目录 +- `--dry-run` 不调用 AI,也不写入文件 +- `--verbose` 在基础上反馈 AI 调用的输出,相当于 log level 是 `DEBUG` +- `--retry` 接数字,重试次数,默认为 3 +- `-n` 接数字,并发数量,默认为 3 + diff --git a/justfile b/justfile new file mode 100644 index 0000000..ed54017 --- /dev/null +++ b/justfile @@ -0,0 +1,8 @@ +img2typ: + python ./scripts/img2typ.py + +solve: img2typ + python ./scripts/solve.py + +generate: solve + python ./scripts/gen_index.py diff --git a/scripts/common.py b/scripts/common.py new file mode 100644 index 0000000..e8fe966 --- /dev/null +++ b/scripts/common.py @@ -0,0 +1,70 @@ +"""Common utilities shared between img2typ and solve scripts.""" + +import logging +import os +from pathlib import Path + +from dotenv import load_dotenv +from rich.console import Console +from rich.logging import RichHandler +from rich.theme import Theme + +console = Console( + theme=Theme({"info": "cyan", "warning": "yellow", "error": "bold red"}) +) + +SCRIPT_DIR = Path(__file__).parent +DATA_DIR = SCRIPT_DIR.parent / "data" +ENV_FILE = SCRIPT_DIR / ".env" + + +def setup_logging(name: str, verbose: bool) -> logging.Logger: + """Setup logging with rich handler.""" + level = logging.DEBUG if verbose else logging.INFO + logging.basicConfig( + level=level, + format="%(message)s", + datefmt="[%X]", + handlers=[RichHandler(console=console, rich_tracebacks=True)], + ) + return logging.getLogger(name) + + +def load_env() -> dict[str, str]: + """Load environment variables from .env file.""" + if ENV_FILE.exists(): + load_dotenv(ENV_FILE) + else: + console.print(f"[yellow]Warning: .env file not found at {ENV_FILE}[/yellow]") + + api_endpoint = os.environ.get("IMG2TYP_API_ENDPOINT", "") + api_key = os.environ.get("IMG2TYP_API_KEY", "") + api_model = os.environ.get("IMG2TYP_MODEL", "qwen-vl-plus") + + if not api_endpoint: + console.print("[yellow]Warning: IMG2TYP_API_ENDPOINT not set[/yellow]") + if not api_key: + console.print("[yellow]Warning: IMG2TYP_API_KEY not set[/yellow]") + + return {"endpoint": api_endpoint, "key": api_key, "model": api_model} + + +def load_prompt(filename: str) -> str: + """Load a prompt template from file.""" + prompt_path = SCRIPT_DIR / filename + if not prompt_path.exists(): + console.print( + f"[yellow]Warning: Prompt file not found at {prompt_path}[/yellow]" + ) + return "" + return prompt_path.read_text(encoding="utf-8") + + +def find_attachments(question: str) -> list[str]: + """Find all attachment files for a given question.""" + attachments = [] + question_prefix = question + "_" + for file_path in DATA_DIR.iterdir(): + if file_path.is_file() and file_path.name.startswith(question_prefix): + attachments.append(file_path.name) + return sorted(attachments) diff --git a/scripts/gen_index.py b/scripts/gen_index.py new file mode 100644 index 0000000..34d6c9a --- /dev/null +++ b/scripts/gen_index.py @@ -0,0 +1,156 @@ +#!/usr/bin/env python3 +""" +gen_index.py - Generate index.typ from questions.json. + +This script generates an index.typ file that includes all questions +and their answers in the specified format. +""" + +import argparse +import sys +from pathlib import Path + +from rich.console import Console + +from common import DATA_DIR, setup_logging + +console = Console() + + +def load_questions_json() -> list[dict]: + """Load questions from questions.json.""" + questions_path = DATA_DIR / "questions.json" + if not questions_path.exists(): + console.print(f"[red]Error: questions.json not found at {questions_path}[/red]") + sys.exit(1) + import json + + with open(questions_path, "r", encoding="utf-8") as f: + return json.load(f) + + +def read_typ_content(target: str) -> str | None: + """Read typ file content.""" + typ_path = DATA_DIR / target + if not typ_path.exists(): + console.print(f"[yellow]Warning: {target} not found[/yellow]") + return None + try: + return typ_path.read_text(encoding="utf-8") + except Exception as e: + console.print(f"[yellow]Warning: Failed to read {target}: {e}[/yellow]") + return None + + +def indent_text(text: str, indent: int = 2) -> str: + """Indent text by specified spaces.""" + lines = text.strip().split("\n") + spaces = " " * indent + return "\n".join(spaces + line if line.strip() else "" for line in lines) + + +def generate_index(questions: list[dict], dry_run: bool, logger) -> str: + """Generate index.typ content.""" + lines = [ + '#import "@local/phomework:0.1.0": homework, question, answer, shadow', + "", + ] + + enable_shadow = ( + "true" + if any(DATA_DIR / f"A_{q['question']}.md" for q in questions) + else "false" + ) + + lines.append( + f'#homework(title: "计算机网络第三次作业", secret: read(".secret"), enable_shadow: {enable_shadow})[' + ) + + def sort_key(q: dict) -> tuple: + name = q["question"] + prefix = name[0] + try: + num = int(name[1:]) + except ValueError: + num = float("inf") + return (0 if prefix == "R" else 1, prefix, num) + + sorted_questions = sorted(questions, key=sort_key) + + for q in sorted_questions: + question_name = q["question"] + typ_target = q["target"] + + lines.append(f' #question(title: "{question_name}")[') + content = read_typ_content(typ_target) + if content: + lines.append(indent_text(content, 4)) + else: + lines.append(" [题目内容加载失败]") + lines.append(" ]") + lines.append("") + lines.append(" #answer[") + + answer_file = DATA_DIR / f"A_{question_name}.md" + if answer_file.exists(): + lines.append(" 请填写答案。") + lines.append("") + lines.append(f' #shadow(read("./data/A_{question_name}.md"))') + else: + lines.append(" [答案文件不存在]") + + lines.append(" ]") + + lines.append("]") + return "\n".join(lines) + "\n" + + +def parse_args() -> argparse.Namespace: + """Parse command line arguments.""" + parser = argparse.ArgumentParser( + description="Generate index.typ from questions.json" + ) + parser.add_argument("--dry-run", action="store_true", help="Do not write files") + parser.add_argument( + "--force", action="store_true", help="Force overwrite without warning" + ) + parser.add_argument("--verbose", action="store_true", help="Enable debug logging") + return parser.parse_args() + + +def main() -> None: + """Main entry point.""" + args = parse_args() + logger = setup_logging("gen_index", args.verbose) + + questions = load_questions_json() + logger.info(f"Loaded {len(questions)} questions") + + if not questions: + console.print("[yellow]No questions found in questions.json[/yellow]") + sys.exit(1) + + content = generate_index(questions, args.dry_run, logger) + + output_path = Path("index.typ") + + if output_path.exists() and not args.force and not args.dry_run: + console.print( + f"[yellow]Warning: {output_path} already exists and will be overwritten![/yellow]" + ) + response = input("Continue? [y/N]: ") + if response.lower() != "y": + console.print("[yellow]Aborted.[/yellow]") + sys.exit(0) + + if args.dry_run: + logger.info(f"[DRY-RUN] Would write {output_path}") + logger.debug(f"Content preview:\n{content[:500]}...") + else: + output_path.write_text(content, encoding="utf-8") + logger.info(f"Wrote {output_path}") + console.print(f"[green]Successfully wrote {output_path}[/green]") + + +if __name__ == "__main__": + main() diff --git a/scripts/img2typ.prompt.txt b/scripts/img2typ.prompt.txt new file mode 100644 index 0000000..fea62e1 --- /dev/null +++ b/scripts/img2typ.prompt.txt @@ -0,0 +1,27 @@ +请你将这道题的题目内容提取出来,并转写成 Typst 格式,不包含题号。 + +题目内容请用代码块包裹起来。如果有多个小问,每个小问之间应该有两个换行符。如果题目有图片,则用 `[图片1]` 的格式表示占位,但不需要描述图片。 + +你的回答应该形如这样: + +```typst +在简化的路径损耗模型下,考虑一对距离为 $r$ 的收发信机,接收端的噪声功率为 $-150"dBm"$,…… + +考虑如下公式: + +$ + y <= k x + b +$ + +…… + +// snip + +[图片1] + +1. 请问…… + +2. 请问…… + +// snip +``` diff --git a/scripts/img2typ.py b/scripts/img2typ.py new file mode 100644 index 0000000..32b533c --- /dev/null +++ b/scripts/img2typ.py @@ -0,0 +1,343 @@ +#!/usr/bin/env python3 +""" +img2typ.py - Convert image files to typst format using AI API. + +This script scans the data directory for image files matching a pattern, +converts them to typst format using an OpenAI-compatible API, and generates +a questions.json manifest. +""" + +import argparse +import asyncio +import json +import logging +import re +import sys +from dataclasses import dataclass +from pathlib import Path + +import aiohttp +from common import DATA_DIR, console, load_env, load_prompt, setup_logging + +IMAGE_PATTERN = re.compile(r"^(\S)\s?([\d.]+)$") +EXCLUDED_PREFIXES = {"答", "A", "a"} +IMAGE_EXTENSIONS = { + ".png", + ".jpg", + ".jpeg", + ".gif", + ".bmp", + ".webp", + ".PNG", + ".JPG", + ".JPEG", + ".GIF", + ".BMP", + ".WEBP", +} + + +@dataclass +class ConversionResult: + """Result of an image to typst conversion.""" + + question: str + target: str + skipped: bool + success: bool + error: str | None = None + + +def find_images() -> list[Path]: + """Find all image files in data directory matching the pattern.""" + images = [] + for file_path in DATA_DIR.iterdir(): + if file_path.is_file() and file_path.suffix in IMAGE_EXTENSIONS: + stem = file_path.stem + match = IMAGE_PATTERN.match(stem) + if match and match.group(1) not in EXCLUDED_PREFIXES: + images.append(file_path) + return images + + +def check_typ_exists(image_path: Path) -> bool: + """Check if corresponding .typ file exists.""" + return image_path.with_suffix(".typ").exists() + + +def parse_markdown_blocks(text: str) -> str: + """Remove markdown code blocks from text.""" + block_pattern = re.compile(r"```(?:typst)?\s*\n?(.*?)\n?```", re.DOTALL) + matches = list(block_pattern.finditer(text)) + if matches: + return matches[0].group(1).strip() + return text.strip() + + +async def call_api( + session: aiohttp.ClientSession, + image_path: Path, + prompt: str, + endpoint: str, + api_key: str, + model: str, + logger: logging.Logger, +) -> str | None: + """Call the AI API to convert image to typst format.""" + import base64 + + headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"} + + with open(image_path, "rb") as f: + image_data = base64.b64encode(f.read()).decode("utf-8") + + payload = { + "model": model, + "messages": [ + { + "role": "user", + "content": [ + {"type": "text", "text": prompt}, + { + "type": "image_url", + "image_url": { + "url": f"data:image/{image_path.suffix[1:]};base64,{image_data}" + }, + }, + ], + } + ], + "max_tokens": 4096, + } + + logger.info(f"[{image_path.stem}] Converting... (timeout 300s)") + try: + async with session.post( + endpoint, + headers=headers, + json=payload, + timeout=aiohttp.ClientTimeout(total=300), + ) as response: + response.raise_for_status() + result = await response.json() + + if "choices" not in result or len(result["choices"]) == 0: + logger.error(f"Invalid API response") + return None + + content = result["choices"][0]["message"]["content"] + usage = result.get("usage", {}) + input_tokens = usage.get("prompt_tokens", 0) + output_tokens = usage.get("completion_tokens", 0) + + logger.info( + f"[{image_path.stem}] Done: {input_tokens} in, {output_tokens} out" + ) + return content + + except asyncio.TimeoutError: + logger.error(f"[{image_path.stem}] Timeout") + return None + except asyncio.CancelledError: + logger.warning(f"[{image_path.stem}] Cancelled") + raise + except Exception as e: + logger.error(f"[{image_path.stem}] Error: {e}") + return None + + +async def convert_image( + session: aiohttp.ClientSession, + image_path: Path, + prompt: str, + api_config: dict, + logger: logging.Logger, + dry_run: bool, +) -> ConversionResult: + """Convert a single image to typst format.""" + stem = image_path.stem + match = IMAGE_PATTERN.match(stem) + question_name = match.group(1) + match.group(2) if match else stem + typ_path = image_path.with_suffix(".typ") + + if typ_path.exists(): + logger.info(f"[{question_name}] Skipping: .typ already exists") + return ConversionResult( + question=question_name, target=typ_path.name, skipped=True, success=True + ) + + if dry_run: + logger.info(f"[{question_name}] Would convert -> {typ_path.name}") + return ConversionResult( + question=question_name, target=typ_path.name, skipped=False, success=True + ) + + content = await call_api( + session, + image_path, + prompt, + str(api_config["endpoint"]), + api_config["key"], + api_config["model"], + logger, + ) + + if content is None: + return ConversionResult( + question=question_name, + target=typ_path.name, + skipped=False, + success=False, + error="API call failed", + ) + + typst_code = parse_markdown_blocks(content) + + try: + typ_path.write_text(typst_code, encoding="utf-8") + logger.info( + f"[{question_name}] Wrote {typ_path.name} ({len(typst_code)} bytes)" + ) + return ConversionResult( + question=question_name, target=typ_path.name, skipped=False, success=True + ) + except IOError as e: + logger.error(f"[{question_name}] Write failed: {e}") + return ConversionResult( + question=question_name, + target=typ_path.name, + skipped=False, + success=False, + error=str(e), + ) + + +def generate_questions_json( + results: list[ConversionResult], + logger: logging.Logger, + dry_run: bool, +) -> None: + """Generate questions.json from conversion results.""" + from common import find_attachments + + questions = [] + for r in results: + attachments = find_attachments(r.question) + questions.append( + { + "question": r.question, + "format": "typst", + "target": r.target, + "attachments": attachments, + } + ) + + output_path = DATA_DIR / "questions.json" + if dry_run: + logger.info( + f"[DRY-RUN] Would write questions.json with {len(questions)} entries" + ) + logger.debug(f"Content: {json.dumps(questions, indent=4, ensure_ascii=False)}") + else: + output_path.write_text( + json.dumps(questions, indent=4, ensure_ascii=False), encoding="utf-8" + ) + logger.info(f"Wrote {output_path.name} ({len(questions)} entries)") + + +def parse_args() -> argparse.Namespace: + """Parse command line arguments.""" + parser = argparse.ArgumentParser( + description="Convert image files to typst format using AI API" + ) + parser.add_argument( + "-f", + "--file", + action="append", + dest="files", + help="Specific image files to process", + ) + parser.add_argument( + "--dry-run", action="store_true", help="Do not call AI or write files" + ) + parser.add_argument("--verbose", action="store_true", help="Enable debug logging") + parser.add_argument( + "--retry", type=int, default=3, help="Retry attempts (default: 3)" + ) + parser.add_argument("-n", type=int, default=3, help="Concurrent limit (default: 3)") + return parser.parse_args() + + +async def async_main(args: argparse.Namespace, logger: logging.Logger) -> None: + """Async main entry point.""" + if args.files: + image_paths = [] + for f in args.files: + p = Path(f) + if not p.is_absolute(): + p = DATA_DIR / f + image_paths.append(p) + else: + image_paths = find_images() + + logger.info(f"Found {len(image_paths)} images to process") + + if not image_paths: + logger.warning("No images found to process") + return + + api_config = load_env() + prompt = load_prompt("img2typ.prompt.txt") + + semaphore = asyncio.Semaphore(args.n) + + async def limited_convert( + session: aiohttp.ClientSession, img_path: Path + ) -> ConversionResult: + async with semaphore: + return await convert_image( + session, img_path, prompt, api_config, logger, args.dry_run + ) + + async with aiohttp.ClientSession() as session: + tasks = [ + asyncio.create_task(limited_convert(session, img)) for img in image_paths + ] + + results = [] + try: + for coro in asyncio.as_completed(tasks): + result = await coro + results.append(result) + except asyncio.CancelledError: + logger.warning("Cancelled! Shutting down...") + for task in tasks: + task.cancel() + await asyncio.gather(*tasks, return_exceptions=True) + sys.exit(1) + + results.sort(key=lambda r: r.question) + generate_questions_json(results, logger, args.dry_run) + + skipped = sum(1 for r in results if r.skipped) + solved = sum(1 for r in results if r.success and not r.skipped) + logger.info(f"Complete: {solved}/{len(results)} converted, {skipped} skipped") + + +def main() -> None: + """Main entry point.""" + args = parse_args() + logger = setup_logging("img2typ", args.verbose) + + logger.info(f"img2typ starting (Dry-run: {args.dry_run}, Workers: {args.n})") + logger.info(f"Data directory: {DATA_DIR}") + + try: + asyncio.run(async_main(args, logger)) + except KeyboardInterrupt: + logger.warning("Interrupted by user") + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/scripts/solve.prompt.txt b/scripts/solve.prompt.txt new file mode 100644 index 0000000..2a7f3ff --- /dev/null +++ b/scripts/solve.prompt.txt @@ -0,0 +1,20 @@ +请根据提供的 typst 内容(包括可能的图片附件),给出一个完整的解答,包括: + +1. **解答文本**:问题的直接答案,格式清晰。 + +2. **答题思路**:逐步解释你如何得出答案,展示你的计算过程和逻辑推理。 + +3. **相关知识点**:与这道题相关的关键概念、公式或原理。 + +请使用以下 Markdown 格式输出(不要使用代码块包裹主要内容): + +# 解答 +[你的直接答案] + +# 答题思路 +[你的逐步推理过程] + +# 相关知识点 +[关键概念和公式] + +如果题目包含图片附件,请仔细分析后再给出解答。 diff --git a/scripts/solve.py b/scripts/solve.py new file mode 100644 index 0000000..d546b25 --- /dev/null +++ b/scripts/solve.py @@ -0,0 +1,310 @@ +#!/usr/bin/env python3 +""" +solve.py - Solve questions defined in questions.json using AI API. + +This script reads questions.json, processes each question with corresponding +typst files and attachments, and generates answer Markdown files. +""" + +import argparse +import asyncio +import json +import logging +import sys +from dataclasses import dataclass +from pathlib import Path + +import aiohttp +from common import ( + DATA_DIR, + console, + find_attachments, + load_env, + load_prompt, + setup_logging, +) + + +@dataclass +class SolveResult: + """Result of solving a question.""" + + question: str + target: str + skipped: bool + success: bool + error: str | None = None + + +def load_questions_json() -> list[dict]: + """Load questions from questions.json.""" + questions_path = DATA_DIR / "questions.json" + if not questions_path.exists(): + console.print( + f"[yellow]Warning: questions.json not found at {questions_path}[/yellow]" + ) + return [] + with open(questions_path, "r", encoding="utf-8") as f: + return json.load(f) + + +def build_prompt(question_data: dict, typ_content: str | None) -> str: + """Build the full prompt including typst content and attachments info.""" + base_prompt = load_prompt("solve.prompt.txt") + if not base_prompt: + base_prompt = "请解答这道题。" + + parts = [base_prompt, "\n\n"] + + if typ_content: + parts.append("## 题目描述 (来自 .typ 文件):\n") + parts.append(typ_content) + parts.append("\n\n") + + attachments = question_data.get("attachments", []) + if attachments: + parts.append("## 附件:\n") + for att in attachments: + att_path = DATA_DIR / att + if att_path.exists(): + if att.lower().endswith( + (".png", ".jpg", ".jpeg", ".gif", ".bmp", ".webp") + ): + parts.append(f"- [图片附件: {att}]\n") + else: + try: + content = att_path.read_text(encoding="utf-8") + parts.append(f"- {att}:\n```\n{content}\n```\n") + except: + parts.append(f"- {att} (二进制文件)\n") + parts.append("\n") + + return "".join(parts) + + +async def call_api_streaming( + session: aiohttp.ClientSession, + question_name: str, + prompt: str, + endpoint: str, + api_key: str, + model: str, + logger: logging.Logger, +) -> str | None: + """Call the AI API with streaming to solve the question.""" + headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"} + + payload = { + "model": model, + "messages": [{"role": "user", "content": [{"type": "text", "text": prompt}]}], + "max_tokens": 4096, + "stream": True, + } + + logger.info(f"[{question_name}] Thinking... (timeout 600s)") + try: + async with session.post( + endpoint, + headers=headers, + json=payload, + timeout=aiohttp.ClientTimeout(total=600), + ) as response: + response.raise_for_status() + + full_content = [] + async for line in response.content: + if not line: + continue + line = line.decode("utf-8") + if line.startswith("data: "): + data = line[6:] + if data == "[DONE]": + continue + try: + chunk = json.loads(data) + logger.debug(f"[{question_name}] Chunk: {chunk}") + delta = chunk.get("choices", [{}])[0].get("delta", {}) + content = delta.get("content") + if content: + full_content.append(content) + except json.JSONDecodeError: + continue + + content = "".join(full_content) + if not content: + logger.warning(f"[{question_name}] Empty response") + return None + + logger.info(f"[{question_name}] Done ({len(content)} chars)") + return content + + except asyncio.TimeoutError: + logger.error(f"[{question_name}] Timeout") + return None + except asyncio.CancelledError: + logger.warning(f"[{question_name}] Cancelled") + raise + except Exception as e: + logger.error(f"[{question_name}] Error: {e}") + return None + + +async def solve_question( + session: aiohttp.ClientSession, + question_data: dict, + api_config: dict, + logger: logging.Logger, + dry_run: bool, +) -> SolveResult: + """Solve a single question and generate answer markdown.""" + question_name = question_data["question"] + target_name = f"A_{question_name}.md" + target_path = DATA_DIR / target_name + + if target_path.exists(): + logger.info(f"[{question_name}] Skipping: answer already exists") + return SolveResult( + question=question_name, target=target_name, skipped=True, success=True + ) + + typ_path = DATA_DIR / question_data["target"] + typ_content = None + if typ_path.exists(): + try: + typ_content = typ_path.read_text(encoding="utf-8") + except: + typ_content = None + + if dry_run: + logger.info(f"[{question_name}] Would solve -> {target_name}") + return SolveResult( + question=question_name, target=target_name, skipped=False, success=True + ) + + prompt = build_prompt(question_data, typ_content) + + content = await call_api_streaming( + session, + question_name, + prompt, + str(api_config["endpoint"]), + api_config["key"], + api_config["model"], + logger, + ) + + if content is None: + return SolveResult( + question=question_name, + target=target_name, + skipped=False, + success=False, + error="API call failed", + ) + + try: + target_path.write_text(content, encoding="utf-8") + logger.info(f"[{question_name}] Wrote {target_name} ({len(content)} bytes)") + return SolveResult( + question=question_name, target=target_name, skipped=False, success=True + ) + except IOError as e: + logger.error(f"[{question_name}] Write failed: {e}") + return SolveResult( + question=question_name, + target=target_name, + skipped=False, + success=False, + error=str(e), + ) + + +def parse_args() -> argparse.Namespace: + """Parse command line arguments.""" + parser = argparse.ArgumentParser( + description="Solve questions from questions.json using AI API" + ) + parser.add_argument( + "-q", + "--question", + action="append", + dest="questions", + help="Specific question IDs to solve (e.g., P15)", + ) + parser.add_argument( + "--dry-run", action="store_true", help="Do not call AI or write files" + ) + parser.add_argument("--verbose", action="store_true", help="Enable debug logging") + parser.add_argument( + "--retry", type=int, default=3, help="Retry attempts (default: 3)" + ) + parser.add_argument("-n", type=int, default=3, help="Concurrent limit (default: 3)") + return parser.parse_args() + + +async def async_main(args: argparse.Namespace, logger: logging.Logger) -> None: + """Async main entry point.""" + questions = load_questions_json() + if not questions: + logger.warning("No questions found in questions.json") + return + + if args.questions: + questions = [q for q in questions if q["question"] in args.questions] + logger.info(f"Processing: {[q['question'] for q in questions]}") + else: + logger.info( + f"Found {len(questions)} questions: {[q['question'] for q in questions]}" + ) + + if not questions: + logger.warning("No matching questions found") + return + + api_config = load_env() + + semaphore = asyncio.Semaphore(args.n) + + async def limited_solve(session: aiohttp.ClientSession, q: dict) -> SolveResult: + async with semaphore: + return await solve_question(session, q, api_config, logger, args.dry_run) + + async with aiohttp.ClientSession() as session: + tasks = [asyncio.create_task(limited_solve(session, q)) for q in questions] + + results = [] + try: + for coro in asyncio.as_completed(tasks): + result = await coro + results.append(result) + except asyncio.CancelledError: + logger.warning("Cancelled! Shutting down...") + for task in tasks: + task.cancel() + await asyncio.gather(*tasks, return_exceptions=True) + sys.exit(1) + + results.sort(key=lambda r: r.question) + + skipped = sum(1 for r in results if r.skipped) + solved = sum(1 for r in results if r.success and not r.skipped) + logger.info(f"Complete: {solved}/{len(results)} solved, {skipped} skipped") + + +def main() -> None: + """Main entry point.""" + args = parse_args() + logger = setup_logging("solve", args.verbose) + + logger.info(f"solve starting (Dry-run: {args.dry_run}, Workers: {args.n})") + logger.info(f"Data directory: {DATA_DIR}") + + try: + asyncio.run(async_main(args, logger)) + except KeyboardInterrupt: + logger.warning("Interrupted by user") + sys.exit(1) + + +if __name__ == "__main__": + main()