"""§7 recovery-tradeoff k-sensitivity analysis.

For each k in {2, 3, 5, 7} and each of {Base, SFT, Vanilla GSPO, Novelty bonus},
on classified OlymMATH-Hard traces compute:

- k-attempt: trace position i such that s[i] is HYPOTHESIZE/BACKTRACK and
  s[i-k:i] are all in EXPLOIT = {COMPUTE, CHECK, SETUP}.
- k-restart: a k-attempt at i is a k-restart iff there exists j > i with
  s[j-k+1:j+1] all in EXPLOIT and the segment s[i+1:j-k+1] contains no
  HYPOTHESIZE/BACKTRACK.

Per checkpoint × k:
- pct_with_k_attempt = n_traces_with_attempt / N
- pct_k_attempts_that_restart = n_restarts / total_attempts (with bootstrap CI)
- pct_with_k_success = n_traces_with_at_least_one_restart / N

Outputs (per spec `reports/neurips/tasks/recovery_k_sensitivity.md`):
- results/exploration_analysis/recovery_k_sensitivity.csv
- writing/neurips_paper/figures/fig_recovery_k_sensitivity.png
- reports/neurips/recovery_k_sensitivity.md (with embedded figure)
"""

import json
import os
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

ROOT = Path("results/exploration_analysis")
OUT_CSV = ROOT / "recovery_k_sensitivity.csv"
OUT_FIG = Path(os.environ.get("FIG_OUT_DIR", "writing/neurips_paper/figures")) / "fig_recovery_k_sensitivity.png"
OUT_MD = Path("reports/neurips/recovery_k_sensitivity.md")

EXPLOIT = {"COMPUTE", "CHECK", "SETUP"}
RECOVERY = {"HYPOTHESIZE", "BACKTRACK"}
K_VALUES = [2, 3, 5, 7]

PATHS = {
    "Base":             ROOT / "v90_base_math/trace_level_metrics.parquet",
    "SFT":              ROOT / "v90_sft_math/trace_level_metrics.parquet",
    "Vanilla GSPO":     ROOT / "v90_gspo_math/trace_level_metrics.parquet",
    "Novelty bonus": ROOT / "v90_prod_s15_math/trace_level_metrics.parquet",
}

PALETTE = {
    "Base":             ("#cfd8dc", "#607d8b"),
    "SFT":              ("#c8e6c9", "#2e7d32"),
    "Vanilla GSPO":     ("#bbdefb", "#1565c0"),
    "Novelty bonus": ("#f8bbd0", "#c2185b"),
}


def parse_seq(s):
    if s is None:
        return []
    if isinstance(s, list):
        return s
    try:
        return json.loads(s)
    except (json.JSONDecodeError, TypeError):
        return []


def find_k_attempts(seq, k):
    """Return list of indices i where s[i] in RECOVERY and s[i-k:i] all EXPLOIT."""
    attempts = []
    for i in range(k, len(seq)):
        if seq[i] in RECOVERY and all(seq[j] in EXPLOIT for j in range(i - k, i)):
            attempts.append(i)
    return attempts


def is_k_restart(seq, i, k):
    """True iff there exists j > i with seq[j-k+1:j+1] all EXPLOIT and
    seq[i+1:j-k+1] contains no RECOVERY primitive."""
    L = len(seq)
    # Iterate possible j values: minimum j is i+k (need at least k chars after i)
    for j in range(i + k, L):
        # Check seq[j-k+1:j+1] all EXPLOIT
        run = seq[j - k + 1:j + 1]
        if any(p not in EXPLOIT for p in run):
            continue
        # Now check segment between recovery (i) and start of this run (j-k+1)
        gap = seq[i + 1:j - k + 1]
        if any(p in RECOVERY for p in gap):
            # Found another recovery in the gap — this j doesn't qualify, but
            # any later j would have the same RECOVERY in its gap, so abort.
            return False
        return True
    return False


def per_trace_counts(seq, k):
    """Return (n_attempts, n_restarts, has_restart_bool) for a single trace."""
    attempts = find_k_attempts(seq, k)
    if not attempts:
        return 0, 0, False
    n_restarts = sum(1 for i in attempts if is_k_restart(seq, i, k))
    return len(attempts), n_restarts, n_restarts > 0


