#!/usr/bin/env python3 """ solve.py - Solve questions defined in questions.json using AI API. This script reads questions.json, processes each question with corresponding typst files and attachments, and generates answer Markdown files. """ import argparse import asyncio import json import logging import sys from dataclasses import dataclass from pathlib import Path import aiohttp from common import ( DATA_DIR, console, find_attachments, load_env, load_prompt, setup_logging, ) @dataclass class SolveResult: """Result of solving a question.""" question: str target: str skipped: bool success: bool error: str | None = None def load_questions_json() -> list[dict]: """Load questions from questions.json.""" questions_path = DATA_DIR / "questions.json" if not questions_path.exists(): console.print( f"[yellow]Warning: questions.json not found at {questions_path}[/yellow]" ) return [] with open(questions_path, "r", encoding="utf-8") as f: return json.load(f) def build_prompt(question_data: dict, typ_content: str | None) -> str: """Build the full prompt including typst content and attachments info.""" base_prompt = load_prompt("solve.prompt.txt") if not base_prompt: base_prompt = "请解答这道题。" parts = [base_prompt, "\n\n"] if typ_content: parts.append("## 题目描述 (来自 .typ 文件):\n") parts.append(typ_content) parts.append("\n\n") attachments = question_data.get("attachments", []) if attachments: parts.append("## 附件:\n") for att in attachments: att_path = DATA_DIR / att if att_path.exists(): if att.lower().endswith( (".png", ".jpg", ".jpeg", ".gif", ".bmp", ".webp") ): parts.append(f"- [图片附件: {att}]\n") else: try: content = att_path.read_text(encoding="utf-8") parts.append(f"- {att}:\n```\n{content}\n```\n") except: parts.append(f"- {att} (二进制文件)\n") parts.append("\n") return "".join(parts) def get_image_attachments(question_data: dict) -> list[tuple[str, Path]]: """Return list of (attachment_name, path) for image attachments.""" images = [] for att in question_data.get("attachments", []): att_path = DATA_DIR / att if att_path.exists() and att.lower().endswith( (".png", ".jpg", ".jpeg", ".gif", ".bmp", ".webp") ): images.append((att, att_path)) return images async def call_api_streaming( session: aiohttp.ClientSession, question_name: str, prompt: str, endpoint: str, api_key: str, model: str, logger: logging.Logger, image_contents: list[tuple[str, Path]] | None = None, ) -> str | None: """Call the AI API with streaming to solve the question.""" import base64 headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"} messages = [{"role": "user", "content": [{"type": "text", "text": prompt}]}] if image_contents: for att_name, att_path in image_contents: with open(att_path, "rb") as f: image_data = base64.b64encode(f.read()).decode("utf-8") suffix = att_path.suffix[1:].lower() if suffix in ("jpg", "jpeg"): mime_type = "image/jpeg" else: mime_type = f"image/{suffix}" messages[0]["content"].append( { "type": "image_url", "image_url": {"url": f"data:{mime_type};base64,{image_data}"}, } ) payload = { "model": model, "messages": messages, "max_tokens": 4096, "stream": True, } logger.info(f"[{question_name}] Thinking... (timeout 600s)") try: async with session.post( endpoint, headers=headers, json=payload, timeout=aiohttp.ClientTimeout(total=600), ) as response: response.raise_for_status() full_content = [] async for line in response.content: if not line: continue line = line.decode("utf-8") if line.startswith("data: "): data = line[6:] if data == "[DONE]": continue try: chunk = json.loads(data) logger.debug(f"[{question_name}] Chunk: {chunk}") choices = chunk.get("choices") if not choices: continue delta = choices[0].get("delta", {}) content = delta.get("content") if content: full_content.append(content) except json.JSONDecodeError: continue except IndexError: continue content = "".join(full_content) if not content: logger.warning(f"[{question_name}] Empty response") return None logger.info(f"[{question_name}] Done ({len(content)} chars)") return content except asyncio.TimeoutError: logger.error(f"[{question_name}] Timeout") return None except asyncio.CancelledError: logger.warning(f"[{question_name}] Cancelled") raise except Exception as e: logger.error(f"[{question_name}] Error: {e}") return None async def solve_question( session: aiohttp.ClientSession, question_data: dict, api_config: dict, logger: logging.Logger, dry_run: bool, ) -> SolveResult: """Solve a single question and generate answer markdown.""" question_name = question_data["question"] target_name = f"A_{question_name}.md" target_path = DATA_DIR / target_name if target_path.exists(): logger.info(f"[{question_name}] Skipping: answer already exists") return SolveResult( question=question_name, target=target_name, skipped=True, success=True ) typ_path = DATA_DIR / question_data["target"] typ_content = None if typ_path.exists(): try: typ_content = typ_path.read_text(encoding="utf-8") except: typ_content = None if dry_run: logger.info(f"[{question_name}] Would solve -> {target_name}") return SolveResult( question=question_name, target=target_name, skipped=False, success=True ) prompt = build_prompt(question_data, typ_content) image_contents = get_image_attachments(question_data) content = await call_api_streaming( session, question_name, prompt, str(api_config["endpoint"]), api_config["key"], api_config["model"], logger, image_contents, ) if content is None: return SolveResult( question=question_name, target=target_name, skipped=False, success=False, error="API call failed", ) try: target_path.write_text(content, encoding="utf-8") logger.info(f"[{question_name}] Wrote {target_name} ({len(content)} bytes)") return SolveResult( question=question_name, target=target_name, skipped=False, success=True ) except IOError as e: logger.error(f"[{question_name}] Write failed: {e}") return SolveResult( question=question_name, target=target_name, skipped=False, success=False, error=str(e), ) def parse_args() -> argparse.Namespace: """Parse command line arguments.""" parser = argparse.ArgumentParser( description="Solve questions from questions.json using AI API" ) parser.add_argument( "-q", "--question", action="append", dest="questions", help="Specific question IDs to solve (e.g., P15)", ) parser.add_argument( "--dry-run", action="store_true", help="Do not call AI or write files" ) parser.add_argument("--verbose", action="store_true", help="Enable debug logging") parser.add_argument( "--retry", type=int, default=3, help="Retry attempts (default: 3)" ) parser.add_argument("-n", type=int, default=3, help="Concurrent limit (default: 3)") return parser.parse_args() async def async_main(args: argparse.Namespace, logger: logging.Logger) -> None: """Async main entry point.""" questions = load_questions_json() if not questions: logger.warning("No questions found in questions.json") return if args.questions: questions = [q for q in questions if q["question"] in args.questions] logger.info(f"Processing: {[q['question'] for q in questions]}") else: logger.info( f"Found {len(questions)} questions: {[q['question'] for q in questions]}" ) if not questions: logger.warning("No matching questions found") return api_config = load_env() semaphore = asyncio.Semaphore(args.n) async def limited_solve(session: aiohttp.ClientSession, q: dict) -> SolveResult: async with semaphore: return await solve_question(session, q, api_config, logger, args.dry_run) async with aiohttp.ClientSession() as session: tasks = [asyncio.create_task(limited_solve(session, q)) for q in questions] results = [] try: for coro in asyncio.as_completed(tasks): result = await coro results.append(result) except asyncio.CancelledError: logger.warning("Cancelled! Shutting down...") for task in tasks: task.cancel() await asyncio.gather(*tasks, return_exceptions=True) sys.exit(1) results.sort(key=lambda r: r.question) skipped = sum(1 for r in results if r.skipped) solved = sum(1 for r in results if r.success and not r.skipped) logger.info(f"Complete: {solved}/{len(results)} solved, {skipped} skipped") def main() -> None: """Main entry point.""" args = parse_args() logger = setup_logging("solve", args.verbose) logger.info(f"solve starting (Dry-run: {args.dry_run}, Workers: {args.n})") logger.info(f"Data directory: {DATA_DIR}") try: asyncio.run(async_main(args, logger)) except KeyboardInterrupt: logger.warning("Interrupted by user") sys.exit(1) if __name__ == "__main__": main()