#!/usr/bin/env python3
"""Summarize Stage-2 runs with safety/quality/cost metrics for multiple baselines."""

from __future__ import annotations

import argparse
import csv
import json
import math
from pathlib import Path
from statistics import mean, median
from typing import Dict, Iterable, List, Optional, Tuple


def _as_float(val: object) -> Optional[float]:
    if isinstance(val, (int, float)) and not math.isnan(val):
        return float(val)
    return None


def _get_path(obj: dict, path: str) -> object:
    cur = obj
    for part in path.split("."):
        if not isinstance(cur, dict):
            return None
        cur = cur.get(part)
    return cur


def _extract_hx_history(row: dict) -> List[dict]:
    meta = row.get("meta") or {}
    hx = meta.get("hx_history")
    if isinstance(hx, list) and hx:
        return hx
    stage2 = row.get("stage2") or {}
    hx = (stage2.get("qwen_hx") or {}).get("hx_history")
    if isinstance(hx, list):
        return hx
    return []


def _hx_to_list(hx: List[dict]) -> Tuple[List[float], List[int]]:
    values = []
    times = []
    for entry in hx:
        if isinstance(entry, dict):
            h_val = entry.get("h")
            t_val = entry.get("t")
            if h_val is None:
                continue
            values.append(float(h_val))
            if t_val is not None:
                times.append(int(t_val))
        elif isinstance(entry, (int, float)):
            values.append(float(entry))
    return values, times


def _auc_deficit(h_vals: List[float], times: List[int]) -> Optional[float]:
    if not h_vals:
        return None
    deficit = [max(0.0, -h) for h in h_vals]
    if times and len(times) == len(deficit):
        auc = 0.0
        for i in range(len(deficit)):
            if i == 0:
                dt = times[i]
            else:
                dt = max(1, times[i] - times[i - 1])
            auc += deficit[i] * dt
        return auc
    return float(sum(deficit))


def _extract_terminal_hmin(row: dict) -> Optional[float]:
    stage2 = row.get("stage2") or {}
    val = _as_float(stage2.get("terminal_hmin"))
    if val is not None:
        return val
    term_safe = stage2.get("terminal_safe")
    if term_safe is True:
        return 0.0
    if term_safe is False:
        return -1.0
    hx_vals, _ = _hx_to_list(_extract_hx_history(row))
    if hx_vals:
        return hx_vals[-1]
    return None


def _extract_quality(row: dict, quality_key: Optional[str], quality_keys: List[str]) -> Optional[float]:
    if quality_key:
        return _as_float(_get_path(row, quality_key))
    for key in quality_keys:
        val = _as_float(_get_path(row, key))
        if val is not None:
            return val
    return None


def _extract_steps(row: dict) -> List[dict]:
    stage2 = row.get("stage2") or {}
    steps = stage2.get("per_step_diag") or []
    if steps:
        return steps
    cont = (stage2.get("continuous") or {}).get("steps") or []
    return cont


def _sum_tv_kl(steps: List[dict]) -> Tuple[float, float]:
    tv_sum = 0.0
    kl_sum = 0.0
    for step in steps:
        tv = _as_float(step.get("tv_q_p"))
        if tv is None:
            tv = _as_float(step.get("tv"))
        kl = _as_float(step.get("kl_q_ref"))
        if kl is None:
            kl = _as_float(step.get("kl_ref"))
        if tv is not None:
            tv_sum += tv
        if kl is not None:
            kl_sum += kl
    return tv_sum, kl_sum


def _mean(vals: Iterable[Optional[float]]) -> Optional[float]:
    arr = [v for v in vals if v is not None]
    if not arr:
        return None
    return float(mean(arr))


def _load_pairwise_summaries(pairwise_dir: Path) -> Dict[str, dict]:
    summaries: Dict[str, dict] = {}
    for path in pairwise_dir.glob("*_summary.json"):
        payload = json.loads(path.read_text(encoding="utf-8"))
        name = payload.get("method_name") or path.stem.replace("_summary", "")
        summaries[name] = payload
    return summaries