def analyze_checkpoint(parquet_path: Path, k: int) -> dict:
    """Returns per-checkpoint stats for a single k."""
    df = pd.read_parquet(parquet_path)
    df = df[df["task_name"].str.contains("olymp_math_hard")]
    seqs = df["primitive_sequence"].apply(parse_seq).tolist()

    n_traces = len(seqs)
    per_trace = [per_trace_counts(s, k) for s in seqs]
    n_with_attempt = sum(1 for (a, r, h) in per_trace if a > 0)
    total_attempts = sum(a for (a, r, h) in per_trace)
    total_restarts = sum(r for (a, r, h) in per_trace)
    n_with_restart = sum(1 for (a, r, h) in per_trace if h)

    # Bootstrap CI on per-attempt restart rate
    if total_attempts > 0:
        # Build list of binary outcomes: 1 if attempt was a restart, 0 otherwise.
        # Total length = total_attempts.
        outcomes = []
        for s, (a, _, _) in zip(seqs, per_trace):
            if a == 0:
                continue
            attempts = find_k_attempts(s, k)
            for i in attempts:
                outcomes.append(1 if is_k_restart(s, i, k) else 0)
        outcomes = np.array(outcomes, dtype=float)
        rng = np.random.default_rng(42)
        boots = np.array([rng.choice(outcomes, size=len(outcomes), replace=True).mean()
                          for _ in range(2000)])
        ci_lo = float(np.quantile(boots, 0.025))
        ci_hi = float(np.quantile(boots, 0.975))
    else:
        ci_lo = ci_hi = float("nan")

    return {
        "n_traces": n_traces,
        "n_with_k_attempt": n_with_attempt,
        "total_k_attempts": total_attempts,
        "n_k_restarts": total_restarts,
        "n_with_k_success": n_with_restart,
        "pct_with_k_attempt": n_with_attempt / max(1, n_traces),
        "pct_k_attempts_that_restart": (total_restarts / total_attempts) if total_attempts else float("nan"),
        "pct_with_k_success": n_with_restart / max(1, n_traces),
        "pct_restart_ci_lo": ci_lo,
        "pct_restart_ci_hi": ci_hi,
    }


def plot_grouped_bars(rows, out_path):
    """Per-attempt restart rate, grouped bars: x=k, groups=checkpoints."""
    df = pd.DataFrame(rows)
    fig, ax = plt.subplots(figsize=(8, 4.5))
    x = np.arange(len(K_VALUES))
    width = 0.20
    ckpts = ["Base", "SFT", "Vanilla GSPO", "Novelty bonus"]
    for j, ckpt in enumerate(ckpts):
        sub = df[df["checkpoint"] == ckpt].set_index("k").reindex(K_VALUES).reset_index()
        means = (sub["pct_k_attempts_that_restart"] * 100).values
        lo = (sub["pct_restart_ci_lo"] * 100).values
        hi = (sub["pct_restart_ci_hi"] * 100).values
        yerr = np.array([np.where(np.isnan(lo), 0, means - lo),
                         np.where(np.isnan(hi), 0, hi - means)])
        fill, edge = PALETTE[ckpt]
        ax.bar(x + (j - 1.5) * width, means, width, color=fill, edgecolor=edge,
               linewidth=0.8, label=ckpt, yerr=yerr, capsize=3,
               error_kw={"elinewidth": 0.8, "ecolor": edge})

    ax.set_xticks(x)
    ax.set_xticklabels([f"k={k}" for k in K_VALUES])
    ax.set_ylabel("% of k-attempts that restart\n(per-attempt rate, 95% CI)", fontsize=10)
    ax.set_title("Recovery-tradeoff k-sensitivity (OlymMATH-Hard)", fontsize=11)
    ax.grid(True, alpha=0.3, axis="y")
    ax.legend(loc="upper right", fontsize=9, framealpha=0.95)
    fig.tight_layout()
    fig.savefig(out_path, dpi=160, bbox_inches="tight")
    plt.close(fig)
    print(f"  saved: {out_path}")


