"""§6 and §7 figures for the NeurIPS 2026 paper.

Produces:
  1. fig_chain_depth.png — 4-checkpoint chain-depth box plot (Base, SFT, GSPO, Novelty bonus)
  2. fig_cc_loop_vs_passk.png — 4-point scatter of CC-loop vs OlymMATH-Hard pass@32
  3. fig_exploration_primitive_counts.png — raw HYPOTHESIZE / BACKTRACK counts
     per trace at SFT, GSPO, novelty bonus, mini (no grouping, no ratio)
  4. fig_exploit_primitive_counts.png — COMPUTE / CHECK / SETUP per trace at the
     four main checkpoints; primitive layer for §4.2 (paired with motif fig in
     a 1:3 subfigure block).

Source data: trace_level_metrics.parquet from
results/exploration_analysis/v90_{base,sft,gspo,prod_s15,mini}_math/

Pass@32 numbers (post OlymMATH scorer fix, commit d3d7789f):
  Base 16.1%, SFT 23.0%, vanilla GSPO 30.0%, novelty bonus 36.0%
  Source: reports/olymp_math_rescored_results.md
"""

import json
import os
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

EXPLOIT = {"COMPUTE", "CHECK", "SETUP"}

PATHS = {
    "Base": "/tmp/v90_csvs/v90_base_traces.parquet",
    "SFT": "/tmp/v90_csvs/v90_sft_traces.parquet",
    "Vanilla GSPO": "/tmp/v90_csvs/v90_gspo_traces.parquet",
    "Novelty bonus": "/tmp/v90_csvs/v90_prod_s15_traces.parquet",
    "Novelty Mini": "/tmp/v90_csvs/v90_mini_traces.parquet",
}

PATHS_PUZZLES = {
    "SFT": "/tmp/v90_csvs/v90_sft_puzzles.parquet",
    "Vanilla GSPO": "/tmp/v90_csvs/v90_gspo_puzzles.parquet",
}

PASSK_HARD = {
    "Base": 16.1,
    "SFT": 23.0,
    "Vanilla GSPO": 30.0,
    "Novelty bonus": 36.0,
}

# Pastel palette for the 4 main checkpoints + Novelty Mini.
# Material Design 100-level shades; dark counterparts are used for edges/emphasis.
# IMPORTANT: keep this in sync with CLAUDE.md "Figure colour scheme" section.
PALETTE = {
    "Base": "#cfd8dc",              # blue-grey 100
    "SFT": "#c8e6c9",               # green 100
    "Vanilla GSPO": "#bbdefb",      # blue 100
    "Novelty bonus": "#f8bbd0",  # pink 100
    "Novelty Mini": "#ffe0b2",      # orange 100
}
EDGE = {
    "Base": "#607d8b",
    "SFT": "#2e7d32",
    "Vanilla GSPO": "#1565c0",
    "Novelty bonus": "#c2185b",
    "Novelty Mini": "#ef6c00",
}

OUT_DIR = Path(os.environ.get("FIG_OUT_DIR", "writing/neurips_paper/figures"))
OUT_DIR.mkdir(parents=True, exist_ok=True)


def parse_seq(s):
    if s is None:
        return []
    try:
        return json.loads(s)
    except (json.JSONDecodeError, TypeError):
        return []


def chain_depth(seq):
    """Max run length of consecutive exploit primitives."""
    best = run = 0
    for p in seq:
        if p in EXPLOIT:
            run += 1
            best = max(best, run)
        else:
            run = 0
    return best


def cc_loop_count(seq):
    return sum(1 for j in range(len(seq) - 1) if seq[j] == "COMPUTE" and seq[j + 1] == "CHECK")


def load(name):
    df = pd.read_parquet(PATHS[name])
    seqs = df["primitive_sequence"].apply(parse_seq)
    df = df.copy()
    df["_chain_depth"] = seqs.apply(chain_depth)
    df["_cc_count"] = seqs.apply(cc_loop_count)
    df["_cc_per_1k"] = (df["_cc_count"] / df["total_tokens"].replace(0, np.nan)) * 1000.0
    # primitive counts are already columns
    return df


