Files
homework-template/scripts/solve.py
2026-04-08 13:55:17 +08:00

311 lines
9.3 KiB
Python

#!/usr/bin/env python3
"""
solve.py - Solve questions defined in questions.json using AI API.
This script reads questions.json, processes each question with corresponding
typst files and attachments, and generates answer Markdown files.
"""
import argparse
import asyncio
import json
import logging
import sys
from dataclasses import dataclass
from pathlib import Path
import aiohttp
from common import (
DATA_DIR,
console,
find_attachments,
load_env,
load_prompt,
setup_logging,
)
@dataclass
class SolveResult:
"""Result of solving a question."""
question: str
target: str
skipped: bool
success: bool
error: str | None = None
def load_questions_json() -> list[dict]:
"""Load questions from questions.json."""
questions_path = DATA_DIR / "questions.json"
if not questions_path.exists():
console.print(
f"[yellow]Warning: questions.json not found at {questions_path}[/yellow]"
)
return []
with open(questions_path, "r", encoding="utf-8") as f:
return json.load(f)
def build_prompt(question_data: dict, typ_content: str | None) -> str:
"""Build the full prompt including typst content and attachments info."""
base_prompt = load_prompt("solve.prompt.txt")
if not base_prompt:
base_prompt = "请解答这道题。"
parts = [base_prompt, "\n\n"]
if typ_content:
parts.append("## 题目描述 (来自 .typ 文件):\n")
parts.append(typ_content)
parts.append("\n\n")
attachments = question_data.get("attachments", [])
if attachments:
parts.append("## 附件:\n")
for att in attachments:
att_path = DATA_DIR / att
if att_path.exists():
if att.lower().endswith(
(".png", ".jpg", ".jpeg", ".gif", ".bmp", ".webp")
):
parts.append(f"- [图片附件: {att}]\n")
else:
try:
content = att_path.read_text(encoding="utf-8")
parts.append(f"- {att}:\n```\n{content}\n```\n")
except:
parts.append(f"- {att} (二进制文件)\n")
parts.append("\n")
return "".join(parts)
async def call_api_streaming(
session: aiohttp.ClientSession,
question_name: str,
prompt: str,
endpoint: str,
api_key: str,
model: str,
logger: logging.Logger,
) -> str | None:
"""Call the AI API with streaming to solve the question."""
headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"}
payload = {
"model": model,
"messages": [{"role": "user", "content": [{"type": "text", "text": prompt}]}],
"max_tokens": 4096,
"stream": True,
}
logger.info(f"[{question_name}] Thinking... (timeout 600s)")
try:
async with session.post(
endpoint,
headers=headers,
json=payload,
timeout=aiohttp.ClientTimeout(total=600),
) as response:
response.raise_for_status()
full_content = []
async for line in response.content:
if not line:
continue
line = line.decode("utf-8")
if line.startswith("data: "):
data = line[6:]
if data == "[DONE]":
continue
try:
chunk = json.loads(data)
logger.debug(f"[{question_name}] Chunk: {chunk}")
delta = chunk.get("choices", [{}])[0].get("delta", {})
content = delta.get("content")
if content:
full_content.append(content)
except json.JSONDecodeError:
continue
content = "".join(full_content)
if not content:
logger.warning(f"[{question_name}] Empty response")
return None
logger.info(f"[{question_name}] Done ({len(content)} chars)")
return content
except asyncio.TimeoutError:
logger.error(f"[{question_name}] Timeout")
return None
except asyncio.CancelledError:
logger.warning(f"[{question_name}] Cancelled")
raise
except Exception as e:
logger.error(f"[{question_name}] Error: {e}")
return None
async def solve_question(
session: aiohttp.ClientSession,
question_data: dict,
api_config: dict,
logger: logging.Logger,
dry_run: bool,
) -> SolveResult:
"""Solve a single question and generate answer markdown."""
question_name = question_data["question"]
target_name = f"A_{question_name}.md"
target_path = DATA_DIR / target_name
if target_path.exists():
logger.info(f"[{question_name}] Skipping: answer already exists")
return SolveResult(
question=question_name, target=target_name, skipped=True, success=True
)
typ_path = DATA_DIR / question_data["target"]
typ_content = None
if typ_path.exists():
try:
typ_content = typ_path.read_text(encoding="utf-8")
except:
typ_content = None
if dry_run:
logger.info(f"[{question_name}] Would solve -> {target_name}")
return SolveResult(
question=question_name, target=target_name, skipped=False, success=True
)
prompt = build_prompt(question_data, typ_content)
content = await call_api_streaming(
session,
question_name,
prompt,
str(api_config["endpoint"]),
api_config["key"],
api_config["model"],
logger,
)
if content is None:
return SolveResult(
question=question_name,
target=target_name,
skipped=False,
success=False,
error="API call failed",
)
try:
target_path.write_text(content, encoding="utf-8")
logger.info(f"[{question_name}] Wrote {target_name} ({len(content)} bytes)")
return SolveResult(
question=question_name, target=target_name, skipped=False, success=True
)
except IOError as e:
logger.error(f"[{question_name}] Write failed: {e}")
return SolveResult(
question=question_name,
target=target_name,
skipped=False,
success=False,
error=str(e),
)
def parse_args() -> argparse.Namespace:
"""Parse command line arguments."""
parser = argparse.ArgumentParser(
description="Solve questions from questions.json using AI API"
)
parser.add_argument(
"-q",
"--question",
action="append",
dest="questions",
help="Specific question IDs to solve (e.g., P15)",
)
parser.add_argument(
"--dry-run", action="store_true", help="Do not call AI or write files"
)
parser.add_argument("--verbose", action="store_true", help="Enable debug logging")
parser.add_argument(
"--retry", type=int, default=3, help="Retry attempts (default: 3)"
)
parser.add_argument("-n", type=int, default=3, help="Concurrent limit (default: 3)")
return parser.parse_args()
async def async_main(args: argparse.Namespace, logger: logging.Logger) -> None:
"""Async main entry point."""
questions = load_questions_json()
if not questions:
logger.warning("No questions found in questions.json")
return
if args.questions:
questions = [q for q in questions if q["question"] in args.questions]
logger.info(f"Processing: {[q['question'] for q in questions]}")
else:
logger.info(
f"Found {len(questions)} questions: {[q['question'] for q in questions]}"
)
if not questions:
logger.warning("No matching questions found")
return
api_config = load_env()
semaphore = asyncio.Semaphore(args.n)
async def limited_solve(session: aiohttp.ClientSession, q: dict) -> SolveResult:
async with semaphore:
return await solve_question(session, q, api_config, logger, args.dry_run)
async with aiohttp.ClientSession() as session:
tasks = [asyncio.create_task(limited_solve(session, q)) for q in questions]
results = []
try:
for coro in asyncio.as_completed(tasks):
result = await coro
results.append(result)
except asyncio.CancelledError:
logger.warning("Cancelled! Shutting down...")
for task in tasks:
task.cancel()
await asyncio.gather(*tasks, return_exceptions=True)
sys.exit(1)
results.sort(key=lambda r: r.question)
skipped = sum(1 for r in results if r.skipped)
solved = sum(1 for r in results if r.success and not r.skipped)
logger.info(f"Complete: {solved}/{len(results)} solved, {skipped} skipped")
def main() -> None:
"""Main entry point."""
args = parse_args()
logger = setup_logging("solve", args.verbose)
logger.info(f"solve starting (Dry-run: {args.dry_run}, Workers: {args.n})")
logger.info(f"Data directory: {DATA_DIR}")
try:
asyncio.run(async_main(args, logger))
except KeyboardInterrupt:
logger.warning("Interrupted by user")
sys.exit(1)
if __name__ == "__main__":
main()