#!/usr/bin/env python3
"""Dataset generation for LeanCheck.

The generator prefers a real Lean 4 executable when one is available, but it
also carries deterministic fallback labels for known accepted examples and
their mutation families. This keeps the experiment runnable in stripped-down
environments while preserving the checker API boundary.
"""

from __future__ import annotations

import argparse
import json
import random
import re
import shutil
import subprocess
import tempfile
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, Iterable, List, Optional, Tuple


SPECIAL_TOKENS = [
    "[BOS]",
    "[EOS]",
    "[PAD]",
    "[THEOREM]",
    "[PROOF]",
    "[RAT]",
    "[CLAIM]",
    "[VERIFIES]",
    "[FAILS]",
]


@dataclass(frozen=True)
class LeanTemplate:
    theorem: str
    proof: str
    domain: str
    rationale_pos: str
    mutations: Tuple[Tuple[str, str, str], ...]


TEMPLATES: Tuple[LeanTemplate, ...] = (
    LeanTemplate(
        theorem="example (a b : Nat) : a + b = b + a := by",
        proof="  exact Nat.add_comm a b",
        domain="nat_add_comm",
        rationale_pos="The target is commutativity of addition over natural numbers, and Nat.add_comm directly proves it.",
        mutations=(
            ("wrong_lemma", "  exact Nat.mul_comm a b", "The proof uses a multiplication commutativity lemma for an addition goal, so the term does not match."),
            ("wrong_pairing", "  exact Nat.add_assoc a b b", "The proof term proves an associativity-shaped equality rather than the theorem target."),
            ("delete_final_line", "", "The proof body is missing the final step, so Lean is left with an unsolved goal."),
            ("rename_variable", "  exact Nat.add_comm a c", "The proof refers to c, but that variable is not in scope."),
            ("replace_tactic", "  exact rfl", "Reflexivity cannot prove a + b = b + a for arbitrary natural numbers."),
        ),
    ),
    LeanTemplate(
        theorem="example (a : Nat) : a + 0 = a := by",
        proof="  exact Nat.add_zero a",
        domain="nat_add_zero",
        rationale_pos="The theorem states addition by zero on the right, which Nat.add_zero proves for the variable.",
        mutations=(
            ("wrong_lemma", "  exact Nat.zero_add a", "The lemma zero_add has the zero on the left, so it does not match the stated goal."),
            ("wrong_pairing", "  exact Nat.succ_ne_zero a", "A nonzero successor lemma has the wrong proposition for this equality target."),
            ("delete_final_line", "", "The proof leaves the equality goal unsolved."),
            ("rename_variable", "  exact Nat.add_zero b", "The proof uses b even though only a is bound in the theorem."),
            ("replace_tactic", "  exact Nat.add_comm a 0", "The commutativity lemma proves a + 0 = 0 + a, not directly a + 0 = a."),
        ),
    ),
    LeanTemplate(
        theorem="example (p q : Prop) (hp : p) : p := by",
        proof="  exact hp",
        domain="prop_exact",
        rationale_pos="The hypothesis hp has exactly the proposition p required by the goal, so exact hp closes it.",
        mutations=(
            ("missing_premise", "  exact hp", "The proof depends on hp, and without the premise the identifier is unavailable."),
            ("rename_variable", "  exact hq", "The proof names hq, but no such hypothesis exists in the context."),
            ("replace_tactic", "  exact q", "The proof tries to use the proposition q as a proof term, which is not a hypothesis."),
            ("delete_final_line", "", "No tactic is supplied to close the proposition goal."),
            ("adversarial_near_miss", "  exact And.intro hp hp", "The proof constructs a conjunction, but the target asks only for p."),
        ),
    ),
    LeanTemplate(
        theorem="example (p q : Prop) (hp : p) (hq : q) : p ∧ q := by",
        proof="  exact And.intro hp hq",
        domain="prop_and",
        rationale_pos="The target is a conjunction, and And.intro combines proofs of the left and right conjuncts.",
        mutations=(
            ("wrong_lemma", "  exact Or.inl hp", "The proof constructs a disjunction while the theorem requires a conjunction."),
            ("wrong_pairing", "  exact And.intro hq hp", "The proof supplies the conjuncts in the wrong order for p and q."),
            ("rename_variable", "  exact And.intro hp hr", "The proof mentions hr, which is not bound in the theorem."),
            ("delete_final_line", "", "The conjunction goal remains unsolved because the proof line is missing."),
            ("replace_tactic", "  exact hp", "A proof of p alone does not prove the conjunction p and q."),
        ),
    ),
    LeanTemplate(
        theorem="example (xs : List Nat) : xs ++ [] = xs := by",
        proof="  exact List.append_nil xs",
        domain="list_append_nil",
        rationale_pos="The theorem is the standard right-identity law for list append, and List.append_nil proves it.",
        mutations=(
            ("wrong_lemma", "  exact List.nil_append xs", "The proof uses the left-identity lemma, but the theorem needs the right-identity lemma."),
            ("rename_variable", "  exact List.append_nil ys", "The proof refers to ys, which is not in the theorem context."),
            ("delete_final_line", "", "The list equality goal is left open."),
            ("replace_tactic", "  exact rfl", "The append-by-empty equality is not closed by reflexivity for an arbitrary list."),
            ("adversarial_near_miss", "  exact List.append_assoc xs [] []", "An associativity lemma has a plausible list shape but the wrong equality."),
        ),
    ),
    LeanTemplate(
        theorem="example (n : Nat) : n = n := by",
        proof="  rfl",
        domain="rfl",
        rationale_pos="The goal is reflexive equality, so rfl proves it immediately.",
        mutations=(
            ("wrong_lemma", "  exact Nat.succ_ne_zero n", "The proof gives an inequality-like fact about successors, not n = n."),
            ("rename_variable", "  exact h", "The proof uses h, but no hypothesis with that name exists."),
            ("delete_final_line", "", "The equality goal is never discharged."),
            ("replace_tactic", "  exact Nat.zero_ne_succ n", "The lemma has an incompatible not-equal-to-successor proposition."),
            ("adversarial_near_miss", "  exact Eq.symm (Nat.add_zero n)", "The proof term has a nearby equality but not the reflexive target."),
        ),
    ),
)


