"""Per-problem progression scatter plots: Base→SFT and SFT→Best RL (Novelty bonus).

Each panel: x = solve_rate at "from" stage, y = solve_rate at "to" stage.
Identity line shows where solve_rate is unchanged. Points above = unlocked,
below = lost. Vertical jitter separates overlapping integer-multiple solve
rates (e.g. many problems at 0/32 = 0%).

Reuses _cached_correct_lists from plot_puzzle_passk.py (cache + parallel).

Outputs:
- writing/neurips_paper/figures/fig_per_problem_progression.png  (Hard, 1×2 panels)
- writing/neurips_paper/figures/fig_per_problem_progression_easy.png  (Easy, 1×2 panels)
"""

import os
import sys
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np

_PROJECT_ROOT = Path(__file__).resolve().parent.parent.parent
sys.path.insert(0, str(_PROJECT_ROOT / "scripts" / "analysis"))
from plot_puzzle_passk import _cached_correct_lists, _local_jsonl  # noqa: E402

OUT_DIR = Path(os.environ.get(
    "FIG_OUT_DIR",
    str(_PROJECT_ROOT / "writing" / "neurips_paper" / "figures"),
))

# Pastel Material palette — checkpoint colors per CLAUDE.md
PALETTE = {
    "Base":              ("#cfd8dc", "#607d8b"),
    "SFT":               ("#c8e6c9", "#2e7d32"),
    "Vanilla GSPO":      ("#bbdefb", "#1565c0"),
    "Novelty bonus":  ("#f8bbd0", "#c2185b"),
}

# Per-benchmark JSONL paths, relative to results/ — matches plot_puzzle_passk.py
HARD_PATHS = {
    "Base": (Path("pass_at_k/olmo3_base/olymp_math_hard_pass64"),
             "samples_olymp_math_hard*.jsonl"),
    "SFT":  (Path("sft_baseline_math_eval_diverse/checkpoints__olmo3_7b_multi_puzzle_dsr_v2__merged_ep5_fp32"),
             "samples_olymp_math_hard*.jsonl"),
    "Novelty bonus": (Path("novelty_production_step15/olymp_math_hard_pass32"),
             "samples_olymp_math_hard*.jsonl"),
}
EASY_PATHS = {
    "Base": (Path("pass_at_k/olmo3_base/olymp_math_easy_pass64"),
             "samples_olymp_math_easy*.jsonl"),
    "SFT":  (Path("sft_v2_ep5/olymp_math_easy_pass32"),
             "samples_olymp_math_easy*.jsonl"),
    "Novelty bonus": (Path("novelty_production_step15/olymp_math_easy_pass32_v2"),
             "samples_olymp_math_easy*.jsonl"),
}


def per_doc_solve_rates(rel_path: Path, glob: str) -> np.ndarray:
    """Return one value per problem = mean(correct) over its rollouts."""
    correct_lists = _cached_correct_lists(_local_jsonl(rel_path, glob))
    return np.array([sum(row) / len(row) for row in correct_lists])


def jitter(x: np.ndarray, amount: float = 0.012, seed: int = 0) -> np.ndarray:
    """Add small symmetric Gaussian noise to break point overlap on grid lines."""
    rng = np.random.default_rng(seed)
    return x + rng.normal(0, amount, size=x.shape)


def plot_one_transition(ax, from_rates, to_rates, from_label, to_label, *,
                        from_color, to_color):
    # Identity line
    ax.plot([0, 1], [0, 1], "--", color="#888", alpha=0.5, linewidth=1.0)

    # Vertical jitter so stacked-zero / stacked-1.0 points separate
    yj = jitter(to_rates, amount=0.012, seed=42)
    xj = jitter(from_rates, amount=0.0, seed=43)  # no x jitter

    # Color points by transition direction (above identity = improved)
    delta = to_rates - from_rates
    above = delta > 0.02
    below = delta < -0.02
    same = ~above & ~below

    ax.scatter(xj[same], yj[same], color="#bbb", s=14, alpha=0.55,
               edgecolor="none", zorder=2, label=f"≈unchanged (n={int(same.sum())})")
    ax.scatter(xj[below], yj[below], color="#cc6677", s=20, alpha=0.75,
               edgecolor="white", linewidth=0.4, zorder=3,
               label=f"regressed (n={int(below.sum())})")
    ax.scatter(xj[above], yj[above], color=to_color, s=22, alpha=0.85,
               edgecolor="white", linewidth=0.4, zorder=4,
               label=f"unlocked (n={int(above.sum())})")

    ax.set_xlabel(f"{from_label} solve rate", fontsize=10)
    ax.set_ylabel(f"{to_label} solve rate", fontsize=10)
    ax.set_xlim(-0.03, 1.03)
    ax.set_ylim(-0.03, 1.03)
    ax.set_aspect("equal")
    ax.grid(alpha=0.25)
    ax.legend(loc="upper left", fontsize=8, framealpha=0.92)


def plot_panel(paths: dict, title: str, out_path: Path):
    base = per_doc_solve_rates(*paths["Base"])
    sft = per_doc_solve_rates(*paths["SFT"])
    rl = per_doc_solve_rates(*paths["Novelty bonus"])

    # Align lengths (Base may have 100 docs at 64 reps; SFT/RL may have 100 at 32)
    n = min(len(base), len(sft), len(rl))
    base, sft, rl = base[:n], sft[:n], rl[:n]

    fig, axes = plt.subplots(1, 3, figsize=(15.5, 5.2))
    fig.suptitle(title, fontsize=12, fontweight="bold")

    plot_one_transition(axes[0], base, sft, "Base", "SFT",
                        from_color=PALETTE["Base"][1], to_color=PALETTE["SFT"][1])
    axes[0].set_title("Base → SFT", fontsize=11)

    plot_one_transition(axes[1], sft, rl, "SFT", "Novelty bonus",
                        from_color=PALETTE["SFT"][1],
                        to_color=PALETTE["Novelty bonus"][1])
    axes[1].set_title("SFT → Novelty bonus", fontsize=11)

    plot_one_transition(axes[2], base, rl, "Base", "Novelty bonus",
                        from_color=PALETTE["Base"][1],
                        to_color=PALETTE["Novelty bonus"][1])
    axes[2].set_title("Base → Novelty bonus", fontsize=11)

    fig.tight_layout(rect=[0, 0, 1, 0.95])
    fig.savefig(out_path, dpi=150)
    plt.close(fig)
    print(f"Saved: {out_path}  (n={n} problems)")


def main():
    OUT_DIR.mkdir(exist_ok=True, parents=True)
    plot_panel(HARD_PATHS, "OlymMATH Hard: per-problem solve rate",
               OUT_DIR / "fig_per_problem_progression.png")
    plot_panel(EASY_PATHS, "OlymMATH Easy: per-problem solve rate",
               OUT_DIR / "fig_per_problem_progression_easy.png")


if __name__ == "__main__":
    main()
