"""Plot pass@k curves for the §3 figures: bridges_8x8de, undead_5x5de,
pattern_5x5de, OlymMATH Easy, OlymMATH Hard. 4 lines: Base / SFT / Vanilla
GSPO / Novelty bonus.

Outputs (overwrite paper figs in `writing/neurips_paper/figures/`):
- fig_passk_bridges.png
- fig_passk_undead.png
- fig_passk_pattern.png
- fig_main_passk_easy.png
- fig_main_passk_hard.png

Pastel Material palette per CLAUDE.md.
"""

import json
import os
import sys
from pathlib import Path

import matplotlib.pyplot as plt
import pandas as pd

_PROJECT_ROOT = Path(__file__).resolve().parent.parent.parent
sys.path.insert(0, str(_PROJECT_ROOT / "scripts" / "evals"))
from compute_pass_at_k import load_and_score, pass_at_k  # noqa: E402


def _cached_correct_lists(jsonl_path: Path, num_workers: int = 32) -> list[list[bool]]:
    """load_and_score caches `correct_per_doc` next to the JSONL on first run.
    Re-runs skip math_verify (which times out on hard problems and is the
    plot-script bottleneck). num_workers>1 parallelises the first-run scoring."""
    cache = jsonl_path.with_suffix(".scored.json")
    if cache.exists():
        with open(cache) as f:
            return json.load(f)
    correct_per_doc, _, _, _ = load_and_score(str(jsonl_path), num_workers=num_workers)
    correct_per_doc = [[bool(c) for c in row] for row in correct_per_doc]
    with open(cache, "w") as f:
        json.dump(correct_per_doc, f)
    return correct_per_doc

K_VALUES = [1, 4, 8, 16, 32]
FIRST_N_DOCS = 100

PALETTE = {
    "Base": ("#cfd8dc", "#607d8b"),
    "SFT": ("#c8e6c9", "#2e7d32"),
    "Vanilla GSPO": ("#bbdefb", "#1565c0"),
    "Novelty bonus": ("#f8bbd0", "#c2185b"),
}

ROOT = _PROJECT_ROOT / "results"
# The paper figs dir is a git submodule (writing/neurips_paper). When that
# submodule is initialized, plots land directly in the paper. On devpods
# without SSH access to the submodule remote, init fails — `mkdir -p` creates
# the path locally and the figs are still produced; copy into the submodule
# once it's available, or commit them to `reports/figures/` instead.
OUT_DIR = Path(os.environ.get(
    "FIG_OUT_DIR",
    str(_PROJECT_ROOT / "writing" / "neurips_paper" / "figures"),
))


def correct_lists_from_parquet(parquet_path: Path, task_name: str) -> list[list[bool]]:
    df = pd.read_parquet(parquet_path)
    df = df[df["task_name"] == task_name].sort_values(["doc_id", "trace_id"])
    out = [[bool(c) for c in sub["correct"].tolist()]
           for _, sub in df.groupby("doc_id")]
    return out


def correct_lists_from_jsonl(jsonl_path: Path) -> list[list[bool]]:
    return _cached_correct_lists(jsonl_path)


def passk_curve(correct_lists: list[list[bool]], ks=K_VALUES) -> list[float]:
    return [pass_at_k(correct_lists, k) * 100 for k in ks]


def plot_curves(curves: dict[str, list[float]], n_per_line: dict[str, int],
                title: str, out_path: Path, ymin: float = -2,
                ymax: float | None = None, fake_base_zero: bool = False):
    fig, ax = plt.subplots(figsize=(6.5, 4.5))
    if fake_base_zero:
        # Puzzle plots: Base is shown as flat 0 since real puzzle eval data isn't available
        fill, edge = PALETTE["Base"]
        ax.plot(K_VALUES, [0.0] * len(K_VALUES), marker="o", color=edge,
                mfc=fill, mec=edge, markersize=8, linewidth=2.0,
                label="Base (≈0%)")
    for label, vals in curves.items():
        fill, edge = PALETTE[label]
        ax.plot(K_VALUES, vals, marker="o", color=edge, mfc=fill, mec=edge,
                markersize=5, linewidth=1.4, label=f"{label} (n={n_per_line[label]})")
    ax.set_xscale("log", base=2)
    ax.set_xticks(K_VALUES)
    ax.set_xticklabels([str(k) for k in K_VALUES])
    ax.set_xlabel("k (samples per problem)", fontsize=11)
    ax.set_ylabel("pass@k (%)", fontsize=11)
    ax.set_title(title, fontsize=12)
    ax.grid(True, alpha=0.3)
    all_vals = [v for vs in curves.values() for v in vs]
    if ymax is None:
        ymax = max(all_vals) + 5
    ax.set_ylim(ymin, ymax)
    ax.legend(loc="upper left", fontsize=9, framealpha=0.95)
    fig.tight_layout()
    fig.savefig(out_path, dpi=150)
    plt.close(fig)
    print(f"Saved: {out_path}")


