#!/usr/bin/env python3
"""Summarize MYA hard-filter baselines vs RBCBF (continuous)."""

from __future__ import annotations

import argparse
import csv
import json
from pathlib import Path
from statistics import median
from math import floor
from typing import Dict, Iterable, List, Optional, Tuple


def _iter_rows(path: Path) -> Iterable[dict]:
    with path.open(encoding="utf-8") as f:
        for line in f:
            if not line.strip():
                continue
            try:
                yield json.loads(line)
            except Exception:
                continue


def _terminal_h_with_source(row: dict) -> Tuple[Optional[float], str]:
    h = (row.get("stage2") or {}).get("terminal_hmin")
    if h is not None:
        return float(h), "terminal_hmin"
    hx = (row.get("meta") or {}).get("hx_history") or []
    if hx:
        last = hx[-1]
        if isinstance(last, dict):
            if last.get("h") is None:
                return None, "hx_history_last"
            return float(last.get("h")), "hx_history_last"
        return float(last), "hx_history_last"
    return None, "missing"


def _response_len(row: dict) -> Optional[int]:
    meta = row.get("meta") or {}
    if meta.get("response_len") is not None:
        return int(meta.get("response_len"))
    if row.get("token_ids") is not None:
        return int(len(row.get("token_ids") or []))
    return None


def _quality_score(row: dict) -> Optional[float]:
    meta = row.get("meta") or {}
    stage2 = row.get("stage2") or {}
    for key in ("qwen_margin_final", "qwen_margin", "quality", "judge_score"):
        if meta.get(key) is not None:
            return float(meta.get(key))
    for key in ("quality", "qwen_margin"):
        if stage2.get(key) is not None:
            return float(stage2.get(key))
    return None


def _steps(row: dict) -> List[dict]:
    stage2 = row.get("stage2") or {}
    return stage2.get("per_step_diag") or (stage2.get("continuous") or {}).get("steps") or []


def _applied_and_sums(row: dict) -> Tuple[bool, float, float]:
    steps = _steps(row)
    sum_tv = 0.0
    sum_kl = 0.0
    for s in steps:
        sum_tv += float(s.get("tv_q_p") or s.get("tv") or 0.0)
        sum_kl += float(s.get("kl_q_ref") or s.get("kl_ref") or 0.0)
    applied = sum_tv > 0.0
    return applied, sum_tv, sum_kl


def _triggered_flag(row: dict) -> bool:
    stage2 = row.get("stage2") or {}
    meta = row.get("meta") or {}
    if stage2.get("t_u") is not None:
        return True
    if meta.get("t_u_qwen") is not None:
        return True
    return False


def _dterm(row: dict) -> Tuple[Optional[float], str]:
    h, src = _terminal_h_with_source(row)
    if h is None:
        return None, src
    return max(0.0, -float(h)), src


def _quantiles(vals: List[float]) -> Dict[str, Optional[float]]:
    if not vals:
        return {"min": None, "p1": None, "med": None, "p99": None, "max": None}
    vals_sorted = sorted(vals)
    n = len(vals_sorted)
    def _pick(p: float) -> float:
        idx = min(n - 1, max(0, int(floor(p * (n - 1)))))
        return float(vals_sorted[idx])
    return {
        "min": float(vals_sorted[0]),
        "p1": _pick(0.01),
        "med": float(vals_sorted[int(floor(0.5 * (n - 1)))]),
        "p99": _pick(0.99),
        "max": float(vals_sorted[-1]),
    }


def _empty_set_stats(steps: List[dict]) -> Tuple[Optional[float], Optional[float]]:
    empty_vals: List[float] = []
    allowed_vals: List[float] = []
    for s in steps:
        if s.get("mya_empty_set") is not None:
            empty_vals.append(float(s.get("mya_empty_set")))
        if s.get("mya_allowed_size") is not None:
            allowed_vals.append(float(s.get("mya_allowed_size")))
    if not empty_vals and not allowed_vals:
        return None, None
    empty_rate = sum(empty_vals) / len(empty_vals) if empty_vals else None
    allowed_med = median(allowed_vals) if allowed_vals else None
    return empty_rate, allowed_med


def _hx_tail(row: dict, n: int = 5) -> List[float]:
    hx = (row.get("meta") or {}).get("hx_history") or []
    tail = hx[-n:] if len(hx) >= n else hx
    out = []
    for item in tail:
        if isinstance(item, dict):
            if item.get("h") is None:
                continue
            out.append(float(item.get("h")))
        else:
            out.append(float(item))
    return out


def _tv_sum(row: dict) -> float:
    steps = _steps(row)
    return sum(float(s.get("tv_q_p") or s.get("tv") or 0.0) for s in steps)


