"""§6 supporting analyses: causal evidence on chain depth.

Implements the data-available versions of the analyses specified in
``reports/neurips/tasks/causal_chain_depth_analyses.md``:

  Analysis 1' — Within-checkpoint puzzle-type chain depth.
                Only puzzle types with classifier outputs on S3 are bridges_8x8de
                and undead_5x5de (both OOD evals). Compares chain depth between
                the two puzzle types at SFT and vanilla GSPO checkpoints.

  Analysis 2' — Solved vs unsolved chain depth on OlymMATH-Hard.
                Per-checkpoint Mann-Whitney U test on chain depth distributions
                of solved vs unsolved traces.

Chain depth definition (per task spec docstring): longest run of
{COMPUTE, CHECK, SETUP} that is not interrupted by {HYPOTHESIZE, BACKTRACK,
ERROR_DETECT}. Other primitives (PLAN, ENUMERATE, DECOMPOSE, VERIFY,
SUMMARIZE, OTHER) neither extend nor interrupt the run.

Note: this differs from the §3/§6 chain-depth definition used in
``master Claim 6`` and ``fig_chain_depth.png`` (strict consecutive
EXPLOIT). The looser definition here is what the task spec asks for; we
flag the difference in the report rather than reconciling.
"""
import json
import os
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from scipy.stats import mannwhitneyu

EXPLOIT = {"COMPUTE", "CHECK", "SETUP"}
INTERRUPT = {"HYPOTHESIZE", "BACKTRACK", "ERROR_DETECT"}

# Pastel palette per CLAUDE.md "Figure colour scheme".
PALETTE = {
    "Base": "#cfd8dc",
    "SFT": "#c8e6c9",
    "Vanilla GSPO": "#bbdefb",
    "Novelty bonus": "#f8bbd0",
}
EDGE = {
    "Base": "#607d8b",
    "SFT": "#2e7d32",
    "Vanilla GSPO": "#1565c0",
    "Novelty bonus": "#c2185b",
}

# Two-class secondary palette for puzzle-type comparisons.
PUZZLE_FILL = {"Bridges 8x8de": "#bbdefb", "Undead 5x5de": "#f8bbd0"}
PUZZLE_EDGE = {"Bridges 8x8de": "#1565c0", "Undead 5x5de": "#c2185b"}

# Solved / unsolved palette.
OUTCOME_FILL = {"Solved": "#c8e6c9", "Unsolved": "#ffccbc"}
OUTCOME_EDGE = {"Solved": "#2e7d32", "Unsolved": "#d84315"}


CHECKPOINT_FILES = {
    "Base": "v90_base_math",
    "SFT": "v90_sft_math",
    "Vanilla GSPO": "v90_gspo_math",
    "Novelty bonus": "v90_prod_s15_math",
}

PUZZLE_FILES = {
    "SFT": "v90_sft_puzzles",
    "Vanilla GSPO": "v90_gspo_puzzles",
}
# Newly classified OOD puzzle evals (200 single-shot traces per task) at SFT+GSPO.
PUZZLE_FILES_EXTRA = {
    "SFT": "v90_sft_puzzles_extra",
    "Vanilla GSPO": "v90_gspo_puzzles_extra",
}
# Pass@8 re-evals on harder OOD puzzles (1600 traces per task; 8 rollouts × 200 docs).
# Used for the matched-N size_ladder comparison; combined with an 8-rollout
# subsample of the pass@32 easier-puzzle parquets gives 8 rollouts/doc on both sides.
PUZZLE_FILES_EXTRA_PASS8 = {
    "SFT": "v90_sft_puzzles_extra_pass8",
    "Vanilla GSPO": "v90_gspo_puzzles_extra_pass8",
}


def _pass8_ladder_available() -> bool:
    """Check whether the pass@8 puzzles_extra trace parquets are present."""
    sft = Path("results/exploration_analysis/v90_sft_puzzles_extra_pass8/trace_level_metrics.parquet")
    gspo = Path("results/exploration_analysis/v90_gspo_puzzles_extra_pass8/trace_level_metrics.parquet")
    return sft.exists() and gspo.exists()


ROOT = Path("results/exploration_analysis")
OUT_DIR = Path("results/exploration_analysis/causal_chain_depth_obs")
FIG_DIR = Path(os.environ.get("FIG_OUT_DIR", "writing/neurips_paper/figures"))
OUT_DIR.mkdir(parents=True, exist_ok=True)
FIG_DIR.mkdir(parents=True, exist_ok=True)