def summarize_run(path: Path, quality_key: Optional[str], quality_keys: List[str]) -> Dict[str, object]:
    rows = 0
    tu_defined = 0
    ctrl_applied = 0
    terminal_hmin_vals = []
    terminal_deficits = []
    hx_mean_deficits = []
    hx_auc_deficits = []
    quality_vals = []
    tv_sums = []
    kl_sums = []
    ctrl_steps = []
    projection_attempts = []
    scorer_calls = []
    regen_tokens = []
    introspect_tokens = []
    response_lens = []

    baseline_name = None

    with path.open("r", encoding="utf-8") as f:
        for line in f:
            if not line.strip():
                continue
            rows += 1
            row = json.loads(line)
            stage2 = row.get("stage2") or {}
            if baseline_name is None:
                baseline_name = stage2.get("baseline_name")
            terminal_hmin = _extract_terminal_hmin(row)
            if terminal_hmin is not None:
                terminal_hmin_vals.append(terminal_hmin)
                terminal_deficits.append(max(0.0, -terminal_hmin))
            meta = row.get("meta") or {}
            t_u = meta.get("t_u_qwen")
            if t_u not in (None, -1):
                tu_defined += 1
            hx_vals, hx_times = _hx_to_list(_extract_hx_history(row))
            if hx_vals:
                deficits = [max(0.0, -h) for h in hx_vals]
                hx_mean_deficits.append(mean(deficits))
                auc_val = _auc_deficit(hx_vals, hx_times)
                if auc_val is not None:
                    hx_auc_deficits.append(auc_val)
            quality = _extract_quality(row, quality_key, quality_keys)
            if quality is not None:
                quality_vals.append(quality)
            steps = _extract_steps(row)
            tv_sum, kl_sum = _sum_tv_kl(steps)
            tv_sums.append(tv_sum)
            kl_sums.append(kl_sum)
            n_ctrl = stage2.get("n_ctrl_applied")
            if n_ctrl is None:
                n_ctrl = len(steps)
            ctrl_steps.append(float(n_ctrl))
            if n_ctrl and n_ctrl > 0:
                ctrl_applied += 1
            projection_attempts.append(float(stage2.get("projection_attempts", 0) or 0))
            scorer_calls.append(float(meta.get("llm_eval_count", 0) or 0))
            regen_tokens.append(float(stage2.get("total_regen_tokens", 0) or 0))
            introspect_tokens.append(float(stage2.get("introspect_tokens_len", 0) or 0))
            resp_len = meta.get("response_len")
            if resp_len is None:
                token_ids = row.get("token_ids") or []
                if token_ids:
                    resp_len = len(token_ids)
            if resp_len is not None:
                response_lens.append(float(resp_len))

    unsafe_rate = None
    if terminal_hmin_vals:
        unsafe_rate = sum(1 for v in terminal_hmin_vals if v < 0) / len(terminal_hmin_vals)
    response_len_mean = _mean(response_lens)
    quality_missing_rate = None
    if rows:
        quality_missing_rate = 1.0 - (len(quality_vals) / rows)
    quality_source = "missing"
    if quality_vals:
        quality_source = "raw"
    sum_tv_mean = _mean(tv_sums)
    sum_kl_mean = _mean(kl_sums)
    mean_scorer_calls = _mean(scorer_calls)
    mean_regen_tokens = _mean(regen_tokens)
    mean_introspect_tokens = _mean(introspect_tokens)
    sum_tv_per_token = None
    sum_kl_per_token = None
    llm_eval_per_token = None
    regen_overhead_ratio = None
    introspect_overhead_ratio = None
    if response_len_mean and response_len_mean > 0:
        if sum_tv_mean is not None:
            sum_tv_per_token = sum_tv_mean / response_len_mean
        if sum_kl_mean is not None:
            sum_kl_per_token = sum_kl_mean / response_len_mean
        if mean_scorer_calls is not None:
            llm_eval_per_token = mean_scorer_calls / response_len_mean
        if mean_regen_tokens is not None:
            regen_overhead_ratio = mean_regen_tokens / response_len_mean
        if mean_introspect_tokens is not None:
            introspect_overhead_ratio = mean_introspect_tokens / response_len_mean
    return {
        "run_id": baseline_name or path.stem,
        "n_total": rows,
        "tu_rate": (tu_defined / rows) if rows else None,
        "unsafe_rate": unsafe_rate,
        "mean_terminal_deficit": _mean(terminal_deficits),
        "traj_mean_deficit": _mean(hx_mean_deficits),
        "traj_auc_deficit": _mean(hx_auc_deficits),
        "quality_mean": _mean(quality_vals),
        "quality_win_rate": None,
        "quality_score_mean": None,
        "quality_missing_rate": quality_missing_rate,
        "quality_source": quality_source,
        "sum_tv_mean": sum_tv_mean,
        "sum_kl_mean": sum_kl_mean,
        "sum_tv_per_token": sum_tv_per_token,
        "sum_kl_per_token": sum_kl_per_token,
        "mean_ctrl_steps": _mean(ctrl_steps),
        "ctrl_applied_rate": (ctrl_applied / rows) if rows else None,
        "mean_projection_attempts": _mean(projection_attempts),
        "mean_scorer_calls": mean_scorer_calls,
        "llm_eval_per_token": llm_eval_per_token,
        "mean_regen_tokens": mean_regen_tokens,
        "regen_overhead_ratio": regen_overhead_ratio,
        "mean_introspect_tokens": mean_introspect_tokens,
        "introspect_overhead_ratio": introspect_overhead_ratio,
        "response_len_mean": response_len_mean,
    }


