#!/usr/bin/env python3
"""Summarize Module A sweep over continuous_steps (N)."""

from __future__ import annotations

import argparse
import json
import math
import random
from pathlib import Path
from statistics import mean, median
from typing import Dict, List, Optional, Tuple

import matplotlib.pyplot as plt


def _as_float(val: object) -> Optional[float]:
    if isinstance(val, (int, float)) and not math.isnan(val):
        return float(val)
    return None


def _get_path(obj: dict, path: str) -> object:
    cur = obj
    for part in path.split("."):
        if not isinstance(cur, dict):
            return None
        cur = cur.get(part)
    return cur


def _extract_hx_history(row: dict) -> List[dict]:
    meta = row.get("meta") or {}
    hx = meta.get("hx_history")
    if isinstance(hx, list) and hx:
        return hx
    stage2 = row.get("stage2") or {}
    hx = (stage2.get("qwen_hx") or {}).get("hx_history")
    if isinstance(hx, list):
        return hx
    return []


def _hx_to_list(hx: List[dict]) -> Tuple[List[float], List[int]]:
    values: List[float] = []
    times: List[int] = []
    for entry in hx:
        if isinstance(entry, dict):
            h_val = entry.get("h")
            t_val = entry.get("t")
            if h_val is None:
                continue
            values.append(float(h_val))
            if t_val is not None:
                times.append(int(t_val))
        elif isinstance(entry, (int, float)):
            values.append(float(entry))
    return values, times


def _auc_deficit(h_vals: List[float], times: List[int]) -> Optional[float]:
    if not h_vals:
        return None
    deficit = [max(0.0, -h) for h in h_vals]
    if times and len(times) == len(deficit):
        auc = 0.0
        for i in range(len(deficit)):
            if i == 0:
                dt = times[i]
            else:
                dt = max(1, times[i] - times[i - 1])
            auc += deficit[i] * dt
        return auc
    return float(sum(deficit))


def _extract_terminal_hmin(row: dict) -> Optional[float]:
    stage2 = row.get("stage2") or {}
    val = _as_float(stage2.get("terminal_hmin"))
    if val is not None:
        return val
    term_safe = stage2.get("terminal_safe")
    if term_safe is True:
        return 0.0
    if term_safe is False:
        return -1.0
    hx_vals, _ = _hx_to_list(_extract_hx_history(row))
    if hx_vals:
        return hx_vals[-1]
    return None


def _extract_quality(row: dict, quality_keys: List[str]) -> Optional[float]:
    for key in quality_keys:
        val = _as_float(_get_path(row, key))
        if val is not None:
            return val
    return None


def _extract_steps(row: dict) -> List[dict]:
    stage2 = row.get("stage2") or {}
    steps = stage2.get("per_step_diag") or []
    if steps:
        return steps
    cont = (stage2.get("continuous") or {}).get("steps") or []
    return cont


def _sum_tv_kl(steps: List[dict]) -> Tuple[float, float]:
    tv_sum = 0.0
    kl_sum = 0.0
    for step in steps:
        tv = _as_float(step.get("tv_q_p"))
        if tv is None:
            tv = _as_float(step.get("tv"))
        kl = _as_float(step.get("kl_q_ref"))
        if kl is None:
            kl = _as_float(step.get("kl_ref"))
        if tv is not None:
            tv_sum += tv
        if kl is not None:
            kl_sum += kl
    return tv_sum, kl_sum


def _mean(vals: List[Optional[float]]) -> Optional[float]:
    arr = [v for v in vals if v is not None]
    if not arr:
        return None
    return float(mean(arr))


def _load_jsonl(path: Path) -> Tuple[List[dict], int, int]:
    rows: List[dict] = []
    total = 0
    ok = 0
    with path.open("r", encoding="utf-8") as f:
        for line in f:
            if not line.strip():
                continue
            total += 1
            try:
                rows.append(json.loads(line))
                ok += 1
            except Exception:
                continue
    return rows, total, ok


def _sample_id(row: dict) -> str:
    return (
        row.get("prompt_id")
        or (row.get("meta") or {}).get("sample_id")
        or str((row.get("meta") or {}).get("prompt_index"))
    )


