#!/usr/bin/env python3
"""Train and evaluate LeanCheck variants."""

from __future__ import annotations

import argparse
import csv
import json
import random
from pathlib import Path
from typing import Dict, List

import torch
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm

from leancheck_data import generate_counterfactual, generate_examples, generate_minimal_pairs, write_jsonl
from leancheck_eval import (
    activation_patching_diagnostic,
    counterfactual_metrics,
    evaluate_loader,
    minimal_pair_flip_rate,
    predict_labels,
)
from leancheck_model import LABEL_TO_ID, LeanCheckModel, SimpleTokenizer


DEFAULT_VARIANTS = [
    "lm_only",
    "no_consistency_loss",
    "rationale_only",
    "full_sequence",
    "proof_only",
    "random_consistency",
    "wrong_span",
]


def parse_layers(spec: str, model_name: str) -> List[int]:
    if spec == "default":
        return [0, 3, 6, 9, 11] if model_name != "tiny-local" else [0, 1]
    return [int(part) for part in spec.split(",") if part.strip()]


class LeanCheckDataset(Dataset):
    def __init__(self, rows: List[Dict[str, object]], tokenizer, max_len: int, random_labels: bool = False, seed: int = 0):
        self.rows = rows
        self.tokenizer = tokenizer
        self.max_len = max_len
        rng = random.Random(seed)
        self.cls_labels = []
        for row in rows:
            label = LABEL_TO_ID[str(row["label"])]
            if random_labels:
                label = rng.randint(0, 1)
            self.cls_labels.append(label)

    def __len__(self) -> int:
        return len(self.rows)

    def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
        if isinstance(self.tokenizer, SimpleTokenizer):
            enc = self.tokenizer.encode(str(self.rows[idx]["text"]), max_length=self.max_len)
        else:
            enc = self.tokenizer(
                str(self.rows[idx]["text"]),
                padding="max_length",
                truncation=True,
                max_length=self.max_len,
                return_tensors="pt",
            )
            enc = {k: v.squeeze(0) for k, v in enc.items()}
        input_ids = enc["input_ids"].long()
        attention_mask = enc["attention_mask"].long()
        lm_labels = input_ids.clone()
        lm_labels[attention_mask == 0] = -100
        return {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "lm_labels": lm_labels,
            "cls_labels": torch.tensor(self.cls_labels[idx], dtype=torch.long),
        }


def read_jsonl(path: Path) -> List[Dict[str, object]]:
    return [json.loads(line) for line in path.read_text(encoding="utf-8").splitlines() if line.strip()]


def ensure_data(args) -> Dict[str, List[Dict[str, object]]]:
    data_dir = Path(args.data_dir)
    paths = {
        "train": data_dir / "leancheck_train.jsonl",
        "eval": data_dir / "leancheck_eval.jsonl",
        "counterfactual": data_dir / "leancheck_counterfactual.jsonl",
        "minimal_pairs": data_dir / "leancheck_minimal_pairs.jsonl",
    }
    if args.regenerate_data or not all(p.exists() for p in paths.values()):
        train = generate_examples(args.train_samples, args.seed, args.use_lean)
        eval_rows = generate_examples(args.eval_samples, args.seed + 1, args.use_lean)
        cf = generate_counterfactual(args.counterfactual_samples, args.seed + 2, args.use_lean)
        mp = generate_minimal_pairs(args.seed + 3, args.use_lean, args.minimal_pair_samples)
        write_jsonl(paths["train"], train)
        write_jsonl(paths["eval"], eval_rows)
        write_jsonl(paths["counterfactual"], cf)
        write_jsonl(paths["minimal_pairs"], mp)
    return {k: read_jsonl(p) for k, p in paths.items()}


def build_tokenizer(model_name: str, rows: Dict[str, List[Dict[str, object]]]):
    if model_name == "tiny-local":
        tok = SimpleTokenizer()
        tok.fit(row["text"] for split in rows.values() for row in split)
        return tok
    from transformers import AutoTokenizer

    tok = AutoTokenizer.from_pretrained(model_name)
    tok.add_special_tokens({"additional_special_tokens": [
        "[BOS]", "[EOS]", "[THEOREM]", "[PROOF]", "[RAT]", "[CLAIM]", "[VERIFIES]", "[FAILS]"
    ], "pad_token": "[PAD]"})
    return tok


