#!/usr/bin/env python3
"""Summarize decoding sweep for continuous_decay (appendix)."""

from __future__ import annotations

import argparse
import csv
import json
import math
from pathlib import Path
from statistics import mean, median
from typing import Dict, List, Optional, Tuple


QUALITY_KEYS = [
    "meta.qwen_margin_final",
    "meta.qwen_margin",
    "meta.quality",
    "meta.judge_score",
    "stage2.quality",
    "stage2.reward",
]


def _as_float(val: object) -> Optional[float]:
    if isinstance(val, (int, float)) and not math.isnan(val):
        return float(val)
    return None


def _get_path(obj: dict, path: str) -> object:
    cur = obj
    for part in path.split("."):
        if not isinstance(cur, dict):
            return None
        cur = cur.get(part)
    return cur


def _load_jsonl(path: Path) -> Tuple[List[dict], int, int]:
    rows: List[dict] = []
    total = 0
    ok = 0
    with path.open("r", encoding="utf-8") as f:
        for line in f:
            if not line.strip():
                continue
            total += 1
            try:
                rows.append(json.loads(line))
                ok += 1
            except Exception:
                continue
    return rows, total, ok


def _extract_terminal_hmin(row: dict) -> Optional[float]:
    stage2 = row.get("stage2") or {}
    val = _as_float(stage2.get("terminal_hmin"))
    if val is not None:
        return val
    term_safe = stage2.get("terminal_safe")
    if term_safe is True:
        return 0.0
    if term_safe is False:
        return -1.0
    hx_vals = _hx_values(row)
    if hx_vals:
        return hx_vals[-1]
    return None


def _hx_values(row: dict) -> List[float]:
    meta = row.get("meta") or {}
    hx = meta.get("hx_history")
    if not hx:
        stage2 = row.get("stage2") or {}
        hx = (stage2.get("qwen_hx") or {}).get("hx_history")
    vals: List[float] = []
    if isinstance(hx, list):
        for entry in hx:
            if isinstance(entry, dict):
                h_val = entry.get("h")
                if h_val is not None:
                    vals.append(float(h_val))
            elif isinstance(entry, (int, float)):
                vals.append(float(entry))
    return vals


def _extract_quality(row: dict) -> Optional[float]:
    for key in QUALITY_KEYS:
        val = _as_float(_get_path(row, key))
        if val is not None:
            return val
    return None


def _extract_steps(row: dict) -> List[dict]:
    stage2 = row.get("stage2") or {}
    steps = stage2.get("per_step_diag") or []
    if steps:
        return steps
    cont = (stage2.get("continuous") or {}).get("steps") or []
    return cont


def _sum_tv_kl(steps: List[dict]) -> Tuple[float, float]:
    tv_sum = 0.0
    kl_sum = 0.0
    for step in steps:
        tv = _as_float(step.get("tv_q_p"))
        if tv is None:
            tv = _as_float(step.get("tv"))
        kl = _as_float(step.get("kl_q_ref"))
        if kl is None:
            kl = _as_float(step.get("kl_ref"))
        if tv is not None:
            tv_sum += tv
        if kl is not None:
            kl_sum += kl
    return tv_sum, kl_sum


def _response_len(row: dict) -> Optional[int]:
    meta = row.get("meta") or {}
    resp = meta.get("response_len")
    if isinstance(resp, int) and resp > 0:
        return resp
    token_ids = row.get("token_ids")
    if isinstance(token_ids, list) and token_ids:
        return len(token_ids)
    return None


def _mean(vals: List[Optional[float]]) -> Optional[float]:
    arr = [v for v in vals if v is not None]
    if not arr:
        return None
    return float(mean(arr))


def _median(vals: List[Optional[float]]) -> Optional[float]:
    arr = [v for v in vals if v is not None]
    if not arr:
        return None
    return float(median(arr))


