"""Stage 4 + 5: evaluate trained classifier.

Modes:
    --mode in_dist     test on the held-out 15% trace-grouped split
    --mode calibration test on the existing 250-span sample with R1-SC labels

Reports per-class precision/recall/F1, 9x9 confusion matrix, kappa,
and a side-by-side comparison vs the heuristic.
"""
from __future__ import annotations

import argparse
import json
from collections import Counter
from pathlib import Path

import joblib
import numpy as np
from sklearn.metrics import (
    classification_report,
    cohen_kappa_score,
    confusion_matrix,
    f1_score,
)

from analysis.exploration.llm_validation._client import (
    PRIMITIVES,
    to_new_taxonomy,
)
from analysis.exploration.llm_validation.classifier.prepare_dataset import (
    load_silver,
)


def predict(bundle: dict, rows: list[dict]) -> list[str]:
    pipe = bundle["feature_pipeline"]
    clf = bundle["classifier"]
    labels = bundle["label_order"]
    X = pipe.transform(rows)
    y = clf.predict(X)
    return [labels[i] for i in y]


def render_confusion_md(cm: np.ndarray, labels: list[str]) -> str:
    header = "| true ↓ \\ pred → | " + " | ".join(labels) + " |\n"
    sep = "|" + "---|" * (len(labels) + 1) + "\n"
    body = ""
    for i, l in enumerate(labels):
        cells = [
            f"**{int(cm[i, j])}**" if i == j else str(int(cm[i, j]))
            for j in range(len(labels))
        ]
        body += f"| {l} | " + " | ".join(cells) + " |\n"
    return header + sep + body


def per_class_table(y_true: np.ndarray, y_pred: np.ndarray, labels: list[str]) -> str:
    report = classification_report(
        y_true, y_pred, labels=list(range(len(labels))),
        target_names=labels, output_dict=True, zero_division=0,
    )
    lines = ["| Class | Precision | Recall | F1 | Support |", "|---|---|---|---|---|"]
    for l in labels:
        d = report.get(l, {"precision": 0, "recall": 0, "f1-score": 0, "support": 0})
        lines.append(
            f"| {l} | {d['precision']:.3f} | {d['recall']:.3f} "
            f"| {d['f1-score']:.3f} | {int(d['support'])} |"
        )
    macro = report.get("macro avg", {})
    weighted = report.get("weighted avg", {})
    lines.append(
        f"| **macro avg** | {macro.get('precision', 0):.3f} | {macro.get('recall', 0):.3f} "
        f"| **{macro.get('f1-score', 0):.3f}** | {int(macro.get('support', 0))} |"
    )
    lines.append(
        f"| **weighted avg** | {weighted.get('precision', 0):.3f} "
        f"| {weighted.get('recall', 0):.3f} "
        f"| {weighted.get('f1-score', 0):.3f} | {int(weighted.get('support', 0))} |"
    )
    return "\n".join(lines)


def heuristic_baseline_on(rows: list[dict]) -> list[str]:
    """Map each row's heuristic_label (legacy 10-class) into the 9-class taxonomy."""
    out = []
    for r in rows:
        h = r.get("heuristic_label")
        if h is None:
            out.append("OTHER")
        else:
            out.append(to_new_taxonomy(h))
    return out


