#!/usr/bin/env python3
"""Compact GRPO-MVP training scaffold for the supplementary artifact.

This file keeps the training entry point independent of the full experiment
codebase. The default `prepare` command only materializes prompts and runs in
the lightweight smoke-test environment. The `train` command requires the
optional ML stack plus an external verifier backend.
"""

from __future__ import annotations

import argparse
import json
import os
from pathlib import Path
from typing import Any, Callable, Dict, Iterable, List

from grpo_mvp import contains_forbidden_placeholder, grpo_mvp_reward, order_by_difficulty
from verify import extract_theory, structural_check


PROMPT_DIR = Path(__file__).resolve().parents[1] / "prompts"


def _load_prompt_template(filename: str) -> str:
    text = (PROMPT_DIR / filename).read_text(encoding="utf-8")
    import re

    match = re.search(r"```prompt\n(.*?)\n```", text, flags=re.DOTALL)
    if not match:
        raise ValueError(f"{filename} must contain a fenced ```prompt block")
    prompt = match.group(1)
    if any(ord(ch) > 127 for ch in prompt):
        raise ValueError(f"{filename} prompt must be ASCII-only")
    return prompt


THEORY_PROMPT_TEMPLATE = _load_prompt_template("theory_model_prompt.md")


def read_jsonl(path: Path) -> List[Dict[str, Any]]:
    rows: List[Dict[str, Any]] = []
    with path.open("r", encoding="utf-8") as f:
        for line in f:
            line = line.strip()
            if line:
                rows.append(json.loads(line))
    return rows


def write_jsonl(path: Path, rows: Iterable[Dict[str, Any]]) -> None:
    path.parent.mkdir(parents=True, exist_ok=True)
    with path.open("w", encoding="utf-8") as f:
        for row in rows:
            f.write(json.dumps(row, ensure_ascii=False, sort_keys=True) + "\n")


def build_theory_prompt(row: Dict[str, Any], isabelle_statement_field: str = "isabelle_statement") -> str:
    lean = row.get("lean") or {}
    isabelle_statement = lean.get(isabelle_statement_field) or lean.get("isabelle_statement") or "theorem omitted"
    return THEORY_PROMPT_TEMPLATE.format(
        header=lean.get("header") or "None",
        formal_proof=lean.get("formal_proof") or "proof omitted",
        formal_statement_norm=str(isabelle_statement).strip(),
    )


def truncate_prompt_for_policy(tokenizer: Any, prompt: str, max_prompt_length: int) -> str:
    if max_prompt_length <= 0:
        return prompt
    encoded = tokenizer(
        prompt,
        truncation=True,
        max_length=int(max_prompt_length),
        add_special_tokens=False,
    )
    input_ids = encoded.get("input_ids") or []
    if not input_ids:
        return prompt
    return tokenizer.decode(input_ids, skip_special_tokens=False)


def completion_to_text(completion: Any) -> str:
    if isinstance(completion, str):
        return completion
    if isinstance(completion, dict):
        return str(completion.get("content") or "")
    if isinstance(completion, list):
        contents = [
            str(item.get("content") or "")
            for item in completion
            if isinstance(item, dict) and item.get("content") is not None
        ]
        return contents[-1] if contents else ""
    return str(completion)


def prepare_grpo_rows(
    rows: List[Dict[str, Any]],
    *,
    stage: str,
    difficulty_order: bool,
    limit: int,
    isabelle_statement_field: str,
    tokenizer: Any | None = None,
    max_prompt_length: int = 0,
) -> List[Dict[str, Any]]:
    selected = order_by_difficulty(rows) if difficulty_order else list(rows)
    if limit > 0:
        selected = selected[:limit]
    prepared: List[Dict[str, Any]] = []
    for row in selected:
        lean = row.get("lean") or {}
        prepared.append(
            {
                "id": row.get("id"),
                "split_key": row.get("split_key"),
                "lean_name": lean.get("name"),
                "prompt": truncate_prompt_for_policy(
                    tokenizer,
                    build_theory_prompt(row, isabelle_statement_field=isabelle_statement_field),
                    max_prompt_length,
                )
                if tokenizer is not None
                else build_theory_prompt(row, isabelle_statement_field=isabelle_statement_field),
            }
        )
    return prepared


class StructuralVerifier:
    """Smoke-test verifier; not a substitute for PISA in reported runs."""

    def verify(self, theory: str) -> Dict[str, Any]:
        result = structural_check(theory)
        return {
            "success": bool(result["success"]),
            "step_results": [],
            "backend": "structural",
            "raw": result,
        }