def _empty_set_rate_sample(row: dict) -> Optional[float]:
    steps = _steps(row)
    if not steps:
        return None
    vals = [float(s.get("mya_empty_set")) for s in steps if s.get("mya_empty_set") is not None]
    if not vals:
        return None
    return sum(vals) / len(vals)


def summarize(path: Path, method: str) -> Dict[str, object]:
    applied_rows: List[dict] = []
    d_terms: List[float] = []
    q_scores: List[float] = []
    sum_tv_vals: List[float] = []
    sum_kl_vals: List[float] = []
    empty_rates: List[float] = []
    allowed_meds: List[float] = []

    for row in _iter_rows(path):
        applied, sum_tv, sum_kl = _applied_and_sums(row)
        if not applied:
            continue
        applied_rows.append(row)
        d_val, _ = _dterm(row)
        if d_val is not None:
            d_terms.append(d_val)
        q = _quality_score(row)
        if q is not None:
            q_scores.append(float(q))
        resp_len = _response_len(row)
        denom = float(resp_len) if resp_len and resp_len > 0 else None
        if denom is not None:
            sum_tv_vals.append(sum_tv / denom)
            sum_kl_vals.append(sum_kl / denom)
        steps = _steps(row)
        empty_rate, allowed_med = _empty_set_stats(steps)
        if empty_rate is not None:
            empty_rates.append(empty_rate)
        if allowed_med is not None:
            allowed_meds.append(allowed_med)

    def _med(vals: List[float]) -> Optional[float]:
        return median(vals) if vals else None

    return {
        "method": method,
        "n_applied": len(applied_rows),
        "D_term_med_applied": _med(d_terms),
        "quality_mean": (sum(q_scores) / len(q_scores) if q_scores else None),
        "sum_tv_per_token_med": _med(sum_tv_vals),
        "sum_kl_per_token_med": _med(sum_kl_vals),
        "empty_set_rate": (sum(empty_rates) / len(empty_rates) if empty_rates else None),
        "median_allowed_size": _med(allowed_meds),
    }


def _fmt(val: Optional[float]) -> str:
    if val is None:
        return "--"
    return f"{val:.3f}"


def summarize_debug(path: Path, method: str) -> Dict[str, object]:
    rows = list(_iter_rows(path))
    n_total = len(rows)
    triggered_rows = [r for r in rows if _triggered_flag(r)]
    applied_rows = []
    dterm_vals_all: List[float] = []
    dterm_vals_trig: List[float] = []
    dterm_vals_applied: List[float] = []
    tv_per_token_trig: List[float] = []
    tv_per_token_applied: List[float] = []
    source_counts = {"terminal_hmin": 0, "hx_history_last": 0, "missing": 0}
    missing_terminal = 0
    missing_hx = 0
    dterm_zero_all = 0
    dterm_zero_trig = 0
    dterm_zero_applied = 0
    dterm_head: List[float] = []

    for row in rows:
        d_val, src = _dterm(row)
        source_counts[src] = source_counts.get(src, 0) + 1
        if (row.get("stage2") or {}).get("terminal_hmin") is None:
            missing_terminal += 1
        if not ((row.get("meta") or {}).get("hx_history") or []):
            missing_hx += 1
        if d_val is not None:
            dterm_vals_all.append(d_val)
            if len(dterm_head) < 10:
                dterm_head.append(d_val)
            if d_val == 0.0:
                dterm_zero_all += 1

        applied, sum_tv, _ = _applied_and_sums(row)
        if _triggered_flag(row):
            if d_val is not None:
                dterm_vals_trig.append(d_val)
                if d_val == 0.0:
                    dterm_zero_trig += 1
            resp_len = _response_len(row)
            if resp_len and resp_len > 0:
                tv_per_token_trig.append(sum_tv / float(resp_len))
        if applied:
            applied_rows.append(row)
            if d_val is not None:
                dterm_vals_applied.append(d_val)
                if d_val == 0.0:
                    dterm_zero_applied += 1
            resp_len = _response_len(row)
            if resp_len and resp_len > 0:
                tv_per_token_applied.append(sum_tv / float(resp_len))

    q_all = _quantiles(dterm_vals_all)
    q_trig = _quantiles(dterm_vals_trig)
    q_applied = _quantiles(dterm_vals_applied)

    return {
        "method": method,
        "n_total": n_total,
        "n_triggered": len(triggered_rows),
        "n_applied": len(applied_rows),
        "applied_rate": (len(applied_rows) / len(triggered_rows) if triggered_rows else None),
        "dterm_source_field": max(source_counts, key=source_counts.get),
        "missing_terminal_hmin": missing_terminal,
        "missing_hx_history": missing_hx,
        "fraction_zero_all": (dterm_zero_all / len(dterm_vals_all) if dterm_vals_all else None),
        "fraction_zero_trig": (dterm_zero_trig / len(dterm_vals_trig) if dterm_vals_trig else None),
        "fraction_zero_applied": (dterm_zero_applied / len(dterm_vals_applied) if dterm_vals_applied else None),
        "dterm_min_all": q_all["min"],
        "dterm_p1_all": q_all["p1"],
        "dterm_med_all": q_all["med"],
        "dterm_p99_all": q_all["p99"],
        "dterm_max_all": q_all["max"],
        "dterm_med_trig": q_trig["med"],
        "dterm_med_applied": q_applied["med"],
        "tv_per_token_med_trig": median(tv_per_token_trig) if tv_per_token_trig else None,
        "tv_per_token_med_applied": median(tv_per_token_applied) if tv_per_token_applied else None,
        "dterm_head_values": ",".join(f"{v:.6f}" for v in dterm_head),
    }