# ============================================================================
# Figure 1: chain depth box plot, 4 checkpoints
# ============================================================================
def fig_chain_depth():
    ckpts = ["Base", "SFT", "Vanilla GSPO", "Novelty bonus"]
    data = [load(c)["_chain_depth"].to_numpy() for c in ckpts]

    fig, ax = plt.subplots(figsize=(7, 4.2))
    bp = ax.boxplot(
        data,
        labels=ckpts,
        widths=0.55,
        patch_artist=True,
        showfliers=False,
        medianprops=dict(color="black", linewidth=1.5),
        whiskerprops=dict(color="#444"),
        capprops=dict(color="#444"),
    )
    for patch, c in zip(bp["boxes"], ckpts):
        patch.set_facecolor(PALETTE[c])
        patch.set_edgecolor(EDGE[c])
        patch.set_linewidth(1.0)

    # overlay mean and p90
    means = [float(np.mean(d)) for d in data]
    p90s = [float(np.percentile(d, 90)) for d in data]
    xs = np.arange(1, len(ckpts) + 1)
    ax.scatter(xs, means, color="black", marker="D", s=42, zorder=5, label="mean")
    ax.scatter(xs, p90s, color="black", marker="^", s=46, zorder=5, label="p90")
    for x, m, p in zip(xs, means, p90s):
        ax.annotate(f"{m:.1f}", (x, m), xytext=(8, 0), textcoords="offset points",
                    fontsize=9, va="center")
        ax.annotate(f"p90={p:.0f}", (x, p), xytext=(8, 0), textcoords="offset points",
                    fontsize=8, va="center", color="#333")

    ax.set_ylabel("Chain depth")
    ax.legend(loc="upper left", fontsize=9)
    ax.grid(True, axis="y", alpha=0.3)
    fig.tight_layout()
    out = OUT_DIR / "fig_chain_depth.png"
    fig.savefig(out, dpi=160, bbox_inches="tight")
    print(f"wrote {out}")
    plt.close(fig)
    return {"means": dict(zip(ckpts, means)), "p90s": dict(zip(ckpts, p90s))}


# ============================================================================
# Figure 2: CC-loop density vs OlymMATH-Hard pass@32 (4 labeled points)
# ============================================================================
def fig_cc_loop_vs_passk():
    ckpts = ["Base", "SFT", "Vanilla GSPO", "Novelty bonus"]
    cc_means = []
    for c in ckpts:
        df = load(c)
        cc_means.append(float(df["_cc_per_1k"].mean()))
    passk = [PASSK_HARD[c] for c in ckpts]

    fig, ax = plt.subplots(figsize=(6.5, 4.5))
    # trend line through the three on-trend points (Base, SFT, vanilla GSPO)
    on_idx = [0, 1, 2]
    xs_on = np.array([cc_means[i] for i in on_idx])
    ys_on = np.array([passk[i] for i in on_idx])
    if len(xs_on) >= 2:
        slope, intercept = np.polyfit(xs_on, ys_on, 1)
        xline = np.linspace(min(cc_means) - 0.05, max(cc_means) + 0.05, 100)
        ax.plot(xline, slope * xline + intercept, color="#888", linestyle="--", lw=1.2,
                alpha=0.8, label="Base→SFT→GSPO trend (depth raises ceiling)")

    for i, c in enumerate(ckpts):
        marker = "*" if c == "Novelty bonus" else "o"
        size = 280 if c == "Novelty bonus" else 130
        ax.scatter(cc_means[i], passk[i], color=PALETTE[c], s=size, marker=marker,
                   edgecolor=EDGE[c], linewidth=1.2, zorder=4)
        # label
        offset = (8, 6) if c != "Novelty bonus" else (-6, -16)
        ha = "left" if c != "Novelty bonus" else "right"
        ax.annotate(c, (cc_means[i], passk[i]), xytext=offset, textcoords="offset points",
                    fontsize=10, fontweight="bold", ha=ha)

    # arrow from on-trend extension to Novelty bonus
    if "Novelty bonus" in ckpts:
        prod_x = cc_means[ckpts.index("Novelty bonus")]
        prod_y = passk[ckpts.index("Novelty bonus")]
        # Extrapolate trend at prod_x
        if len(xs_on) >= 2:
            trend_y_at_prod = slope * prod_x + intercept
            ax.annotate("",
                        xy=(prod_x, prod_y),
                        xytext=(prod_x, trend_y_at_prod),
                        arrowprops=dict(arrowstyle="->", color=EDGE["Novelty bonus"], lw=1.6))
            ax.text(prod_x + 0.01, (prod_y + trend_y_at_prod) / 2,
                    f"+{prod_y - trend_y_at_prod:.1f}pp\nabove trend",
                    fontsize=9, color=EDGE["Novelty bonus"], va="center")

    ax.set_xlabel("CC-loop density (COMPUTE→CHECK per 1k tokens, mean over rollouts)")
    ax.set_ylabel("OlymMATH-Hard pass@32 (%)")
    ax.set_title("Vanilla RL trades depth for ceiling; novelty bonus sits above the trend",
                 fontsize=10)
    ax.grid(True, alpha=0.3)
    ax.legend(loc="lower right", fontsize=9)
    fig.tight_layout()
    out = OUT_DIR / "fig_cc_loop_vs_passk.png"
    fig.savefig(out, dpi=160, bbox_inches="tight")
    print(f"wrote {out}")
    plt.close(fig)
    return {"cc_means": dict(zip(ckpts, cc_means)),
            "passk_hard": dict(zip(ckpts, passk))}