# =============================================================================
# Puzzle plots — Base flat at 0 (real Base puzzle data not available)
# =============================================================================

def plot_puzzles():
    bridges_curves = {
        "SFT": passk_curve(correct_lists_from_parquet(
            ROOT / "exploration_analysis/v90_sft_puzzles/trace_level_metrics.parquet",
            "bridges_8x8de_pass32")),
        "Vanilla GSPO": passk_curve(correct_lists_from_parquet(
            ROOT / "exploration_analysis/v90_gspo_puzzles/trace_level_metrics.parquet",
            "bridges_8x8de_pass32")),
        "Novelty bonus": passk_curve(correct_lists_from_jsonl(
            ROOT / "novelty_prod_s15_puzzle_pass32/novelty_prod_s15/bridges_8x8de_pass32"
            / "checkpoints__olmo3-puzzle-grpo__novelty_production_gspo_topk100_a01__merged_step_15"
            / "samples_bridges_8x8de_pass32_2026-05-02T21-41-53.008758.jsonl")[:FIRST_N_DOCS]),
    }
    bridges_n = {"SFT": 100, "Vanilla GSPO": 100, "Novelty bonus": FIRST_N_DOCS}
    print("\n=== Bridges 8x8de pass@k ===")
    print(pd.DataFrame(bridges_curves, index=[f"pass@{k}" for k in K_VALUES]).round(1))
    plot_curves(bridges_curves, bridges_n,
                "Bridges 8x8de: pass@k (higher grid sizes than training)",
                OUT_DIR / "fig_passk_bridges.png", ymax=100, fake_base_zero=True)

    undead_curves = {
        "SFT": passk_curve(correct_lists_from_parquet(
            ROOT / "exploration_analysis/v90_sft_puzzles/trace_level_metrics.parquet",
            "undead_5x5de_pass32")),
        "Vanilla GSPO": passk_curve(correct_lists_from_parquet(
            ROOT / "exploration_analysis/v90_gspo_puzzles/trace_level_metrics.parquet",
            "undead_5x5de_pass32")),
        "Novelty bonus": passk_curve(correct_lists_from_jsonl(
            ROOT / "novelty_prod_s15_puzzle_pass32/novelty_prod_s15/undead_5x5de_pass32"
            / "checkpoints__olmo3-puzzle-grpo__novelty_production_gspo_topk100_a01__merged_step_15"
            / "samples_undead_5x5de_pass32_2026-05-03T03-16-07.130798.jsonl")),
    }
    undead_n = {"SFT": 100, "Vanilla GSPO": 100, "Novelty bonus": 100}
    print("\n=== Undead 5x5de pass@k ===")
    print(pd.DataFrame(undead_curves, index=[f"pass@{k}" for k in K_VALUES]).round(1))
    plot_curves(undead_curves, undead_n,
                "Undead 5x5de: pass@k (higher grid sizes than training)",
                OUT_DIR / "fig_passk_undead.png", ymax=80, fake_base_zero=True)

    # Pattern 5x5de — Pattern was trained at 4x4 only; 5x5 is OOD.
    # All 3 ckpts run on a matched 100-doc subset (--limit 100) at 40K-token
    # gen budget. 28K caused ~78-89% truncation rate; 40K drops it to <10%.
    pattern_curves = {
        "SFT": passk_curve(correct_lists_from_jsonl(
            ROOT / "pattern_5x5de_pass32/sft_v2_ep5"
            / "checkpoints__olmo3_7b_multi_puzzle_dsr_v2__merged_ep5_fp32"
            / "samples_pattern_5x5de_ascii_pass32_2026-05-05T22-49-09.714292.jsonl")),
        "Vanilla GSPO": passk_curve(correct_lists_from_jsonl(
            ROOT / "pattern_5x5de_pass32/gspo_v2_sft_v2_s20"
            / "checkpoints__olmo3-puzzle-grpo__multi_puzzle_gspo_olmo3_v2_sft_v2__merged_step_20"
            / "samples_pattern_5x5de_ascii_pass32_2026-05-06T02-53-52.119987.jsonl")),
        "Novelty bonus": passk_curve(correct_lists_from_jsonl(
            ROOT / "pattern_5x5de_pass32/novelty_prod_s15"
            / "checkpoints__olmo3-puzzle-grpo__novelty_production_gspo_topk100_a01__merged_step_15_fp32"
            / "samples_pattern_5x5de_ascii_pass32_2026-05-06T06-04-04.147787.jsonl")),
    }
    pattern_n = {"SFT": 100, "Vanilla GSPO": 100, "Novelty bonus": 100}
    print("\n=== Pattern 5x5de pass@k ===")
    print(pd.DataFrame(pattern_curves, index=[f"pass@{k}" for k in K_VALUES]).round(1))
    plot_curves(pattern_curves, pattern_n,
                "Pattern 5x5de: pass@k (trained at 4x4, OOD at 5x5)",
                OUT_DIR / "fig_passk_pattern.png", ymax=100, fake_base_zero=True)


