#!/usr/bin/env python3
"""Teacher-translation wrapper for statement and proof generation.

This release file keeps the portable teacher API boundary and the model prompt
templates while omitting verifier calls, semantic judging, retries, run-state
logging, and private infrastructure. It contains no API keys, private base URLs,
local paths, logs, or PISA process management.
"""

from __future__ import annotations

import argparse
import json
import os
import re
import time
from pathlib import Path
from typing import Any, Dict, Iterable, List

from verify import extract_theory


PROMPT_DIR = Path(__file__).resolve().parents[1] / "prompts"


def _load_prompt_template(filename: str) -> str:
    text = (PROMPT_DIR / filename).read_text(encoding="utf-8")
    match = re.search(r"```prompt\n(.*?)\n```", text, flags=re.DOTALL)
    if not match:
        raise ValueError(f"{filename} must contain a fenced ```prompt block")
    prompt = match.group(1)
    if any(ord(ch) > 127 for ch in prompt):
        raise ValueError(f"{filename} prompt must be ASCII-only")
    return prompt


STATEMENT_PROMPT_TEMPLATE = _load_prompt_template("statement_model_prompt.md")
THEORY_PROMPT_TEMPLATE = _load_prompt_template("theory_model_prompt.md")


def read_jsonl(path: Path) -> List[Dict[str, Any]]:
    rows: List[Dict[str, Any]] = []
    with path.open("r", encoding="utf-8") as f:
        for line in f:
            line = line.strip()
            if line:
                rows.append(json.loads(line))
    return rows


def write_jsonl(path: Path, rows: Iterable[Dict[str, Any]]) -> None:
    path.parent.mkdir(parents=True, exist_ok=True)
    with path.open("w", encoding="utf-8") as f:
        for row in rows:
            f.write(json.dumps(row, ensure_ascii=False, sort_keys=True) + "\n")


def _row_label(row: Dict[str, Any], index: int) -> str:
    return str(row.get("id") or row.get("split_key") or f"row_{index}")


def _require_text(row: Dict[str, Any], index: int, path: Iterable[str]) -> str:
    current: Any = row
    visited: List[str] = []
    for key in path:
        visited.append(key)
        if not isinstance(current, dict) or key not in current:
            raise ValueError(f"{_row_label(row, index)} missing field: {'.'.join(visited)}")
        current = current[key]
    if current is None or not str(current).strip():
        raise ValueError(f"{_row_label(row, index)} has empty field: {'.'.join(path)}")
    return str(current)


def _optional_text(row: Dict[str, Any], path: Iterable[str]) -> str:
    current: Any = row
    for key in path:
        if not isinstance(current, dict):
            return ""
        current = current.get(key)
    return "" if current is None else str(current)


def build_teacher_prompt(
    row: Dict[str, Any],
    *,
    index: int = 0,
    stage: str,
) -> Dict[str, str]:
    """Build the release teacher prompt for the requested stage."""

    if stage == "statement":
        lean_header = _optional_text(row, ("lean", "header")) or "None"
        formal_statement = _require_text(row, index, ("lean", "formal_theorem"))
        return {
            "developer": "",
            "user": STATEMENT_PROMPT_TEMPLATE.format(
                header=lean_header,
                formal_theorem=formal_statement,
            ),
        }
    if stage == "theory":
        lean_header = _optional_text(row, ("lean", "header")) or "None"
        formal_proof = _require_text(row, index, ("lean", "formal_proof"))
        isabelle_statement = _optional_text(row, ("lean", "isabelle_statement")) or "theorem omitted"
        return {
            "developer": "",
            "user": THEORY_PROMPT_TEMPLATE.format(
                header=lean_header,
                formal_proof=formal_proof,
                formal_statement_norm=isabelle_statement.strip(),
            ),
        }
    raise ValueError(f"unsupported stage: {stage}")


def _messages(prompt: Dict[str, str], developer_role: str) -> List[Dict[str, str]]:
    if developer_role not in {"developer", "system"}:
        raise ValueError("--developer_role must be 'developer' or 'system'")
    messages: List[Dict[str, str]] = []
    if prompt.get("developer", "").strip():
        messages.append({"role": developer_role, "content": prompt["developer"]})
    messages.append({"role": "user", "content": prompt["user"]})
    return messages


def has_complete_theory_block(text: str) -> bool:
    return bool(re.search(r"(?ms)^\s*theory\b.*?^\s*end\b", str(text)))


def dry_run(
    rows: List[Dict[str, Any]],
    *,
    stage: str,
    developer_role: str,
) -> List[Dict[str, Any]]:
    outputs: List[Dict[str, Any]] = []
    for index, row in enumerate(rows):
        prompt = build_teacher_prompt(
            row,
            index=index,
            stage=stage,
        )
        outputs.append(
            {
                "id": row.get("id"),
                "split_key": row.get("split_key"),
                "stage": stage,
                "developer_prompt": prompt["developer"],
                "user_prompt": prompt["user"],
                "messages": _messages(prompt, developer_role),
                "generation": "",
                "raw_generation": "",
                "structural_ok": False,
                "dry_run": True,
                "error": "",
            }
        )
    return outputs


