#!/usr/bin/env python3
"""
Compare MT-bench single-answer grading results between two models,
e.g. original vs defended-rewrite.

Outputs:
- Overall mean / per-turn mean
- Per-category mean
- Mean rewrite similarity (from rewritten answer file) overall / per-turn / per-category
- Delta table (after - before)
"""

from __future__ import annotations

import argparse
import json
from collections import defaultdict
from pathlib import Path
from statistics import mean
from typing import Any, DefaultDict, Dict, Iterable, List, Tuple


def _read_jsonl(path: Path) -> Iterable[Dict[str, Any]]:
    with path.open("r", encoding="utf-8") as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
            yield json.loads(line)


def _load_question_category_map(question_file: Path) -> Dict[int, str]:
    qcat: Dict[int, str] = {}
    for q in _read_jsonl(question_file):
        qcat[int(q["question_id"])] = str(q.get("category", "unknown"))
    return qcat


def _collect_scores(
    judge_file: Path, model: str, qcat: Dict[int, str]
) -> Tuple[List[float], Dict[int, List[float]], Dict[str, List[float]]]:
    all_scores: List[float] = []
    by_turn: Dict[int, List[float]] = defaultdict(list)
    by_cat: DefaultDict[str, List[float]] = defaultdict(list)

    for r in _read_jsonl(judge_file):
        if r.get("model") != model:
            continue
        score = r.get("score", None)
        if score is None:
            continue
        try:
            s = float(score)
        except Exception:
            continue

        qid = int(r["question_id"])
        turn = int(r.get("turn", 0))
        cat = qcat.get(qid, "unknown")

        all_scores.append(s)
        by_turn[turn].append(s)
        by_cat[cat].append(s)

    return all_scores, by_turn, dict(by_cat)


def _collect_scores_from_multiple_files(
    judge_files: List[Path], model: str, qcat: Dict[int, str]
) -> Tuple[List[float], Dict[int, List[float]], Dict[str, List[float]]]:
    all_scores: List[float] = []
    by_turn: Dict[int, List[float]] = defaultdict(list)
    by_cat: DefaultDict[str, List[float]] = defaultdict(list)

    for jf in judge_files:
        if not jf.exists():
            continue
        for r in _read_jsonl(jf):
            if r.get("model") != model:
                continue
            score = r.get("score", None)
            if score is None:
                continue
            try:
                s = float(score)
            except Exception:
                continue

            qid = int(r["question_id"])
            turn = int(r.get("turn", 0))
            cat = qcat.get(qid, "unknown")

            all_scores.append(s)
            by_turn[turn].append(s)
            by_cat[cat].append(s)

    return all_scores, by_turn, dict(by_cat)


def _fmt(x: float) -> str:
    return f"{x:.4f}"


def _collect_rewrite_similarity_from_after_answer_file(
    after_answer_file: Path, qcat: Dict[int, str]
) -> Tuple[List[float], Dict[int, List[float]], Dict[str, List[float]]]:
    """
    Collect rewrite similarity directly from the rewritten answer file produced by
    dj_defense.py.

    Expected schema (per choice):
      - choices[0]["rewrite_similarity"] = [float|None, float|None, ...]  (per turn)
    """
    sim_overall: List[float] = []
    sim_by_turn: Dict[int, List[float]] = defaultdict(list)
    sim_by_cat: DefaultDict[str, List[float]] = defaultdict(list)

    for r in _read_jsonl(after_answer_file):
        qid = int(r["question_id"])
        cat = qcat.get(qid, "unknown")
        choices = r.get("choices", []) or []
        if not choices:
            continue
        sim_list = choices[0].get("rewrite_similarity", None)
        if not isinstance(sim_list, list):
            continue
        for i, s in enumerate(sim_list):
            if s is None:
                continue
            try:
                v = float(s)
            except Exception:
                continue
            turn = i + 1
            sim_overall.append(v)
            sim_by_turn[turn].append(v)
            sim_by_cat[cat].append(v)

    return sim_overall, sim_by_turn, dict(sim_by_cat)


