#!/usr/bin/env python3
"""Summarize Module B forcing/slack ablation for continuous_decay."""

from __future__ import annotations

import argparse
import json
import math
from pathlib import Path
from statistics import mean
from typing import Dict, List, Optional, Tuple


def _as_float(val: object) -> Optional[float]:
    if isinstance(val, (int, float)) and not math.isnan(val):
        return float(val)
    return None


def _get_path(obj: dict, path: str) -> object:
    cur = obj
    for part in path.split("."):
        if not isinstance(cur, dict):
            return None
        cur = cur.get(part)
    return cur


def _extract_hx_history(row: dict) -> List[dict]:
    meta = row.get("meta") or {}
    hx = meta.get("hx_history")
    if isinstance(hx, list) and hx:
        return hx
    stage2 = row.get("stage2") or {}
    hx = (stage2.get("qwen_hx") or {}).get("hx_history")
    if isinstance(hx, list):
        return hx
    return []


def _hx_to_list(hx: List[dict]) -> Tuple[List[float], List[int]]:
    values: List[float] = []
    times: List[int] = []
    for entry in hx:
        if isinstance(entry, dict):
            h_val = entry.get("h")
            t_val = entry.get("t")
            if h_val is None:
                continue
            values.append(float(h_val))
            if t_val is not None:
                times.append(int(t_val))
        elif isinstance(entry, (int, float)):
            values.append(float(entry))
    return values, times


def _extract_terminal_hmin(row: dict) -> Optional[float]:
    stage2 = row.get("stage2") or {}
    val = _as_float(stage2.get("terminal_hmin"))
    if val is not None:
        return val
    term_safe = stage2.get("terminal_safe")
    if term_safe is True:
        return 0.0
    if term_safe is False:
        return -1.0
    hx_vals, _ = _hx_to_list(_extract_hx_history(row))
    if hx_vals:
        return hx_vals[-1]
    return None


def _extract_quality(row: dict, quality_keys: List[str]) -> Optional[float]:
    for key in quality_keys:
        val = _as_float(_get_path(row, key))
        if val is not None:
            return val
    return None


def _extract_steps(row: dict) -> List[dict]:
    stage2 = row.get("stage2") or {}
    steps = stage2.get("per_step_diag") or []
    if steps:
        return steps
    cont = (stage2.get("continuous") or {}).get("steps") or []
    return cont


def _sum_tv_kl(steps: List[dict]) -> Tuple[float, float]:
    tv_sum = 0.0
    kl_sum = 0.0
    for step in steps:
        tv = _as_float(step.get("tv_q_p"))
        if tv is None:
            tv = _as_float(step.get("tv"))
        kl = _as_float(step.get("kl_q_ref"))
        if kl is None:
            kl = _as_float(step.get("kl_ref"))
        if tv is not None:
            tv_sum += tv
        if kl is not None:
            kl_sum += kl
    return tv_sum, kl_sum


def _mean(vals: List[Optional[float]]) -> Optional[float]:
    arr = [v for v in vals if v is not None]
    if not arr:
        return None
    return float(mean(arr))


def _load_jsonl(path: Path) -> Tuple[List[dict], int, int]:
    rows: List[dict] = []
    total = 0
    ok = 0
    with path.open("r", encoding="utf-8") as f:
        for line in f:
            if not line.strip():
                continue
            total += 1
            try:
                rows.append(json.loads(line))
                ok += 1
            except Exception:
                continue
    return rows, total, ok


def _parse_flags_from_name(name: str) -> Tuple[int, int]:
    # Expected: cont_N20_f1_s0.jsonl
    forcing = None
    slack = None
    for part in name.split("_"):
        if part.startswith("f"):
            try:
                forcing = int(part[1:])
            except Exception:
                pass
        if part.startswith("s"):
            try:
                slack = int(part[1:])
            except Exception:
                pass
    if forcing is None or slack is None:
        raise ValueError(f"Cannot parse forcing/slack from filename: {name}")
    return forcing, slack


