#!/usr/bin/env python3
"""Summarize policy sensitivity for continuous_decay (t* selection strategies)."""

from __future__ import annotations

import argparse
import json
import math
import random
from pathlib import Path
from statistics import mean, median
from typing import Dict, List, Optional, Tuple


def _as_float(val: object) -> Optional[float]:
    if isinstance(val, (int, float)) and not math.isnan(val):
        return float(val)
    return None


def _get_path(obj: dict, path: str) -> object:
    cur = obj
    for part in path.split("."):
        if not isinstance(cur, dict):
            return None
        cur = cur.get(part)
    return cur


def _extract_hx_history(row: dict) -> List[dict]:
    meta = row.get("meta") or {}
    hx = meta.get("hx_history")
    if isinstance(hx, list) and hx:
        return hx
    stage2 = row.get("stage2") or {}
    hx = (stage2.get("qwen_hx") or {}).get("hx_history")
    if isinstance(hx, list):
        return hx
    return []


def _hx_to_list(hx: List[dict]) -> List[float]:
    values: List[float] = []
    for entry in hx:
        if isinstance(entry, dict):
            h_val = entry.get("h")
            if h_val is None:
                continue
            values.append(float(h_val))
        elif isinstance(entry, (int, float)):
            values.append(float(entry))
    return values


def _extract_terminal_hmin(row: dict) -> Optional[float]:
    stage2 = row.get("stage2") or {}
    val = _as_float(stage2.get("terminal_hmin"))
    if val is not None:
        return val
    term_safe = stage2.get("terminal_safe")
    if term_safe is True:
        return 0.0
    if term_safe is False:
        return -1.0
    hx_vals = _hx_to_list(_extract_hx_history(row))
    if hx_vals:
        return hx_vals[-1]
    return None


def _extract_quality(row: dict, quality_keys: List[str]) -> Optional[float]:
    for key in quality_keys:
        val = _as_float(_get_path(row, key))
        if val is not None:
            return val
    return None


def _extract_steps(row: dict) -> List[dict]:
    stage2 = row.get("stage2") or {}
    steps = stage2.get("per_step_diag") or []
    if steps:
        return steps
    cont = (stage2.get("continuous") or {}).get("steps") or []
    return cont


def _sum_tv_kl(steps: List[dict]) -> Tuple[float, float]:
    tv_sum = 0.0
    kl_sum = 0.0
    for step in steps:
        tv = _as_float(step.get("tv_q_p"))
        if tv is None:
            tv = _as_float(step.get("tv"))
        kl = _as_float(step.get("kl_q_ref"))
        if kl is None:
            kl = _as_float(step.get("kl_ref"))
        if tv is not None:
            tv_sum += tv
        if kl is not None:
            kl_sum += kl
    return tv_sum, kl_sum


def _candidate_steps(cands: object) -> List[int]:
    if not isinstance(cands, list):
        return []
    steps: List[int] = []
    for entry in cands:
        if isinstance(entry, dict):
            if entry.get("step") is not None:
                steps.append(int(entry["step"]))
        elif isinstance(entry, (int, float)):
            steps.append(int(entry))
    return steps


def _intervention_meta(row: dict) -> Tuple[Optional[int], Optional[int], List[dict]]:
    interventions = (row.get("stage2") or {}).get("interventions") or []
    if not interventions:
        return None, None, []
    first = interventions[0]
    window_start = first.get("window_start")
    window_end = first.get("window_end")
    candidates = first.get("candidates") or []
    return (
        int(window_start) if window_start is not None else None,
        int(window_end) if window_end is not None else None,
        candidates,
    )


def _extract_policy_name(path: Path, rows: List[dict]) -> str:
    run_args = path.with_suffix(".run_args.json")
    if run_args.exists():
        try:
            data = json.loads(run_args.read_text(encoding="utf-8"))
            val = data.get("policy_name")
            if isinstance(val, str) and val:
                return val
        except Exception:
            pass
    if rows:
        val = (rows[0].get("stage2") or {}).get("policy_name")
        if isinstance(val, str) and val:
            return val
    stem = path.stem
    if "policy_" in stem:
        return stem.split("policy_", 1)[1]
    return stem


