"""Stage 6b/6c: mint V3-SC labels on math spans + score the trained classifier.

Workflow:
    1. Read sampled math spans (from sample_math_spans.py)
    2. Run judge_runner --n-samples 5 --temperature 0.7 --model deepseek-chat
       to mint V3-SC labels (resumable on span_id)
    3. Predict with classifier, compute per-class F1 + per-benchmark breakdown.
"""
from __future__ import annotations

import argparse
import json
import subprocess
import sys
from collections import Counter
from pathlib import Path

import joblib
import numpy as np
from sklearn.metrics import (
    classification_report,
    cohen_kappa_score,
    confusion_matrix,
    f1_score,
)

from analysis.exploration.llm_validation._client import PRIMITIVES
from analysis.exploration.llm_validation.classifier.evaluate import (
    per_class_table,
    render_confusion_md,
)


def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--spans", required=True, type=Path,
                    help="silver_math_spans.jsonl from sample_math_spans.py")
    ap.add_argument("--judgments", required=True, type=Path,
                    help="V3-SC labels output (resumable)")
    ap.add_argument("--bundle", required=True, type=Path,
                    help="trained classifier joblib")
    ap.add_argument("--out", required=True, type=Path,
                    help="OOD eval markdown report")
    ap.add_argument(
        "--in-dist-macro-f1", type=float, default=None,
        help="(optional) in-dist macro-F1 from Stage 4, for OOD-gap reporting",
    )
    ap.add_argument("--workers", type=int, default=100)
    ap.add_argument("--n-samples", type=int, default=5)
    args = ap.parse_args()

    # ---- Run V3-SC labeling (subprocess; resumable) ----
    print("Running V3-SC on math spans...")
    cmd = [
        sys.executable, "-m", "analysis.exploration.llm_validation.judge_runner",
        "--in", str(args.spans), "--out", str(args.judgments),
        "--model", "deepseek-chat",
        "--temperature", "0.7",
        "--workers", str(args.workers),
        "--n-samples", str(args.n_samples),
    ]
    print("  cmd:", " ".join(cmd))
    subprocess.run(cmd, check=True)

    # ---- Load and join ----
    spans = {json.loads(l)["span_id"]: json.loads(l) for l in open(args.spans)}
    judg = {json.loads(l)["span_id"]: json.loads(l) for l in open(args.judgments)}
    rows = []
    for sid, sp in spans.items():
        j = judg.get(sid)
        if j is None or j["llm_label"] not in PRIMITIVES:
            continue
        rows.append({**sp, "llm_label": j["llm_label"]})
    print(f"Joined {len(rows)} math spans with V3-SC labels.")

    # ---- Predict and score ----
    bundle = joblib.load(args.bundle)
    pipe = bundle["feature_pipeline"]
    clf = bundle["classifier"]
    labels = bundle["label_order"]
    idx = {l: i for i, l in enumerate(labels)}

    X = pipe.transform(rows)
    y_pred = clf.predict(X)
    y_pred_lbl = [labels[i] for i in y_pred]
    y_true = np.array([idx[r["llm_label"]] for r in rows])
    y_pred = np.array(y_pred)

    macro_f1 = f1_score(y_true, y_pred, average="macro", zero_division=0)
    weighted_f1 = f1_score(y_true, y_pred, average="weighted", zero_division=0)
    kappa = cohen_kappa_score(y_true, y_pred)
    cm = confusion_matrix(y_true, y_pred, labels=list(range(len(labels))))

    md = []
    md.append("# OOD math evaluation (classifier vs V3-SC on AIME/Olympiad)\n")
    md.append("## Summary\n")
    md.append(f"- Math spans evaluated: **{len(rows)}**")
    md.append(f"- **Macro-F1: {macro_f1:.4f}**, weighted-F1: {weighted_f1:.4f}")
    md.append(f"- Cohen's kappa vs V3-SC: **{kappa:.4f}**")
    if args.in_dist_macro_f1 is not None:
        delta = macro_f1 - args.in_dist_macro_f1
        md.append(
            f"- vs in-dist macro-F1 ({args.in_dist_macro_f1:.4f}): "
            f"**Δ = {delta*100:+.1f} pp**"
        )
    md.append("")

    # ---- Per-benchmark breakdown ----
    md.append("## Per-benchmark macro-F1\n")
    md.append("| Benchmark | Checkpoint | N | Macro-F1 | Kappa |")
    md.append("|---|---|---|---|---|")
    bench_groups: dict[tuple, list[int]] = {}
    for i, r in enumerate(rows):
        key = (r["task_name"], r["checkpoint_id"])
        bench_groups.setdefault(key, []).append(i)
    for (bm, ckpt), indices in sorted(bench_groups.items()):
        sub_y_true = y_true[indices]
        sub_y_pred = y_pred[indices]
        bf1 = f1_score(sub_y_true, sub_y_pred, average="macro", zero_division=0)
        bk = cohen_kappa_score(sub_y_true, sub_y_pred)
        md.append(f"| {bm} | {ckpt} | {len(indices)} | {bf1:.4f} | {bk:.4f} |")
    md.append("")

    # ---- Per-class F1 ----
    md.append("## Per-class metrics\n")
    md.append(per_class_table(y_true, y_pred, labels))
    md.append("")

    md.append("## Confusion matrix\n")
    md.append(render_confusion_md(cm, labels))
    md.append("")

    # ---- Per-class V3-SC distribution ----
    md.append("## V3-SC label distribution on math spans\n")
    md.append("| Class | N | % |")
    md.append("|---|---|---|")
    counts = Counter(r["llm_label"] for r in rows)
    total = sum(counts.values())
    for p in labels:
        c = counts.get(p, 0)
        md.append(f"| {p} | {c} | {c/total*100:.1f}% |")
    md.append("")

    args.out.parent.mkdir(parents=True, exist_ok=True)
    args.out.write_text("\n".join(md))
    print(f"Wrote {args.out}")
    print(f"  macro-F1={macro_f1:.4f}, kappa={kappa:.4f}")


if __name__ == "__main__":
    main()