def chain_depth(seq):
    if isinstance(seq, str):
        try:
            seq = json.loads(seq)
        except (json.JSONDecodeError, TypeError):
            return 0
    if not isinstance(seq, (list, np.ndarray)) or len(seq) == 0:
        return 0
    best = cur = 0
    for p in seq:
        if p in EXPLOIT:
            cur += 1
            if cur > best:
                best = cur
        elif p in INTERRUPT:
            cur = 0
        # else: pass-through (PLAN, ENUMERATE, DECOMPOSE, VERIFY, SUMMARIZE, OTHER)
    return best


def load(parquet_dir):
    df = pd.read_parquet(ROOT / parquet_dir / "trace_level_metrics.parquet")
    df["chain_depth"] = df["primitive_sequence"].apply(chain_depth)
    return df


# ---------------------------------------------------------------------------
# Analysis 1': within-checkpoint puzzle-type chain depth
# ---------------------------------------------------------------------------
def analysis1():
    rows = []
    for ck, dirname in PUZZLE_FILES.items():
        df = load(dirname)
        for task, short in [("bridges_8x8de_pass32", "Bridges 8x8de"),
                            ("undead_5x5de_pass32", "Undead 5x5de")]:
            sub = df[df["task_name"] == task]
            cd = sub["chain_depth"]
            rows.append({
                "checkpoint": ck,
                "task": short,
                "n_traces": len(sub),
                "mean": float(cd.mean()),
                "median": float(cd.median()),
                "p90": float(cd.quantile(0.9)),
                "p95": float(cd.quantile(0.95)),
                "frac_correct": float(sub["correct"].mean()),
            })
    out = pd.DataFrame(rows)
    out.to_csv(OUT_DIR / "analysis1_within_checkpoint_puzzle.csv", index=False)

    # MWU between puzzle types within each checkpoint
    for ck, dirname in PUZZLE_FILES.items():
        df = load(dirname)
        b = df[df["task_name"] == "bridges_8x8de_pass32"]["chain_depth"]
        u = df[df["task_name"] == "undead_5x5de_pass32"]["chain_depth"]
        U, p = mannwhitneyu(b, u, alternative="two-sided")
        print(f"[A1'] {ck}: bridges_8x8de vs undead_5x5de  U={U:.0f}  p={p:.3g}")

    # ---- Figure ----
    pivot = out.pivot(index="checkpoint", columns="task", values="mean")
    cks = ["SFT", "Vanilla GSPO"]
    pivot = pivot.reindex(cks)
    tasks = ["Bridges 8x8de", "Undead 5x5de"]
    xs = np.arange(len(cks))
    width = 0.36

    fig, ax = plt.subplots(figsize=(7, 4.2))
    for i, t in enumerate(tasks):
        offset = (i - 0.5) * width
        ax.bar(xs + offset, pivot[t].values, width,
               color=PUZZLE_FILL[t], edgecolor=PUZZLE_EDGE[t], linewidth=1.0,
               label=t)
        for x, v in zip(xs + offset, pivot[t].values):
            ax.text(x, v + 0.4, f"{v:.1f}", ha="center", va="bottom", fontsize=9)
    ax.set_xticks(xs)
    ax.set_xticklabels(cks)
    ax.set_ylabel("Mean chain depth")
    ax.set_title("Chain depth on OOD puzzle evals: same model, different puzzle types",
                 fontsize=10)
    panel_max = max(pivot[t].max() for t in tasks)
    ax.set_ylim(0, panel_max + 1.5)
    ax.grid(True, axis="y", alpha=0.3)
    ax.legend(loc="upper left", fontsize=9)
    fig.tight_layout()
    out_fig = FIG_DIR / "fig_chain_depth_per_puzzle.png"
    fig.savefig(out_fig, dpi=160, bbox_inches="tight")
    print(f"wrote {out_fig}")
    plt.close(fig)
    return out