def summarize_rows(rows: List[dict], quality_keys: List[str]) -> Dict[str, object]:
    terminal_deficits: List[float] = []
    quality_vals: List[float] = []
    tv_sums: List[float] = []
    kl_sums: List[float] = []
    llm_eval_per_token: List[float] = []
    sum_tv_per_token: List[float] = []
    sum_kl_per_token: List[float] = []
    regen_overhead_ratio: List[float] = []
    ctrl_steps = []
    applied = 0
    nonzero_slack = 0
    slack_used_sum = []
    failure_or_fallback = 0

    for row in rows:
        stage2 = row.get("stage2") or {}
        meta = row.get("meta") or {}
        steps = _extract_steps(row)
        n_ctrl = stage2.get("n_ctrl_applied")
        if n_ctrl is None:
            n_ctrl = len(steps)
        ctrl_steps.append(float(n_ctrl))
        if n_ctrl and n_ctrl > 0:
            applied += 1

        terminal_hmin = _extract_terminal_hmin(row)
        if terminal_hmin is not None:
            terminal_deficits.append(max(0.0, -terminal_hmin))
        quality = _extract_quality(row, quality_keys)
        if quality is not None:
            quality_vals.append(quality)

        tv_sum, kl_sum = _sum_tv_kl(steps)
        tv_sums.append(tv_sum)
        kl_sums.append(kl_sum)

        resp_len = meta.get("response_len")
        if resp_len is None:
            token_ids = row.get("token_ids") or []
            if token_ids:
                resp_len = len(token_ids)
        if resp_len:
            llm_eval = meta.get("llm_eval_count")
            if llm_eval is not None:
                llm_eval_per_token.append(float(llm_eval) / float(resp_len))
            sum_tv_per_token.append(tv_sum / float(resp_len))
            sum_kl_per_token.append(kl_sum / float(resp_len))
            regen_tokens = stage2.get("total_regen_tokens", 0) or 0
            regen_overhead_ratio.append(float(regen_tokens) / float(resp_len))

        # diagnostics
        slack_vals = []
        failed = False
        for step in steps:
            su = _as_float(step.get("slack_used"))
            if su is not None:
                slack_vals.append(su)
            status = step.get("solver_status")
            if status and str(status) != "success":
                failed = True
            reason = step.get("fallback_reason")
            if reason:
                failed = True
        if slack_vals and any(val > 0 for val in slack_vals):
            nonzero_slack += 1
        slack_used_sum.append(sum(slack_vals) if slack_vals else 0.0)
        if failed:
            failure_or_fallback += 1

    n_total = len(rows)
    applied_rate = (applied / n_total) if n_total else None
    failure_rate = (failure_or_fallback / n_total) if n_total else None
    nonzero_slack_rate = (nonzero_slack / n_total) if n_total else None

    cost_total_per_token = None
    if llm_eval_per_token or sum_tv_per_token or sum_kl_per_token or regen_overhead_ratio:
        cost_total_per_token = (
            (_mean(llm_eval_per_token) or 0.0)
            + (_mean(sum_tv_per_token) or 0.0)
            + (_mean(sum_kl_per_token) or 0.0)
            + (_mean(regen_overhead_ratio) or 0.0)
        )

    return {
        "n_total": n_total,
        "applied_rate": applied_rate,
        "D_term": _mean(terminal_deficits),
        "quality_score_mean": _mean(quality_vals),
        "cost_total_per_token": cost_total_per_token,
        "sum_tv_per_token": _mean(sum_tv_per_token),
        "sum_kl_per_token": _mean(sum_kl_per_token),
        "llm_eval_per_token": _mean(llm_eval_per_token),
        "regen_overhead_ratio": _mean(regen_overhead_ratio),
        "ctrl_steps_mean": _mean(ctrl_steps),
        "nonzero_slack_rate": nonzero_slack_rate,
        "mean_slack_used": _mean(slack_used_sum),
        "failure_or_fallback_rate": failure_rate,
    }