def summarize_run(path: Path, args: dict) -> Dict[str, object]:
    rows, total, ok = _load_jsonl(path)
    baseline_names = {((r.get("stage2") or {}).get("baseline_name")) for r in rows}
    if baseline_names != {"continuous_decay"}:
        raise ValueError(f"Non-continuous baseline in {path}: {baseline_names}")

    triggered = []
    applied = []
    dterm_all = []
    dterm_trig = []
    quality_all = []
    quality_trig = []
    tv_all = []
    kl_all = []
    tv_trig = []
    kl_trig = []
    len_all = []
    len_trig = []

    for row in rows:
        stage2 = row.get("stage2") or {}
        meta = row.get("meta") or {}
        t_u = stage2.get("t_u")
        if t_u is None:
            t_u = meta.get("t_u_qwen")
        is_triggered = t_u not in (None, -1)
        n_ctrl = stage2.get("n_ctrl_applied")
        if n_ctrl is None:
            n_ctrl = len(_extract_steps(row))
        is_applied = bool(n_ctrl and n_ctrl > 0)

        term_h = _extract_terminal_hmin(row)
        if term_h is not None:
            dterm_all.append(max(0.0, -term_h))
            if is_triggered:
                dterm_trig.append(max(0.0, -term_h))

        q = _extract_quality(row)
        if q is not None:
            quality_all.append(q)
            if is_triggered:
                quality_trig.append(q)

        steps = _extract_steps(row)
        tv_sum, kl_sum = _sum_tv_kl(steps)
        tv_all.append(tv_sum)
        kl_all.append(kl_sum)
        if is_triggered:
            tv_trig.append(tv_sum)
            kl_trig.append(kl_sum)

        resp_len = _response_len(row) or len(steps) or 1
        len_all.append(resp_len)
        if is_triggered:
            len_trig.append(resp_len)

        if is_triggered:
            triggered.append(row)
        if is_applied:
            applied.append(row)

    def per_token(vals: List[float], lens: List[int]) -> Optional[float]:
        if not vals or not lens:
            return None
        return float(mean([v / l if l > 0 else 0.0 for v, l in zip(vals, lens)]))

    quality_missing_rate = None
    if rows:
        quality_missing_rate = 1.0 - (len(quality_all) / len(rows))

    return {
        "run_id": path.stem,
        "temperature": args.get("temperature"),
        "top_p": args.get("top_p"),
        "top_k": args.get("top_k"),
        "seed": args.get("seed"),
        "continuous_steps": args.get("continuous_steps"),
        "n_total": len(rows),
        "n_parse_ok": ok,
        "n_parse_total": total,
        "n_triggered": len(triggered),
        "triggered_rate": len(triggered) / len(rows) if rows else None,
        "applied_rate_all": len(applied) / len(rows) if rows else None,
        "applied_rate_trig": len(applied) / len(triggered) if triggered else None,
        "dterm_mean_all": _mean(dterm_all),
        "dterm_med_all": _median(dterm_all),
        "dterm_mean_trig": _mean(dterm_trig),
        "dterm_med_trig": _median(dterm_trig),
        "kl_per_token_all": per_token(kl_all, len_all),
        "kl_per_token_trig": per_token(kl_trig, len_trig),
        "tv_per_token_all": per_token(tv_all, len_all),
        "tv_per_token_trig": per_token(tv_trig, len_trig),
        "quality_mean_all": _mean(quality_all),
        "quality_mean_trig": _mean(quality_trig),
        "quality_missing_rate": quality_missing_rate,
    }


def _load_run_args(path: Path) -> Dict[str, object]:
    args_path = path.with_suffix(".run_args.json")
    if not args_path.exists():
        raise FileNotFoundError(f"Missing run_args.json for {path}")
    return json.loads(args_path.read_text(encoding="utf-8"))


def _matrix(rows: List[Dict[str, object]], value_key: str) -> Tuple[List[str], List[str], List[List[object]]]:
    temps = sorted({str(r["temperature"]) for r in rows})
    ps = sorted({str(r["top_p"]) for r in rows})
    table: List[List[object]] = []
    for t in temps:
        row_vals = []
        for p in ps:
            val = next((r.get(value_key) for r in rows if str(r["temperature"]) == t and str(r["top_p"]) == p), None)
            row_vals.append(val)
        table.append(row_vals)
    return temps, ps, table


def _write_matrix(path: Path, temps: List[str], ps: List[str], table: List[List[object]]) -> None:
    with path.open("w", encoding="utf-8", newline="") as f:
        writer = csv.writer(f)
        writer.writerow(["T"] + ps)
        for t, row in zip(temps, table):
            writer.writerow([t] + row)


def main() -> None:
    parser = argparse.ArgumentParser(description="Summarize appendix decoding sweep (continuous only).")
    parser.add_argument("--runs_dir", required=True)
    args = parser.parse_args()

    runs_dir = Path(args.runs_dir)
    out_dir = runs_dir / "aggregate"
    out_dir.mkdir(parents=True, exist_ok=True)

    rows: List[Dict[str, object]] = []
    for path in sorted(runs_dir.glob("*.jsonl")):
        run_args = _load_run_args(path)
        rows.append(summarize_run(path, run_args))

    if len(rows) != 6:
        print(f"WARN: expected 6 runs, found {len(rows)}")

    cells_path = out_dir / "decode_sweep_cells.csv"
    fieldnames = list(rows[0].keys()) if rows else []
    with cells_path.open("w", encoding="utf-8", newline="") as f:
        writer = csv.DictWriter(f, fieldnames=fieldnames)
        writer.writeheader()
        writer.writerows(rows)

    temps, ps, table = _matrix(rows, "dterm_med_trig")
    _write_matrix(out_dir / "matrix_dterm_med.csv", temps, ps, table)

    temps, ps, table = _matrix(rows, "kl_per_token_trig")
    _write_matrix(out_dir / "matrix_kl_per_token.csv", temps, ps, table)

    temps, ps, table = _matrix(rows, "quality_mean_trig")
    _write_matrix(out_dir / "matrix_quality_mean.csv", temps, ps, table)

    readme = out_dir / "README.txt"
    readme.write_text(
        "decode_sweep_cells.csv: one row per (T, top_p) run.\n"
        "matrix_* files: rows=temperature, cols=top_p.\n"
        "dterm_* uses terminal_deficit = max(0, -terminal_hmin).\n"
        "kl/tv per token divide by response_len (or token_ids length).\n",
        encoding="utf-8",
    )

    print(f"wrote {cells_path}")


if __name__ == "__main__":
    main()