class PisaHttpVerifier:
    """Small adapter boundary for a Portal-to-Isabelle-compatible service."""

    def __init__(self, endpoint: str, timeout: float) -> None:
        if not endpoint:
            raise ValueError("--pisa_endpoint is required when --verifier_mode=pisa_http")
        self.endpoint = endpoint
        self.timeout = timeout

    def verify(self, theory: str) -> Dict[str, Any]:
        import requests

        response = requests.post(self.endpoint, json={"theory": theory}, timeout=self.timeout)
        response.raise_for_status()
        payload = response.json()
        success = bool(
            payload.get("success")
            or payload.get("passed")
            or str(payload.get("complete", "")).lower() == "true"
        )
        step_results = payload.get("step_results") or payload.get("steps") or []
        return {"success": success, "step_results": step_results, "backend": "pisa_http", "raw": payload}


def build_verifier(args: argparse.Namespace) -> Any:
    if args.verifier_mode == "structural":
        return StructuralVerifier()
    if args.verifier_mode == "pisa_http":
        return PisaHttpVerifier(args.pisa_endpoint, args.pisa_timeout)
    raise ValueError(f"unsupported verifier mode: {args.verifier_mode}")


def make_reward_function(args: argparse.Namespace) -> Callable[..., List[float]]:
    verifier = build_verifier(args)

    def reward_func(prompts: Iterable[Any], completions: Iterable[Any], **_: Any) -> List[float]:
        rewards: List[float] = []
        for completion in completions:
            theory = extract_theory(completion_to_text(completion))
            if not theory:
                rewards.append(float(args.invalid_structure_reward))
                continue
            if contains_forbidden_placeholder(theory):
                rewards.append(float(args.placeholder_reward))
                continue
            verification = verifier.verify(theory)
            scored = grpo_mvp_reward(
                verification,
                lambda_mvp=float(args.lambda_mvp),
                pass_reward=float(args.pass_reward),
                fail_reward=float(args.fail_reward),
            )
            rewards.append(float(scored["reward"]))
        return rewards

    return reward_func


def command_prepare(args: argparse.Namespace) -> None:
    rows = read_jsonl(args.train_jsonl)
    prepared = prepare_grpo_rows(
        rows,
        stage=args.stage,
        difficulty_order=bool(args.difficulty_order),
        limit=int(args.limit),
        isabelle_statement_field=args.isabelle_statement_field,
    )
    write_jsonl(args.output, prepared)
    print(json.dumps({"status": "ok", "count": len(prepared), "output": str(args.output)}, indent=2))


def command_reward_smoke(args: argparse.Namespace) -> None:
    reward_func = make_reward_function(args)
    completion = (
        "theory Smoke\n"
        "imports Main\n"
        "begin\n\n"
        "theorem smoke: \"True\"\n"
        "  by simp\n\n"
        "end\n"
    )
    rewards = reward_func(prompts=[""], completions=[completion])
    print(json.dumps({"status": "ok", "rewards": rewards}, indent=2, sort_keys=True))