def main() -> None:
    parser = argparse.ArgumentParser(description="Summarize Module B forcing/slack ablation")
    parser.add_argument("--runs_dir", required=True)
    args = parser.parse_args()

    runs_dir = Path(args.runs_dir)
    out_dir = runs_dir / "aggregate"
    out_dir.mkdir(parents=True, exist_ok=True)
    out_csv = out_dir / "moduleB_table.csv"
    out_tex = out_dir / "moduleB_table.tex"
    readme_path = out_dir / "README.txt"

    quality_keys = [
        "meta.qwen_margin_final",
        "meta.qwen_margin",
        "meta.quality",
        "meta.judge_score",
        "stage2.quality",
        "stage2.reward",
    ]

    runs = sorted(runs_dir.glob("*.jsonl"))
    if not runs:
        raise SystemExit(f"No jsonl files found under {runs_dir}")

    rows = []
    combos = {}
    for path in runs:
        rows_json, total, ok = _load_jsonl(path)
        if total > 0 and ok / total < 0.9:
            raise SystemExit(f"{path}: only {ok}/{total} JSON lines parsed")
        args_path = path.with_suffix(".run_args.json")
        if not args_path.exists():
            raise SystemExit(f"Missing run_args.json for {path}")
        run_args = json.loads(args_path.read_text(encoding="utf-8"))
        forcing_enabled = int(run_args.get("forcing_enabled", 1))
        slack_enabled = int(run_args.get("slack_enabled", 1))
        f_name, s_name = _parse_flags_from_name(path.stem)
        if forcing_enabled != f_name or slack_enabled != s_name:
            raise SystemExit(f"{path}: forcing/slack mismatch with run_args")
        base_names = {((r.get("stage2") or {}).get("baseline_name")) for r in rows_json}
        if base_names != {"continuous_decay"}:
            raise SystemExit(f"{path}: baseline_name mismatch: {base_names}")

        metrics = summarize_rows(rows_json, quality_keys)
        metrics.update({"forcing": forcing_enabled, "slack": slack_enabled, "run_file": path.name})
        rows.append(metrics)
        combos[(forcing_enabled, slack_enabled)] = metrics

    # Ensure 4 combos exist
    for combo in [(1, 1), (0, 1), (1, 0), (0, 0)]:
        if combo not in combos:
            raise SystemExit(f"Missing combo forcing={combo[0]} slack={combo[1]}")

    # sanity checks
    for metrics in rows:
        if metrics["slack"] == 0 and (metrics["nonzero_slack_rate"] or 0) > 0:
            raise SystemExit("slack=0 but nonzero_slack_rate > 0")
    for metrics in rows:
        if metrics["slack"] == 0 and (metrics["failure_or_fallback_rate"] or 0) == 0:
            print("WARN: slack=0 but failure_or_fallback_rate is 0 (check if slack disabled)")

    # Forcing off should reduce TV/KL
    f1 = combos[(1, 1)]
    f0 = combos[(0, 1)]
    if (
        f0.get("sum_tv_per_token") is not None
        and f1.get("sum_tv_per_token") is not None
        and f0["sum_tv_per_token"] >= 0.8 * f1["sum_tv_per_token"]
    ):
        print("WARN: forcing=0 has similar sum_tv_per_token to forcing=1")

    headers = [
        "forcing",
        "slack",
        "n_total",
        "applied_rate",
        "D_term",
        "quality_score_mean",
        "cost_total_per_token",
        "sum_tv_per_token",
        "sum_kl_per_token",
        "nonzero_slack_rate",
        "mean_slack_used",
        "failure_or_fallback_rate",
    ]

    with out_csv.open("w", encoding="utf-8") as f:
        f.write(",".join(headers) + "\n")
        for row in sorted(rows, key=lambda x: (x["forcing"], x["slack"]), reverse=True):
            vals = []
            for h in headers:
                v = row.get(h)
                if isinstance(v, float):
                    vals.append(f"{v:.6f}")
                elif v is None:
                    vals.append("")
                else:
                    vals.append(str(v))
            f.write(",".join(vals) + "\n")

    # Build LaTeX table with best-value bolding for D_term (min), quality (max), cost (min)
    def _best(rows, key, mode):
        vals = [r.get(key) for r in rows if r.get(key) is not None]
        if not vals:
            return None
        return min(vals) if mode == "min" else max(vals)

    best_d = _best(rows, "D_term", "min")
    best_q = _best(rows, "quality_score_mean", "max")
    best_c = _best(rows, "cost_total_per_token", "min")

    def fmt(val, key):
        if val is None:
            return "NA"
        s = f"{val:.4f}"
        if (key == "D_term" and best_d is not None and val == best_d) or (
            key == "quality_score_mean" and best_q is not None and val == best_q
        ) or (key == "cost_total_per_token" and best_c is not None and val == best_c):
            return f"\\textbf{{{s}}}"
        return s

    with out_tex.open("w", encoding="utf-8") as f:
        f.write("\\begin{table}[t]\\centering\\small\n")
        f.write("\\begin{tabular}{ccccccll}\\hline\n")
        f.write("forcing & slack & $D_{term}\\downarrow$ & $W\\uparrow$ & $C\\downarrow$ & slack>0 & fail\\\\\\\\\\hline\n")
        for row in sorted(rows, key=lambda x: (x["forcing"], x["slack"]), reverse=True):
            f.write(
                f"{row['forcing']} & {row['slack']} & {fmt(row.get('D_term'),'D_term')} & "
                f"{fmt(row.get('quality_score_mean'),'quality_score_mean')} & "
                f"{fmt(row.get('cost_total_per_token'),'cost_total_per_token')} & "
                f"{row.get('nonzero_slack_rate',0):.2f} & {row.get('failure_or_fallback_rate',0):.2f} \\\\\\\\\n"
            )
        f.write("\\hline\\end{tabular}\n")
        f.write("\\caption{Forcing/Slack ablation under continuous control (N=20). Lower $D_{term}$ and $C$ are better; higher $W$ is better. Forcing prevents trivial updates; slack maintains feasibility under top-$V$ truncation.}\\end{table}\n")

    with readme_path.open("w", encoding="utf-8") as f:
        f.write("Module B ablation summary (N=20).\n")
        f.write("mean_slack_used is averaged over all samples (non-applied counted as 0).\n")
        f.write("failure_or_fallback_rate counts any step with solver_status!=success or fallback_reason.\n")
        f.write("Quality uses qwen_margin fields when available; otherwise NA.\n")

    print("wrote", out_csv)
    print("wrote", out_tex)
    print("wrote", readme_path)
    print("SANITY: combos", sorted(combos.keys()))


if __name__ == "__main__":
    main()