def make_loader(rows, tokenizer, args, shuffle=False, random_labels=False):
    ds = LeanCheckDataset(rows, tokenizer, args.max_seq_len, random_labels=random_labels, seed=args.seed)
    return DataLoader(ds, batch_size=args.batch_size, shuffle=shuffle)


def train_one_variant(variant: str, rows, tokenizer, args, device: str) -> Dict[str, float]:
    random_labels = variant == "random_consistency"
    train_loader = make_loader(rows["train"], tokenizer, args, shuffle=True, random_labels=random_labels)
    eval_loader = make_loader(rows["eval"], tokenizer, args)
    cf_loader = make_loader(rows["counterfactual"], tokenizer, args)
    mp_loader = make_loader(rows["minimal_pairs"], tokenizer, args)
    model = LeanCheckModel(
        args.model_name,
        tokenizer,
        variant=variant,
        hidden_size=args.tiny_hidden_size,
        tiny_layers=args.tiny_layers,
    ).to(device)
    opt = torch.optim.AdamW(model.parameters(), lr=args.lr)
    lambda_cons = 0.0 if variant in {"lm_only", "no_consistency_loss"} else args.lambda_cons

    for epoch in range(args.epochs):
        model.train()
        pbar = tqdm(train_loader, desc=f"{variant} epoch {epoch+1}/{args.epochs}", leave=False)
        for batch in pbar:
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            lm_labels = batch["lm_labels"].to(device)
            cls_labels = batch["cls_labels"].to(device)
            out = model(input_ids, attention_mask, lm_labels, cls_labels, tokenizer)
            loss = out["lm_loss"] + lambda_cons * out["cons_loss"]
            opt.zero_grad(set_to_none=True)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            opt.step()
            pbar.set_postfix(loss=float(loss.detach().cpu()))

    metrics = evaluate_loader(model, eval_loader, tokenizer, device)
    cf_preds = predict_labels(model, cf_loader, tokenizer, device)
    metrics.update(counterfactual_metrics(cf_preds, rows["counterfactual"]))
    mp_preds = predict_labels(model, mp_loader, tokenizer, device)
    metrics.update(minimal_pair_flip_rate(mp_preds, rows["minimal_pairs"]))
    if args.activation_patching:
        patch_csv = Path(args.output_dir) / variant / "activation_patching.csv"
        metrics.update(activation_patching_diagnostic(
            model=model,
            tokenizer=tokenizer,
            rows=rows["minimal_pairs"],
            device=device,
            max_len=args.max_seq_len,
            layers=parse_layers(args.patch_layers, args.model_name),
            samples=args.patch_samples,
            output_csv=patch_csv,
            seed=args.seed,
        ))
    else:
        metrics.update({
            "patch_rationale_effect": "",
            "patch_theorem_effect": "",
            "patch_random_effect": "",
            "patch_rationale_minus_random": "",
            "head_patch_rationale_effect": "",
            "head_patch_theorem_effect": "",
            "head_patch_random_effect": "",
            "head_patch_rationale_minus_random": "",
        })
    if variant == "random_consistency":
        metrics["shuffled_cls_acc"] = metrics["cls_claim_acc"]
    else:
        metrics["shuffled_cls_acc"] = ""
    if variant == "wrong_span":
        metrics["wrong_span_cls_acc"] = metrics["cls_claim_acc"]
    else:
        metrics["wrong_span_cls_acc"] = ""
    metrics["variant"] = variant
    metrics["lambda_cons"] = lambda_cons
    metrics["train_examples"] = len(rows["train"])
    metrics["eval_examples"] = len(rows["eval"])
    metrics["counterfactual_examples"] = len(rows["counterfactual"])
    metrics["model_name"] = args.model_name

    out_dir = Path(args.output_dir) / variant
    out_dir.mkdir(parents=True, exist_ok=True)
    if args.save_models:
        torch.save(model.state_dict(), out_dir / "model.pt")
        tokenizer.save_pretrained(str(out_dir / "tokenizer"))
    return metrics