NEG_RATIONALE_PARAPHRASES: Dict[str, Tuple[str, ...]] = {
    "wrong_lemma": (
        "A nearby lemma is used, but its statement does not match the formal target, so Lean should reject it.",
        "The cited theorem proves a different shape of fact than the goal requires.",
        "The proof appeals to a theorem from the wrong arithmetic or logical pattern, leaving the target unproved.",
        "Lean cannot coerce this lemma into the requested equality or proposition.",
        "Although the lemma name is plausible, its conclusion is not the goal currently in scope.",
    ),
    "wrong_pairing": (
        "The proof may be valid elsewhere, but it is paired with a theorem whose target is different.",
        "The proof object does not inhabit the proposition stated in this theorem.",
        "This candidate proof solves a neighboring theorem, not the theorem shown here.",
        "The theorem and proof have been mismatched, so the final proof term has the wrong type.",
        "Lean checks the proof against the displayed goal and finds that the propositions do not align.",
    ),
    "missing_premise": (
        "A hypothesis needed by the proof has been removed from the theorem context.",
        "The candidate relies on a premise that Lean cannot find in scope.",
        "The proof names an assumption that is no longer available in the local context.",
        "Without the required premise, the exact proof term cannot be elaborated.",
        "The script depends on context that the theorem statement does not provide.",
    ),
    "wrong_rewrite_direction": (
        "The rewrite step points the equality in an incompatible direction.",
        "The proposed rewrite does not transform the goal into the required form.",
        "The equality is being used in a direction that does not match the current target.",
        "The rewrite looks relevant, but it moves the expression away from the goal.",
        "Lean cannot apply this rewrite to close the stated theorem.",
    ),
    "delete_final_line": (
        "The proof script stops before closing all goals, so the checker should fail.",
        "A final tactic is missing and an unsolved goal remains.",
        "The candidate opens the proof block but never supplies the closing argument.",
        "Lean still has a target after the script ends.",
        "The proof body is incomplete, so acceptance would require an omitted final step.",
    ),
    "rename_variable": (
        "The proof refers to a variable that is not in scope, causing an unknown identifier error.",
        "A bound name was changed in the proof without changing the theorem context.",
        "The proof mentions a name absent from the theorem's local variables.",
        "Lean rejects the candidate because one of the identifiers cannot be resolved.",
        "The argument list contains a variable name that was never introduced.",
    ),
    "replace_tactic": (
        "The replacement tactic is too weak or has the wrong proof term for this goal.",
        "A plausible tactic was substituted, but it cannot close the formal target.",
        "The tactic succeeds only on simpler goals, not on this theorem.",
        "The proof command has the right flavor but does not produce a term of the required type.",
        "Lean needs a more specific proof step than the tactic supplied here.",
    ),
    "adversarial_near_miss": (
        "The proof looks related to the theorem but proves a subtly different proposition.",
        "This is a near miss: the syntax is plausible, but the formal statement is not satisfied.",
        "The candidate resembles a valid proof but its conclusion is not exactly the requested theorem.",
        "The proof is semantically close enough to look convincing, yet Lean checks exact types.",
        "A small mismatch in the target proposition makes the otherwise plausible proof fail.",
    ),
}