def main() -> None:
    parser = argparse.ArgumentParser(description="Summarize Stage-2 tradeoff metrics")
    parser.add_argument("--input_jsonl", action="append", required=True)
    parser.add_argument("--out_dir", required=True)
    parser.add_argument("--quality_key", default=None)
    parser.add_argument("--pairwise_dir", default=None)
    parser.add_argument(
        "--pairwise_alias",
        action="append",
        default=[],
        help="Optional alias mapping like continuous_decay=ours_decay",
    )
    parser.add_argument("--out_json", default=None)
    args = parser.parse_args()

    quality_keys = [
        "meta.qwen_margin_final",
        "meta.qwen_margin",
        "meta.quality",
        "meta.judge_score",
        "stage2.quality",
        "stage2.reward",
    ]

    out_dir = Path(args.out_dir)
    out_dir.mkdir(parents=True, exist_ok=True)
    out_csv = out_dir / "tradeoff_ext.csv"
    out_md = out_dir / "tradeoff_ext.md"
    out_json = Path(args.out_json) if args.out_json else None

    rows = []
    for path_str in args.input_jsonl:
        path = Path(path_str)
        rows.append(summarize_run(path, args.quality_key, quality_keys))

    if args.pairwise_dir:
        summaries = _load_pairwise_summaries(Path(args.pairwise_dir))
        alias_map = {}
        for item in args.pairwise_alias:
            if "=" in item:
                left, right = item.split("=", 1)
                alias_map[left.strip()] = right.strip()
        for row in rows:
            method = alias_map.get(row["run_id"], row["run_id"])
            summary = summaries.get(method)
            if summary:
                win_rate = summary.get("win_rate")
                score_mean = summary.get("score_mean")
                row["quality_win_rate"] = win_rate
                row["quality_score_mean"] = score_mean
                if win_rate is not None:
                    row["quality_mean"] = win_rate
                    row["quality_missing_rate"] = 0.0
                row["quality_source"] = "pairwise_win_rate"
            else:
                row["quality_source"] = "missing_pairwise"

    headers = [
        "run_id",
        "n_total",
        "tu_rate",
        "unsafe_rate",
        "mean_terminal_deficit",
        "traj_mean_deficit",
        "traj_auc_deficit",
        "quality_mean",
        "quality_win_rate",
        "quality_score_mean",
        "quality_missing_rate",
        "quality_source",
        "sum_tv_mean",
        "sum_kl_mean",
        "sum_tv_per_token",
        "sum_kl_per_token",
        "mean_ctrl_steps",
        "ctrl_applied_rate",
        "mean_projection_attempts",
        "mean_scorer_calls",
        "llm_eval_per_token",
        "mean_regen_tokens",
        "regen_overhead_ratio",
        "mean_introspect_tokens",
        "introspect_overhead_ratio",
        "response_len_mean",
    ]

    with out_csv.open("w", encoding="utf-8", newline="") as f:
        writer = csv.DictWriter(f, fieldnames=headers)
        writer.writeheader()
        writer.writerows(rows)

    with out_md.open("w", encoding="utf-8") as f:
        f.write("| " + " | ".join(headers) + " |\n")
        f.write("|" + "|".join(["---"] * len(headers)) + "|\n")
        for row in rows:
            vals = []
            for h in headers:
                val = row.get(h)
                if isinstance(val, float):
                    vals.append(f"{val:.4f}")
                elif val is None:
                    vals.append("NA")
                else:
                    vals.append(str(val))
            f.write("| " + " | ".join(vals) + " |\n")

    print("wrote", out_csv)
    print("wrote", out_md)
    if out_json:
        with out_json.open("w", encoding="utf-8") as f:
            json.dump({row["run_id"]: row for row in rows}, f, ensure_ascii=False, indent=2)
        print("wrote", out_json)


if __name__ == "__main__":
    main()