def summarize_rows(rows: List[dict], quality_keys: List[str]) -> Dict[str, object]:
    terminal_deficits: List[float] = []
    traj_mean_deficits: List[float] = []
    traj_auc_deficits: List[float] = []
    quality_vals: List[float] = []
    tv_sums: List[float] = []
    kl_sums: List[float] = []
    llm_eval_per_token: List[float] = []
    sum_tv_per_token: List[float] = []
    sum_kl_per_token: List[float] = []
    regen_overhead_ratio: List[float] = []
    applied = 0
    ctrl_steps = []

    for row in rows:
        stage2 = row.get("stage2") or {}
        meta = row.get("meta") or {}
        n_ctrl = stage2.get("n_ctrl_applied")
        if n_ctrl is None:
            n_ctrl = len(_extract_steps(row))
        ctrl_steps.append(float(n_ctrl))
        if n_ctrl and n_ctrl > 0:
            applied += 1
        terminal_hmin = _extract_terminal_hmin(row)
        if terminal_hmin is not None:
            terminal_deficits.append(max(0.0, -terminal_hmin))
        hx_vals, hx_times = _hx_to_list(_extract_hx_history(row))
        if hx_vals:
            deficits = [max(0.0, -h) for h in hx_vals]
            traj_mean_deficits.append(mean(deficits))
            auc_val = _auc_deficit(hx_vals, hx_times)
            if auc_val is not None:
                traj_auc_deficits.append(auc_val)
        quality = _extract_quality(row, quality_keys)
        if quality is not None:
            quality_vals.append(quality)
        steps = _extract_steps(row)
        tv_sum, kl_sum = _sum_tv_kl(steps)
        tv_sums.append(tv_sum)
        kl_sums.append(kl_sum)
        resp_len = meta.get("response_len")
        if resp_len is None:
            token_ids = row.get("token_ids") or []
            if token_ids:
                resp_len = len(token_ids)
        if resp_len:
            llm_eval = meta.get("llm_eval_count")
            if llm_eval is not None:
                llm_eval_per_token.append(float(llm_eval) / float(resp_len))
            sum_tv_per_token.append(tv_sum / float(resp_len))
            sum_kl_per_token.append(kl_sum / float(resp_len))
            regen_tokens = stage2.get("total_regen_tokens", 0) or 0
            regen_overhead_ratio.append(float(regen_tokens) / float(resp_len))

    n_total = len(rows)
    applied_rate = (applied / n_total) if n_total else None
    return {
        "n_total": n_total,
        "applied_rate": applied_rate,
        "D_term": _mean(terminal_deficits),
        "D_traj_mean_deficit": _mean(traj_mean_deficits),
        "D_traj_auc_deficit": _mean(traj_auc_deficits),
        "quality_score_mean": _mean(quality_vals),
        "sum_tv_per_token": _mean(sum_tv_per_token),
        "sum_kl_per_token": _mean(sum_kl_per_token),
        "llm_eval_per_token": _mean(llm_eval_per_token),
        "regen_overhead_ratio": _mean(regen_overhead_ratio),
        "ctrl_steps_mean": _mean(ctrl_steps),
    }


