#!/usr/bin/env python3 """ fix_typ.py - Fix Typst files to conform to proper Typst syntax using AI API. This script scans the data directory for .typ files, uses AI to fix common syntax issues (e.g., Markdown math syntax), and overwrites the original files. """ import argparse import asyncio import logging import sys from dataclasses import dataclass from pathlib import Path import aiohttp from common import DATA_DIR, console, load_env, load_prompt, setup_logging @dataclass class FixResult: """Result of fixing a typst file.""" question: str target: str skipped: bool success: bool error: str | None = None def find_typ_files() -> list[Path]: """Find all .typ files in data directory.""" typ_files = [] for file_path in DATA_DIR.iterdir(): if file_path.is_file() and file_path.suffix == ".typ": if not file_path.name.startswith("A_"): typ_files.append(file_path) return sorted(typ_files) async def call_api( session: aiohttp.ClientSession, typ_content: str, question_name: str, endpoint: str, api_key: str, model: str, logger: logging.Logger, ) -> str | None: """Call the AI API to fix typst syntax.""" headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"} prompt = load_prompt("fix_typ.prompt.txt") if not prompt: prompt = "请修复以下 Typst 代码的语法错误,使其符合 Typst 规范。特别注意数学表达式不要使用 Markdown/LaTeX 语法。" full_prompt = f"""{prompt} 需要修复的文件内容: ``` {typ_content} ``` """ payload = { "model": model, "messages": [ {"role": "user", "content": [{"type": "text", "text": full_prompt}]} ], "max_tokens": 4096, } logger.info(f"[{question_name}] Fixing... (timeout 120s)") try: async with session.post( endpoint, headers=headers, json=payload, timeout=aiohttp.ClientTimeout(total=120), ) as response: response.raise_for_status() result = await response.json() if "choices" not in result or len(result["choices"]) == 0: logger.error(f"Invalid API response") return None content = result["choices"][0]["message"]["content"] usage = result.get("usage", {}) input_tokens = usage.get("prompt_tokens", 0) output_tokens = usage.get("completion_tokens", 0) logger.info( f"[{question_name}] Done: {input_tokens} in, {output_tokens} out" ) return content except asyncio.TimeoutError: logger.error(f"[{question_name}] Timeout") return None except asyncio.CancelledError: logger.warning(f"[{question_name}] Cancelled") raise except Exception as e: logger.error(f"[{question_name}] Error: {e}") return None async def fix_typ_file( session: aiohttp.ClientSession, typ_path: Path, api_config: dict, logger: logging.Logger, dry_run: bool, ) -> FixResult: """Fix a single .typ file.""" question_name = typ_path.stem if dry_run: logger.info(f"[{question_name}] Would fix -> {typ_path.name}") return FixResult( question=question_name, target=typ_path.name, skipped=False, success=True ) try: typ_content = typ_path.read_text(encoding="utf-8") except IOError as e: logger.error(f"[{question_name}] Read failed: {e}") return FixResult( question=question_name, target=typ_path.name, skipped=False, success=False, error=str(e), ) fixed_content = await call_api( session, typ_content, question_name, str(api_config["endpoint"]), api_config["key"], api_config["model"], logger, ) if fixed_content is None: return FixResult( question=question_name, target=typ_path.name, skipped=False, success=False, error="API call failed", ) import re block_pattern = re.compile(r"```(?:typst)?\s*\n?(.*?)\n?```", re.DOTALL) matches = list(block_pattern.finditer(fixed_content)) if matches: fixed_content = matches[0].group(1).strip() try: typ_path.write_text(fixed_content, encoding="utf-8") logger.info( f"[{question_name}] Wrote {typ_path.name} ({len(fixed_content)} bytes)" ) return FixResult( question=question_name, target=typ_path.name, skipped=False, success=True ) except IOError as e: logger.error(f"[{question_name}] Write failed: {e}") return FixResult( question=question_name, target=typ_path.name, skipped=False, success=False, error=str(e), ) def parse_args() -> argparse.Namespace: """Parse command line arguments.""" parser = argparse.ArgumentParser( description="Fix .typ files to conform to Typst syntax using AI API" ) parser.add_argument( "-f", "--file", action="append", dest="files", help="Specific .typ files to fix", ) parser.add_argument( "--dry-run", action="store_true", help="Do not call AI or write files" ) parser.add_argument("--verbose", action="store_true", help="Enable debug logging") parser.add_argument("-n", type=int, default=3, help="Concurrent limit (default: 3)") return parser.parse_args() async def async_main(args: argparse.Namespace, logger: logging.Logger) -> None: """Async main entry point.""" if args.files: typ_paths = [] for f in args.files: p = Path(f) if not p.is_absolute(): p = DATA_DIR / f typ_paths.append(p) else: typ_paths = find_typ_files() logger.info(f"Found {len(typ_paths)} .typ files to fix") if not typ_paths: logger.warning("No .typ files found to fix") return api_config = load_env() semaphore = asyncio.Semaphore(args.n) async def limited_fix(session: aiohttp.ClientSession, typ_path: Path) -> FixResult: async with semaphore: return await fix_typ_file( session, typ_path, api_config, logger, args.dry_run ) async with aiohttp.ClientSession() as session: tasks = [asyncio.create_task(limited_fix(session, typ)) for typ in typ_paths] results = [] try: for coro in asyncio.as_completed(tasks): result = await coro results.append(result) except asyncio.CancelledError: logger.warning("Cancelled! Shutting down...") for task in tasks: task.cancel() await asyncio.gather(*tasks, return_exceptions=True) sys.exit(1) results.sort(key=lambda r: r.question) skipped = sum(1 for r in results if r.skipped) fixed = sum(1 for r in results if r.success and not r.skipped) logger.info(f"Complete: {fixed}/{len(results)} fixed, {skipped} skipped") def main() -> None: """Main entry point.""" args = parse_args() logger = setup_logging("fix_typ", args.verbose) logger.info(f"fix_typ starting (Dry-run: {args.dry_run}, Workers: {args.n})") logger.info(f"Data directory: {DATA_DIR}") try: asyncio.run(async_main(args, logger)) except KeyboardInterrupt: logger.warning("Interrupted by user") sys.exit(1) if __name__ == "__main__": main()