# =============================================================================
# Math plots — real Base data
# =============================================================================

def _local_jsonl(rel: Path, glob: str = "samples_*.jsonl") -> Path:
    """Find a JSONL under `rel`. Use `glob` to disambiguate when a dir contains
    multiple eval JSONLs (e.g. pin to samples_olymp_math_hard_pass32_*)."""
    matches = list((ROOT / rel).rglob(glob))
    if not matches:
        raise FileNotFoundError(f"no JSONL matching {glob} under {ROOT / rel}")
    return matches[0]


def plot_math():
    # OlymMATH Easy — v2 canonical pass@32 JSONLs from S3 sft_v2_ep5/, gspo_v2_s20/
    easy_curves = {
        "Base": passk_curve(correct_lists_from_jsonl(
            _local_jsonl(Path("pass_at_k/olmo3_base/olymp_math_easy_pass64"),
                         "samples_olymp_math_easy*.jsonl"))),
        "SFT": passk_curve(correct_lists_from_jsonl(
            _local_jsonl(Path("sft_v2_ep5/olymp_math_easy_pass32"),
                         "samples_olymp_math_easy*.jsonl"))),
        "Vanilla GSPO": passk_curve(correct_lists_from_jsonl(
            _local_jsonl(Path("gspo_v2_s20/olymp_math_easy_pass32"),
                         "samples_olymp_math_easy*.jsonl"))),
        "Novelty bonus": passk_curve(correct_lists_from_jsonl(
            _local_jsonl(Path("novelty_production_step15/olymp_math_easy_pass32_v2"),
                         "samples_olymp_math_easy*.jsonl"))),
    }
    easy_n = {"Base": 100, "SFT": 100, "Vanilla GSPO": 100, "Novelty bonus": 100}
    print("\n=== OlymMATH Easy pass@k ===")
    print(pd.DataFrame(easy_curves, index=[f"pass@{k}" for k in K_VALUES]).round(1))
    plot_curves(easy_curves, easy_n,
                "OlymMATH Easy: pass@k",
                OUT_DIR / "fig_main_passk_easy.png", ymin=0, ymax=85)

    # OlymMATH Hard — uses v2-canonical checkpoints to match reports/olymp_math_rescored_results.md:
    #   SFT v2 ep5 (`merged_ep5_fp32`) → 23.0% pass@32
    #   GSPO v2 SFT s20 (`multi_puzzle_gspo_olmo3_v2_sft_v2/merged_step_20`) → 30.0% pass@32
    hard_curves = {
        "Base": passk_curve(correct_lists_from_jsonl(
            _local_jsonl(Path("pass_at_k/olmo3_base/olymp_math_hard_pass64"),
                         "samples_olymp_math_hard*.jsonl"))),
        "SFT": passk_curve(correct_lists_from_jsonl(
            _local_jsonl(Path("sft_baseline_math_eval_diverse/checkpoints__olmo3_7b_multi_puzzle_dsr_v2__merged_ep5_fp32"),
                         "samples_olymp_math_hard*.jsonl"))),
        "Vanilla GSPO": passk_curve(correct_lists_from_jsonl(
            _local_jsonl(Path("gspo_v2_sft_s20_math_eval_diverse/checkpoints__olmo3-puzzle-grpo__multi_puzzle_gspo_olmo3_v2_sft_v2__merged_step_20"),
                         "samples_olymp_math_hard*.jsonl"))),
        "Novelty bonus": passk_curve(correct_lists_from_jsonl(
            _local_jsonl(Path("novelty_production_step15/olymp_math_hard_pass32"),
                         "samples_olymp_math_hard*.jsonl"))),
    }
    hard_n = {"Base": 100, "SFT": 100, "Vanilla GSPO": 100, "Novelty bonus": 100}
    print("\n=== OlymMATH Hard pass@k ===")
    print(pd.DataFrame(hard_curves, index=[f"pass@{k}" for k in K_VALUES]).round(1))
    plot_curves(hard_curves, hard_n,
                "OlymMATH Hard: pass@k",
                OUT_DIR / "fig_main_passk_hard.png", ymin=0, ymax=40)