def evaluate_in_dist(
    bundle: dict, silver_path: Path, splits_path: Path, out_md: Path,
) -> None:
    rows = load_silver(silver_path)
    by_id = {r["span_id"]: r for r in rows}
    splits = json.load(open(splits_path))
    test_rows = [by_id[sid] for sid in splits["test_ids"] if sid in by_id]
    if not test_rows:
        print("No test rows found")
        return
    print(f"In-dist test: {len(test_rows)} rows")

    labels = bundle["label_order"]
    y_true_str = [r["llm_label"] for r in test_rows]
    y_pred_str = predict(bundle, test_rows)
    idx = {l: i for i, l in enumerate(labels)}
    y_true = np.array([idx[l] for l in y_true_str])
    y_pred = np.array([idx[l] for l in y_pred_str])

    cm = confusion_matrix(y_true, y_pred, labels=list(range(len(labels))))
    macro_f1 = f1_score(y_true, y_pred, average="macro", zero_division=0)
    weighted_f1 = f1_score(y_true, y_pred, average="weighted", zero_division=0)
    kappa = cohen_kappa_score(y_true, y_pred)

    # Heuristic baseline on the same test rows
    heur_str = heuristic_baseline_on(test_rows)
    heur_idx = np.array([idx.get(l, idx["OTHER"]) for l in heur_str])
    heur_macro = f1_score(y_true, heur_idx, average="macro", zero_division=0)
    heur_kappa = cohen_kappa_score(y_true, heur_idx)
    heur_acc = float((heur_idx == y_true).mean())

    md = []
    md.append("# In-distribution evaluation (classifier vs V3-SC)\n")
    md.append("## Summary\n")
    md.append(f"- Test rows: **{len(test_rows)}** (held-out 15% by trace)")
    md.append(f"- Classifier accuracy: **{(y_pred == y_true).mean()*100:.1f}%**")
    md.append(f"- Classifier **macro-F1: {macro_f1:.4f}**, weighted-F1: {weighted_f1:.4f}")
    md.append(f"- Classifier Cohen's kappa vs V3-SC: **{kappa:.4f}**")
    md.append("")
    md.append("## Heuristic baseline on the same test rows")
    md.append(f"- Heuristic accuracy vs V3-SC: {heur_acc*100:.1f}%")
    md.append(f"- Heuristic macro-F1 vs V3-SC: {heur_macro:.4f}")
    md.append(f"- Heuristic kappa vs V3-SC: {heur_kappa:.4f}")
    md.append(f"- **Lift over heuristic** (macro-F1): +{(macro_f1 - heur_macro)*100:.1f} pp")
    md.append("")
    md.append("## Per-class metrics (classifier)\n")
    md.append(per_class_table(y_true, y_pred, labels))
    md.append("")
    md.append("## Confusion matrix\n")
    md.append("Rows = V3-SC label (true), columns = classifier prediction. Diagonal in **bold**.\n")
    md.append(render_confusion_md(cm, labels))
    md.append("")

    out_md.parent.mkdir(parents=True, exist_ok=True)
    out_md.write_text("\n".join(md))
    print(f"Wrote {out_md}")
    print(f"  macro-F1={macro_f1:.4f}, weighted-F1={weighted_f1:.4f}, kappa={kappa:.4f}")
    print(f"  lift over heuristic: +{(macro_f1 - heur_macro)*100:.1f} pp macro-F1")


