"""Within-problem paired analysis for §6/§7 mechanism claims.

For each of {Base, SFT, GSPO, Novelty bonus}:
  - Per problem (doc_id), among that problem's traces compute:
      mean(metric | solved)  vs  mean(metric | failed)
  - Per-problem difference d = solved_mean - failed_mean
  - Wilcoxon signed-rank test on d across problems.
  - Report mean d, 95% bootstrap CI, n_problems, p-value.

Metrics:
  1. CC-loop density: COMPUTE->VERIFY per 1k tokens.
  2. Chain depth: max consecutive {COMPUTE, VERIFY, CHECK, SETUP} run length, parsed
     from primitive_sequence per trace.
  3. Productivity ratio: per trace, (#successful recoveries) / (#recovery attempts).
     Recovery attempt = exploit→explore→exploit pattern in primitive_sequence.
     Successful recovery = the trailing exploit run has length >=2 with no immediate
     re-exploration (definition from neurips_results_master.md Claim 10).

Outputs:
  results/within_problem_paired/{cc_loop, chain_depth, productivity}_paired.csv
  results/within_problem_paired/fig_within_problem_paired.png
"""

import json
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from scipy import stats

CHECKPOINTS = [
    ("Base", "/tmp/v90_csvs/v90_base_traces.parquet"),
    ("SFT", "/tmp/v90_csvs/v90_sft_traces.parquet"),
    ("Vanilla GSPO", "/tmp/v90_csvs/v90_gspo_traces.parquet"),
    ("Novelty bonus", "/tmp/v90_csvs/v90_prod_s15_traces.parquet"),
]

EXPLOIT = {"COMPUTE", "CHECK", "SETUP"}
EXPLORE = {"HYPOTHESIZE", "BACKTRACK", "ENUMERATE"}

OUT_DIR = Path("results/within_problem_paired")
FIG_DIR = Path("reports/figures")
OUT_DIR.mkdir(parents=True, exist_ok=True)
FIG_DIR.mkdir(parents=True, exist_ok=True)


def parse_seq(s):
    if s is None:
        return []
    try:
        return json.loads(s)
    except (json.JSONDecodeError, TypeError):
        return []


def chain_depth(seq):
    """Max run length of consecutive exploit primitives."""
    best = run = 0
    for p in seq:
        if p in EXPLOIT:
            run += 1
            best = max(best, run)
        else:
            run = 0
    return best


def productivity_ratio(seq):
    """Per-trace productivity ratio.

    Recovery attempt: a maximal explore run that is preceded by an exploit run
    AND followed by an exploit primitive (i.e. exploit -> explore+ -> exploit).
    Successful recovery: trailing exploit run length >= 2.
    Returns NaN if no recovery attempts.
    """
    if len(seq) < 3:
        return np.nan

    attempts = successes = 0
    n = len(seq)
    i = 0
    while i < n:
        if seq[i] in EXPLOIT:
            # collect exploit run
            j = i
            while j < n and seq[j] in EXPLOIT:
                j += 1
            # j is now the index after the exploit run
            if j == n:
                break
            # explore run
            k = j
            while k < n and seq[k] in EXPLORE:
                k += 1
            # k is now after explore run
            if k > j and k < n and seq[k] in EXPLOIT:
                # have exploit -> explore+ -> exploit pattern
                attempts += 1
                # measure trailing exploit run
                m = k
                while m < n and seq[m] in EXPLOIT:
                    m += 1
                if (m - k) >= 2:
                    successes += 1
                i = k  # continue scanning from start of trailing exploit run
            else:
                i = max(k, j + 1)
        else:
            i += 1

    if attempts == 0:
        return np.nan
    return successes / attempts


def cc_loop_count(seq):
    """Count COMPUTE -> CHECK transitions in the parsed primitive sequence.

    Note: the trace-level parquet column ``COMPUTE->VERIFY_per_1k`` is stale
    (the actual classifier emits CHECK, not VERIFY), so we recompute from
    the primitive_sequence column directly.
    """
    return sum(1 for j in range(len(seq) - 1) if seq[j] == "COMPUTE" and seq[j + 1] == "CHECK")


def per_trace_metrics(df):
    """Add chain_depth, productivity, and cc_loop columns to a trace-level dataframe."""
    seqs = df["primitive_sequence"].apply(parse_seq)
    df = df.copy()
    df["chain_depth"] = seqs.apply(chain_depth)
    df["productivity"] = seqs.apply(productivity_ratio)
    cc_counts = seqs.apply(cc_loop_count)
    # per 1000 tokens (consistent with the per_1k convention used elsewhere)
    df["cc_loop"] = (cc_counts / df["total_tokens"].replace(0, np.nan)) * 1000.0
    return df