def command_train(args: argparse.Namespace) -> None:
    if args.verifier_mode == "structural" and not args.allow_structural_training:
        raise SystemExit(
            "Refusing to train with the structural smoke-test verifier. "
            "Use --verifier_mode=pisa_http with a local PISA wrapper, or pass "
            "--allow_structural_training only for debugging."
        )
    try:
        from datasets import Dataset
        from transformers import AutoTokenizer
        from trl import GRPOConfig, GRPOTrainer
    except ImportError as exc:
        raise SystemExit(
            "GRPO training requires the optional ML stack: "
            "torch, transformers, accelerate, datasets, and trl."
        ) from exc
    if not args.model_name_or_path:
        raise SystemExit("--model_name_or_path is required for training")
    if int(args.num_generations) < 2:
        raise SystemExit("--num_generations must be >= 2 for GRPO")
    world_size = int(os.environ.get("WORLD_SIZE") or "1")
    effective_generation_batch = (
        int(args.per_device_train_batch_size) * int(args.gradient_accumulation_steps) * int(world_size)
    )
    if effective_generation_batch % int(args.num_generations) != 0:
        raise SystemExit(
            "Invalid GRPO batch config: "
            "(per_device_train_batch_size * gradient_accumulation_steps * world_size) must be divisible by "
            "num_generations."
        )

    tokenizer = AutoTokenizer.from_pretrained(
        args.model_name_or_path,
        trust_remote_code=True,
        local_files_only=bool(args.local_files_only),
    )
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    tokenizer.padding_side = "left"
    tokenizer.truncation_side = "left"
    rows = read_jsonl(args.train_jsonl)
    prepared = prepare_grpo_rows(
        rows,
        stage=args.stage,
        difficulty_order=bool(args.difficulty_order),
        limit=int(args.limit),
        isabelle_statement_field=args.isabelle_statement_field,
        tokenizer=tokenizer,
        max_prompt_length=int(args.max_prompt_length),
    )
    dataset = Dataset.from_list(prepared)
    model_init_kwargs = {
        "trust_remote_code": True,
        "local_files_only": bool(args.local_files_only),
        "device_map": None,
    }
    if args.attn_implementation:
        model_init_kwargs["attn_implementation"] = args.attn_implementation
    training_args = GRPOConfig(
        output_dir=str(args.output_dir),
        learning_rate=float(args.learning_rate),
        per_device_train_batch_size=int(args.per_device_train_batch_size),
        gradient_accumulation_steps=int(args.gradient_accumulation_steps),
        num_generations=int(args.num_generations),
        max_prompt_length=int(args.max_prompt_length),
        max_completion_length=int(args.max_completion_length),
        num_train_epochs=float(args.num_train_epochs),
        logging_steps=int(args.logging_steps),
        save_steps=int(args.save_steps),
        remove_unused_columns=False,
        model_init_kwargs=model_init_kwargs,
    )
    trainer = GRPOTrainer(
        model=args.model_name_or_path,
        args=training_args,
        train_dataset=dataset,
        reward_funcs=make_reward_function(args),
    )
    trainer.train()
    trainer.save_model(str(args.output_dir))


def add_common_args(parser: argparse.ArgumentParser) -> None:
    parser.add_argument("--stage", choices=["theory"], default="theory")
    parser.add_argument("--isabelle_statement_field", default="isabelle_statement")
    parser.add_argument("--difficulty_order", action="store_true")
    parser.add_argument("--limit", type=int, default=0)


def add_reward_args(parser: argparse.ArgumentParser) -> None:
    parser.add_argument("--verifier_mode", choices=["structural", "pisa_http"], default="structural")
    parser.add_argument("--pisa_endpoint", default="")
    parser.add_argument("--pisa_timeout", type=float, default=60.0)
    parser.add_argument("--lambda_mvp", type=float, default=0.2)
    parser.add_argument("--pass_reward", type=float, default=1.0)
    parser.add_argument("--fail_reward", type=float, default=0.0)
    parser.add_argument("--invalid_structure_reward", type=float, default=-0.1)
    parser.add_argument("--placeholder_reward", type=float, default=-0.1)


def main() -> None:
    parser = argparse.ArgumentParser()
    subparsers = parser.add_subparsers(dest="command", required=True)

    prepare = subparsers.add_parser("prepare")
    prepare.add_argument("--train_jsonl", type=Path, required=True)
    prepare.add_argument("--output", type=Path, required=True)
    add_common_args(prepare)
    prepare.set_defaults(func=command_prepare)

    reward_smoke = subparsers.add_parser("reward-smoke")
    add_reward_args(reward_smoke)
    reward_smoke.set_defaults(func=command_reward_smoke)

    train = subparsers.add_parser("train")
    train.add_argument("--train_jsonl", type=Path, required=True)
    train.add_argument("--model_name_or_path", default="")
    train.add_argument("--output_dir", type=Path, required=True)
    train.add_argument("--learning_rate", type=float, default=5e-6)
    train.add_argument("--per_device_train_batch_size", type=int, default=1)
    train.add_argument("--gradient_accumulation_steps", type=int, default=4)
    train.add_argument("--num_generations", type=int, default=4)
    train.add_argument("--max_prompt_length", type=int, default=2048)
    train.add_argument("--max_completion_length", type=int, default=1024)
    train.add_argument("--num_train_epochs", type=float, default=1.0)
    train.add_argument("--logging_steps", type=int, default=10)
    train.add_argument("--save_steps", type=int, default=100)
    train.add_argument("--local_files_only", action="store_true")
    train.add_argument("--attn_implementation", default="")
    train.add_argument("--allow_structural_training", action="store_true")
    add_common_args(train)
    add_reward_args(train)
    train.set_defaults(func=command_train)

    args = parser.parse_args()
    args.func(args)


if __name__ == "__main__":
    main()