def evaluate_calibration(
    bundle: dict,
    spans_path: Path, r1_judgments_path: Path, v3_judgments_path: Path,
    out_md: Path,
) -> None:
    """Compare classifier predictions against R1-SC and V3-SC on the 250-span set.

    Both R1 and V3 used the v4 9-class taxonomy in the comparison runs.
    """
    spans = {json.loads(l)["span_id"]: json.loads(l) for l in open(spans_path)}
    r1 = {json.loads(l)["span_id"]: json.loads(l) for l in open(r1_judgments_path)}
    v3 = {json.loads(l)["span_id"]: json.loads(l) for l in open(v3_judgments_path)}

    # Restrict to spans where both judges have a real label in our 9-class set.
    common = []
    for sid, sp in spans.items():
        if sid in r1 and sid in v3 \
            and r1[sid]["llm_label"] in PRIMITIVES \
            and v3[sid]["llm_label"] in PRIMITIVES:
            # Need to add n_episodes_in_trace; fall back to 1 if missing
            if "n_episodes_in_trace" not in sp:
                sp["n_episodes_in_trace"] = 1  # synthetic; pos feature degrades
            common.append(sid)

    print(f"Calibration set: {len(common)} spans (both R1 and V3 have valid labels)")
    rows = [spans[sid] for sid in common]
    r1_labels = [r1[sid]["llm_label"] for sid in common]
    v3_labels = [v3[sid]["llm_label"] for sid in common]
    pred_labels = predict(bundle, rows)

    labels = bundle["label_order"]
    idx = {l: i for i, l in enumerate(labels)}
    y_r1 = np.array([idx[l] for l in r1_labels])
    y_v3 = np.array([idx[l] for l in v3_labels])
    y_pred = np.array([idx[l] for l in pred_labels])

    def _stats(name_pair, y_true, y_pred):
        acc = float((y_true == y_pred).mean())
        f1 = f1_score(y_true, y_pred, average="macro", zero_division=0)
        kappa = cohen_kappa_score(y_true, y_pred)
        return name_pair, acc, f1, kappa

    rows_out = []
    rows_out.append(_stats(f"classifier ↔ R1-SC", y_r1, y_pred))
    rows_out.append(_stats(f"classifier ↔ V3-SC", y_v3, y_pred))
    rows_out.append(_stats(f"V3-SC ↔ R1-SC (baseline)", y_r1, y_v3))

    md = []
    md.append("# Calibration evaluation (classifier vs R1-SC + V3-SC)\n")
    md.append(f"Calibration set: {len(common)} spans from `sampled_spans.jsonl`.\n")
    md.append("## Pairwise summary\n")
    md.append("| Comparison | Accuracy | Macro-F1 | Cohen's kappa |")
    md.append("|---|---|---|---|")
    for name, acc, f1, kappa in rows_out:
        md.append(f"| {name} | {acc*100:.1f}% | {f1:.4f} | {kappa:.4f} |")
    md.append("")
    md.append("## Interpretation")
    md.append(
        "- A healthy classifier reaches **classifier ↔ R1-SC ≈ V3-SC ↔ R1-SC** "
        "(78.4%, κ=0.687 in our 9-class run). Beating it suggests overfitting "
        "to V3-SC quirks; trailing it materially means the classifier is "
        "weaker than its teacher.\n"
        "- `classifier ↔ V3-SC` should be the highest of the three rows; this "
        "is the teacher and we expect close imitation."
    )
    md.append("")
    md.append("## Confusion matrix vs R1-SC")
    cm = confusion_matrix(y_r1, y_pred, labels=list(range(len(labels))))
    md.append(render_confusion_md(cm, labels))

    out_md.parent.mkdir(parents=True, exist_ok=True)
    out_md.write_text("\n".join(md))
    print(f"Wrote {out_md}")
    for name, acc, f1, kappa in rows_out:
        print(f"  {name}: acc={acc*100:.1f}%, macro-F1={f1:.4f}, kappa={kappa:.4f}")


def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--mode", required=True,
                    choices=["in_dist", "calibration"])
    ap.add_argument("--bundle", required=True, type=Path)
    ap.add_argument("--silver", type=Path, help="for --mode in_dist")
    ap.add_argument("--splits", type=Path, help="for --mode in_dist (auto from bundle)")
    ap.add_argument("--spans", type=Path, help="for --mode calibration: sampled_spans.jsonl")
    ap.add_argument("--r1-judgments", type=Path, help="for --mode calibration")
    ap.add_argument("--v3-judgments", type=Path, help="for --mode calibration")
    ap.add_argument("--out", required=True, type=Path)
    args = ap.parse_args()

    bundle = joblib.load(args.bundle)

    if args.mode == "in_dist":
        if not args.silver:
            raise SystemExit("--silver is required for in_dist")
        if not args.splits:
            args.splits = args.bundle.with_suffix(".splits.json")
        evaluate_in_dist(bundle, args.silver, args.splits, args.out)
    elif args.mode == "calibration":
        if not (args.spans and args.r1_judgments and args.v3_judgments):
            raise SystemExit(
                "--spans, --r1-judgments, --v3-judgments required for calibration"
            )
        evaluate_calibration(
            bundle, args.spans, args.r1_judgments, args.v3_judgments, args.out,
        )


if __name__ == "__main__":
    main()