def paired_test(df, metric):
    """Within-problem paired test on (solved_mean - failed_mean) per doc_id."""
    diffs = []
    n_solved_traces = []
    n_failed_traces = []
    for doc_id, sub in df.groupby("doc_id"):
        sol = sub[sub["correct"]][metric].dropna()
        fail = sub[~sub["correct"]][metric].dropna()
        if len(sol) == 0 or len(fail) == 0:
            continue
        diffs.append(sol.mean() - fail.mean())
        n_solved_traces.append(len(sol))
        n_failed_traces.append(len(fail))
    diffs = np.array(diffs)

    if len(diffs) < 3:
        return {
            "n_problems": len(diffs),
            "mean_diff": float(diffs.mean()) if len(diffs) else np.nan,
            "ci_low": np.nan,
            "ci_high": np.nan,
            "p_value": np.nan,
            "test": "n/a",
            "median_n_solved": np.nan,
            "median_n_failed": np.nan,
        }

    # Wilcoxon signed-rank
    if np.allclose(diffs, 0):
        p = 1.0
    else:
        try:
            _, p = stats.wilcoxon(diffs, zero_method="wilcox", alternative="two-sided")
        except ValueError:
            p = np.nan

    # bootstrap 95% CI on the mean diff
    rng = np.random.default_rng(42)
    n_boot = 10_000
    boot_means = rng.choice(diffs, size=(n_boot, len(diffs)), replace=True).mean(axis=1)
    ci_low, ci_high = np.percentile(boot_means, [2.5, 97.5])

    return {
        "n_problems": len(diffs),
        "mean_diff": float(diffs.mean()),
        "ci_low": float(ci_low),
        "ci_high": float(ci_high),
        "p_value": float(p),
        "test": "wilcoxon-signed-rank",
        "median_n_solved": float(np.median(n_solved_traces)),
        "median_n_failed": float(np.median(n_failed_traces)),
    }


def main():
    rows = {"cc_loop": [], "chain_depth": [], "productivity": []}
    for ckpt_name, path in CHECKPOINTS:
        print(f"=== {ckpt_name} ===")
        df = pd.read_parquet(path)
        df = per_trace_metrics(df)
        n_total = len(df)
        n_correct = int(df["correct"].sum())
        print(f"  traces: {n_total}, correct: {n_correct} ({100*n_correct/n_total:.1f}%)")
        for metric in ["cc_loop", "chain_depth", "productivity"]:
            r = paired_test(df, metric)
            r["checkpoint"] = ckpt_name
            r["metric"] = metric
            rows[metric].append(r)
            sig = ""
            if not np.isnan(r["p_value"]):
                sig = "***" if r["p_value"] < 0.001 else "**" if r["p_value"] < 0.01 else "*" if r["p_value"] < 0.05 else ""
            print(
                f"  {metric:12s} mean_diff={r['mean_diff']:+.3f}  "
                f"95%CI=[{r['ci_low']:+.3f}, {r['ci_high']:+.3f}]  "
                f"n_problems={r['n_problems']:3d}  p={r['p_value']:.4f} {sig}"
            )

    # write CSVs
    for metric, rs in rows.items():
        out = pd.DataFrame(rs)[
            [
                "metric",
                "checkpoint",
                "mean_diff",
                "ci_low",
                "ci_high",
                "n_problems",
                "p_value",
                "median_n_solved",
                "median_n_failed",
                "test",
            ]
        ]
        path = OUT_DIR / f"{metric}_paired.csv"
        out.to_csv(path, index=False)
        print(f"wrote {path}")

    # ---- Summary bar chart (kept for backward compatibility) ----
    metric_label = {
        "cc_loop": "CC-loop\n(COMPUTE→CHECK per 1k tokens)",
        "chain_depth": "Chain depth\n(max consecutive exploit primitives)",
        "productivity": "Productivity ratio\n(successful recoveries / attempts)",
    }
    ckpt_order = [c[0] for c in CHECKPOINTS]
    palette = {"Base": "#888", "SFT": "#4caf50", "Vanilla GSPO": "#2196f3", "Novelty bonus": "#e91e63"}
    fig, axes = plt.subplots(1, 3, figsize=(13, 4), sharey=False)
    for ax, metric in zip(axes, ["cc_loop", "chain_depth", "productivity"]):
        rs = {r["checkpoint"]: r for r in rows[metric]}
        means = [rs[c]["mean_diff"] for c in ckpt_order]
        lows = [rs[c]["ci_low"] for c in ckpt_order]
        highs = [rs[c]["ci_high"] for c in ckpt_order]
        errs = [[m - lo for m, lo in zip(means, lows)], [hi - m for m, hi in zip(means, highs)]]
        colors = [palette[c] for c in ckpt_order]
        ax.bar(range(len(ckpt_order)), means, yerr=errs, color=colors, capsize=4)
        ax.axhline(0, color="black", lw=0.8)
        ax.set_xticks(range(len(ckpt_order)))
        ax.set_xticklabels(ckpt_order, rotation=20, ha="right", fontsize=9)
        ax.set_title(metric_label[metric], fontsize=10)
        ax.set_ylabel("Mean(solved) − Mean(failed) per problem")
        ymax = max(highs) if max(highs) > 0 else 0
        for i, c in enumerate(ckpt_order):
            r = rs[c]; p = r["p_value"]; n = r["n_problems"]
            sig = "***" if p < 0.001 else "**" if p < 0.01 else "*" if p < 0.05 else "n.s."
            ax.text(i, highs[i] + 0.02 * abs(ymax) if ymax else highs[i] + 0.01,
                    f"{sig}\nn={n}", ha="center", va="bottom", fontsize=8)
    fig.suptitle("Within-problem paired analysis (summary): per-problem (solved − failed) means",
                 fontsize=11)
    fig.tight_layout()
    fig.savefig(FIG_DIR / "fig_within_problem_paired.png", dpi=160, bbox_inches="tight")
    print(f"wrote {FIG_DIR / 'fig_within_problem_paired.png'}")
    plt.close(fig)