POS_RATIONALE_PARAPHRASES: Dict[str, Tuple[str, ...]] = {
    "nat_add_comm": (
        "The goal asks for addition commutativity, and Nat.add_comm returns exactly that equality for a and b.",
        "The proof term is the standard commutativity theorem for natural-number addition with the displayed variables.",
        "Lean can close the equality because Nat.add_comm specializes to the theorem target.",
        "The lemma and goal have matching sides, so exact Nat.add_comm a b verifies.",
    ),
    "nat_add_zero": (
        "The theorem is right addition by zero, and Nat.add_zero gives precisely n + 0 = n.",
        "The proof specializes Nat.add_zero to the bound variable, matching the target equality.",
        "Lean accepts because the zero-addition identity is exactly the requested result.",
        "The candidate proof has the same proposition as the theorem after instantiating the variable.",
    ),
    "prop_exact": (
        "The context already contains hp as a proof of p, which is exactly the goal.",
        "The exact proof term matches the target proposition without further reasoning.",
        "Lean accepts because hp inhabits p in the local context.",
        "The proof simply returns the available hypothesis whose type is the goal.",
    ),
    "prop_and": (
        "The target is a conjunction, and And.intro combines the available proofs of p and q in order.",
        "Lean checks hp for the left conjunct and hq for the right conjunct, so the proof closes.",
        "The proof constructs exactly the conjunction demanded by the theorem.",
        "Both premises are in scope and are supplied to the conjunction introduction rule.",
    ),
    "list_append_nil": (
        "The theorem is the right identity law for list append, which List.append_nil proves for xs.",
        "Lean accepts because List.append_nil specializes to xs ++ [] = xs.",
        "The proof invokes the standard list lemma whose conclusion is the displayed equality.",
        "The list append goal matches the library theorem after instantiating xs.",
    ),
    "rfl": (
        "The goal is a reflexive equality, so rfl directly constructs the proof.",
        "Lean accepts because both sides of the equality are definitionally the same.",
        "The proof is immediate: reflexivity closes n = n.",
        "No rewriting is needed because the target is exactly self-equality.",
    ),
}


GENERIC_POS_RATIONALES: Tuple[str, ...] = (
    "The proof applies a term whose type matches the theorem target, so Lean should accept it.",
    "Each identifier used by the proof is in scope, and the final expression has exactly the required proposition.",
    "After specializing the referenced lemma to the theorem variables, the proof closes the goal.",
    "The candidate proof and theorem target align at the type-checker level.",
)


def make_full_code(theorem: str, proof: str) -> str:
    if proof.strip():
        return f"{theorem}\n{proof}\n"
    return f"{theorem}\n"


