311 lines
9.3 KiB
Python
311 lines
9.3 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
solve.py - Solve questions defined in questions.json using AI API.
|
|
|
|
This script reads questions.json, processes each question with corresponding
|
|
typst files and attachments, and generates answer Markdown files.
|
|
"""
|
|
|
|
import argparse
|
|
import asyncio
|
|
import json
|
|
import logging
|
|
import sys
|
|
from dataclasses import dataclass
|
|
from pathlib import Path
|
|
|
|
import aiohttp
|
|
from common import (
|
|
DATA_DIR,
|
|
console,
|
|
find_attachments,
|
|
load_env,
|
|
load_prompt,
|
|
setup_logging,
|
|
)
|
|
|
|
|
|
@dataclass
|
|
class SolveResult:
|
|
"""Result of solving a question."""
|
|
|
|
question: str
|
|
target: str
|
|
skipped: bool
|
|
success: bool
|
|
error: str | None = None
|
|
|
|
|
|
def load_questions_json() -> list[dict]:
|
|
"""Load questions from questions.json."""
|
|
questions_path = DATA_DIR / "questions.json"
|
|
if not questions_path.exists():
|
|
console.print(
|
|
f"[yellow]Warning: questions.json not found at {questions_path}[/yellow]"
|
|
)
|
|
return []
|
|
with open(questions_path, "r", encoding="utf-8") as f:
|
|
return json.load(f)
|
|
|
|
|
|
def build_prompt(question_data: dict, typ_content: str | None) -> str:
|
|
"""Build the full prompt including typst content and attachments info."""
|
|
base_prompt = load_prompt("solve.prompt.txt")
|
|
if not base_prompt:
|
|
base_prompt = "请解答这道题。"
|
|
|
|
parts = [base_prompt, "\n\n"]
|
|
|
|
if typ_content:
|
|
parts.append("## 题目描述 (来自 .typ 文件):\n")
|
|
parts.append(typ_content)
|
|
parts.append("\n\n")
|
|
|
|
attachments = question_data.get("attachments", [])
|
|
if attachments:
|
|
parts.append("## 附件:\n")
|
|
for att in attachments:
|
|
att_path = DATA_DIR / att
|
|
if att_path.exists():
|
|
if att.lower().endswith(
|
|
(".png", ".jpg", ".jpeg", ".gif", ".bmp", ".webp")
|
|
):
|
|
parts.append(f"- [图片附件: {att}]\n")
|
|
else:
|
|
try:
|
|
content = att_path.read_text(encoding="utf-8")
|
|
parts.append(f"- {att}:\n```\n{content}\n```\n")
|
|
except:
|
|
parts.append(f"- {att} (二进制文件)\n")
|
|
parts.append("\n")
|
|
|
|
return "".join(parts)
|
|
|
|
|
|
async def call_api_streaming(
|
|
session: aiohttp.ClientSession,
|
|
question_name: str,
|
|
prompt: str,
|
|
endpoint: str,
|
|
api_key: str,
|
|
model: str,
|
|
logger: logging.Logger,
|
|
) -> str | None:
|
|
"""Call the AI API with streaming to solve the question."""
|
|
headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"}
|
|
|
|
payload = {
|
|
"model": model,
|
|
"messages": [{"role": "user", "content": [{"type": "text", "text": prompt}]}],
|
|
"max_tokens": 4096,
|
|
"stream": True,
|
|
}
|
|
|
|
logger.info(f"[{question_name}] Thinking... (timeout 600s)")
|
|
try:
|
|
async with session.post(
|
|
endpoint,
|
|
headers=headers,
|
|
json=payload,
|
|
timeout=aiohttp.ClientTimeout(total=600),
|
|
) as response:
|
|
response.raise_for_status()
|
|
|
|
full_content = []
|
|
async for line in response.content:
|
|
if not line:
|
|
continue
|
|
line = line.decode("utf-8")
|
|
if line.startswith("data: "):
|
|
data = line[6:]
|
|
if data == "[DONE]":
|
|
continue
|
|
try:
|
|
chunk = json.loads(data)
|
|
logger.debug(f"[{question_name}] Chunk: {chunk}")
|
|
delta = chunk.get("choices", [{}])[0].get("delta", {})
|
|
content = delta.get("content")
|
|
if content:
|
|
full_content.append(content)
|
|
except json.JSONDecodeError:
|
|
continue
|
|
|
|
content = "".join(full_content)
|
|
if not content:
|
|
logger.warning(f"[{question_name}] Empty response")
|
|
return None
|
|
|
|
logger.info(f"[{question_name}] Done ({len(content)} chars)")
|
|
return content
|
|
|
|
except asyncio.TimeoutError:
|
|
logger.error(f"[{question_name}] Timeout")
|
|
return None
|
|
except asyncio.CancelledError:
|
|
logger.warning(f"[{question_name}] Cancelled")
|
|
raise
|
|
except Exception as e:
|
|
logger.error(f"[{question_name}] Error: {e}")
|
|
return None
|
|
|
|
|
|
async def solve_question(
|
|
session: aiohttp.ClientSession,
|
|
question_data: dict,
|
|
api_config: dict,
|
|
logger: logging.Logger,
|
|
dry_run: bool,
|
|
) -> SolveResult:
|
|
"""Solve a single question and generate answer markdown."""
|
|
question_name = question_data["question"]
|
|
target_name = f"A_{question_name}.md"
|
|
target_path = DATA_DIR / target_name
|
|
|
|
if target_path.exists():
|
|
logger.info(f"[{question_name}] Skipping: answer already exists")
|
|
return SolveResult(
|
|
question=question_name, target=target_name, skipped=True, success=True
|
|
)
|
|
|
|
typ_path = DATA_DIR / question_data["target"]
|
|
typ_content = None
|
|
if typ_path.exists():
|
|
try:
|
|
typ_content = typ_path.read_text(encoding="utf-8")
|
|
except:
|
|
typ_content = None
|
|
|
|
if dry_run:
|
|
logger.info(f"[{question_name}] Would solve -> {target_name}")
|
|
return SolveResult(
|
|
question=question_name, target=target_name, skipped=False, success=True
|
|
)
|
|
|
|
prompt = build_prompt(question_data, typ_content)
|
|
|
|
content = await call_api_streaming(
|
|
session,
|
|
question_name,
|
|
prompt,
|
|
str(api_config["endpoint"]),
|
|
api_config["key"],
|
|
api_config["model"],
|
|
logger,
|
|
)
|
|
|
|
if content is None:
|
|
return SolveResult(
|
|
question=question_name,
|
|
target=target_name,
|
|
skipped=False,
|
|
success=False,
|
|
error="API call failed",
|
|
)
|
|
|
|
try:
|
|
target_path.write_text(content, encoding="utf-8")
|
|
logger.info(f"[{question_name}] Wrote {target_name} ({len(content)} bytes)")
|
|
return SolveResult(
|
|
question=question_name, target=target_name, skipped=False, success=True
|
|
)
|
|
except IOError as e:
|
|
logger.error(f"[{question_name}] Write failed: {e}")
|
|
return SolveResult(
|
|
question=question_name,
|
|
target=target_name,
|
|
skipped=False,
|
|
success=False,
|
|
error=str(e),
|
|
)
|
|
|
|
|
|
def parse_args() -> argparse.Namespace:
|
|
"""Parse command line arguments."""
|
|
parser = argparse.ArgumentParser(
|
|
description="Solve questions from questions.json using AI API"
|
|
)
|
|
parser.add_argument(
|
|
"-q",
|
|
"--question",
|
|
action="append",
|
|
dest="questions",
|
|
help="Specific question IDs to solve (e.g., P15)",
|
|
)
|
|
parser.add_argument(
|
|
"--dry-run", action="store_true", help="Do not call AI or write files"
|
|
)
|
|
parser.add_argument("--verbose", action="store_true", help="Enable debug logging")
|
|
parser.add_argument(
|
|
"--retry", type=int, default=3, help="Retry attempts (default: 3)"
|
|
)
|
|
parser.add_argument("-n", type=int, default=3, help="Concurrent limit (default: 3)")
|
|
return parser.parse_args()
|
|
|
|
|
|
async def async_main(args: argparse.Namespace, logger: logging.Logger) -> None:
|
|
"""Async main entry point."""
|
|
questions = load_questions_json()
|
|
if not questions:
|
|
logger.warning("No questions found in questions.json")
|
|
return
|
|
|
|
if args.questions:
|
|
questions = [q for q in questions if q["question"] in args.questions]
|
|
logger.info(f"Processing: {[q['question'] for q in questions]}")
|
|
else:
|
|
logger.info(
|
|
f"Found {len(questions)} questions: {[q['question'] for q in questions]}"
|
|
)
|
|
|
|
if not questions:
|
|
logger.warning("No matching questions found")
|
|
return
|
|
|
|
api_config = load_env()
|
|
|
|
semaphore = asyncio.Semaphore(args.n)
|
|
|
|
async def limited_solve(session: aiohttp.ClientSession, q: dict) -> SolveResult:
|
|
async with semaphore:
|
|
return await solve_question(session, q, api_config, logger, args.dry_run)
|
|
|
|
async with aiohttp.ClientSession() as session:
|
|
tasks = [asyncio.create_task(limited_solve(session, q)) for q in questions]
|
|
|
|
results = []
|
|
try:
|
|
for coro in asyncio.as_completed(tasks):
|
|
result = await coro
|
|
results.append(result)
|
|
except asyncio.CancelledError:
|
|
logger.warning("Cancelled! Shutting down...")
|
|
for task in tasks:
|
|
task.cancel()
|
|
await asyncio.gather(*tasks, return_exceptions=True)
|
|
sys.exit(1)
|
|
|
|
results.sort(key=lambda r: r.question)
|
|
|
|
skipped = sum(1 for r in results if r.skipped)
|
|
solved = sum(1 for r in results if r.success and not r.skipped)
|
|
logger.info(f"Complete: {solved}/{len(results)} solved, {skipped} skipped")
|
|
|
|
|
|
def main() -> None:
|
|
"""Main entry point."""
|
|
args = parse_args()
|
|
logger = setup_logging("solve", args.verbose)
|
|
|
|
logger.info(f"solve starting (Dry-run: {args.dry_run}, Workers: {args.n})")
|
|
logger.info(f"Data directory: {DATA_DIR}")
|
|
|
|
try:
|
|
asyncio.run(async_main(args, logger))
|
|
except KeyboardInterrupt:
|
|
logger.warning("Interrupted by user")
|
|
sys.exit(1)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|