def per_checkpoint_paired_plots():
    """One figure per checkpoint, 3 panels (one per metric).

    Each panel shows per-problem paired points: solved-mean vs failed-mean,
    connected by a line. Reveals magnitude, direction, and consistency of the
    within-problem signal at that checkpoint without bar-chart compression.
    """
    metric_label = {
        "cc_loop": "CC-loop (COMPUTE→CHECK per 1k tokens)",
        "chain_depth": "Chain depth (max consecutive exploit)",
        "productivity": "Productivity ratio (success/attempt)",
    }
    palette = {"Base": "#888", "SFT": "#4caf50", "Vanilla GSPO": "#2196f3", "Novelty bonus": "#e91e63"}

    for ckpt_name, path in CHECKPOINTS:
        df = pd.read_parquet(path)
        df = per_trace_metrics(df)

        fig, axes = plt.subplots(1, 3, figsize=(15, 4.5))
        color = palette[ckpt_name]

        for ax, metric in zip(axes, ["cc_loop", "chain_depth", "productivity"]):
            pairs = []
            for doc_id, sub in df.groupby("doc_id"):
                sol = sub[sub["correct"]][metric].dropna()
                fail = sub[~sub["correct"]][metric].dropna()
                if len(sol) == 0 or len(fail) == 0:
                    continue
                pairs.append((doc_id, fail.mean(), sol.mean(), len(sol), len(fail)))

            if not pairs:
                ax.set_title(metric_label[metric] + "\n(no contributing problems)", fontsize=9)
                continue

            xs_fail = [p[1] for p in pairs]
            xs_sol = [p[2] for p in pairs]

            # paired connector lines
            for fmean, smean in zip(xs_fail, xs_sol):
                ax.plot([0, 1], [fmean, smean], color=color, alpha=0.35, lw=1.0)
            # endpoint dots
            ax.scatter([0] * len(xs_fail), xs_fail, color=color, s=30, alpha=0.65,
                       edgecolor="white", linewidth=0.6, zorder=3)
            ax.scatter([1] * len(xs_sol), xs_sol, color=color, s=30, alpha=0.65,
                       edgecolor="white", linewidth=0.6, zorder=3)
            # group means with thicker line
            ax.plot([0, 1], [np.mean(xs_fail), np.mean(xs_sol)], color="black",
                    lw=2.5, zorder=4, label="group mean")

            mean_diff = np.mean(xs_sol) - np.mean(xs_fail)
            try:
                _, pval = stats.wilcoxon(np.array(xs_sol) - np.array(xs_fail),
                                         alternative="two-sided")
            except ValueError:
                pval = float("nan")
            sig = "***" if pval < 0.001 else "**" if pval < 0.01 else "*" if pval < 0.05 else "n.s."

            ax.set_xticks([0, 1])
            ax.set_xticklabels(["failed traces", "solved traces"])
            ax.set_xlim(-0.3, 1.3)
            ax.set_title(f"{metric_label[metric]}\nΔ={mean_diff:+.3f}, p={pval:.3f} {sig}, n={len(pairs)}",
                         fontsize=9)
            ax.grid(True, axis="y", alpha=0.3)

        slug = ckpt_name.lower().replace(" ", "_")
        fig.suptitle(f"Per-problem paired metrics: {ckpt_name}", fontsize=12, y=1.02)
        fig.tight_layout()
        out = FIG_DIR / f"fig_within_problem_paired_{slug}.png"
        fig.savefig(out, dpi=160, bbox_inches="tight")
        print(f"wrote {out}")
        plt.close(fig)


if __name__ == "__main__":
    main()
    per_checkpoint_paired_plots()
