"""§6 depth-aware metric candidates — produce 4 scatter plots replacing CC-loop density.

Per spec `reports/neurips/tasks/depth_metric_replacement.md`:
- Metric A: mean chain depth (longest exploit run per trace)
- Metric B: mean exploit-run length (averaged over runs, not weighted)
- Metric C: fraction of exploit primitives in runs ≥ 3
- Metric D: count of k=5 motif `[CHECK, COMPUTE, CHECK, COMPUTE, CHECK]` per trace

Each plot: scatter of (per-checkpoint-mean-metric × OlymMATH-Hard pass@32),
trend line through Base/SFT/Vanilla GSPO, deviation arrow at Novelty bonus.
Mirrors the visual style of `fig_cc_loop_vs_passk` from
`scripts/analysis/section6_7_figures.py`.
"""

import json
import os
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

ROOT = Path("results/exploration_analysis")
OUT_DIR = Path(os.environ.get("FIG_OUT_DIR", "writing/neurips_paper/figures"))
CSV_OUT = ROOT / "depth_metric_comparison.csv"
SUMMARY_OUT = Path("reports/neurips/depth_metric_comparison.md")

EXPLOIT = {"COMPUTE", "CHECK", "SETUP"}

PATHS = {
    "Base":             ROOT / "v90_base_math/trace_level_metrics.parquet",
    "SFT":              ROOT / "v90_sft_math/trace_level_metrics.parquet",
    "Vanilla GSPO":     ROOT / "v90_gspo_math/trace_level_metrics.parquet",
    "Novelty bonus": ROOT / "v90_prod_s15_math/trace_level_metrics.parquet",
}

# OlymMATH-Hard rescored pass@32 (from reports/olymp_math_rescored_results.md)
PASSK_HARD = {
    "Base":             16.0,
    "SFT":              23.0,
    "Vanilla GSPO":     29.0,
    "Novelty bonus": 36.0,
}

PALETTE = {
    "Base":             "#cfd8dc",
    "SFT":              "#c8e6c9",
    "Vanilla GSPO":     "#bbdefb",
    "Novelty bonus": "#f8bbd0",
}
EDGE = {
    "Base":             "#607d8b",
    "SFT":              "#2e7d32",
    "Vanilla GSPO":     "#1565c0",
    "Novelty bonus": "#c2185b",
}


# ---------------------------------------------------------------------------
# Per-trace metric helpers
# ---------------------------------------------------------------------------

def parse_seq(s):
    if s is None:
        return []
    if isinstance(s, list):
        return s
    try:
        return json.loads(s)
    except (json.JSONDecodeError, TypeError):
        return []


def exploit_runs(seq: list) -> list[int]:
    """Return list of maximal exploit run lengths."""
    runs = []
    cur = 0
    for p in seq:
        if p in EXPLOIT:
            cur += 1
        else:
            if cur > 0:
                runs.append(cur)
            cur = 0
    if cur > 0:
        runs.append(cur)
    return runs


def metric_A_chain_depth(seq):
    runs = exploit_runs(seq)
    return max(runs) if runs else 0


def metric_B_mean_run(seq):
    runs = exploit_runs(seq)
    return float(np.mean(runs)) if runs else np.nan


def metric_C_frac_in_long(seq, k=3):
    runs = exploit_runs(seq)
    total = sum(runs)
    if total == 0:
        return np.nan
    long_total = sum(r for r in runs if r >= k)
    return long_total / total


def metric_D_k5_motif(seq):
    """Count of [CHECK, COMPUTE, CHECK, COMPUTE, CHECK] subsequence as a contiguous bigram chain."""
    target = ("CHECK", "COMPUTE", "CHECK", "COMPUTE", "CHECK")
    if len(seq) < 5:
        return 0
    return sum(1 for i in range(len(seq) - 4)
               if tuple(seq[i:i + 5]) == target)


METRICS = {
    "A_chain_depth":  ("Mean chain depth (longest exploit run)",      metric_A_chain_depth),
    "B_mean_run":     ("Mean exploit-run length",                     metric_B_mean_run),
    "C_frac_long":    ("Fraction of exploit primitives in runs ≥ 3", metric_C_frac_in_long),
    "D_k5_motif":     ("k=5 CH→CO→CH→CO→CH count",  metric_D_k5_motif),
}


def compute_per_checkpoint() -> dict[str, dict[str, float]]:
    """Return {checkpoint: {metric_id: mean_value}}."""
    out = {}
    for ckpt, parq in PATHS.items():
        df = pd.read_parquet(parq)
        seqs = df["primitive_sequence"].apply(parse_seq)
        cell = {}
        for mid, (_, fn) in METRICS.items():
            vals = seqs.apply(fn)
            cell[mid] = float(vals.dropna().mean()) if vals.notna().any() else np.nan
        out[ckpt] = cell
        print(f"  {ckpt:<20} " + "  ".join(f"{mid}={cell[mid]:.3f}" for mid in METRICS))
    return out


# ---------------------------------------------------------------------------
# Plot one metric in fig_cc_loop_vs_passk style
# ---------------------------------------------------------------------------

