"""Stage 3 — agreement report between heuristic and LLM judge.

Joins sampled_spans.jsonl + llm_judgments.jsonl on span_id, computes
overall agreement, per-primitive precision/recall (LLM = ground truth),
the 10x10 confusion matrix, Cohen's kappa, and a qualitative section
sampling disagreements from each off-diagonal cell.
"""
from __future__ import annotations

import argparse
import json
import random
from collections import Counter, defaultdict
from pathlib import Path

import numpy as np

from analysis.exploration.llm_validation._client import PRIMITIVES


def load_joined(spans_path: Path, judgments_path: Path) -> list[dict]:
    spans = {json.loads(l)["span_id"]: json.loads(l) for l in open(spans_path)}
    rows: list[dict] = []
    with open(judgments_path) as f:
        for line in f:
            j = json.loads(line)
            s = spans.get(j["span_id"])
            if s is None:
                continue
            rows.append({**s, **j})
    return rows


def build_confusion(rows: list[dict]) -> np.ndarray:
    """rows = heuristic, cols = LLM. PARSE_ERROR/API_FAILURE rows excluded."""
    idx = {p: i for i, p in enumerate(PRIMITIVES)}
    cm = np.zeros((len(PRIMITIVES), len(PRIMITIVES)), dtype=int)
    for r in rows:
        h = r["heuristic_label"]
        l = r["llm_label"]
        if h not in idx or l not in idx:
            continue
        cm[idx[h], idx[l]] += 1
    return cm


def per_primitive_metrics(cm: np.ndarray) -> list[dict]:
    out = []
    for i, p in enumerate(PRIMITIVES):
        tp = int(cm[i, i])
        fp = int(cm[i, :].sum() - tp)  # heuristic said p, LLM said something else
        fn = int(cm[:, i].sum() - tp)  # LLM said p, heuristic said something else
        precision = tp / (tp + fp) if (tp + fp) else float("nan")
        recall = tp / (tp + fn) if (tp + fn) else float("nan")
        if np.isnan(precision) or np.isnan(recall) or (precision + recall) == 0:
            f1 = float("nan")
        else:
            f1 = 2 * precision * recall / (precision + recall)
        out.append({
            "label": p,
            "heuristic_count": int(cm[i, :].sum()),
            "llm_count": int(cm[:, i].sum()),
            "tp": tp, "fp": fp, "fn": fn,
            "precision": precision, "recall": recall, "f1": f1,
        })
    return out


def cohens_kappa(cm: np.ndarray) -> float:
    n = cm.sum()
    if n == 0:
        return float("nan")
    po = np.trace(cm) / n
    pe = (cm.sum(axis=0) * cm.sum(axis=1)).sum() / (n * n)
    if pe == 1.0:
        return float("nan")
    return float((po - pe) / (1 - pe))


def fmt_pct(x: float) -> str:
    if isinstance(x, float) and np.isnan(x):
        return "—"
    return f"{x*100:.1f}%"


def fmt_f1(x: float) -> str:
    if isinstance(x, float) and np.isnan(x):
        return "—"
    return f"{x:.3f}"


def render_confusion_md(cm: np.ndarray) -> str:
    header = "| heuristic ↓ \\ LLM → | " + " | ".join(PRIMITIVES) + " |\n"
    sep = "|" + "---|" * (len(PRIMITIVES) + 1) + "\n"
    body = ""
    for i, p in enumerate(PRIMITIVES):
        cells = []
        for j in range(len(PRIMITIVES)):
            v = int(cm[i, j])
            cells.append(f"**{v}**" if i == j else str(v))
        body += f"| {p} | " + " | ".join(cells) + " |\n"
    return header + sep + body


def render_metrics_md(metrics: list[dict]) -> str:
    header = "| Primitive | Heuristic count | LLM count | Precision | Recall | F1 |\n"
    sep = "|---|---|---|---|---|---|\n"
    body = ""
    for m in metrics:
        body += (
            f"| {m['label']} | {m['heuristic_count']} | {m['llm_count']} "
            f"| {fmt_pct(m['precision'])} | {fmt_pct(m['recall'])} "
            f"| {fmt_f1(m['f1'])} |\n"
        )
    return header + sep + body


