"""Compare heuristic vs two LLM judges (and the two judges to each other).

Produces a markdown summary with pairwise agreement, kappa, per-primitive
F1, and a judge-vs-judge confusion matrix. The judge-to-judge agreement
is the key signal: if R1 and V3 disagree with the heuristic in the SAME
way, the disagreement is real signal rather than per-model noise.
"""
from __future__ import annotations

import argparse
import json
from pathlib import Path

import numpy as np

from analysis.exploration.llm_validation._client import (
    PRIMITIVES, to_new_taxonomy,
)
from analysis.exploration.llm_validation.agreement_report import (
    build_confusion,
    cohens_kappa,
    fmt_f1,
    fmt_pct,
    per_primitive_metrics,
    render_confusion_md,
)


def load_judgments(path: Path) -> dict[str, str]:
    """Return {span_id: llm_label}, dropping PARSE_ERROR / API_FAILURE."""
    out = {}
    with open(path) as f:
        for line in f:
            j = json.loads(line)
            if j["llm_label"] in PRIMITIVES:
                out[j["span_id"]] = j["llm_label"]
    return out


def load_heuristic(spans_path: Path, map_to_new: bool = True) -> dict[str, str]:
    """Load heuristic labels from sampled_spans.jsonl.

    If `map_to_new`, fold legacy 10-class labels into the 8-class taxonomy
    (DECOMPOSE -> PLAN, VERIFY -> CHECK, ERROR_DETECT -> CHECK).
    """
    out = {}
    with open(spans_path) as f:
        for line in f:
            s = json.loads(line)
            label = s["heuristic_label"]
            if map_to_new:
                label = to_new_taxonomy(label)
            out[s["span_id"]] = label
    return out


def pairwise_rows(
    a: dict[str, str], b: dict[str, str]
) -> list[dict]:
    """Return joined rows on common span_ids in the shape build_confusion expects."""
    rows = []
    for span_id, la in a.items():
        lb = b.get(span_id)
        if lb is None:
            continue
        rows.append({"heuristic_label": la, "llm_label": lb})
    return rows


def summarize(name_a: str, name_b: str, rows: list[dict]) -> dict:
    cm = build_confusion(rows)
    n = len(rows)
    agreement = float(np.trace(cm)) / n if n else float("nan")
    return {
        "name_a": name_a,
        "name_b": name_b,
        "n": n,
        "agreement": agreement,
        "kappa": cohens_kappa(cm),
        "cm": cm,
        "metrics": per_primitive_metrics(cm),
    }


def render_metrics_table(metrics: list[dict], a: str, b: str) -> str:
    header = (
        f"| Primitive | {a} count | {b} count | Precision (rows={a}) | Recall | F1 |\n"
    )
    sep = "|---|---|---|---|---|---|\n"
    body = ""
    for m in metrics:
        body += (
            f"| {m['label']} | {m['heuristic_count']} | {m['llm_count']} "
            f"| {fmt_pct(m['precision'])} | {fmt_pct(m['recall'])} "
            f"| {fmt_f1(m['f1'])} |\n"
        )
    return header + sep + body


def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--spans", required=True, type=Path)
    ap.add_argument("--judgments-a", required=True, type=Path,
                    help="First judge (e.g. deepseek-reasoner)")
    ap.add_argument("--judgments-b", required=True, type=Path,
                    help="Second judge (e.g. deepseek-chat)")
    ap.add_argument("--name-a", default="R1")
    ap.add_argument("--name-b", default="V3")
    ap.add_argument("--out", required=True, type=Path)
    args = ap.parse_args()

    heuristic = load_heuristic(args.spans)
    judge_a = load_judgments(args.judgments_a)
    judge_b = load_judgments(args.judgments_b)

    print(f"spans (heuristic): {len(heuristic)}")
    print(f"{args.name_a} judgments: {len(judge_a)}")
    print(f"{args.name_b} judgments: {len(judge_b)}")

    pairs = [
        ("heuristic", args.name_a, pairwise_rows(heuristic, judge_a)),
        ("heuristic", args.name_b, pairwise_rows(heuristic, judge_b)),
        (args.name_a, args.name_b, pairwise_rows(judge_a, judge_b)),
    ]
    summaries = [summarize(na, nb, rows) for na, nb, rows in pairs]

    md = []
    md.append(f"# Two-Judge Comparison: heuristic vs {args.name_a} vs {args.name_b}\n")

    md.append("## Pairwise summary\n")
    md.append("| Comparison | N | Agreement | Cohen's kappa |")
    md.append("|---|---|---|---|")
    for s in summaries:
        md.append(
            f"| {s['name_a']} ↔ {s['name_b']} | {s['n']} "
            f"| {fmt_pct(s['agreement'])} | {fmt_f1(s['kappa'])} |"
        )
    md.append("")

    md.append(
        "## Reading the table\n"
        f"- Lines 1 & 2 (heuristic vs each judge) measure the heuristic against "
        f"each LLM as ground truth.\n"
        f"- Line 3 ({args.name_a} ↔ {args.name_b}) measures the two LLMs against "
        f"each other. **A high {args.name_a}↔{args.name_b} agreement combined with "
        f"low heuristic↔judge agreement is strong evidence the heuristic is wrong, "
        f"not just that each LLM has its own noise.**\n"
    )

    for s in summaries:
        md.append(f"## {s['name_a']} (rows) vs {s['name_b']} (cols) — N={s['n']}\n")
        md.append(f"- Agreement: **{fmt_pct(s['agreement'])}**, kappa: **{fmt_f1(s['kappa'])}**\n")
        md.append(render_metrics_table(s["metrics"], s["name_a"], s["name_b"]))
        md.append("")
        md.append(render_confusion_md(s["cm"]))
        md.append("")

    args.out.parent.mkdir(parents=True, exist_ok=True)
    args.out.write_text("\n".join(md))
    print(f"\nWrote {args.out}")
    print()
    print("Pairwise summary:")
    for s in summaries:
        print(
            f"  {s['name_a']:>10} ↔ {s['name_b']:<10} "
            f"agreement={fmt_pct(s['agreement']):>6}  kappa={fmt_f1(s['kappa'])}"
        )


if __name__ == "__main__":
    main()