def plot_metric(metric_id: str, metric_label: str,
                values: dict[str, float], out_path: Path) -> tuple[float, float, float]:
    """Returns (slope, intercept, deviation_pp_at_prod) for the comparison summary."""
    ckpts = ["Base", "SFT", "Vanilla GSPO", "Novelty bonus"]
    xs = [values[c] for c in ckpts]
    ys = [PASSK_HARD[c] for c in ckpts]

    fig, ax = plt.subplots(figsize=(6.5, 4.5))

    # Trend line through Base, SFT, Vanilla GSPO
    on_idx = [0, 1, 2]
    xs_on = np.array([xs[i] for i in on_idx])
    ys_on = np.array([ys[i] for i in on_idx])
    slope, intercept = np.polyfit(xs_on, ys_on, 1)
    xpad = max(0.05 * (max(xs) - min(xs) + 1e-9), 0.01)
    xline = np.linspace(min(xs) - xpad, max(xs) + xpad, 100)
    ax.plot(xline, slope * xline + intercept, color="#888", linestyle="--", lw=1.2,
            alpha=0.8, label="Base→SFT→GSPO trend")

    for i, c in enumerate(ckpts):
        marker = "*" if c == "Novelty bonus" else "o"
        size = 280 if c == "Novelty bonus" else 130
        ax.scatter(xs[i], ys[i], color=PALETTE[c], s=size, marker=marker,
                   edgecolor=EDGE[c], linewidth=1.2, zorder=4)
        offset = (8, 6) if c != "Novelty bonus" else (-6, -16)
        ha = "left" if c != "Novelty bonus" else "right"
        ax.annotate(c, (xs[i], ys[i]), xytext=offset, textcoords="offset points",
                    fontsize=10, fontweight="bold", ha=ha)

    # Deviation arrow at Novelty bonus
    prod_x = xs[3]
    prod_y = ys[3]
    trend_y_at_prod = slope * prod_x + intercept
    deviation = prod_y - trend_y_at_prod
    ax.annotate("", xy=(prod_x, prod_y), xytext=(prod_x, trend_y_at_prod),
                arrowprops=dict(arrowstyle="->", color=EDGE["Novelty bonus"], lw=1.6))
    sign = "+" if deviation >= 0 else ""
    label_x_offset = 0.02 * (max(xs) - min(xs) + 1e-9)
    ax.text(prod_x + label_x_offset, (prod_y + trend_y_at_prod) / 2,
            f"{sign}{deviation:.1f}pp\nfrom trend",
            fontsize=9, color=EDGE["Novelty bonus"], va="center")

    ax.set_xlabel(f"{metric_label} (mean over rollouts)", fontsize=10)
    ax.set_ylabel("OlymMATH-Hard pass@32 (%)", fontsize=10)
    ax.set_title(f"Metric {metric_id.split('_')[0]}: {metric_label}", fontsize=10)
    ax.grid(True, alpha=0.3)
    ax.legend(loc="lower right", fontsize=9, framealpha=0.95)
    fig.tight_layout()
    fig.savefig(out_path, dpi=160, bbox_inches="tight")
    plt.close(fig)
    print(f"  saved: {out_path}")
    return float(slope), float(intercept), float(deviation)


# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------