def main() -> None:
    parser = argparse.ArgumentParser(description="Summarize Module A N sweep")
    parser.add_argument("--runs_dir", required=True)
    args = parser.parse_args()

    runs_dir = Path(args.runs_dir)
    out_dir = runs_dir / "aggregate"
    out_dir.mkdir(parents=True, exist_ok=True)
    out_csv = out_dir / "moduleA_metrics.csv"
    out_csv_applied = out_dir / "moduleA_metrics_applied.csv"
    fig_pdf = out_dir / "fig_moduleA_N_sensitivity.pdf"
    fig_png = out_dir / "fig_moduleA_N_sensitivity.png"
    readme_path = out_dir / "README.txt"

    quality_keys = [
        "meta.qwen_margin_final",
        "meta.qwen_margin",
        "meta.quality",
        "meta.judge_score",
        "stage2.quality",
        "stage2.reward",
    ]

    runs = sorted(runs_dir.glob("*.jsonl"))
    if not runs:
        raise SystemExit(f"No jsonl files found under {runs_dir}")

    metrics = []
    metrics_applied = []
    sample_map = {}
    for path in runs:
        rows, total, ok = _load_jsonl(path)
        if total > 0 and ok / total < 0.9:
            raise SystemExit(f"{path}: only {ok}/{total} JSON lines parsed")
        args_path = path.with_suffix(".run_args.json")
        if not args_path.exists():
            raise SystemExit(f"Missing run_args.json for {path}")
        run_args = json.loads(args_path.read_text(encoding="utf-8"))
        N = run_args.get("continuous_steps")
        if N is None:
            raise SystemExit(f"{path}: run_args missing continuous_steps")
        base_names = {((r.get("stage2") or {}).get("baseline_name")) for r in rows}
        if base_names != {"continuous_decay"}:
            raise SystemExit(f"{path}: baseline_name mismatch: {base_names}")

        row_all = summarize_rows(rows, quality_keys)
        row_all["N"] = int(N)
        row_all["run_file"] = path.name
        metrics.append(row_all)

        applied_rows = [r for r in rows if ((r.get("stage2") or {}).get("n_ctrl_applied", 0) or 0) > 0]
        row_applied = summarize_rows(applied_rows, quality_keys)
        row_applied["N"] = int(N)
        row_applied["run_file"] = path.name
        metrics_applied.append(row_applied)

        # For W_vs_N20
        score_map = {}
        for r in rows:
            sid = _sample_id(r)
            score = _extract_quality(r, quality_keys)
            if score is not None:
                score_map[sid] = score
        sample_map[int(N)] = score_map

    metrics = sorted(metrics, key=lambda x: x["N"])
    metrics_applied = sorted(metrics_applied, key=lambda x: x["N"])

    # Compute W_vs_N20
    ref_scores = sample_map.get(20, {})
    for row in metrics:
        N = row["N"]
        if N == 20:
            row["W_vs_N20"] = 0.5
            row["W_vs_N20_n"] = len(ref_scores)
            continue
        scores = sample_map.get(N, {})
        wins = 0
        total = 0
        for sid, s20 in ref_scores.items():
            sN = scores.get(sid)
            if sN is None:
                continue
            wins += int(sN > s20)
            total += 1
        row["W_vs_N20"] = (wins / total) if total else None
        row["W_vs_N20_n"] = total

    headers = [
        "N",
        "run_file",
        "n_total",
        "applied_rate",
        "D_term",
        "D_traj_mean_deficit",
        "D_traj_auc_deficit",
        "quality_score_mean",
        "W_vs_N20",
        "W_vs_N20_n",
        "sum_tv_per_token",
        "sum_kl_per_token",
        "llm_eval_per_token",
        "regen_overhead_ratio",
        "ctrl_steps_mean",
    ]

    def write_csv(path: Path, rows: List[dict]) -> None:
        with path.open("w", encoding="utf-8") as f:
            f.write(",".join(headers) + "\n")
            for row in rows:
                vals = []
                for h in headers:
                    v = row.get(h)
                    if isinstance(v, float):
                        vals.append(f"{v:.6f}")
                    elif v is None:
                        vals.append("")
                    else:
                        vals.append(str(v))
                f.write(",".join(vals) + "\n")

    write_csv(out_csv, metrics)
    write_csv(out_csv_applied, metrics_applied)

    # Plot
    Ns = [r["N"] for r in metrics]
    D_term = [r.get("D_term") for r in metrics]
    D_traj = [r.get("D_traj_mean_deficit") for r in metrics]
    W_vs = [r.get("W_vs_N20") for r in metrics]
    cost = [r.get("sum_tv_per_token") for r in metrics]

    fig, axes = plt.subplots(1, 2, figsize=(10, 4))
    ax = axes[0]
    ax.plot(Ns, D_term, marker="o", label="D_term")
    ax.plot(Ns, D_traj, marker="s", label="D_traj_mean")
    ax.set_xlabel("N (continuous steps)")
    ax.set_ylabel("Deficit")
    ax.legend()
    ax.grid(True, alpha=0.3)

    ax = axes[1]
    ax.plot(Ns, W_vs, marker="o", label="W_vs_N20")
    ax.plot(Ns, cost, marker="s", label="sum_tv_per_token")
    ax.set_xlabel("N (continuous steps)")
    ax.set_ylabel("Score / Cost")
    ax.legend()
    ax.grid(True, alpha=0.3)

    fig.tight_layout()
    fig.savefig(fig_pdf)
    fig.savefig(fig_png, dpi=200)

    with readme_path.open("w", encoding="utf-8") as f:
        f.write("Module A N-sweep summary.\n")
        f.write("Reference for W_vs_N20: N=20.\n")
        f.write("Cost proxy: sum_tv_per_token (per-token TV).\n")

    print("wrote", out_csv)
    print("wrote", out_csv_applied)
    print("wrote", fig_pdf)
    print("wrote", fig_png)
    print("wrote", readme_path)

    # Sanity checks
    for row in metrics:
        if row["applied_rate"] is not None and row["applied_rate"] < 0.0:
            raise SystemExit("applied_rate invalid")
    # Applied subset: ensure TV/KL available if any applied rows
    for row in metrics_applied:
        if row["n_total"] and row.get("sum_tv_per_token") is not None:
            if row["sum_tv_per_token"] <= 0:
                raise SystemExit("Applied subset has non-positive sum_tv_per_token")


if __name__ == "__main__":
    main()