def lean_check(theorem: str, proof: str, timeout: float = 5.0) -> Tuple[str, str]:
    """Return (binary_label, raw_output) from Lean 4 if installed."""
    lean = shutil.which("lean")
    if not lean:
        return "UNKNOWN", "lean executable not found"
    code = make_full_code(theorem, proof)
    with tempfile.TemporaryDirectory() as td:
        path = Path(td) / "LeanCheckTmp.lean"
        path.write_text(code, encoding="utf-8")
        try:
            proc = subprocess.run(
                [lean, str(path)],
                text=True,
                capture_output=True,
                timeout=timeout,
            )
        except subprocess.TimeoutExpired as exc:
            return "FAILS", f"TIMEOUT: {exc}"
    raw = (proc.stdout + "\n" + proc.stderr).strip()
    return ("VERIFIES" if proc.returncode == 0 else "FAILS"), raw


def claim_token(label: str) -> str:
    return "[VERIFIES]" if label == "VERIFIES" else "[FAILS]"


def sequence_for(theorem: str, proof: str, rationale: str, label: str) -> str:
    return (
        f"[BOS]\n[THEOREM]\n{theorem}\n\n[PROOF]\n{proof}\n\n"
        f"[RAT]\n{rationale}\n\n[CLAIM]\n{claim_token(label)}\n[EOS]"
    )


def mutate_theorem_for_missing_premise(theorem: str) -> str:
    return re.sub(r"\s*\(hp\s*:\s*p\)", "", theorem)


def build_candidate(
    tmpl: LeanTemplate,
    rng: random.Random,
    positive: bool,
    verify_with_lean: bool,
) -> Dict[str, object]:
    if positive:
        theorem = tmpl.theorem
        proof = tmpl.proof
        mutation = "accepted"
        rationale_choices = POS_RATIONALE_PARAPHRASES.get(tmpl.domain, ()) + GENERIC_POS_RATIONALES
        rationale = rng.choice((tmpl.rationale_pos,) + rationale_choices)
        fallback_label = "VERIFIES"
    else:
        mutation, proof, rationale = rng.choice(tmpl.mutations)
        theorem = tmpl.theorem
        if mutation == "missing_premise":
            theorem = mutate_theorem_for_missing_premise(theorem)
        paras = NEG_RATIONALE_PARAPHRASES.get(mutation, ())
        if paras and rng.random() < 0.55:
            rationale = rng.choice(paras)
        fallback_label = "FAILS"

    raw = "fallback label from known template/mutation"
    label = fallback_label
    if verify_with_lean:
        checked_label, raw = lean_check(theorem, proof)
        if checked_label != "UNKNOWN":
            label = checked_label

    return {
        "theorem": theorem,
        "proof": proof,
        "rationale": rationale,
        "label": label,
        "fallback_label": fallback_label,
        "mutation": mutation,
        "domain": tmpl.domain,
        "lean_output": raw,
        "used_lean": verify_with_lean and raw != "lean executable not found",
        "text": sequence_for(theorem, proof, rationale, label),
    }


def generate_examples(n: int, seed: int, verify_with_lean: bool = False) -> List[Dict[str, object]]:
    rng = random.Random(seed)
    rows: List[Dict[str, object]] = []
    for i in range(n):
        tmpl = rng.choice(TEMPLATES)
        positive = (i % 2 == 0)
        rows.append(build_candidate(tmpl, rng, positive, verify_with_lean))
    rng.shuffle(rows)
    return rows


def generate_counterfactual(n: int, seed: int, verify_with_lean: bool = False) -> List[Dict[str, object]]:
    rng = random.Random(seed)
    rows: List[Dict[str, object]] = []
    while len(rows) < n:
        tmpl = rng.choice(TEMPLATES)
        pos = build_candidate(tmpl, rng, True, verify_with_lean)
        neg = build_candidate(tmpl, rng, False, verify_with_lean)
        rows.append({
            **pos,
            "text": sequence_for(pos["theorem"], pos["proof"], neg["rationale"], neg["label"]),
            "label": neg["label"],
            "proof_label": pos["label"],
            "rationale_label": neg["label"],
            "swap_type": "proof_accept_rationale_fail",
        })
        if len(rows) >= n:
            break
        rows.append({
            **neg,
            "text": sequence_for(neg["theorem"], neg["proof"], pos["rationale"], pos["label"]),
            "label": pos["label"],
            "proof_label": neg["label"],
            "rationale_label": pos["label"],
            "swap_type": "proof_fail_rationale_accept",
        })
    return rows[:n]