# ---------------------------------------------------------------------------
# Analysis 2': solved vs unsolved chain depth on OlymMATH-Hard
# ---------------------------------------------------------------------------
def analysis2():
    rows = []
    plot_rows = {}
    for ck_name, dirname in CHECKPOINT_FILES.items():
        df = load(dirname)
        solved = df[df["correct"] == True]["chain_depth"]
        unsolved = df[df["correct"] == False]["chain_depth"]
        if len(solved) >= 5:
            U, p = mannwhitneyu(solved, unsolved, alternative="two-sided")
            row = {
                "checkpoint": ck_name,
                "n_solved": len(solved),
                "n_unsolved": len(unsolved),
                "solved_mean": float(solved.mean()),
                "unsolved_mean": float(unsolved.mean()),
                "delta": float(solved.mean() - unsolved.mean()),
                "U": int(U),
                "p_value": float(p),
            }
        else:
            row = {
                "checkpoint": ck_name,
                "n_solved": len(solved),
                "n_unsolved": len(unsolved),
                "solved_mean": float(solved.mean()) if len(solved) else float("nan"),
                "unsolved_mean": float(unsolved.mean()),
                "delta": float("nan"),
                "U": float("nan"),
                "p_value": float("nan"),
            }
        rows.append(row)
        plot_rows[ck_name] = row

    out = pd.DataFrame(rows)
    out.to_csv(OUT_DIR / "analysis2_solved_vs_unsolved.csv", index=False)
    for r in rows:
        print(f"[A2'] {r['checkpoint']:18s}  solved={r['solved_mean']:5.2f}  "
              f"unsolved={r['unsolved_mean']:5.2f}  Δ={r['delta']:+5.2f}  "
              f"p={r['p_value']:.3g}  n_solved={r['n_solved']}")

    # ---- Figure ----
    cks = list(CHECKPOINT_FILES.keys())
    xs = np.arange(len(cks))
    width = 0.38

    fig, ax = plt.subplots(figsize=(8, 4.5))
    solved_means = [plot_rows[c]["solved_mean"] for c in cks]
    unsolved_means = [plot_rows[c]["unsolved_mean"] for c in cks]

    ax.bar(xs - width / 2, solved_means, width,
           color=OUTCOME_FILL["Solved"], edgecolor=OUTCOME_EDGE["Solved"],
           linewidth=1.0, label="Solved traces")
    ax.bar(xs + width / 2, unsolved_means, width,
           color=OUTCOME_FILL["Unsolved"], edgecolor=OUTCOME_EDGE["Unsolved"],
           linewidth=1.0, label="Unsolved traces")

    for i, c in enumerate(cks):
        sm = solved_means[i]
        um = unsolved_means[i]
        ax.text(xs[i] - width / 2, sm + 0.3, f"{sm:.1f}", ha="center", va="bottom",
                fontsize=9)
        ax.text(xs[i] + width / 2, um + 0.3, f"{um:.1f}", ha="center", va="bottom",
                fontsize=9)
        # n_solved annotation only; significance lives in the appendix table
        # (tab:chain-depth-sig).
        n = plot_rows[c]["n_solved"]
        top = max(sm, um)
        ax.text(xs[i], top + 1.0, f"n={n}", ha="center", va="bottom", fontsize=8,
                color="#555")

    ax.set_xticks(xs)
    ax.set_xticklabels(cks, fontsize=10)
    ax.set_ylabel("Mean chain depth")
    ax.set_ylim(0, max(unsolved_means + solved_means) * 1.30)
    ax.set_title("OlymMATH-Hard: chain depth in solved vs unsolved traces",
                 fontsize=10)
    ax.legend(loc="upper left", fontsize=9)
    ax.grid(True, axis="y", alpha=0.3)
    fig.tight_layout()
    out_fig = FIG_DIR / "fig_chain_depth_solved_vs_unsolved.png"
    fig.savefig(out_fig, dpi=160, bbox_inches="tight")
    print(f"wrote {out_fig}")
    plt.close(fig)
    return out