def plot_hmmt_combined():
    """HMMT combined (N=123, Feb 2024 + Feb 2025 + Nov 2025 + Feb 2026)
    pass@k curve at full K=[1, 4, 8, 16, 32] from JSONLs."""
    hmmt_curves = {
        "Base": passk_curve(correct_lists_from_jsonl(
            _local_jsonl(Path("hmmt_combined_pass32/base"),
                         "samples_hmmt_combined*.jsonl"))),
        "SFT": passk_curve(correct_lists_from_jsonl(
            _local_jsonl(Path("hmmt_combined_pass32/sft_v2_ep5"),
                         "samples_hmmt_combined*.jsonl"))),
        "Vanilla GSPO": passk_curve(correct_lists_from_jsonl(
            _local_jsonl(Path("hmmt_combined_pass32/gspo_v2_s20"),
                         "samples_hmmt_combined*.jsonl"))),
        "Novelty bonus": passk_curve(correct_lists_from_jsonl(
            _local_jsonl(Path("hmmt_combined_pass32/novelty_s15"),
                         "samples_hmmt_combined*.jsonl"))),
    }
    hmmt_n = {k: 123 for k in hmmt_curves}
    print("\n=== HMMT combined pass@k ===")
    print(pd.DataFrame(hmmt_curves, index=[f"pass@{k}" for k in K_VALUES]).round(1))
    plot_curves(hmmt_curves, hmmt_n,
                "HMMT combined: pass@k",
                OUT_DIR / "fig_main_passk_hmmt.png", ymin=0, ymax=55)


def plot_aime25():
    """AIME25 pass@k curve. JSONLs come from different S3 dirs per ckpt."""
    aime25_curves = {
        "Base": passk_curve(correct_lists_from_jsonl(_local_jsonl(
            Path("pass_at_k/olmo3_base/aime25_r1_pass64/allenai__OLMo-3-7B-Instruct-SFT"),
            "samples_aime25_r1*.jsonl"))),
        "SFT": passk_curve(correct_lists_from_jsonl(_local_jsonl(
            Path("sft_baseline_math_eval_diverse"),
            "samples_aime25_r1*.jsonl"))),
        "Vanilla GSPO": passk_curve(correct_lists_from_jsonl(_local_jsonl(
            Path("gspo_v2_sft_s20_math_eval_diverse"),
            "samples_aime25_r1*.jsonl"))),
        "Novelty bonus": passk_curve(correct_lists_from_jsonl(_local_jsonl(
            Path("novelty_production_step15/aime25_pass32_v2"),
            "samples_aime25_r1*.jsonl"))),
    }
    aime25_n = {k: 30 for k in aime25_curves}
    print("\n=== AIME25 pass@k ===")
    print(pd.DataFrame(aime25_curves, index=[f"pass@{k}" for k in K_VALUES]).round(1))
    plot_curves(aime25_curves, aime25_n,
                "AIME25: pass@k",
                OUT_DIR / "fig_appendix_passk_aime25.png", ymin=0, ymax=85)