def generate_minimal_pairs(seed: int, verify_with_lean: bool = False, pair_count: int = 0) -> List[Dict[str, object]]:
    rows: List[Dict[str, object]] = []
    pair_id = 0
    rng = random.Random(seed)
    target_pairs = pair_count if pair_count > 0 else len(TEMPLATES)
    for i in range(target_pairs):
        tmpl = TEMPLATES[i % len(TEMPLATES)]
        pos = build_candidate(tmpl, rng, True, verify_with_lean)
        neg_mut = next((m for m in tmpl.mutations if m[0] in {"wrong_lemma", "rename_variable", "replace_tactic"}), tmpl.mutations[0])
        mutation, proof, rationale = neg_mut
        theorem = tmpl.theorem if mutation != "missing_premise" else mutate_theorem_for_missing_premise(tmpl.theorem)
        neg_label = "FAILS"
        raw = "fallback minimal-pair label"
        if verify_with_lean:
            checked, raw = lean_check(theorem, proof)
            if checked != "UNKNOWN":
                neg_label = checked
        pos["pair_id"] = pair_id
        pos["pair_role"] = "accepted"
        neg = {
            "theorem": theorem,
            "proof": proof,
            "rationale": rationale,
            "label": neg_label,
            "fallback_label": "FAILS",
            "mutation": mutation,
            "domain": tmpl.domain,
            "lean_output": raw,
            "used_lean": verify_with_lean and raw != "lean executable not found",
            "text": sequence_for(theorem, proof, rationale, neg_label),
            "pair_id": pair_id,
            "pair_role": "rejected",
        }
        rows.extend([pos, neg])
        pair_id += 1
    return rows


def write_jsonl(path: Path, rows: Iterable[Dict[str, object]]) -> None:
    path.parent.mkdir(parents=True, exist_ok=True)
    with path.open("w", encoding="utf-8") as f:
        for row in rows:
            f.write(json.dumps(row, ensure_ascii=False) + "\n")


def main() -> None:
    ap = argparse.ArgumentParser()
    ap.add_argument("--out-dir", default="data")
    ap.add_argument("--train-samples", type=int, default=2000)
    ap.add_argument("--eval-samples", type=int, default=500)
    ap.add_argument("--counterfactual-samples", type=int, default=200)
    ap.add_argument("--minimal-pair-samples", type=int, default=0, help="Number of accepted/rejected minimal pairs to generate; 0 uses one per template.")
    ap.add_argument("--seed", type=int, default=42)
    ap.add_argument("--use-lean", action="store_true")
    args = ap.parse_args()

    out = Path(args.out_dir)
    train = generate_examples(args.train_samples, args.seed, args.use_lean)
    eval_rows = generate_examples(args.eval_samples, args.seed + 1, args.use_lean)
    cf = generate_counterfactual(args.counterfactual_samples, args.seed + 2, args.use_lean)
    mp = generate_minimal_pairs(args.seed + 3, args.use_lean, args.minimal_pair_samples)
    write_jsonl(out / "leancheck_train.jsonl", train)
    write_jsonl(out / "leancheck_eval.jsonl", eval_rows)
    write_jsonl(out / "leancheck_counterfactual.jsonl", cf)
    write_jsonl(out / "leancheck_minimal_pairs.jsonl", mp)
    used = sum(1 for r in train + eval_rows + cf + mp if r.get("used_lean"))
    print(json.dumps({
        "train": len(train),
        "eval": len(eval_rows),
        "counterfactual": len(cf),
        "minimal_pairs": len(mp),
        "lean_checked_rows": used,
        "lean_available": shutil.which("lean") is not None,
    }, indent=2))


if __name__ == "__main__":
    main()
