"""§6 size ladder — chain depth on solved traces only.

Combines:
- New pass@8 cells: bridges_7x7dm + undead_4x4de × {SFT v2 ep5, GSPO v2 s20}
  (parquet at results/exploration_analysis/v90_size_ladder_pass8/)
- Existing pass@32 cells: bridges_8x8de + undead_5x5de × {SFT, GSPO}
  (parquets at v90_sft_puzzles/, v90_gspo_puzzles/)

Filters to `correct=True` traces, computes chain depth distribution per
(checkpoint × puzzle × size), bootstrap 95% CIs on the mean.

Outputs:
- writing/neurips_paper/figures/fig_chain_depth_size_ladder_solved.png
- results/exploration_analysis/v90_size_ladder_pass8/solved_chain_depth_stats.csv
"""

import json
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

ROOT = Path("results/exploration_analysis")
OUT_DIR = Path("writing/neurips_paper/figures")
STATS_OUT = ROOT / "v90_size_ladder_pass8/solved_chain_depth_stats.csv"

COMPUTE_PHASE = {"COMPUTE", "CHECK", "SETUP"}
INTERRUPT = {"HYPOTHESIZE", "BACKTRACK", "ERROR_DETECT"}

# Pastel Material palette per CLAUDE.md
PALETTE = {
    "SFT":          ("#c8e6c9", "#2e7d32"),
    "Vanilla GSPO": ("#bbdefb", "#1565c0"),
}

# (checkpoint, puzzle, size) → (parquet path, task_name in parquet)
CELLS = {
    # bridges
    ("SFT",          "bridges", "7x7dm"): (ROOT / "v90_size_ladder_pass8/trace_level_metrics.parquet", "bridges_7x7dm_pass8", "dsr_sft_v2"),
    ("SFT",          "bridges", "8x8de"): (ROOT / "v90_sft_puzzles/trace_level_metrics.parquet",      "bridges_8x8de_pass32", "dsr_sft_v2"),
    ("Vanilla GSPO", "bridges", "7x7dm"): (ROOT / "v90_size_ladder_pass8/trace_level_metrics.parquet", "bridges_7x7dm_pass8", "gspo_v2_sft_step20"),
    ("Vanilla GSPO", "bridges", "8x8de"): (ROOT / "v90_gspo_puzzles/trace_level_metrics.parquet",     "bridges_8x8de_pass32", "gspo_v2_sft_step20"),
    # undead
    ("SFT",          "undead", "4x4de"):  (ROOT / "v90_size_ladder_pass8/trace_level_metrics.parquet", "undead_4x4de_pass8",  "dsr_sft_v2"),
    ("SFT",          "undead", "5x5de"):  (ROOT / "v90_sft_puzzles/trace_level_metrics.parquet",      "undead_5x5de_pass32", "dsr_sft_v2"),
    ("Vanilla GSPO", "undead", "4x4de"):  (ROOT / "v90_size_ladder_pass8/trace_level_metrics.parquet", "undead_4x4de_pass8",  "gspo_v2_sft_step20"),
    ("Vanilla GSPO", "undead", "5x5de"):  (ROOT / "v90_gspo_puzzles/trace_level_metrics.parquet",     "undead_5x5de_pass32", "gspo_v2_sft_step20"),
}


def chain_depth(seq):
    if isinstance(seq, str):
        try:
            seq = json.loads(seq)
        except (json.JSONDecodeError, TypeError):
            return 0
    if not isinstance(seq, (list, np.ndarray)) or len(seq) == 0:
        return 0
    best = cur = 0
    for p in seq:
        if p in COMPUTE_PHASE:
            cur += 1
            if cur > best:
                best = cur
        elif p in INTERRUPT:
            cur = 0
    return best


def bootstrap_ci(values: np.ndarray, n_boot: int = 5000, seed: int = 42, alpha: float = 0.05) -> tuple[float, float]:
    if len(values) == 0:
        return (np.nan, np.nan)
    rng = np.random.default_rng(seed)
    n = len(values)
    means = np.empty(n_boot)
    for i in range(n_boot):
        means[i] = np.mean(rng.choice(values, size=n, replace=True))
    return (np.quantile(means, alpha / 2), np.quantile(means, 1 - alpha / 2))