def plot_omega_compositional():
    """OMEGA Compositional N=100 pass@k curve (strict scoring)."""
    glob = "samples_omega_compositional_pass32_*.jsonl"
    base_rel = Path("omega_compositional_pass32")
    curves = {
        "Base": passk_curve(correct_lists_from_jsonl(_local_jsonl(base_rel / "base", glob))),
        "SFT": passk_curve(correct_lists_from_jsonl(_local_jsonl(base_rel / "sft_v2_ep5", glob))),
        "Vanilla GSPO": passk_curve(correct_lists_from_jsonl(_local_jsonl(base_rel / "gspo_v2_s20", glob))),
        "Novelty bonus": passk_curve(correct_lists_from_jsonl(_local_jsonl(base_rel / "novelty_s15", glob))),
    }
    n_per = {k: 100 for k in curves}
    print("\n=== OMEGA Compositional pass@k ===")
    print(pd.DataFrame(curves, index=[f"pass@{k}" for k in K_VALUES]).round(1))
    plot_curves(curves, n_per,
                "OMEGA Compositional: pass@k",
                OUT_DIR / "fig_appendix_passk_omega_compositional.png", ymin=0, ymax=70)


def plot_omega_transformative():
    """OMEGA Transformative N=100 pass@k curve (strict scoring)."""
    glob = "samples_omega_transformative_pass32_*.jsonl"
    base_rel = Path("omega_transformative_pass32")
    curves = {
        "Base": passk_curve(correct_lists_from_jsonl(_local_jsonl(base_rel / "base", glob))),
        "SFT": passk_curve(correct_lists_from_jsonl(_local_jsonl(base_rel / "sft_v2_ep5", glob))),
        "Vanilla GSPO": passk_curve(correct_lists_from_jsonl(_local_jsonl(base_rel / "gspo_v2_s20", glob))),
        "Novelty bonus": passk_curve(correct_lists_from_jsonl(_local_jsonl(base_rel / "novelty_s15", glob))),
    }
    n_per = {k: 100 for k in curves}
    print("\n=== OMEGA Transformative pass@k ===")
    print(pd.DataFrame(curves, index=[f"pass@{k}" for k in K_VALUES]).round(1))
    plot_curves(curves, n_per,
                "OMEGA Transformative: pass@k",
                OUT_DIR / "fig_appendix_passk_omega_transformative.png", ymin=0, ymax=55)


def plot_omega_explorative_test_out():
    """OMEGA Explorative test_out (N=134) pass@k curve.

    Reads from `results/omega_explorative_test_out_pass32/<ckpt>/<model_dir>/samples_*.jsonl`
    pulled from `<S3_BUCKET>/results/omega_explorative_test_out_pass32/`.
    Strict `\\boxed{}`-only scoring (commit `b39f785d`); see omega utils.
    """
    base_rel = Path("omega_explorative_test_out_pass32")
    glob = "samples_omega_explorative_test_out_pass32_*.jsonl"
    omega_curves = {
        "Base": passk_curve(correct_lists_from_jsonl(_local_jsonl(base_rel / "base", glob))),
        "SFT": passk_curve(correct_lists_from_jsonl(_local_jsonl(base_rel / "sft_v2_ep5", glob))),
        "Vanilla GSPO": passk_curve(correct_lists_from_jsonl(_local_jsonl(base_rel / "gspo_v2_s20", glob))),
        "Novelty bonus": passk_curve(correct_lists_from_jsonl(_local_jsonl(base_rel / "novelty_s15", glob))),
    }
    omega_n = {k: 134 for k in omega_curves}
    print("\n=== OMEGA Explorative test_out pass@k ===")
    print(pd.DataFrame(omega_curves, index=[f"pass@{k}" for k in K_VALUES]).round(1))
    plot_curves(omega_curves, omega_n,
                "OMEGA Explorative: pass@k",
                OUT_DIR / "fig_main_passk_omega.png", ymin=0, ymax=60)


def main():
    OUT_DIR.mkdir(exist_ok=True, parents=True)
    plot_puzzles()
    plot_math()
    plot_hmmt_combined()
    plot_omega_explorative_test_out()


if __name__ == "__main__":
    main()