def write_results_csv(path: Path, rows: List[Dict[str, object]]) -> None:
    path.parent.mkdir(parents=True, exist_ok=True)
    fields = [
        "variant", "model_name", "train_examples", "eval_examples", "counterfactual_examples",
        "lambda_cons", "final_lm_loss", "final_cons_loss", "gen_claim_acc", "cls_claim_acc",
        "cfact_cls_follows_swap", "cfact_cls_follows_orig", "minimal_pair_flip_acc",
        "shuffled_cls_acc", "wrong_span_cls_acc",
        "patch_rationale_effect", "patch_theorem_effect", "patch_random_effect", "patch_rationale_minus_random",
        "head_patch_rationale_effect", "head_patch_theorem_effect", "head_patch_random_effect", "head_patch_rationale_minus_random",
    ]
    with path.open("w", newline="", encoding="utf-8") as f:
        w = csv.DictWriter(f, fieldnames=fields)
        w.writeheader()
        for row in rows:
            w.writerow({k: row.get(k, "") for k in fields})


def write_markdown_report(path: Path, results: List[Dict[str, object]], rows) -> None:
    def fmt(x):
        return f"{x:.3f}" if isinstance(x, float) else str(x)

    lines = [
        "# LeanCheck: Coupling Informal Rationales to Formal Proof-Checker Outcomes",
        "",
        "## Motivation",
        "LeanCheck isolates verifier-coupled reasoning from open-ended proof search. The model consumes a Lean theorem, a candidate proof, and a natural-language rationale, while labels come from Lean when available or from deterministic checker-derived templates in fallback mode.",
        "",
        "## Method",
        "Sequences use `[THEOREM]`, `[PROOF]`, `[RAT]`, and `[CLAIM]` sections. A causal LM is trained with next-token loss, and a linear consistency head predicts VERIFIES/FAILS from pooled hidden states over a variant-specific span.",
        "",
        "## Dataset Construction",
        f"Generated {len(rows['train'])} train examples, {len(rows['eval'])} eval examples, {len(rows['counterfactual'])} counterfactual swaps, and {len(rows['minimal_pairs'])} minimal-pair rows.",
        "Domains include natural-number equalities, propositional logic, and simple list lemmas.",
        "",
        "## Mutation Families",
        "Wrong lemma, wrong theorem/proof pairing, missing premise, deleted proof line, renamed variable, replacement tactic, and adversarial near-miss mutations are included.",
        "",
        "## Results",
        "| variant | cls_claim_acc | gen_claim_acc | cons_loss | cfact_follows_swap | cfact_follows_orig | minimal_pair_flip |",
        "|---|---:|---:|---:|---:|---:|---:|",
    ]
    for r in results:
        lines.append(
            f"| {r['variant']} | {fmt(r['cls_claim_acc'])} | {fmt(r['gen_claim_acc'])} | "
            f"{fmt(r['final_cons_loss'])} | {fmt(r['cfact_cls_follows_swap'])} | "
            f"{fmt(r['cfact_cls_follows_orig'])} | {fmt(r['minimal_pair_flip_acc'])} |"
        )
    lines += [
        "",
        "## Activation Patching",
        "When enabled, activation patching takes accepted/rejected minimal pairs and patches source hidden states into a base example at selected GPT-2 layers. LM patch effects are shifts in the source label token logit at the claim position; head patch effects are shifts in the source class logit after pooling patched final hidden states through the consistency head.",
        "| variant | lm_rat | lm_theorem | lm_random | lm_rat_minus_random | head_rat | head_theorem | head_random | head_rat_minus_random |",
        "|---|---:|---:|---:|---:|---:|---:|---:|---:|",
    ]
    for r in results:
        lines.append(
            f"| {r['variant']} | {fmt(r.get('patch_rationale_effect', ''))} | "
            f"{fmt(r.get('patch_theorem_effect', ''))} | {fmt(r.get('patch_random_effect', ''))} | "
            f"{fmt(r.get('patch_rationale_minus_random', ''))} | "
            f"{fmt(r.get('head_patch_rationale_effect', ''))} | {fmt(r.get('head_patch_theorem_effect', ''))} | "
            f"{fmt(r.get('head_patch_random_effect', ''))} | {fmt(r.get('head_patch_rationale_minus_random', ''))} |"
        )
    lines += [
        "",
        "## Interpretation",
        "LeanCheck instantiates verifier-coupled reasoning with a formal proof checker. The goal is not open-ended proof synthesis, but measuring whether informal rationale representations encode formal verifier outcomes. Consistency-trained rationale spans become substantially more predictive of Lean accept/reject labels than untrained, random-label, or wrong-span controls, suggesting that natural-language explanations can be coupled to programmatic verification signals.",
        "",
        "This sentence should be used only when the table supports it; otherwise treat this run as a smoke-test validation of the pipeline rather than a paper claim.",
        "",
        "## Limitations",
        "Templated rationales may make the task easier. Binary accept/reject is simpler than proof synthesis. Consistency-head accuracy proves decodability, not causal use. Activation patching or RL is needed to show stronger causal faithfulness. The dataset mostly covers simple Lean examples unless expanded.",
        "",
        "## Recommended Next Steps",
        "Run the full GPT-2 configuration on GPU, enable a real Lean 4 checker for every generated example, expand mutation coverage with multiclass error labels, and add activation patching for causal diagnostics.",
    ]
    path.write_text("\n".join(lines) + "\n", encoding="utf-8")