def load_cell(parq: Path, task: str, ckpt_id: str) -> np.ndarray:
    """Return chain depths for solved-only traces in this (parquet × task × ckpt) cell."""
    df = pd.read_parquet(parq)
    sub = df[(df["task_name"] == task) & (df["checkpoint_id"] == ckpt_id) & (df["correct"] == True)].copy()
    return sub["primitive_sequence"].apply(chain_depth).values


def main():
    OUT_DIR.mkdir(exist_ok=True, parents=True)
    STATS_OUT.parent.mkdir(exist_ok=True, parents=True)

    rows = []
    for (ckpt, puzzle, size), (parq, task, ckpt_id) in CELLS.items():
        depths = load_cell(parq, task, ckpt_id)
        n = len(depths)
        if n == 0:
            print(f"  [empty] {ckpt} × {puzzle}_{size}")
            continue
        mean = float(np.mean(depths))
        median = float(np.median(depths))
        p90 = float(np.quantile(depths, 0.9)) if n >= 10 else np.nan
        ci_lo, ci_hi = bootstrap_ci(depths)
        rows.append({"checkpoint": ckpt, "puzzle": puzzle, "size": size,
                     "n_solved": n, "mean": mean, "median": median, "p90": p90,
                     "ci_lo": ci_lo, "ci_hi": ci_hi})
        print(f"  {ckpt:<14} × {puzzle}_{size}  n_solved={n:>4}  "
              f"mean={mean:>5.2f} (95% CI [{ci_lo:.2f}, {ci_hi:.2f}])  "
              f"median={median:>4.0f}  p90={p90:>4.0f}")

    stats_df = pd.DataFrame(rows)
    stats_df.to_csv(STATS_OUT, index=False)
    print(f"\nSaved stats: {STATS_OUT}")

    # ---- Plot: 1×2 panels (Bridges, Undead) ----
    SIZE_ORDER = {"bridges": ["7x7dm", "8x8de"], "undead": ["4x4de", "5x5de"]}
    fig, axes = plt.subplots(1, 2, figsize=(11, 4.5))
    fig.suptitle("Chain depth on solved traces — size ladder",
                 fontsize=12, fontweight="bold")

    for ax, puzzle in zip(axes, ["bridges", "undead"]):
        sizes = SIZE_ORDER[puzzle]
        x = np.arange(len(sizes))
        for ckpt in ["SFT", "Vanilla GSPO"]:
            fill, edge = PALETTE[ckpt]
            sub = stats_df[(stats_df["checkpoint"] == ckpt) & (stats_df["puzzle"] == puzzle)]
            sub = sub.set_index("size").reindex(sizes).reset_index()
            means = sub["mean"].values
            lo = sub["ci_lo"].values
            hi = sub["ci_hi"].values
            yerr = np.array([means - lo, hi - means])
            ax.errorbar(x, means, yerr=yerr, fmt="o-", color=edge, mfc=fill,
                        mec=edge, markersize=6, linewidth=1.5, capsize=4,
                        label=f"{ckpt}")
            for xi, m, n in zip(x, means, sub["n_solved"]):
                ax.annotate(f"n={int(n)}", (xi, m), xytext=(6, -10),
                            textcoords="offset points", fontsize=8, color=edge)

        ax.set_xticks(x)
        ax.set_xticklabels([f"{puzzle}_{s}" for s in sizes])
        ax.set_xlabel("Grid size", fontsize=10)
        ax.set_ylabel("Mean chain depth (solved traces, 95% CI)", fontsize=10)
        ax.set_title(f"{puzzle.title()} ladder", fontsize=11)
        ax.grid(alpha=0.3)
        ax.legend(loc="upper left", fontsize=9, framealpha=0.95)

    fig.tight_layout(rect=[0, 0, 1, 0.94])
    out = OUT_DIR / "fig_chain_depth_size_ladder_solved.png"
    fig.savefig(out, dpi=150)
    plt.close(fig)
    print(f"Saved: {out}")


if __name__ == "__main__":
    main()