def main():
    OUT_DIR.mkdir(parents=True, exist_ok=True)
    SUMMARY_OUT.parent.mkdir(parents=True, exist_ok=True)
    CSV_OUT.parent.mkdir(parents=True, exist_ok=True)

    print("=== Computing per-checkpoint metrics ===")
    values = compute_per_checkpoint()

    # CSV
    rows = []
    for ckpt in ["Base", "SFT", "Vanilla GSPO", "Novelty bonus"]:
        for mid in METRICS:
            rows.append({"metric": mid, "checkpoint": ckpt,
                         "mean_value": values[ckpt][mid],
                         "passk_hard": PASSK_HARD[ckpt]})
    pd.DataFrame(rows).to_csv(CSV_OUT, index=False)
    print(f"\nSaved CSV: {CSV_OUT}")

    # Plot each metric
    print("\n=== Generating plots ===")
    deviations = {}
    for mid, (label, _) in METRICS.items():
        per_ckpt = {c: values[c][mid] for c in PATHS}
        slug = mid  # e.g. "A_chain_depth"
        out_path = OUT_DIR / f"fig_depth_metric_{slug}.png"
        slope, intercept, dev = plot_metric(slug, label, per_ckpt, out_path)
        deviations[mid] = {"slope": slope, "intercept": intercept,
                           "novelty_deviation_pp": dev,
                           "values": per_ckpt}

    # Markdown summary
    print(f"\n=== Writing summary: {SUMMARY_OUT} ===")
    with open(SUMMARY_OUT, "w") as f:
        f.write("# §6 depth-metric replacement — comparison of 4 candidates\n\n")
        f.write("**Date**: 2026-05-04\n")
        f.write("**Spec**: `reports/neurips/tasks/depth_metric_replacement.md`\n")
        f.write("**Generator**: `scripts/analysis/depth_metric_replacement.py`\n")
        f.write("**Source data**: `results/exploration_analysis/v90_{base,sft,gspo,prod_s15}_math/trace_level_metrics.parquet`\n\n")
        f.write("## Metric values per checkpoint\n\n")
        f.write("| Checkpoint | pass@32 | A: chain depth | B: mean run len | C: frac in runs ≥ 3 | D: k=5 motif count |\n")
        f.write("|---|---:|---:|---:|---:|---:|\n")
        for ckpt in ["Base", "SFT", "Vanilla GSPO", "Novelty bonus"]:
            f.write(f"| {ckpt} | {PASSK_HARD[ckpt]:.1f}% |"
                    f" {values[ckpt]['A_chain_depth']:.2f} |"
                    f" {values[ckpt]['B_mean_run']:.2f} |"
                    f" {values[ckpt]['C_frac_long']:.3f} |"
                    f" {values[ckpt]['D_k5_motif']:.2f} |\n")

        f.write("\n## Trend (Base→SFT→GSPO) and Novelty bonus deviation\n\n")
        f.write("| Metric | Slope (pp/unit) | Intercept (pp) | Prod x | Trend y at Prod x (pp) | Prod actual (pp) | **Deviation (pp)** |\n")
        f.write("|---|---:|---:|---:|---:|---:|---:|\n")
        for mid, d in deviations.items():
            prod_x = d["values"]["Novelty bonus"]
            trend_y = d["slope"] * prod_x + d["intercept"]
            f.write(f"| {mid} | {d['slope']:+.3f} | {d['intercept']:+.2f} |"
                    f" {prod_x:.3f} | {trend_y:.2f} | {PASSK_HARD['Novelty bonus']:.1f} |"
                    f" **{d['novelty_deviation_pp']:+.2f}** |\n")

        # Embed the figures inline so the markdown report is self-contained.
        f.write("\n## Figures\n\n")
        f.write("### Metric A — Mean chain depth (longest exploit run)\n\n")
        f.write("![Metric A: chain depth](../../writing/neurips_paper/figures/fig_depth_metric_A_chain_depth.png)\n\n")
        f.write("### Metric B — Mean exploit-run length\n\n")
        f.write("![Metric B: mean run length](../../writing/neurips_paper/figures/fig_depth_metric_B_mean_run.png)\n\n")
        f.write("### Metric C — Fraction of exploit primitives in runs ≥ 3\n\n")
        f.write("![Metric C: frac in runs ≥ 3](../../writing/neurips_paper/figures/fig_depth_metric_C_frac_long.png)\n\n")
        f.write("### Metric D — k=5 CHECK→COMPUTE→CHECK→COMPUTE→CHECK motif count\n\n")
        f.write("![Metric D: k=5 motif count](../../writing/neurips_paper/figures/fig_depth_metric_D_k5_motif.png)\n\n")

        f.write("## Decision criteria (per spec §Decision criteria)\n\n")
        f.write("Looking for the metric where novelty sits *clearly* off the trend with **less metric value than vanilla GSPO** (the \"depth without scattering\" critique). Positive deviation in the table above = novelty is *above* the Base→SFT→GSPO trend at its observed metric value (good).\n\n")
        f.write("- **Metric A (chain depth)**: already in Fig 4 as a box plot. Reuse on a scatter is light redundancy.\n")
        f.write("- **Metric B (mean run length)**: averages all runs, not just longest. Sensitive to the \"scattered short runs vs one long run\" distinction the spec flags.\n")
        f.write("- **Metric C (frac in runs ≥ 3)**: filters out one-off CV pairs. Most directly responsive to the clustering critique. Requires one-line definition in the figure caption.\n")
        f.write("- **Metric D (k=5 motif count)**: reuses Fig 6 motif (`CHECK→COMPUTE→CHECK→COMPUTE→CHECK`). Cleanest interpretability — a deep verify chain by definition.\n\n")
        f.write("Outputs:\n")
        for mid in METRICS:
            f.write(f"- `writing/neurips_paper/figures/fig_depth_metric_{mid}.png`\n")
        f.write(f"- `{CSV_OUT}` (16 rows, 4 metrics × 4 checkpoints)\n")
        f.write("\n## Recommendation\n\n")
        # Rank metrics by absolute novelty deviation
        ranked = sorted(deviations.items(), key=lambda kv: -abs(kv[1]["novelty_deviation_pp"]))
        top = ranked[0]
        f.write(f"By absolute novelty deviation: **{top[0]}** has the largest off-trend signal "
                f"({top[1]['novelty_deviation_pp']:+.1f} pp). Per the spec the writer's prior is **Metric D (k=5 motif)** for interpretability; "
                f"if the deviation magnitudes are comparable, prefer D for clean wording. "
                f"If A/B/C have substantially larger deviations, that's a stronger anti-confound case but with a more bespoke metric.\n")

    print(f"Done: {SUMMARY_OUT}")


if __name__ == "__main__":
    main()