def _load_jsonl(path: Path) -> List[dict]:
    rows: List[dict] = []
    with path.open("r", encoding="utf-8") as f:
        for line in f:
            if not line.strip():
                continue
            rows.append(json.loads(line))
    return rows


def _sample_id(row: dict) -> str:
    return (
        row.get("prompt_id")
        or (row.get("meta") or {}).get("sample_id")
        or str((row.get("meta") or {}).get("prompt_index"))
    )


def summarize_rows(rows: List[dict], quality_keys: List[str]) -> Dict[str, object]:
    terminal_deficits: List[float] = []
    quality_vals: List[float] = []
    tv_per_token: List[float] = []
    kl_per_token: List[float] = []
    applied = 0
    rel_pos: List[float] = []

    for row in rows:
        stage2 = row.get("stage2") or {}
        meta = row.get("meta") or {}
        n_ctrl = stage2.get("n_ctrl_applied")
        if n_ctrl is None:
            n_ctrl = len(_extract_steps(row))
        if n_ctrl and n_ctrl > 0:
            applied += 1
        terminal_hmin = _extract_terminal_hmin(row)
        if terminal_hmin is not None:
            terminal_deficits.append(max(0.0, -terminal_hmin))
        quality = _extract_quality(row, quality_keys)
        if quality is not None:
            quality_vals.append(quality)
        steps = _extract_steps(row)
        tv_sum, kl_sum = _sum_tv_kl(steps)
        resp_len = meta.get("response_len")
        if resp_len is None:
            resp_len = len(row.get("token_ids") or [])
        if isinstance(resp_len, int) and resp_len > 0:
            tv_per_token.append(tv_sum / resp_len)
            kl_per_token.append(kl_sum / resp_len)
        w_start, w_end, _ = _intervention_meta(row)
        t_star = stage2.get("t_star")
        if w_start is not None and w_end is not None and t_star is not None:
            denom = max(1, int(w_end) - int(w_start) + 1)
            rel_pos.append((int(t_star) - int(w_start)) / denom)

    quality_missing_rate = 1.0
    if rows:
        quality_missing_rate = 1.0 - (len(quality_vals) / len(rows))

    return {
        "n_total": len(rows),
        "applied_rate": applied / len(rows) if rows else 0.0,
        "D_term_mean": mean(terminal_deficits) if terminal_deficits else None,
        "D_term_median": median(terminal_deficits) if terminal_deficits else None,
        "quality_mean": mean(quality_vals) if quality_vals else None,
        "quality_missing_rate": quality_missing_rate,
        "sum_tv_per_token": mean(tv_per_token) if tv_per_token else None,
        "sum_kl_per_token": mean(kl_per_token) if kl_per_token else None,
        "mean_rel_pos": mean(rel_pos) if rel_pos else None,
    }