# ============================================================================
# Figure 3: HYPOTHESIZE + BACKTRACK raw counts (SFT, GSPO, prod15, mini)
# ============================================================================
def fig_exploration_primitive_counts():
    ckpts = ["Base", "SFT", "Vanilla GSPO", "Novelty bonus"]

    # per-trace mean counts of HYPOTHESIZE and BACKTRACK
    hyp_means = []
    btk_means = []
    for c in ckpts:
        df = load(c)
        hyp_means.append(float(df["HYPOTHESIZE_count"].mean()))
        btk_means.append(float(df["BACKTRACK_count"].mean()))

    xs = np.arange(len(ckpts))
    width = 0.38

    fig, ax = plt.subplots(figsize=(7, 4.2))
    b1 = ax.bar(xs - width / 2, hyp_means, width, label="HYPOTHESIZE",
                color="#d1c4e9", edgecolor="#5e35b1", linewidth=0.8)
    b2 = ax.bar(xs + width / 2, btk_means, width, label="BACKTRACK",
                color="#ffccbc", edgecolor="#d84315", linewidth=0.8)

    for bars, vals in [(b1, hyp_means), (b2, btk_means)]:
        for bar, v in zip(bars, vals):
            ax.text(bar.get_x() + bar.get_width() / 2, bar.get_height() + 0.04,
                    f"{v:.1f}", ha="center", va="bottom", fontsize=9)

    ax.set_xticks(xs)
    ax.set_xticklabels(ckpts, fontsize=10)
    ax.set_ylabel("Mean count per trace")
    ax.set_title("Exploratory primitive counts across recipes", fontsize=10)
    panel_max = max(max(hyp_means), max(btk_means))
    ax.set_ylim(0, panel_max + 1)
    ax.legend(loc="upper right", fontsize=9)
    ax.grid(True, axis="y", alpha=0.3)
    fig.tight_layout()
    out = OUT_DIR / "fig_exploration_primitive_counts.png"
    fig.savefig(out, dpi=160, bbox_inches="tight")
    print(f"wrote {out}")
    plt.close(fig)
    return {"hyp": dict(zip(ckpts, hyp_means)),
            "btk": dict(zip(ckpts, btk_means))}