def analysis1_size_ladder():
    """Within-puzzle size and difficulty ladder using newly classified OOD evals.

    Two modes, controlled by whether the pass@8 puzzles_extra parquets exist:

    A) Single-shot match (USE_PASS8_LADDER=False):
       Bridges: 8x8de (pass@32, subsampled to trace_id==0, n=100) and
                10x10dh (pass@1, n=200)
       Undead:  5x5de (pass@32, subsampled to trace_id==0, n=100) and
                5x5dm (pass@1, n=200)
       Pattern: 7x7dm (pass@1, n=200)

    B) Pass@8 match (USE_PASS8_LADDER=True):
       Bridges: 8x8de (pass@32, subsampled to trace_id < 8, n=800) and
                10x10dh (pass@8 re-eval, n=1600)
       Undead:  5x5de (pass@32, subsampled to trace_id < 8, n=800) and
                5x5dm (pass@8 re-eval, n=1600)
       Pattern dropped from B (no pass@8 re-eval).

    Both modes fix the original artifact where the easier cell got 32 retries
    while the harder cell was single-shot, biasing the ladder anti-monotonically.
    """
    use_pass8 = _pass8_ladder_available()
    print(f"[size_ladder] mode = {'PASS8' if use_pass8 else 'SINGLE-SHOT'}")

    rows = []
    if use_pass8:
        label_for = {
            "bridges_8x8de_pass32": ("Bridges", "8x8de"),
            "bridges_10x10dh_pass8": ("Bridges", "10x10dh"),
            "undead_5x5de_pass32":  ("Undead", "5x5de"),
            "undead_5x5dm_pass8":   ("Undead", "5x5dm"),
        }
        target_rollouts_per_doc = 8
        extra_files = PUZZLE_FILES_EXTRA_PASS8
    else:
        label_for = {
            "bridges_8x8de_pass32":     ("Bridges", "8x8de"),
            "bridges_10x10dh_test200":  ("Bridges", "10x10dh"),
            "undead_5x5de_pass32":      ("Undead", "5x5de"),
            "undead_5x5dm_test200":     ("Undead", "5x5dm"),
            "pattern_7x7dm_test200":    ("Pattern", "7x7dm"),
        }
        target_rollouts_per_doc = 1
        extra_files = PUZZLE_FILES_EXTRA

    def _subsample(df: pd.DataFrame, n: int) -> pd.DataFrame:
        """Take first `n` trace_ids per doc. Leaves dfs already at <=n unchanged."""
        if "trace_id" in df.columns and df["trace_id"].nunique() > n:
            return df[df["trace_id"] < n].copy()
        return df

    for ck in ["SFT", "Vanilla GSPO"]:
        # easier puzzles (pass@32, subsampled to target rollouts/doc)
        df_a = _subsample(load(PUZZLE_FILES[ck]), target_rollouts_per_doc)
        # harder puzzles (single-shot test200 OR pass@8 re-eval, depending on mode)
        df_b = _subsample(load(extra_files[ck]), target_rollouts_per_doc)
        for df in (df_a, df_b):
            for task in df["task_name"].unique():
                if task not in label_for:
                    continue
                family, size = label_for[task]
                sub = df[df["task_name"] == task]
                cd = sub["chain_depth"]
                rows.append({
                    "checkpoint": ck,
                    "family": family,
                    "size": size,
                    "task": task,
                    "n_traces": len(sub),
                    "mean_chain_depth": float(cd.mean()),
                    "median": float(cd.median()),
                    "p90": float(cd.quantile(0.9)),
                    "frac_correct": float(sub["correct"].mean()),
                })

    out = pd.DataFrame(rows)
    out.to_csv(OUT_DIR / "analysis1_size_ladder.csv", index=False)

    # MWU on bridges 8x8 vs 10x10 within each checkpoint (matched rollouts/doc)
    bridges_harder = "bridges_10x10dh_pass8" if use_pass8 else "bridges_10x10dh_test200"
    undead_harder  = "undead_5x5dm_pass8"    if use_pass8 else "undead_5x5dm_test200"
    print(f"\n[Bridges size ladder, {target_rollouts_per_doc} rollouts/doc]")
    for ck in ["SFT", "Vanilla GSPO"]:
        a = _subsample(load(PUZZLE_FILES[ck]), target_rollouts_per_doc)
        a = a[a["task_name"] == "bridges_8x8de_pass32"]["chain_depth"]
        b = _subsample(load(extra_files[ck]), target_rollouts_per_doc)
        b = b[b["task_name"] == bridges_harder]["chain_depth"]
        U, p = mannwhitneyu(a, b, alternative="two-sided")
        print(f"  {ck}: 8x8de (mean={a.mean():.2f}, n={len(a)}) "
              f"vs 10x10dh (mean={b.mean():.2f}, n={len(b)}) "
              f"U={U:.0f} p={p:.3g}")

    print(f"\n[Undead difficulty pair, fixed size 5x5, {target_rollouts_per_doc} rollouts/doc]")
    for ck in ["SFT", "Vanilla GSPO"]:
        a = _subsample(load(PUZZLE_FILES[ck]), target_rollouts_per_doc)
        a = a[a["task_name"] == "undead_5x5de_pass32"]["chain_depth"]
        b = _subsample(load(extra_files[ck]), target_rollouts_per_doc)
        b = b[b["task_name"] == undead_harder]["chain_depth"]
        U, p = mannwhitneyu(a, b, alternative="two-sided")
        print(f"  {ck}: 5x5de (mean={a.mean():.2f}, n={len(a)}) "
              f"vs 5x5dm (mean={b.mean():.2f}, n={len(b)}) "
              f"U={U:.0f} p={p:.3g}")

    # ---- Figure: 1x2 panel — Bridges size ladder, Undead difficulty pair ----
    fig, axes = plt.subplots(1, 2, figsize=(11, 4.4))
    cks = ["SFT", "Vanilla GSPO"]
    width = 0.36

    # Panel 1: Bridges size ladder
    ax = axes[0]
    sizes = ["8x8de", "10x10dh"]
    colors = [PUZZLE_FILL["Bridges 8x8de"], "#90caf9"]  # within-family pastel pair
    edges = [PUZZLE_EDGE["Bridges 8x8de"], "#0d47a1"]
    xs = np.arange(len(cks))
    panel_max = 0.0
    for i, sz in enumerate(sizes):
        means = [out[(out["checkpoint"] == c) & (out["family"] == "Bridges")
                     & (out["size"] == sz)]["mean_chain_depth"].iloc[0] for c in cks]
        ns = [out[(out["checkpoint"] == c) & (out["family"] == "Bridges")
                  & (out["size"] == sz)]["n_traces"].iloc[0] for c in cks]
        panel_max = max(panel_max, max(means))
        offset = (i - 0.5) * width
        ax.bar(xs + offset, means, width, color=colors[i], edgecolor=edges[i],
               linewidth=1.0, label=f"Bridges {sz} (n={ns[0]}/{ns[1]})")
        for x, v in zip(xs + offset, means):
            ax.text(x, v + 0.4, f"{v:.1f}", ha="center", va="bottom", fontsize=9)
    ax.set_xticks(xs)
    ax.set_xticklabels(cks)
    ax.set_ylabel("Mean chain depth")
    ax.set_title("Bridges size ladder: 8x8de vs 10x10dh", fontsize=10)
    ax.set_ylim(0, panel_max + 2.5)
    ax.grid(True, axis="y", alpha=0.3)
    ax.legend(loc="upper left", fontsize=8)

    # Panel 2: Undead difficulty pair (size fixed)
    ax = axes[1]
    diffs = ["5x5de", "5x5dm"]
    colors = [PUZZLE_FILL["Undead 5x5de"], "#f48fb1"]
    edges = [PUZZLE_EDGE["Undead 5x5de"], "#880e4f"]
    panel_max = 0.0
    for i, d in enumerate(diffs):
        means = [out[(out["checkpoint"] == c) & (out["family"] == "Undead")
                     & (out["size"] == d)]["mean_chain_depth"].iloc[0] for c in cks]
        ns = [out[(out["checkpoint"] == c) & (out["family"] == "Undead")
                  & (out["size"] == d)]["n_traces"].iloc[0] for c in cks]
        panel_max = max(panel_max, max(means))
        offset = (i - 0.5) * width
        ax.bar(xs + offset, means, width, color=colors[i], edgecolor=edges[i],
               linewidth=1.0, label=f"Undead {d} (n={ns[0]}/{ns[1]})")
        for x, v in zip(xs + offset, means):
            ax.text(x, v + 0.4, f"{v:.1f}", ha="center", va="bottom", fontsize=9)
    ax.set_xticks(xs)
    ax.set_xticklabels(cks)
    ax.set_ylabel("Mean chain depth")
    ax.set_title("Undead difficulty pair (size fixed): 5x5de vs 5x5dm", fontsize=10)
    ax.set_ylim(0, panel_max + 2.5)
    ax.grid(True, axis="y", alpha=0.3)
    ax.legend(loc="upper left", fontsize=8)

    fig.suptitle("Within-puzzle size and difficulty ladders (same model, same family)",
                 fontsize=11)
    fig.tight_layout()
    out_fig = FIG_DIR / "fig_chain_depth_size_ladder.png"
    fig.savefig(out_fig, dpi=160, bbox_inches="tight")
    print(f"wrote {out_fig}")
    plt.close(fig)
    return out


if __name__ == "__main__":
    print("=== Analysis 1': per-puzzle-type chain depth ===")
    a1 = analysis1()
    print(a1.to_string(index=False))

    print("\n=== Analysis 2': solved vs unsolved chain depth ===")
    a2 = analysis2()
    print(a2.to_string(index=False))

    print("\n=== Analysis 1' extended: within-puzzle size + difficulty ladder ===")
    a1x = analysis1_size_ladder()
    print(a1x.to_string(index=False))