def render_disagreements(
    rows: list[dict], cm: np.ndarray,
    samples_per_cell: int = 2, min_count: int = 3,
    seed: int = 17,
) -> str:
    """For each off-diagonal cell with count >= min_count, sample disagreements."""
    rng = random.Random(seed)
    by_cell: dict[tuple[str, str], list[dict]] = defaultdict(list)
    for r in rows:
        h, l = r["heuristic_label"], r["llm_label"]
        if h == l or h not in PRIMITIVES or l not in PRIMITIVES:
            continue
        by_cell[(h, l)].append(r)

    lines = []
    cells_sorted = sorted(by_cell.items(), key=lambda kv: -len(kv[1]))
    for (h, l), items in cells_sorted:
        if len(items) < min_count:
            continue
        lines.append(f"### {h} → {l}  (N={len(items)})\n")
        sample = rng.sample(items, min(samples_per_cell, len(items)))
        for r in sample:
            text = (r.get("span_text") or "").strip().replace("\n", " ")
            if len(text) > 300:
                text = text[:300] + "…"
            llm_reason = (r.get("llm_reasoning") or "").strip()
            lines.append(
                f"- **span**: {text}\n"
                f"  - heuristic: `{h}` (conf={r.get('heuristic_confidence', 0):.3f}) "
                f"| llm: `{l}` (conf={r.get('llm_confidence', '?')})\n"
                f"  - llm reasoning: {llm_reason}\n"
            )
        lines.append("")
    return "\n".join(lines) if lines else "_No off-diagonal cells reached the count threshold._"


def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--spans", required=True, type=Path)
    ap.add_argument("--judgments", required=True, type=Path)
    ap.add_argument("--out", required=True, type=Path)
    ap.add_argument("--sample-seed", type=int, default=17)
    args = ap.parse_args()

    rows = load_joined(args.spans, args.judgments)
    n_total = len(rows)
    bad = [r for r in rows if r["llm_label"] in {"PARSE_ERROR", "API_FAILURE"}]
    good = [r for r in rows if r["llm_label"] in PRIMITIVES]

    n_parse = sum(1 for r in bad if r["llm_label"] == "PARSE_ERROR")
    n_api = sum(1 for r in bad if r["llm_label"] == "API_FAILURE")

    cm = build_confusion(good)
    metrics = per_primitive_metrics(cm)
    kappa = cohens_kappa(cm)

    n_judged = len(good)
    if n_judged:
        agreement = float(np.trace(cm)) / n_judged
    else:
        agreement = float("nan")

    md = []
    md.append("# LLM-Judge Validation of Heuristic Primitive Classifier\n")
    md.append("## Summary\n")
    md.append(f"- Spans sampled: **{n_total}**")
    md.append(f"- Spans judged successfully: **{n_judged}**")
    md.append(f"- Excluded (parse errors): **{n_parse}**")
    md.append(f"- Excluded (API failures): **{n_api}**")
    md.append(f"- **Overall agreement: {fmt_pct(agreement)}**")
    md.append(f"- **Cohen's kappa: {fmt_f1(kappa)}**\n")

    md.append("## Per-primitive precision and recall (LLM as ground truth)\n")
    md.append(render_metrics_md(metrics))
    md.append("")

    md.append("## Confusion matrix")
    md.append("Rows = heuristic label, columns = LLM label. Diagonal in **bold**.\n")
    md.append(render_confusion_md(cm))
    md.append("")

    md.append("## Systematic misclassifications")
    md.append(
        "_Off-diagonal cells with count ≥3, "
        f"up to 2 sampled disagreements per cell (seed={args.sample_seed})._\n"
    )
    md.append(render_disagreements(good, cm, seed=args.sample_seed))
    md.append("")

    md.append("## Notes\n")
    md.append(
        "- Per-primitive F1 is the primary signal. Cohen's kappa is reported "
        "but moderately affected by class imbalance.\n"
        "- The judge prompt uses conceptual definitions (function-based), not "
        "trigger phrases drawn from the heuristic's regex (e.g. 'Wait', "
        "'Case 1...', 'Here is my plan'). Some shared vocabulary remains "
        "(label names like COMPUTE/VERIFY, conceptual words like "
        "'satisfies', 'contradiction') because removing them would make "
        "the definitions unintelligible.\n"
        "- Heuristic confidence is shown alongside disagreements to help spot "
        "low-confidence borderline cases that the LLM relabeled.\n"
    )

    args.out.parent.mkdir(parents=True, exist_ok=True)
    args.out.write_text("\n".join(md))
    print(f"Wrote {args.out}")
    print(f"Overall agreement: {fmt_pct(agreement)}, kappa: {fmt_f1(kappa)}")


if __name__ == "__main__":
    main()