def openai_compatible_generate(
    rows: List[Dict[str, Any]],
    *,
    stage: str,
    model: str,
    api_key_env: str,
    base_url: str,
    temperature: float | None,
    max_tokens: int,
    reasoning_effort: str,
    sleep_sec: float,
    timeout_sec: float,
    max_retries: int,
    developer_role: str,
    keep_prompts: bool,
) -> List[Dict[str, Any]]:
    from openai import OpenAI

    api_key = os.environ.get(api_key_env)
    if not api_key:
        raise SystemExit(f"Missing API key env var: {api_key_env}")
    client = OpenAI(api_key=api_key, base_url=base_url or None, timeout=float(timeout_sec))
    outputs: List[Dict[str, Any]] = []
    for index, row in enumerate(rows):
        prompt = build_teacher_prompt(
            row,
            index=index,
            stage=stage,
        )
        messages = _messages(prompt, developer_role)
        error = ""
        text = ""
        finish_reason = ""
        for attempt in range(max(1, int(max_retries) + 1)):
            try:
                request: Dict[str, Any] = {"model": model, "messages": messages}
                if temperature is not None:
                    request["temperature"] = float(temperature)
                if max_tokens > 0:
                    request["max_tokens"] = int(max_tokens)
                if reasoning_effort:
                    request["reasoning_effort"] = reasoning_effort
                response = client.chat.completions.create(**request)
                choice = response.choices[0] if response.choices else None
                message = getattr(choice, "message", None) if choice else None
                text = getattr(message, "content", "") if message else ""
                finish_reason = str(getattr(choice, "finish_reason", "") or "") if choice else ""
                if not text.strip():
                    raise RuntimeError("empty response from chat.completions.create()")
                error = ""
                break
            except Exception as exc:
                error = f"{type(exc).__name__}: {exc}"
                if attempt >= max_retries:
                    break
                time.sleep(min(30.0, 2.0**attempt))
        theory = extract_theory(text)
        item: Dict[str, Any] = {
            "id": row.get("id"),
            "split_key": row.get("split_key"),
            "stage": stage,
            "model": model,
            "generation": theory or text.strip(),
            "raw_generation": text.strip(),
            "finish_reason": finish_reason,
            "structural_ok": has_complete_theory_block(theory),
            "dry_run": False,
            "error": error,
        }
        if keep_prompts:
            item["developer_prompt"] = prompt["developer"]
            item["user_prompt"] = prompt["user"]
        outputs.append(item)
        if sleep_sec > 0:
            time.sleep(float(sleep_sec))
    return outputs


def main() -> None:
    parser = argparse.ArgumentParser()
    parser.add_argument("--input", type=Path, required=True)
    parser.add_argument("--output", type=Path, required=True)
    parser.add_argument("--stage", choices=["statement", "theory"], required=True)
    parser.add_argument("--limit", type=int, default=0)
    parser.add_argument("--dry_run", action="store_true")
    parser.add_argument("--model", default="")
    parser.add_argument("--api_key_env", default="OPENAI_API_KEY")
    parser.add_argument("--base_url", default="")
    parser.add_argument("--temperature", type=float, default=None)
    parser.add_argument("--max_tokens", type=int, default=0)
    parser.add_argument("--reasoning_effort", default="medium")
    parser.add_argument("--sleep_sec", type=float, default=0.0)
    parser.add_argument("--timeout_sec", type=float, default=120.0)
    parser.add_argument("--max_retries", type=int, default=2)
    parser.add_argument("--developer_role", choices=["developer", "system"], default="developer")
    parser.add_argument("--keep_prompts", action="store_true")
    args = parser.parse_args()

    if not args.dry_run and not args.model:
        raise SystemExit("--model is required unless --dry_run is set")

    rows = read_jsonl(args.input)
    if args.limit > 0:
        rows = rows[: args.limit]
    try:
        if args.dry_run:
            generations = dry_run(
                rows,
                stage=args.stage,
                developer_role=args.developer_role,
            )
        else:
            generations = openai_compatible_generate(
                rows,
                stage=args.stage,
                model=args.model,
                api_key_env=args.api_key_env,
                base_url=args.base_url,
                temperature=args.temperature,
                max_tokens=int(args.max_tokens),
                reasoning_effort=args.reasoning_effort,
                sleep_sec=float(args.sleep_sec),
                timeout_sec=float(args.timeout_sec),
                max_retries=int(args.max_retries),
                developer_role=args.developer_role,
                keep_prompts=bool(args.keep_prompts),
            )
    except ValueError as exc:
        raise SystemExit(f"Input validation failed: {exc}") from exc
    write_jsonl(args.output, generations)
    print(json.dumps({"status": "ok", "count": len(generations), "output": args.output.name}, indent=2))


if __name__ == "__main__":
    main()