def fig_exploit_primitive_counts():
    """§4.2 primitive layer: COMPUTE / CHECK / SETUP per trace at the four
    main checkpoints. Mirrors fig_exploration_primitive_counts (HYP/BTK)
    layout but groups the bars by primitive on x with the canonical
    checkpoint pastel palette as the colour key, matching the rest of §4.2."""
    ckpts = ["Base", "SFT", "Vanilla GSPO", "Novelty bonus"]
    prims = ["COMPUTE", "CHECK", "SETUP"]

    means = {p: [] for p in prims}
    for c in ckpts:
        df = load(c)
        for p in prims:
            means[p].append(float(df[f"{p}_count"].mean()))

    xs = np.arange(len(prims))
    n = len(ckpts)
    width = 0.8 / n

    fig, ax = plt.subplots(figsize=(7, 4.2))
    for i, c in enumerate(ckpts):
        vals = [means[p][i] for p in prims]
        offset = (i - (n - 1) / 2) * width
        bars = ax.bar(xs + offset, vals, width, label=c,
                      color=PALETTE[c], edgecolor=EDGE[c], linewidth=0.8)
        for bar, v in zip(bars, vals):
            ax.text(bar.get_x() + bar.get_width() / 2, bar.get_height() + 0.08,
                    f"{v:.1f}", ha="center", va="bottom", fontsize=8)

    ax.set_xticks(xs)
    ax.set_xticklabels(prims, fontsize=10)
    ax.set_ylabel("Mean count per trace")
    ax.set_title("Vanilla RL amplifies exploitation primitives on math",
                 fontsize=10)
    panel_max = max(v for vals in means.values() for v in vals)
    ax.set_ylim(0, panel_max + 1)
    ax.legend(loc="upper right", fontsize=9)
    ax.grid(True, axis="y", alpha=0.3)
    fig.tight_layout()
    out = OUT_DIR / "fig_exploit_primitive_counts.png"
    fig.savefig(out, dpi=160, bbox_inches="tight")
    print(f"wrote {out}")
    plt.close(fig)
    return {p: dict(zip(ckpts, means[p])) for p in prims}


def fig_exploration_primitive_counts_puzzles():
    """§4.3 puzzle-side panel: HYP/BTK per trace at SFT and vanilla GSPO on
    the puzzle training distribution. Mirrors fig_exploration_primitive_counts
    layout (same colours, same primitive groupings) but only 2 conditions
    (Base/Novelty puzzle parquets do not exist)."""
    ckpts = ["SFT", "Vanilla GSPO"]

    hyp_means = []
    btk_means = []
    for c in ckpts:
        df = pd.read_parquet(PATHS_PUZZLES[c])
        hyp_means.append(float(df["HYPOTHESIZE_count"].mean()))
        btk_means.append(float(df["BACKTRACK_count"].mean()))

    xs = np.arange(len(ckpts))
    width = 0.38

    fig, ax = plt.subplots(figsize=(4.0, 4.2))
    b1 = ax.bar(xs - width / 2, hyp_means, width, label="HYPOTHESIZE",
                color="#d1c4e9", edgecolor="#5e35b1", linewidth=0.8)
    b2 = ax.bar(xs + width / 2, btk_means, width, label="BACKTRACK",
                color="#ffccbc", edgecolor="#d84315", linewidth=0.8)

    for bars, vals in [(b1, hyp_means), (b2, btk_means)]:
        for bar, v in zip(bars, vals):
            ax.text(bar.get_x() + bar.get_width() / 2, bar.get_height() + 0.1,
                    f"{v:.2f}", ha="center", va="bottom", fontsize=9)

    ax.set_xticks(xs)
    ax.set_xticklabels(ckpts, fontsize=10)
    ax.set_ylabel("Mean count per trace")
    ax.set_title("Puzzles (Bridges 8x8 + Undead 5x5)", fontsize=10)
    panel_max = max(max(hyp_means), max(btk_means))
    ax.set_ylim(0, panel_max + 1)
    ax.legend(loc="upper right", fontsize=9)
    ax.grid(True, axis="y", alpha=0.3)
    fig.tight_layout()
    out = OUT_DIR / "fig_exploration_primitive_counts_puzzles.png"
    fig.savefig(out, dpi=160, bbox_inches="tight")
    print(f"wrote {out}")
    plt.close(fig)
    return {"hyp": dict(zip(ckpts, hyp_means)),
            "btk": dict(zip(ckpts, btk_means))}


if __name__ == "__main__":
    print("=== Figure 1: chain depth ===")
    r1 = fig_chain_depth()
    print(r1)
    print("\n=== Figure 2: CC-loop vs pass@32 ===")
    r2 = fig_cc_loop_vs_passk()
    print(r2)
    print("\n=== Figure 3: exploration primitive counts ===")
    r3 = fig_exploration_primitive_counts()
    print(r3)
    print("\n=== Figure 4: exploit primitive counts (§4.2) ===")
    r4 = fig_exploit_primitive_counts()
    print(r4)
    print("\n=== Figure 5: exploration primitive counts on puzzles (§4.3) ===")
    r5 = fig_exploration_primitive_counts_puzzles()
    print(r5)