def main() -> None:
    parser = argparse.ArgumentParser()
    parser.add_argument("--rbcbf", required=True)
    parser.add_argument("--mya_start", required=True)
    parser.add_argument("--mya_post", required=True)
    parser.add_argument("--out_csv", required=True)
    parser.add_argument("--out_tex", required=True)
    args = parser.parse_args()

    rows = [
        summarize(Path(args.rbcbf), "rbcbf_continuous"),
        summarize(Path(args.mya_start), "mya_from_start"),
        summarize(Path(args.mya_post), "mya_post_rollback"),
    ]

    out_csv = Path(args.out_csv)
    out_csv.parent.mkdir(parents=True, exist_ok=True)
    with out_csv.open("w", encoding="utf-8", newline="") as f:
        writer = csv.DictWriter(f, fieldnames=list(rows[0].keys()))
        writer.writeheader()
        writer.writerows(rows)

    out_tex = Path(args.out_tex)
    out_tex.parent.mkdir(parents=True, exist_ok=True)
    with out_tex.open("w", encoding="utf-8") as f:
        f.write("\\begin{tabular}{lcccccc}\n")
        f.write("\\toprule\n")
        f.write("Method & $n_{app}$ & $D_{term}$ & W & $C_{TV}$ & empty\\_set & allowed\\\\\n")
        f.write("\\midrule\n")
        for row in rows:
            f.write(
                f"{row['method']} & {row['n_applied']} & "
                f"{_fmt(row['D_term_med_applied'])} & "
                f"{_fmt(row['quality_mean'])} & "
                f"{_fmt(row['sum_tv_per_token_med'])} & "
                f"{_fmt(row['empty_set_rate'])} & "
                f"{_fmt(row['median_allowed_size'])} \\\\\n"
            )
        f.write("\\bottomrule\n")
        f.write("\\end{tabular}\n")

    print(f"wrote {out_csv}")
    print(f"wrote {out_tex}")

    # Debug summary
    debug_rows = [
        summarize_debug(Path(args.rbcbf), "rbcbf_continuous"),
        summarize_debug(Path(args.mya_start), "mya_from_start"),
        summarize_debug(Path(args.mya_post), "mya_post_rollback"),
    ]
    debug_csv = out_csv.with_name(out_csv.stem + "_debug.csv")
    with debug_csv.open("w", encoding="utf-8", newline="") as f:
        writer = csv.DictWriter(f, fieldnames=list(debug_rows[0].keys()))
        writer.writeheader()
        writer.writerows(debug_rows)
    print(f"wrote {debug_csv}")

    # Spot-check D_term == 0 samples for mya_from_start (and compare to rbcbf)
    rb_rows = list(_iter_rows(Path(args.rbcbf)))
    mya_rows = list(_iter_rows(Path(args.mya_start)))
    rb_by_prompt = {}
    for r in rb_rows:
        pid = r.get("prompt_id") or (r.get("meta") or {}).get("sample_id")
        if pid is not None:
            rb_by_prompt[pid] = r
    printed = 0
    print("\n[D_term==0 samples] mya_from_start (first 10)")
    for r in mya_rows:
        d_val, _ = _dterm(r)
        if d_val != 0.0:
            continue
        pid = r.get("prompt_id") or (r.get("meta") or {}).get("sample_id")
        resp_len = _response_len(r)
        tv_sum = _tv_sum(r)
        empty_rate = _empty_set_rate_sample(r)
        hx_tail = _hx_tail(r, n=5)
        rb = rb_by_prompt.get(pid)
        rb_d = None
        if rb is not None:
            rb_d, _ = _dterm(rb)
        print(
            f"prompt_id={pid} resp_len={resp_len} tv_sum={tv_sum:.4f} "
            f"empty_set_rate={empty_rate} hx_tail={hx_tail} rbcbf_dterm={rb_d}"
        )
        printed += 1
        if printed >= 10:
            break


if __name__ == "__main__":
    main()