def main() -> None:
    parser = argparse.ArgumentParser(description="Summarize policy sensitivity experiment")
    parser.add_argument("--runs_dir", required=True)
    args = parser.parse_args()

    runs_dir = Path(args.runs_dir)
    paths = sorted(runs_dir.glob("cont_N*_policy_*.jsonl"))
    if not paths:
        raise SystemExit(f"No runs found in {runs_dir}")

    quality_keys = [
        "meta.qwen_margin_final",
        "meta.qwen_margin",
        "meta.qwen_margin_sentence",
        "meta.quality",
        "stage2.quality",
    ]

    runs: Dict[str, List[dict]] = {}
    policy_for_path: Dict[str, str] = {}
    for path in paths:
        rows = _load_jsonl(path)
        policy = _extract_policy_name(path, rows)
        runs[policy] = rows
        policy_for_path[str(path)] = policy

    ref_policy = "score_max" if "score_max" in runs else sorted(runs.keys())[0]
    ref_rows = runs[ref_policy]
    ref_map: Dict[str, Tuple[Optional[int], Optional[int], List[int]]] = {}
    for row in ref_rows:
        pid = _sample_id(row)
        w_start, w_end, cand = _intervention_meta(row)
        ref_map[pid] = (w_start, w_end, _candidate_steps(cand))

    mismatch_count = 0
    for policy, rows in runs.items():
        if policy == ref_policy:
            continue
        for row in rows:
            pid = _sample_id(row)
            if pid not in ref_map:
                mismatch_count += 1
                continue
            w_start, w_end, cand = _intervention_meta(row)
            ref_start, ref_end, ref_cand = ref_map[pid]
            if w_start != ref_start or w_end != ref_end or _candidate_steps(cand) != ref_cand:
                mismatch_count += 1
    if mismatch_count:
        raise SystemExit(f"candidate/window mismatch_count={mismatch_count}")

    summaries: Dict[str, Dict[str, object]] = {}
    quality_by_policy: Dict[str, Dict[str, float]] = {}
    for policy, rows in runs.items():
        summaries[policy] = summarize_rows(rows, quality_keys)
        q_map: Dict[str, float] = {}
        for row in rows:
            pid = _sample_id(row)
            q_val = _extract_quality(row, quality_keys)
            if q_val is not None:
                q_map[pid] = q_val
        quality_by_policy[policy] = q_map

    w_vs_ref: Dict[str, Optional[float]] = {}
    ref_q = quality_by_policy.get(ref_policy, {})
    for policy, q_map in quality_by_policy.items():
        if policy == ref_policy:
            w_vs_ref[policy] = None
            continue
        pairs = [(pid, q_map[pid], ref_q[pid]) for pid in q_map.keys() & ref_q.keys()]
        if not pairs:
            w_vs_ref[policy] = None
            continue
        wins = sum(1 for _, qa, qb in pairs if qa > qb)
        w_vs_ref[policy] = wins / len(pairs)

    out_dir = runs_dir / "aggregate"
    out_dir.mkdir(parents=True, exist_ok=True)
    csv_path = out_dir / "policy_sensitivity.csv"
    tex_path = out_dir / "policy_sensitivity_table.tex"

    header = [
        "policy",
        "n_total",
        "applied_rate",
        "D_term_mean",
        "D_term_median",
        "quality_mean",
        "quality_missing_rate",
        "sum_tv_per_token",
        "sum_kl_per_token",
        "mean_rel_pos",
        "W_vs_score_max",
    ]
    rows_out: List[Dict[str, object]] = []
    for policy in sorted(summaries.keys()):
        data = summaries[policy]
        row = {k: data.get(k) for k in header}
        row["policy"] = policy
        row["W_vs_score_max"] = w_vs_ref.get(policy)
        rows_out.append(row)

    with csv_path.open("w", encoding="utf-8") as f:
        f.write(",".join(header) + "\n")
        for row in rows_out:
            f.write(",".join("" if row.get(k) is None else str(row.get(k)) for k in header) + "\n")

    def fmt(val: Optional[float]) -> str:
        if val is None:
            return "--"
        return f"{val:.3f}"

    with tex_path.open("w", encoding="utf-8") as f:
        f.write("\\begin{tabular}{lcccc}\n")
        f.write("\\toprule\n")
        f.write("Policy & $D_{\\mathrm{term}}$ & W & $C_{TV}$ & RelPos \\\\\n")
        f.write("\\midrule\n")
        for row in rows_out:
            f.write(
                f"{row['policy']} & {fmt(row.get('D_term_mean'))} & {fmt(row.get('W_vs_score_max'))} & {fmt(row.get('sum_tv_per_token'))} & {fmt(row.get('mean_rel_pos'))} \\\\\n"
            )
        f.write("\\bottomrule\n")
        f.write("\\end{tabular}\n")

    print("mismatch_count", mismatch_count)
    print("wrote", csv_path)
    print("wrote", tex_path)

    sample_pool = list(ref_map.keys())
    random.shuffle(sample_pool)
    sample_pool = sample_pool[:3]
    for policy, rows in runs.items():
        row_map = {_sample_id(r): r for r in rows}
        for pid in sample_pool:
            row = row_map.get(pid)
            if not row:
                continue
            w_start, w_end, cand = _intervention_meta(row)
            t_star = (row.get("stage2") or {}).get("t_star")
            u_val = None
            for cand_entry in cand if isinstance(cand, list) else []:
                if isinstance(cand_entry, dict) and cand_entry.get("step") == t_star:
                    u_val = cand_entry.get("u_total")
                    break
            print(
                policy,
                "prompt_id",
                pid,
                "window",
                w_start,
                w_end,
                "cands",
                len(_candidate_steps(cand)),
                "t_star",
                t_star,
                "u_total",
                u_val,
            )


if __name__ == "__main__":
    main()