def main() -> None:
    ap = argparse.ArgumentParser()
    ap.add_argument("--model-name", default="tiny-local", help="Use gpt2 for Hugging Face GPT-2, tiny-local for offline smoke.")
    ap.add_argument("--data-dir", default="data")
    ap.add_argument("--output-dir", default="leancheck_runs")
    ap.add_argument("--results-csv", default="results_leancheck.csv")
    ap.add_argument("--results-md", default="results_leancheck.md")
    ap.add_argument("--variants", nargs="+", default=DEFAULT_VARIANTS)
    ap.add_argument("--train-samples", type=int, default=200)
    ap.add_argument("--eval-samples", type=int, default=50)
    ap.add_argument("--counterfactual-samples", type=int, default=50)
    ap.add_argument("--minimal-pair-samples", type=int, default=0, help="Number of accepted/rejected minimal pairs to generate; 0 uses one per template.")
    ap.add_argument("--max-seq-len", type=int, default=192)
    ap.add_argument("--batch-size", type=int, default=16)
    ap.add_argument("--epochs", type=int, default=2)
    ap.add_argument("--lr", type=float, default=5e-4)
    ap.add_argument("--lambda-cons", type=float, default=0.5)
    ap.add_argument("--seed", type=int, default=42)
    ap.add_argument("--tiny-hidden-size", type=int, default=96)
    ap.add_argument("--tiny-layers", type=int, default=2)
    ap.add_argument("--use-lean", action="store_true")
    ap.add_argument("--regenerate-data", action="store_true")
    ap.add_argument("--save-models", action="store_true")
    ap.add_argument("--activation-patching", action="store_true", help="Run GPT-2 activation patching on minimal pairs after each variant.")
    ap.add_argument("--patch-layers", default="default", help="Comma-separated GPT-2 layer ids, or 'default'.")
    ap.add_argument("--patch-samples", type=int, default=24, help="Number of directed accepted/rejected examples to patch.")
    args = ap.parse_args()

    random.seed(args.seed)
    torch.manual_seed(args.seed)
    device = "cuda" if torch.cuda.is_available() else "cpu"
    rows = ensure_data(args)
    tokenizer = build_tokenizer(args.model_name, rows)
    results = []
    for variant in args.variants:
        print(f"\n=== Running {variant} on {device} ===")
        results.append(train_one_variant(variant, rows, tokenizer, args, device))
        write_results_csv(Path(args.results_csv), results)
        write_markdown_report(Path(args.results_md), results, rows)
    print(json.dumps(results, indent=2))


if __name__ == "__main__":
    main()
