Files
homework-template/scripts/fix_typ.py

269 lines
7.6 KiB
Python

#!/usr/bin/env python3
"""
fix_typ.py - Fix Typst files to conform to proper Typst syntax using AI API.
This script scans the data directory for .typ files, uses AI to fix common
syntax issues (e.g., Markdown math syntax), and overwrites the original files.
"""
import argparse
import asyncio
import logging
import sys
from dataclasses import dataclass
from pathlib import Path
import aiohttp
from common import DATA_DIR, console, load_env, load_prompt, setup_logging
@dataclass
class FixResult:
"""Result of fixing a typst file."""
question: str
target: str
skipped: bool
success: bool
error: str | None = None
def find_typ_files() -> list[Path]:
"""Find all .typ files in data directory."""
typ_files = []
for file_path in DATA_DIR.iterdir():
if file_path.is_file() and file_path.suffix == ".typ":
if not file_path.name.startswith("A_"):
typ_files.append(file_path)
return sorted(typ_files)
async def call_api(
session: aiohttp.ClientSession,
typ_content: str,
question_name: str,
endpoint: str,
api_key: str,
model: str,
logger: logging.Logger,
) -> str | None:
"""Call the AI API to fix typst syntax."""
headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"}
prompt = load_prompt("fix_typ.prompt.txt")
if not prompt:
prompt = "请修复以下 Typst 代码的语法错误,使其符合 Typst 规范。特别注意数学表达式不要使用 Markdown/LaTeX 语法。"
full_prompt = f"""{prompt}
需要修复的文件内容:
```
{typ_content}
```
"""
payload = {
"model": model,
"messages": [
{"role": "user", "content": [{"type": "text", "text": full_prompt}]}
],
"max_tokens": 4096,
}
logger.info(f"[{question_name}] Fixing... (timeout 120s)")
try:
async with session.post(
endpoint,
headers=headers,
json=payload,
timeout=aiohttp.ClientTimeout(total=120),
) as response:
response.raise_for_status()
result = await response.json()
if "choices" not in result or len(result["choices"]) == 0:
logger.error(f"Invalid API response")
return None
content = result["choices"][0]["message"]["content"]
usage = result.get("usage", {})
input_tokens = usage.get("prompt_tokens", 0)
output_tokens = usage.get("completion_tokens", 0)
logger.info(
f"[{question_name}] Done: {input_tokens} in, {output_tokens} out"
)
return content
except asyncio.TimeoutError:
logger.error(f"[{question_name}] Timeout")
return None
except asyncio.CancelledError:
logger.warning(f"[{question_name}] Cancelled")
raise
except Exception as e:
logger.error(f"[{question_name}] Error: {e}")
return None
async def fix_typ_file(
session: aiohttp.ClientSession,
typ_path: Path,
api_config: dict,
logger: logging.Logger,
dry_run: bool,
) -> FixResult:
"""Fix a single .typ file."""
question_name = typ_path.stem
if dry_run:
logger.info(f"[{question_name}] Would fix -> {typ_path.name}")
return FixResult(
question=question_name, target=typ_path.name, skipped=False, success=True
)
try:
typ_content = typ_path.read_text(encoding="utf-8")
except IOError as e:
logger.error(f"[{question_name}] Read failed: {e}")
return FixResult(
question=question_name,
target=typ_path.name,
skipped=False,
success=False,
error=str(e),
)
fixed_content = await call_api(
session,
typ_content,
question_name,
str(api_config["endpoint"]),
api_config["key"],
api_config["model"],
logger,
)
if fixed_content is None:
return FixResult(
question=question_name,
target=typ_path.name,
skipped=False,
success=False,
error="API call failed",
)
import re
block_pattern = re.compile(r"```(?:typst)?\s*\n?(.*?)\n?```", re.DOTALL)
matches = list(block_pattern.finditer(fixed_content))
if matches:
fixed_content = matches[0].group(1).strip()
try:
typ_path.write_text(fixed_content, encoding="utf-8")
logger.info(
f"[{question_name}] Wrote {typ_path.name} ({len(fixed_content)} bytes)"
)
return FixResult(
question=question_name, target=typ_path.name, skipped=False, success=True
)
except IOError as e:
logger.error(f"[{question_name}] Write failed: {e}")
return FixResult(
question=question_name,
target=typ_path.name,
skipped=False,
success=False,
error=str(e),
)
def parse_args() -> argparse.Namespace:
"""Parse command line arguments."""
parser = argparse.ArgumentParser(
description="Fix .typ files to conform to Typst syntax using AI API"
)
parser.add_argument(
"-f",
"--file",
action="append",
dest="files",
help="Specific .typ files to fix",
)
parser.add_argument(
"--dry-run", action="store_true", help="Do not call AI or write files"
)
parser.add_argument("--verbose", action="store_true", help="Enable debug logging")
parser.add_argument("-n", type=int, default=3, help="Concurrent limit (default: 3)")
return parser.parse_args()
async def async_main(args: argparse.Namespace, logger: logging.Logger) -> None:
"""Async main entry point."""
if args.files:
typ_paths = []
for f in args.files:
p = Path(f)
if not p.is_absolute():
p = DATA_DIR / f
typ_paths.append(p)
else:
typ_paths = find_typ_files()
logger.info(f"Found {len(typ_paths)} .typ files to fix")
if not typ_paths:
logger.warning("No .typ files found to fix")
return
api_config = load_env()
semaphore = asyncio.Semaphore(args.n)
async def limited_fix(session: aiohttp.ClientSession, typ_path: Path) -> FixResult:
async with semaphore:
return await fix_typ_file(
session, typ_path, api_config, logger, args.dry_run
)
async with aiohttp.ClientSession() as session:
tasks = [asyncio.create_task(limited_fix(session, typ)) for typ in typ_paths]
results = []
try:
for coro in asyncio.as_completed(tasks):
result = await coro
results.append(result)
except asyncio.CancelledError:
logger.warning("Cancelled! Shutting down...")
for task in tasks:
task.cancel()
await asyncio.gather(*tasks, return_exceptions=True)
sys.exit(1)
results.sort(key=lambda r: r.question)
skipped = sum(1 for r in results if r.skipped)
fixed = sum(1 for r in results if r.success and not r.skipped)
logger.info(f"Complete: {fixed}/{len(results)} fixed, {skipped} skipped")
def main() -> None:
"""Main entry point."""
args = parse_args()
logger = setup_logging("fix_typ", args.verbose)
logger.info(f"fix_typ starting (Dry-run: {args.dry_run}, Workers: {args.n})")
logger.info(f"Data directory: {DATA_DIR}")
try:
asyncio.run(async_main(args, logger))
except KeyboardInterrupt:
logger.warning("Interrupted by user")
sys.exit(1)
if __name__ == "__main__":
main()