def write_markdown(rows, csv_path, fig_path, md_path):
    df = pd.DataFrame(rows)
    pivot = df.set_index(["k", "checkpoint"])

    with open(md_path, "w") as f:
        f.write("# §7 recovery-tradeoff k-sensitivity\n\n")
        f.write("**Date**: 2026-05-04\n")
        f.write("**Spec**: `reports/neurips/tasks/recovery_k_sensitivity.md`\n")
        f.write("**Generator**: `scripts/analysis/recovery_k_sensitivity.py`\n")
        f.write(f"**CSV**: `{csv_path}`\n")
        f.write(f"**Figure**: `{fig_path}`\n\n")
        f.write("Definitions (k-parameterized):\n")
        f.write("- **k-attempt**: position i with `s[i]` ∈ {HYPOTHESIZE, BACKTRACK} and "
                "`s[i-k:i]` all in EXPLOIT = {COMPUTE, CHECK, SETUP}.\n")
        f.write("- **k-restart**: a k-attempt at i is a k-restart iff there exists j > i "
                "with `s[j-k+1:j+1]` all EXPLOIT and the segment `s[i+1:j-k+1]` contains "
                "no HYPOTHESIZE / BACKTRACK.\n")
        f.write("- N = total OlymMATH-Hard traces per checkpoint.\n\n")

        for k in K_VALUES:
            f.write(f"### k = {k}\n\n")
            f.write("| Metric | Base | SFT | Vanilla GSPO | Novelty bonus |\n")
            f.write("|---|---:|---:|---:|---:|\n")
            row_attempt = ["% traces with k-attempt"]
            row_restart = ["% k-attempts that restart"]
            row_success = ["% traces with k-success"]
            for ckpt in ["Base", "SFT", "Vanilla GSPO", "Novelty bonus"]:
                cell = pivot.loc[(k, ckpt)]
                row_attempt.append(f"{cell['pct_with_k_attempt'] * 100:.1f}%")
                rr = cell["pct_k_attempts_that_restart"]
                lo = cell["pct_restart_ci_lo"]
                hi = cell["pct_restart_ci_hi"]
                if np.isnan(rr):
                    row_restart.append("—")
                else:
                    row_restart.append(f"{rr * 100:.1f}% [{lo * 100:.1f}, {hi * 100:.1f}]")
                row_success.append(f"{cell['pct_with_k_success'] * 100:.1f}%")
            f.write("| " + " | ".join(row_attempt) + " |\n")
            f.write("| " + " | ".join(row_restart) + " |\n")
            f.write("| " + " | ".join(row_success) + " |\n\n")
            # raw counts for reference
            f.write("Raw counts: ")
            f.write(", ".join(
                f"{ckpt} N={int(pivot.loc[(k, ckpt)]['n_traces'])} "
                f"attempts={int(pivot.loc[(k, ckpt)]['total_k_attempts'])} "
                f"restarts={int(pivot.loc[(k, ckpt)]['n_k_restarts'])}"
                for ckpt in ["Base", "SFT", "Vanilla GSPO", "Novelty bonus"]
            ) + "\n\n")

        f.write("## Figure\n\n")
        f.write("![recovery k-sensitivity](../../writing/neurips_paper/figures/fig_recovery_k_sensitivity.png)\n\n")

        # Interpretation paragraph (auto-generated; writer to refine)
        f.write("## Interpretation\n\n")
        # Compute key facts to reference
        gspo_rates = {k: pivot.loc[(k, "Vanilla GSPO"), "pct_k_attempts_that_restart"] for k in K_VALUES}
        sft_rates = {k: pivot.loc[(k, "SFT"), "pct_k_attempts_that_restart"] for k in K_VALUES}
        prod_rates = {k: pivot.loc[(k, "Novelty bonus"), "pct_k_attempts_that_restart"] for k in K_VALUES}
        gspo_attempt = {k: pivot.loc[(k, "Vanilla GSPO"), "pct_with_k_attempt"] for k in K_VALUES}
        sft_attempt = {k: pivot.loc[(k, "SFT"), "pct_with_k_attempt"] for k in K_VALUES}
        prod_attempt = {k: pivot.loc[(k, "Novelty bonus"), "pct_with_k_attempt"] for k in K_VALUES}

        # 1. Does GSPO > SFT on per-attempt-rate hold across k?
        gspo_beats_sft = {k: (gspo_rates[k] > sft_rates[k]) if not np.isnan(gspo_rates[k]) and not np.isnan(sft_rates[k]) else None
                          for k in K_VALUES}
        attempt_rank_preserved = {k: (prod_attempt[k] >= sft_attempt[k] >= gspo_attempt[k])
                                  for k in K_VALUES}

        f.write("**Q1 (does GSPO maintain a per-attempt-restart-rate advantage over SFT across k?)**\n\n")
        for k in K_VALUES:
            sym = "✓" if gspo_beats_sft.get(k) else "✗" if gspo_beats_sft.get(k) is False else "?"
            f.write(f"- k={k}: GSPO {gspo_rates[k] * 100:.1f}% vs SFT {sft_rates[k] * 100:.1f}%  {sym}\n")
        f.write("\n**Q2 (is the attempt-rate ranking Novelty ≥ SFT ≥ GSPO preserved?)**\n\n")
        for k in K_VALUES:
            sym = "✓" if attempt_rank_preserved.get(k) else "✗"
            f.write(f"- k={k}: Novelty {prod_attempt[k] * 100:.1f}% / SFT {sft_attempt[k] * 100:.1f}% / "
                    f"GSPO {gspo_attempt[k] * 100:.1f}%  {sym}\n")

        # Recommended k
        clear_ks = [k for k in K_VALUES
                    if k >= 3 and gspo_beats_sft.get(k) and attempt_rank_preserved.get(k)]
        if clear_ks:
            rec_k = min(clear_ks)
            placement = "**body**"
        else:
            rec_k = 2
            placement = "**appendix** (with 1-sentence summary in §7 body)"

        f.write(f"\n**Q3 (recommended k for the paper)**: {rec_k}.\n\n")
        f.write(f"**Q4 (recommended placement)**: {placement}.\n\n")

        f.write("**Q5 (recommended column rename for \"% productive when attempted\")**: "
                "`% post-recovery long-chain rate` — most precise; reflects the structural "
                "definition (a long compute chain follows the recovery primitive) without "
                "implying problem-solving productivity. Alternatives: `% search-restart rate` "
                "(shorter, slightly more loaded) or `% k-restart rate` (most literal).\n\n")

        f.write("---\n\n")
        f.write("Decision logic implemented above is rule-based; the writer should sanity-check "
                "by reading the per-k tables. If trace-count-noise at higher k makes some cells "
                "unstable (look at raw counts and CI widths), prefer the lowest k where the "
                "pattern is non-trivial AND the bootstrap CIs are tight enough to support it.\n\n")

        # Inconsistency flag — keep this in every regen so the writer agent sees it.
        gspo_k2 = pivot.loc[(2, "Vanilla GSPO"), "pct_with_k_attempt"] * 100
        prod_k2 = pivot.loc[(2, "Novelty bonus"), "pct_with_k_attempt"] * 100
        sft_k2 = pivot.loc[(2, "SFT"), "pct_with_k_attempt"] * 100
        f.write("---\n\n")
        f.write("## ⚠️ Inconsistency with paper §7 `tab:recovery-tradeoff`\n\n")
        f.write("The existing §7 table (line 77 of `writing/neurips_paper/sections/07_rl_collapses.tex`) "
                "reports:\n\n")
        f.write("| Metric | Base | SFT | Vanilla GSPO | Novelty bonus |\n")
        f.write("|---|---:|---:|---:|---:|\n")
        f.write("| % traces with recovery attempt | 0.2 | 24.3 | **8.7** | 17.5 |\n")
        f.write("| % productive when attempted | 0.0 | 23.7 | 37.6 | 30.3 |\n")
        f.write("| % successful recoveries per trace | 0.0 | 5.8 | 3.3 | 5.3 |\n\n")
        f.write("The §7 prose then says **\"recovery attempts drop nearly threefold "
                "(24.3% → 8.7% of traces)\"**.\n\n")
        f.write("The appendix definition (`app:recovery-metrics`) says recovery attempt = "
                "HYPOTHESIZE/BACKTRACK preceded by exploit run of depth ≥ 2 — which is *exactly* "
                "my k=2 implementation (with EXPLOIT = {COMPUTE, CHECK, SETUP}).\n\n")
        f.write("**My recomputation on the canonical v90 parquets at k=2 gives**:\n\n")
        f.write("| Metric | Base | SFT | Vanilla GSPO | Novelty bonus |\n")
        f.write("|---|---:|---:|---:|---:|\n")
        base_k2 = pivot.loc[(2, "Base"), "pct_with_k_attempt"] * 100
        f.write(f"| % traces with k=2 attempt | {base_k2:.1f} | {sft_k2:.1f} | "
                f"**{gspo_k2:.1f}** | {prod_k2:.1f} |\n\n")
        f.write(f"So **SFT ({sft_k2:.1f}) and Base ({base_k2:.1f})** match the paper's 24.3 and 0.2 "
                f"within ~0.2pp, but **Vanilla GSPO ({gspo_k2:.1f}) does not match the paper's 8.7** "
                f"(mine is {gspo_k2/8.7:.1f}× higher) and **Novelty bonus ({prod_k2:.1f}) does not "
                f"match the paper's 17.5** (mine is {prod_k2/17.5:.1f}× higher).\n\n")
        f.write("### Where the paper's numbers came from\n\n")
        f.write("`grep` across the codebase: the values 24.3 / 8.7 / 17.5 appear **only in "
                "`07_rl_collapses.tex`** — no committed Python script, CSV, or markdown report "
                "produces them. Submodule git log shows they entered in the very first draft "
                "commit (`4f3d7fb` — \"Draft NeurIPS 2026 paper end-to-end\") and have not been "
                "revised since. Most likely the writer agent computed them during its draft "
                "session using a Python session that was not committed, on an older v90 parquet "
                "or with a stricter definition than what the appendix states.\n\n")
        f.write("### Implications for §7 prose\n\n")
        f.write("- ✅ The \"rare attempts are productive\" half of the §7 framing is "
                "**robustly supported** at all k by the canonical v90 data: GSPO has the highest "
                "per-attempt restart rate at every k, with non-overlapping bootstrap CIs vs SFT "
                "at k=3, 5, 7.\n")
        f.write("- ❌ The \"threefold drop in attempt rate\" half is **not supported** by the "
                "canonical v90 data at any plausible definition: at k=2 the drop from SFT to "
                "GSPO is only ~1.17× ({:.1f} → {:.1f}), and at k≥5 GSPO actually has *more* "
                "k-attempts than SFT.\n".format(sft_k2, gspo_k2))
        f.write("- The §7 prose and `tab:recovery-tradeoff` table values should be updated to "
                "match the canonical v90 numbers in this report. The \"rare-but-productive\" "
                "framing survives; the \"threefold drop\" framing should be dropped or replaced "
                "with the smaller per-k effect.\n")
    print(f"  saved: {md_path}")


def main():
    OUT_CSV.parent.mkdir(exist_ok=True, parents=True)
    OUT_FIG.parent.mkdir(exist_ok=True, parents=True)
    OUT_MD.parent.mkdir(exist_ok=True, parents=True)

    rows = []
    print("=== Computing recovery k-sensitivity ===")
    for ckpt, parq in PATHS.items():
        for k in K_VALUES:
            stats = analyze_checkpoint(parq, k)
            rows.append({"checkpoint": ckpt, "k": k, **stats})
            print(f"  {ckpt:<20} k={k}  "
                  f"%attempt={stats['pct_with_k_attempt']*100:>5.1f}%  "
                  f"%restart_per_attempt={stats['pct_k_attempts_that_restart']*100:>5.1f}%  "
                  f"%success_per_trace={stats['pct_with_k_success']*100:>5.1f}%  "
                  f"(n_attempts={stats['total_k_attempts']:>4})")

    pd.DataFrame(rows).to_csv(OUT_CSV, index=False)
    print(f"\nSaved CSV: {OUT_CSV}")

    plot_grouped_bars(rows, OUT_FIG)
    write_markdown(rows, OUT_CSV, OUT_FIG, OUT_MD)


if __name__ == "__main__":
    main()