def main() -> int:
    ap = argparse.ArgumentParser()
    ap.add_argument(
        "--bench-name",
        default="mt_bench",
        help="Bench name (default: mt_bench) used to find data/<bench>/question.jsonl",
    )
    ap.add_argument(
        "--judge-file",
        default="data/mt_bench/model_judgment/gpt-4_single.jsonl",
        help="Path to judgment jsonl (default: data/mt_bench/model_judgment/gpt-4_single.jsonl)",
    )
    ap.add_argument(
        "--judge-file-before",
        default="",
        help="Optional: separate judgment file for BEFORE model (if provided, overrides --judge-file for before)",
    )
    ap.add_argument(
        "--judge-file-after",
        default="",
        help="Optional: separate judgment file for AFTER model (if provided, overrides --judge-file for after)",
    )
    ap.add_argument(
        "--before",
        default="gpt-4",
        help="Model id before rewrite (default: gpt-4)",
    )
    ap.add_argument(
        "--after",
        default="gpt-4-defended-increase",
        help="Model id after rewrite (default: gpt-4-defended-increase)",
    )
    ap.add_argument(
        "--before-answer-file",
        default="",
        help="Override before model_answer jsonl (default: data/<bench>/model_answer/<before>.jsonl)",
    )
    ap.add_argument(
        "--after-answer-file",
        default="",
        help="Override after model_answer jsonl (default: data/<bench>/model_answer/<after>.jsonl)",
    )
    args = ap.parse_args()

    base = Path(__file__).resolve().parent
    question_file = base / "data" / args.bench_name / "question.jsonl"
    judge_file = base / args.judge_file

    if not question_file.exists():
        raise FileNotFoundError(f"Missing question file: {question_file}")
    # Only require --judge-file when using the combined-file mode.
    if not args.judge_file_before and not args.judge_file_after:
        if not judge_file.exists():
            raise FileNotFoundError(f"Missing judge file: {judge_file}")

    qcat = _load_question_category_map(question_file)

    if args.judge_file_before:
        jf_before = Path(args.judge_file_before)
        if not jf_before.is_absolute():
            jf_before = base / jf_before
        if not jf_before.exists():
            raise FileNotFoundError(f"Missing BEFORE judge file: {jf_before}")
        b_all, b_turn, b_cat = _collect_scores(jf_before, args.before, qcat)
    else:
        b_all, b_turn, b_cat = _collect_scores(judge_file, args.before, qcat)

    if args.judge_file_after:
        jf_after = Path(args.judge_file_after)
        if not jf_after.is_absolute():
            jf_after = base / jf_after
        if not jf_after.exists():
            raise FileNotFoundError(f"Missing AFTER judge file: {jf_after}")
        a_all, a_turn, a_cat = _collect_scores(jf_after, args.after, qcat)
    else:
        a_all, a_turn, a_cat = _collect_scores(judge_file, args.after, qcat)

    if not b_all:
        raise RuntimeError(f"No scores found for before model: {args.before}")
    if not a_all:
        raise RuntimeError(f"No scores found for after model: {args.after}")

    if args.judge_file_before or args.judge_file_after:
        print("Judge files:")
        if args.judge_file_before:
            print(f"  before: {jf_before}")
        else:
            print(f"  before: {judge_file} (combined)")
        if args.judge_file_after:
            print(f"  after : {jf_after}")
        else:
            print(f"  after : {judge_file} (combined)")
    else:
        print(f"Judge file: {judge_file}")
    print(f"Before: {args.before}  n={len(b_all)}  mean={_fmt(mean(b_all))}")
    print(f"After : {args.after}  n={len(a_all)}  mean={_fmt(mean(a_all))}")
    print(f"Delta : {_fmt(mean(a_all) - mean(b_all))}")
    print("")

    # Similarity: prefer using rewrite output field (no recomputation)
    after_answer_file = (
        Path(args.after_answer_file)
        if args.after_answer_file
        else (base / "data" / args.bench_name / "model_answer" / f"{args.after}.jsonl")
    )

    sim_by_cat: Dict[str, List[float]] = {}
    sim_by_turn: Dict[int, List[float]] = {}
    sim_overall: List[float] = []
    if after_answer_file.exists():
        sim_overall, sim_by_turn, sim_by_cat = _collect_rewrite_similarity_from_after_answer_file(
            after_answer_file, qcat
        )
        if sim_overall:
            print("Rewrite similarity (from rewritten answer file):")
            print(f"  n={len(sim_overall)}  mean={_fmt(mean(sim_overall))}")
            for t in sorted(sim_by_turn.keys()):
                print(f"  turn {t}: mean_sim={_fmt(mean(sim_by_turn[t]))}  n={len(sim_by_turn[t])}")
            print("")
        else:
            print("Rewrite similarity: not found in rewritten answer file (missing choices[0].rewrite_similarity)")
            print(f"  after_answer_file: {after_answer_file}")
            print("")
    else:
        print("Rewrite similarity: skipped (missing rewritten answer file)")
        print(f"  after_answer_file : {after_answer_file} exists={after_answer_file.exists()}")
        print("")

    # Turn breakdown
    turns = sorted(set(b_turn.keys()) | set(a_turn.keys()))
    print("Per-turn mean:")
    for t in turns:
        bm = mean(b_turn.get(t, [float('nan')]))
        am = mean(a_turn.get(t, [float('nan')]))
        print(f"  turn {t}: before={_fmt(bm)}  after={_fmt(am)}  delta={_fmt(am - bm)}")
    print("")

    # Category breakdown
    cats = sorted(set(b_cat.keys()) | set(a_cat.keys()))
    print("Per-category mean:")
    for c in cats:
        bm = mean(b_cat.get(c, [float('nan')]))
        am = mean(a_cat.get(c, [float('nan')]))
        sim_str = ""
        if sim_by_cat and c in sim_by_cat and sim_by_cat[c]:
            sim_str = f"  sim={_fmt(mean(sim_by_cat[c]))}"
        print(
            f"  {c:12s} before={_fmt(bm)}  after={_fmt(am)}  delta={_fmt(am - bm)}{sim_str}  n={len(a_cat.get(c, []))}"
        )

    return 0


if __name__ == "__main__":
    raise SystemExit(main())


