# cirbench/reporting/aggregate.py
from __future__ import annotations
from pathlib import Path
import json
import csv
import statistics
from typing import Dict, Any, List, Optional, Tuple



def _glob_task_modes(run_dir: Path, task: str) -> List[Path]:
    raw = run_dir / "raw"
    prefix = task + "."
    out: List[Path] = []
    if raw.exists():
        for p in raw.iterdir():
            if p.is_dir() and p.name.startswith(prefix):
                out.append(p)
    return sorted(out, key=lambda x: x.name)


def _read_metrics_json(case_dir: Path) -> Optional[Dict[str, Any]]:
    m = case_dir / "01_artifacts" / "metrics.json"
    if not m.exists():
        return None
    try:
        return json.loads(m.read_text(encoding="utf-8"))
    except Exception:
        return None


def _ensure_tables_dir(run_dir: Path) -> Path:
    t = run_dir / "tables"
    t.mkdir(parents=True, exist_ok=True)
    return t


def _write_csv(path: Path, header: List[str], rows: List[List[Any]]) -> None:
    path.parent.mkdir(parents=True, exist_ok=True)
    with open(path, "w", encoding="utf-8", newline="") as f:
        w = csv.writer(f)
        w.writerow(header)
        for r in rows:
            w.writerow(r)


def _safe_rate(num: int, den: int) -> Optional[float]:
    if den <= 0:
        return None
    return num / den


def _num_stats(values: List[Optional[float]]) -> Dict[str, Optional[float]]:
    vals = [float(v) for v in values if isinstance(v, (int, float))]
    if not vals:
        return {
            "sum": None,
            "avg": None,
            "median": None,
            "min": None,
            "max": None,
            "count": 0,
        }
    return {
        "sum": sum(vals),
        "avg": sum(vals) / len(vals),
        "median": statistics.median(vals),
        "min": min(vals),
        "max": max(vals),
        "count": len(vals),
    }


def _case_prefix(case_id: str) -> str:
    """
      A001_alias_001 -> A001
      RF003_loop_007 -> RF003
      T001_Loops_001 -> T001
      T005_Super_003 -> T005
    """
    if not case_id:
        return ""
    return case_id.split("_", 1)[0]


def _best_shot(shots: List[Dict[str, Any]]) -> Optional[Dict[str, Any]]:
    if not shots:
        return None
    for s in shots:
        if (s.get("equiv") or {}).get("ok") is True:
            return s
    for s in shots:
        if s.get("valid") is True:
            return s
    for s in shots:
        if s.get("pass") is True:
            return s
    return shots[0]

def write_analysis_aggregates(run_dir: Path) -> None:
    tables = _ensure_tables_dir(run_dir)

    overall_rows: List[List[Any]] = []
    diff_rows: List[List[Any]] = []
    prefix_rows: List[List[Any]] = []
    prefix_diff_rows: List[List[Any]] = []

    for mode_dir in _glob_task_modes(run_dir, "analysis"):
        mode = mode_dir.name.split(".", 1)[1] if "." in mode_dir.name else mode_dir.name
        records = []

        for case_dir in sorted(mode_dir.iterdir()):
            if not case_dir.is_dir():
                continue
            jj = _read_metrics_json(case_dir)
            if not jj:
                continue

            case_id = jj.get("id") or jj.get("case_id") or case_dir.name
            difficulty = jj.get("difficulty")
            em = bool(jj.get("em"))
            unparseable = bool(jj.get("unparseable"))
            toks = jj.get("tokens") or {}
            tin = toks.get("in")
            tout = toks.get("out")

            records.append(
                {
                    "mode": mode,
                    "case_id": case_id,
                    "difficulty": difficulty,
                    "prefix": _case_prefix(case_id),
                    "em": em,
                    "unparseable": unparseable,
                    "tokens_in": tin,
                    "tokens_out": tout,
                }
            )

        if not records:
            continue

        n = len(records)
        em_rate = _safe_rate(sum(1 for r in records if r["em"]), n)
        unp_rate = _safe_rate(sum(1 for r in records if r["unparseable"]), n)
        t_in_stats = _num_stats([r["tokens_in"] for r in records])
        t_out_stats = _num_stats([r["tokens_out"] for r in records])

        overall_rows.append(
            [
                "analysis",
                mode,
                n,
                em_rate,
                unp_rate,
                t_in_stats["sum"],
                t_in_stats["avg"],
                t_in_stats["median"],
                t_in_stats["min"],
                t_in_stats["max"],
                t_out_stats["sum"],
                t_out_stats["avg"],
                t_out_stats["median"],
                t_out_stats["min"],
                t_out_stats["max"],
            ]
        )

        by_diff: Dict[str, List[Dict[str, Any]]] = {}
        for r in records:
            key = r["difficulty"] or "UNKNOWN"
            by_diff.setdefault(key, []).append(r)
        for diff, rs in sorted(by_diff.items()):
            n = len(rs)
            em_rate = _safe_rate(sum(1 for r in rs if r["em"]), n)
            unp_rate = _safe_rate(sum(1 for r in rs if r["unparseable"]), n)
            ti = _num_stats([r["tokens_in"] for r in rs])
            to = _num_stats([r["tokens_out"] for r in rs])
            diff_rows.append(
                [
                    "analysis",
                    mode,
                    diff,
                    n,
                    em_rate,
                    unp_rate,
                    ti["sum"],
                    ti["avg"],
                    ti["median"],
                    ti["min"],
                    ti["max"],
                    to["sum"],
                    to["avg"],
                    to["median"],
                    to["min"],
                    to["max"],
                ]
            )

        by_pref: Dict[str, List[Dict[str, Any]]] = {}
        for r in records:
            key = r["prefix"] or "UNKNOWN"
            by_pref.setdefault(key, []).append(r)
        for pref, rs in sorted(by_pref.items()):
            n = len(rs)
            em_rate = _safe_rate(sum(1 for r in rs if r["em"]), n)
            unp_rate = _safe_rate(sum(1 for r in rs if r["unparseable"]), n)
            ti = _num_stats([r["tokens_in"] for r in rs])
            to = _num_stats([r["tokens_out"] for r in rs])
            prefix_rows.append(
                [
                    "analysis",
                    mode,
                    pref,
                    n,
                    em_rate,
                    unp_rate,
                    ti["sum"],
                    ti["avg"],
                    ti["median"],
                    ti["min"],
                    ti["max"],
                    to["sum"],
                    to["avg"],
                    to["median"],
                    to["min"],
                    to["max"],
                ]
            )

        by_pd: Dict[Tuple[str, str], List[Dict[str, Any]]] = {}
        for r in records:
            pref = r["prefix"] or "UNKNOWN"
            diff = r["difficulty"] or "UNKNOWN"
            by_pd.setdefault((pref, diff), []).append(r)
        for (pref, diff), rs in sorted(by_pd.items()):
            n = len(rs)
            em_rate = _safe_rate(sum(1 for r in rs if r["em"]), n)
            unp_rate = _safe_rate(sum(1 for r in rs if r["unparseable"]), n)
            ti = _num_stats([r["tokens_in"] for r in rs])
            to = _num_stats([r["tokens_out"] for r in rs])
            prefix_diff_rows.append(
                [
                    "analysis",
                    mode,
                    pref,
                    diff,
                    n,
                    em_rate,
                    unp_rate,
                    ti["sum"],
                    ti["avg"],
                    ti["median"],
                    ti["min"],
                    ti["max"],
                    to["sum"],
                    to["avg"],
                    to["median"],
                    to["min"],
                    to["max"],
                ]
            )

    if overall_rows:
        _write_csv(
            tables / "agg.analysis.overall.csv",
            [
                "task",
                "mode",
                "case_count",
                "em_rate",
                "unparseable_rate",
                "tokens_in_sum",
                "tokens_in_avg",
                "tokens_in_median",
                "tokens_in_min",
                "tokens_in_max",
                "tokens_out_sum",
                "tokens_out_avg",
                "tokens_out_median",
                "tokens_out_min",
                "tokens_out_max",
            ],
            overall_rows,
        )

    if diff_rows:
        _write_csv(
            tables / "agg.analysis.by_difficulty.csv",
            [
                "task",
                "mode",
                "difficulty",
                "case_count",
                "em_rate",
                "unparseable_rate",
                "tokens_in_sum",
                "tokens_in_avg",
                "tokens_in_median",
                "tokens_in_min",
                "tokens_in_max",
                "tokens_out_sum",
                "tokens_out_avg",
                "tokens_out_median",
                "tokens_out_min",
                "tokens_out_max",
            ],
            diff_rows,
        )

    if prefix_rows:
        _write_csv(
            tables / "agg.analysis.by_prefix.csv",
            [
                "task",
                "mode",
                "prefix",
                "case_count",
                "em_rate",
                "unparseable_rate",
                "tokens_in_sum",
                "tokens_in_avg",
                "tokens_in_median",
                "tokens_in_min",
                "tokens_in_max",
                "tokens_out_sum",
                "tokens_out_avg",
                "tokens_out_median",
                "tokens_out_min",
                "tokens_out_max",
            ],
            prefix_rows,
        )

    if prefix_diff_rows:
        _write_csv(
            tables / "agg.analysis.by_prefix_difficulty.csv",
            [
                "task",
                "mode",
                "prefix",
                "difficulty",
                "case_count",
                "em_rate",
                "unparseable_rate",
                "tokens_in_sum",
                "tokens_in_avg",
                "tokens_in_median",
                "tokens_in_min",
                "tokens_in_max",
                "tokens_out_sum",
                "tokens_out_avg",
                "tokens_out_median",
                "tokens_out_min",
                "tokens_out_max",
            ],
            prefix_diff_rows,
        )


def write_repair_aggregates(run_dir: Path) -> None:
    """
    repair.normal / repair.hard:
      pass@1/5, valid@1/5, equiv@1/5
      overall + by_difficulty
    """
    tables = _ensure_tables_dir(run_dir)

    overall_rows: List[List[Any]] = []
    diff_rows: List[List[Any]] = []

    for mode_dir in _glob_task_modes(run_dir, "repair"):
        mode = mode_dir.name.split(".", 1)[1] if "." in mode_dir.name else mode_dir.name
        records = []

        for case_dir in sorted(mode_dir.iterdir()):
            if not case_dir.is_dir():
                continue
            jj = _read_metrics_json(case_dir)
            if not jj:
                continue

            case_id = jj.get("case_id") or case_dir.name
            difficulty = jj.get("difficulty")
            summ = jj.get("summary") or {}
            shots = jj.get("shots") or []
            unparseable = any(bool((s or {}).get("unparseable")) for s in shots)

            tin = tout = None
            if shots:
                t0 = (shots[0] or {}).get("tokens") or {}
                tin = t0.get("in")
                tout = t0.get("out")

            records.append(
                {
                    "mode": mode,
                    "case_id": case_id,
                    "difficulty": difficulty,
                    "pass_at_1": bool(summ.get("pass_at_1")),
                    "pass_at_5": bool(summ.get("pass_at_5")),
                    "valid_at_1": bool(summ.get("valid_at_1")),
                    "valid_at_5": bool(summ.get("valid_at_5")),
                    "equiv_at_1": bool(summ.get("equiv_at_1")),
                    "equiv_at_5": bool(summ.get("equiv_at_5")),
                    "unparseable": unparseable,
                    "tokens_in": tin,
                    "tokens_out": tout,
                }
            )

        if not records:
            continue

        # overall
        n = len(records)
        def cr(key: str) -> Optional[float]:
            return _safe_rate(sum(1 for r in records if r[key]), n)

        pass1 = cr("pass_at_1")
        pass5 = cr("pass_at_5")
        valid1 = cr("valid_at_1")
        valid5 = cr("valid_at_5")
        equiv1 = cr("equiv_at_1")
        equiv5 = cr("equiv_at_5")
        unp = _safe_rate(sum(1 for r in records if r["unparseable"]), n)
        ti = _num_stats([r["tokens_in"] for r in records])
        to = _num_stats([r["tokens_out"] for r in records])

        overall_rows.append(
            [
                "repair",
                mode,
                n,
                pass1,
                pass5,
                valid1,
                valid5,
                equiv1,
                equiv5,
                unp,
                ti["sum"],
                ti["avg"],
                ti["median"],
                ti["min"],
                ti["max"],
                to["sum"],
                to["avg"],
                to["median"],
                to["min"],
                to["max"],
            ]
        )

        # by difficulty
        by_diff: Dict[str, List[Dict[str, Any]]] = {}
        for r in records:
            key = r["difficulty"] or "UNKNOWN"
            by_diff.setdefault(key, []).append(r)
        for diff, rs in sorted(by_diff.items()):
            n = len(rs)

            def cr_local(key: str) -> Optional[float]:
                return _safe_rate(sum(1 for x in rs if x[key]), n)

            pass1 = cr_local("pass_at_1")
            pass5 = cr_local("pass_at_5")
            valid1 = cr_local("valid_at_1")
            valid5 = cr_local("valid_at_5")
            equiv1 = cr_local("equiv_at_1")
            equiv5 = cr_local("equiv_at_5")
            unp = _safe_rate(sum(1 for x in rs if x["unparseable"]), n)
            ti = _num_stats([x["tokens_in"] for x in rs])
            to = _num_stats([x["tokens_out"] for x in rs])

            diff_rows.append(
                [
                    "repair",
                    mode,
                    diff,
                    n,
                    pass1,
                    pass5,
                    valid1,
                    valid5,
                    equiv1,
                    equiv5,
                    unp,
                    ti["sum"],
                    ti["avg"],
                    ti["median"],
                    ti["min"],
                    ti["max"],
                    to["sum"],
                    to["avg"],
                    to["median"],
                    to["min"],
                    to["max"],
                ]
            )

    if overall_rows:
        _write_csv(
            tables / "agg.repair.overall.csv",
            [
                "task",
                "mode",
                "case_count",
                "pass_at_1_rate",
                "pass_at_5_rate",
                "valid_at_1_rate",
                "valid_at_5_rate",
                "equiv_at_1_rate",
                "equiv_at_5_rate",
                "unparseable_rate",
                "tokens_in_sum",
                "tokens_in_avg",
                "tokens_in_median",
                "tokens_in_min",
                "tokens_in_max",
                "tokens_out_sum",
                "tokens_out_avg",
                "tokens_out_median",
                "tokens_out_min",
                "tokens_out_max",
            ],
            overall_rows,
        )

    if diff_rows:
        _write_csv(
            tables / "agg.repair.by_difficulty.csv",
            [
                "task",
                "mode",
                "difficulty",
                "case_count",
                "pass_at_1_rate",
                "pass_at_5_rate",
                "valid_at_1_rate",
                "valid_at_5_rate",
                "equiv_at_1_rate",
                "equiv_at_5_rate",
                "unparseable_rate",
                "tokens_in_sum",
                "tokens_in_avg",
                "tokens_in_median",
                "tokens_in_min",
                "tokens_in_max",
                "tokens_out_sum",
                "tokens_out_avg",
                "tokens_out_median",
                "tokens_out_min",
                "tokens_out_max",
            ],
            diff_rows,
        )



def write_refactor_aggregates(run_dir: Path) -> None:
    """
    refactor.normal:
      - pass/valid/equiv @1/5
      - overall / by_difficulty / by_prefix(RF001..) / by_prefix_difficulty

    refactor.reverse:
      - overall / by_difficulty / by_prefix / by_prefix_difficulty
    """
    tables = _ensure_tables_dir(run_dir)

    # normal
    overall_rows_n: List[List[Any]] = []
    diff_rows_n: List[List[Any]] = []
    pre_rows_n: List[List[Any]] = []
    pre_diff_rows_n: List[List[Any]] = []

    # reverse
    overall_rows_r: List[List[Any]] = []
    diff_rows_r: List[List[Any]] = []
    pre_rows_r: List[List[Any]] = []
    pre_diff_rows_r: List[List[Any]] = []

    for mode_dir in _glob_task_modes(run_dir, "refactor"):
        mode = mode_dir.name.split(".", 1)[1] if "." in mode_dir.name else mode_dir.name
        if "reverse" in mode:
            target = "reverse"
        else:
            target = "normal"

        records = []

        for case_dir in sorted(mode_dir.iterdir()):
            if not case_dir.is_dir():
                continue
            jj = _read_metrics_json(case_dir)
            if not jj:
                continue

            case_id = jj.get("case_id") or case_dir.name
            difficulty = jj.get("difficulty")
            prefix = _case_prefix(case_id)
            summ = jj.get("summary") or {}
            shots = jj.get("shots") or []
            unparseable = any(bool((s or {}).get("unparseable")) for s in shots)

            tin = tout = None
            if shots:
                t0 = (shots[0] or {}).get("tokens") or {}
                tin = t0.get("in")
                tout = t0.get("out")

            rec = {
                "mode": mode,
                "case_id": case_id,
                "difficulty": difficulty,
                "prefix": prefix,
                "pass_at_1": bool(summ.get("pass_at_1")),
                "pass_at_5": bool(summ.get("pass_at_5")),
                "valid_at_1": bool(summ.get("valid_at_1")),
                "valid_at_5": bool(summ.get("valid_at_5")),
                "equiv_at_1": bool(summ.get("equiv_at_1")),
                "equiv_at_5": bool(summ.get("equiv_at_5")),
                "unparseable": unparseable,
                "tokens_in": tin,
                "tokens_out": tout,
            }
            records.append(rec)

        if not records:
            continue

        if target == "normal":
            _acc_refactor_normal(records, overall_rows_n, diff_rows_n, pre_rows_n, pre_diff_rows_n)
        else:
            _acc_refactor_reverse(records, overall_rows_r, diff_rows_r, pre_rows_r, pre_diff_rows_r)

    if overall_rows_n:
        _write_csv(
            tables / "agg.refactor.normal.overall.csv",
            [
                "task",
                "mode",
                "case_count",
                "pass_at_1_rate",
                "pass_at_5_rate",
                "valid_at_1_rate",
                "valid_at_5_rate",
                "equiv_at_1_rate",
                "equiv_at_5_rate",
                "unparseable_rate",
                "tokens_in_sum",
                "tokens_in_avg",
                "tokens_in_median",
                "tokens_in_min",
                "tokens_in_max",
                "tokens_out_sum",
                "tokens_out_avg",
                "tokens_out_median",
                "tokens_out_min",
                "tokens_out_max",
            ],
            overall_rows_n,
        )
    if diff_rows_n:
        _write_csv(
            tables / "agg.refactor.normal.by_difficulty.csv",
            [
                "task",
                "mode",
                "difficulty",
                "case_count",
                "pass_at_1_rate",
                "pass_at_5_rate",
                "valid_at_1_rate",
                "valid_at_5_rate",
                "equiv_at_1_rate",
                "equiv_at_5_rate",
                "unparseable_rate",
                "tokens_in_sum",
                "tokens_in_avg",
                "tokens_in_median",
                "tokens_in_min",
                "tokens_in_max",
                "tokens_out_sum",
                "tokens_out_avg",
                "tokens_out_median",
                "tokens_out_min",
                "tokens_out_max",
            ],
            diff_rows_n,
        )
    if pre_rows_n:
        _write_csv(
            tables / "agg.refactor.normal.by_prefix.csv",
            [
                "task",
                "mode",
                "prefix",
                "case_count",
                "pass_at_1_rate",
                "pass_at_5_rate",
                "valid_at_1_rate",
                "valid_at_5_rate",
                "equiv_at_1_rate",
                "equiv_at_5_rate",
                "unparseable_rate",
                "tokens_in_sum",
                "tokens_in_avg",
                "tokens_in_median",
                "tokens_in_min",
                "tokens_in_max",
                "tokens_out_sum",
                "tokens_out_avg",
                "tokens_out_median",
                "tokens_out_min",
                "tokens_out_max",
            ],
            pre_rows_n,
        )
    if pre_diff_rows_n:
        _write_csv(
            tables / "agg.refactor.normal.by_prefix_difficulty.csv",
            [
                "task",
                "mode",
                "prefix",
                "difficulty",
                "case_count",
                "pass_at_1_rate",
                "pass_at_5_rate",
                "valid_at_1_rate",
                "valid_at_5_rate",
                "equiv_at_1_rate",
                "equiv_at_5_rate",
                "unparseable_rate",
                "tokens_in_sum",
                "tokens_in_avg",
                "tokens_in_median",
                "tokens_in_min",
                "tokens_in_max",
                "tokens_out_sum",
                "tokens_out_avg",
                "tokens_out_median",
                "tokens_out_min",
                "tokens_out_max",
            ],
            pre_diff_rows_n,
        )

    if overall_rows_r:
        _write_csv(
            tables / "agg.refactor.reverse.overall.csv",
            [
                "task",
                "mode",
                "case_count",
                "em_at_1_rate",
                "em_at_5_rate",
                "unparseable_rate",
                "tokens_in_sum",
                "tokens_in_avg",
                "tokens_in_median",
                "tokens_in_min",
                "tokens_in_max",
                "tokens_out_sum",
                "tokens_out_avg",
                "tokens_out_median",
                "tokens_out_min",
                "tokens_out_max",
            ],
            overall_rows_r,
        )
    if diff_rows_r:
        _write_csv(
            tables / "agg.refactor.reverse.by_difficulty.csv",
            [
                "task",
                "mode",
                "difficulty",
                "case_count",
                "em_at_1_rate",
                "em_at_5_rate",
                "unparseable_rate",
                "tokens_in_sum",
                "tokens_in_avg",
                "tokens_in_median",
                "tokens_in_min",
                "tokens_in_max",
                "tokens_out_sum",
                "tokens_out_avg",
                "tokens_out_median",
                "tokens_out_min",
                "tokens_out_max",
            ],
            diff_rows_r,
        )
    if pre_rows_r:
        _write_csv(
            tables / "agg.refactor.reverse.by_prefix.csv",
            [
                "task",
                "mode",
                "prefix",
                "case_count",
                "em_at_1_rate",
                "em_at_5_rate",
                "unparseable_rate",
                "tokens_in_sum",
                "tokens_in_avg",
                "tokens_in_median",
                "tokens_in_min",
                "tokens_in_max",
                "tokens_out_sum",
                "tokens_out_avg",
                "tokens_out_median",
                "tokens_out_min",
                "tokens_out_max",
            ],
            pre_rows_r,
        )
    if pre_diff_rows_r:
        _write_csv(
            tables / "agg.refactor.reverse.by_prefix_difficulty.csv",
            [
                "task",
                "mode",
                "prefix",
                "difficulty",
                "case_count",
                "em_at_1_rate",
                "em_at_5_rate",
                "unparseable_rate",
                "tokens_in_sum",
                "tokens_in_avg",
                "tokens_in_median",
                "tokens_in_min",
                "tokens_in_max",
                "tokens_out_sum",
                "tokens_out_avg",
                "tokens_out_median",
                "tokens_out_min",
                "tokens_out_max",
            ],
            pre_diff_rows_r,
        )


def _acc_refactor_normal(
    records: List[Dict[str, Any]],
    overall_rows: List[List[Any]],
    diff_rows: List[List[Any]],
    pre_rows: List[List[Any]],
    pre_diff_rows: List[List[Any]],
) -> None:
    if not records:
        return
    by_mode: Dict[str, List[Dict[str, Any]]] = {}
    for r in records:
        by_mode.setdefault(r["mode"], []).append(r)
    for mode, rs_mode in sorted(by_mode.items()):
        n = len(rs_mode)

        def rate(rs, key):
            return _safe_rate(sum(1 for x in rs if x[key]), n)

        pass1 = rate(rs_mode, "pass_at_1")
        pass5 = rate(rs_mode, "pass_at_5")
        valid1 = rate(rs_mode, "valid_at_1")
        valid5 = rate(rs_mode, "valid_at_5")
        equiv1 = rate(rs_mode, "equiv_at_1")
        equiv5 = rate(rs_mode, "equiv_at_5")
        unp = _safe_rate(sum(1 for x in rs_mode if x["unparseable"]), n)
        ti = _num_stats([x["tokens_in"] for x in rs_mode])
        to = _num_stats([x["tokens_out"] for x in rs_mode])

        overall_rows.append(
            [
                "refactor",
                mode,
                n,
                pass1,
                pass5,
                valid1,
                valid5,
                equiv1,
                equiv5,
                unp,
                ti["sum"],
                ti["avg"],
                ti["median"],
                ti["min"],
                ti["max"],
                to["sum"],
                to["avg"],
                to["median"],
                to["min"],
                to["max"],
            ]
        )

        # by difficulty
        by_diff: Dict[str, List[Dict[str, Any]]] = {}
        for r in rs_mode:
            key = r["difficulty"] or "UNKNOWN"
            by_diff.setdefault(key, []).append(r)
        for diff, rs in sorted(by_diff.items()):
            n = len(rs)

            def rate2(key):
                return _safe_rate(sum(1 for x in rs if x[key]), n)

            pass1 = rate2("pass_at_1")
            pass5 = rate2("pass_at_5")
            valid1 = rate2("valid_at_1")
            valid5 = rate2("valid_at_5")
            equiv1 = rate2("equiv_at_1")
            equiv5 = rate2("equiv_at_5")
            unp = _safe_rate(sum(1 for x in rs if x["unparseable"]), n)
            ti = _num_stats([x["tokens_in"] for x in rs])
            to = _num_stats([x["tokens_out"] for x in rs])

            diff_rows.append(
                [
                    "refactor",
                    mode,
                    diff,
                    n,
                    pass1,
                    pass5,
                    valid1,
                    valid5,
                    equiv1,
                    equiv5,
                    unp,
                    ti["sum"],
                    ti["avg"],
                    ti["median"],
                    ti["min"],
                    ti["max"],
                    to["sum"],
                    to["avg"],
                    to["median"],
                    to["min"],
                    to["max"],
                ]
            )

        # by prefix
        by_pref: Dict[str, List[Dict[str, Any]]] = {}
        for r in rs_mode:
            key = r["prefix"] or "UNKNOWN"
            by_pref.setdefault(key, []).append(r)
        for pref, rs in sorted(by_pref.items()):
            n = len(rs)

            def rate2(key):
                return _safe_rate(sum(1 for x in rs if x[key]), n)

            pass1 = rate2("pass_at_1")
            pass5 = rate2("pass_at_5")
            valid1 = rate2("valid_at_1")
            valid5 = rate2("valid_at_5")
            equiv1 = rate2("equiv_at_1")
            equiv5 = rate2("equiv_at_5")
            unp = _safe_rate(sum(1 for x in rs if x["unparseable"]), n)
            ti = _num_stats([x["tokens_in"] for x in rs])
            to = _num_stats([x["tokens_out"] for x in rs])

            pre_rows.append(
                [
                    "refactor",
                    mode,
                    pref,
                    n,
                    pass1,
                    pass5,
                    valid1,
                    valid5,
                    equiv1,
                    equiv5,
                    unp,
                    ti["sum"],
                    ti["avg"],
                    ti["median"],
                    ti["min"],
                    ti["max"],
                    to["sum"],
                    to["avg"],
                    to["median"],
                    to["min"],
                    to["max"],
                ]
            )

        # by (prefix, difficulty)
        by_pd: Dict[Tuple[str, str], List[Dict[str, Any]]] = {}
        for r in rs_mode:
            pref = r["prefix"] or "UNKNOWN"
            diff = r["difficulty"] or "UNKNOWN"
            by_pd.setdefault((pref, diff), []).append(r)
        for (pref, diff), rs in sorted(by_pd.items()):
            n = len(rs)

            def rate2(key):
                return _safe_rate(sum(1 for x in rs if x[key]), n)

            pass1 = rate2("pass_at_1")
            pass5 = rate2("pass_at_5")
            valid1 = rate2("valid_at_1")
            valid5 = rate2("valid_at_5")
            equiv1 = rate2("equiv_at_1")
            equiv5 = rate2("equiv_at_5")
            unp = _safe_rate(sum(1 for x in rs if x["unparseable"]), n)
            ti = _num_stats([x["tokens_in"] for x in rs])
            to = _num_stats([x["tokens_out"] for x in rs])

            pre_diff_rows.append(
                [
                    "refactor",
                    mode,
                    pref,
                    diff,
                    n,
                    pass1,
                    pass5,
                    valid1,
                    valid5,
                    equiv1,
                    equiv5,
                    unp,
                    ti["sum"],
                    ti["avg"],
                    ti["median"],
                    ti["min"],
                    ti["max"],
                    to["sum"],
                    to["avg"],
                    to["median"],
                    to["min"],
                    to["max"],
                ]
            )


def _acc_refactor_reverse(
    records: List[Dict[str, Any]],
    overall_rows: List[List[Any]],
    diff_rows: List[List[Any]],
    pre_rows: List[List[Any]],
    pre_diff_rows: List[List[Any]],
) -> None:
    if not records:
        return
    by_mode: Dict[str, List[Dict[str, Any]]] = {}
    for r in records:
        by_mode.setdefault(r["mode"], []).append(r)
    for mode, rs_mode in sorted(by_mode.items()):
        n = len(rs_mode)

        em1 = _safe_rate(sum(1 for x in rs_mode if x["pass_at_1"]), n)
        em5 = _safe_rate(sum(1 for x in rs_mode if x["pass_at_5"]), n)
        unp = _safe_rate(sum(1 for x in rs_mode if x["unparseable"]), n)
        ti = _num_stats([x["tokens_in"] for x in rs_mode])
        to = _num_stats([x["tokens_out"] for x in rs_mode])

        overall_rows.append(
            [
                "refactor",
                mode,
                n,
                em1,
                em5,
                unp,
                ti["sum"],
                ti["avg"],
                ti["median"],
                ti["min"],
                ti["max"],
                to["sum"],
                to["avg"],
                to["median"],
                to["min"],
                to["max"],
            ]
        )

        # by difficulty
        by_diff: Dict[str, List[Dict[str, Any]]] = {}
        for r in rs_mode:
            key = r["difficulty"] or "UNKNOWN"
            by_diff.setdefault(key, []).append(r)
        for diff, rs in sorted(by_diff.items()):
            n = len(rs)
            em1 = _safe_rate(sum(1 for x in rs if x["pass_at_1"]), n)
            em5 = _safe_rate(sum(1 for x in rs if x["pass_at_5"]), n)
            unp = _safe_rate(sum(1 for x in rs if x["unparseable"]), n)
            ti = _num_stats([x["tokens_in"] for x in rs])
            to = _num_stats([x["tokens_out"] for x in rs])

            diff_rows.append(
                [
                    "refactor",
                    mode,
                    diff,
                    n,
                    em1,
                    em5,
                    unp,
                    ti["sum"],
                    ti["avg"],
                    ti["median"],
                    ti["min"],
                    ti["max"],
                    to["sum"],
                    to["avg"],
                    to["median"],
                    to["min"],
                    to["max"],
                ]
            )

        # by prefix
        by_pref: Dict[str, List[Dict[str, Any]]] = {}
        for r in rs_mode:
            key = r["prefix"] or "UNKNOWN"
            by_pref.setdefault(key, []).append(r)
        for pref, rs in sorted(by_pref.items()):
            n = len(rs)
            em1 = _safe_rate(sum(1 for x in rs if x["pass_at_1"]), n)
            em5 = _safe_rate(sum(1 for x in rs if x["pass_at_5"]), n)
            unp = _safe_rate(sum(1 for x in rs if x["unparseable"]), n)
            ti = _num_stats([x["tokens_in"] for x in rs])
            to = _num_stats([x["tokens_out"] for x in rs])

            pre_rows.append(
                [
                    "refactor",
                    mode,
                    pref,
                    n,
                    em1,
                    em5,
                    unp,
                    ti["sum"],
                    ti["avg"],
                    ti["median"],
                    ti["min"],
                    ti["max"],
                    to["sum"],
                    to["avg"],
                    to["median"],
                    to["min"],
                    to["max"],
                ]
            )

        # by (prefix, difficulty)
        by_pd: Dict[Tuple[str, str], List[Dict[str, Any]]] = {}
        for r in rs_mode:
            pref = r["prefix"] or "UNKNOWN"
            diff = r["difficulty"] or "UNKNOWN"
            by_pd.setdefault((pref, diff), []).append(r)
        for (pref, diff), rs in sorted(by_pd.items()):
            n = len(rs)
            em1 = _safe_rate(sum(1 for x in rs if x["pass_at_1"]), n)
            em5 = _safe_rate(sum(1 for x in rs if x["pass_at_5"]), n)
            unp = _safe_rate(sum(1 for x in rs if x["unparseable"]), n)
            ti = _num_stats([x["tokens_in"] for x in rs])
            to = _num_stats([x["tokens_out"] for x in rs])

            pre_diff_rows.append(
                [
                    "refactor",
                    mode,
                    pref,
                    diff,
                    n,
                    em1,
                    em5,
                    unp,
                    ti["sum"],
                    ti["avg"],
                    ti["median"],
                    ti["min"],
                    ti["max"],
                    to["sum"],
                    to["avg"],
                    to["median"],
                    to["min"],
                    to["max"],
                ]
            )



def write_transform_aggregates(run_dir: Path) -> None:
    tables = _ensure_tables_dir(run_dir)

    overall_rows: List[List[Any]] = []
    diff_rows: List[List[Any]] = []
    pre_rows: List[List[Any]] = []
    pre_diff_rows: List[List[Any]] = []
    speed_core_rows: List[List[Any]] = []

    for mode_dir in _glob_task_modes(run_dir, "transform"):
        mode = mode_dir.name.split(".", 1)[1] if "." in mode_dir.name else mode_dir.name

        case_recs = []

        for case_dir in sorted(mode_dir.iterdir()):
            if not case_dir.is_dir():
                continue
            jj = _read_metrics_json(case_dir)
            if not jj:
                continue

            case_id = jj.get("case_id") or case_dir.name
            difficulty = jj.get("difficulty")
            prefix = _case_prefix(case_id)
            summ = jj.get("summary") or {}
            shots = jj.get("shots") or []

            if not shots:
                continue

            unparseable = any(bool((s or {}).get("unparseable")) for s in shots)

            # token: shot1
            t0 = (shots[0] or {}).get("tokens") or {}
            tin = t0.get("in")
            tout = t0.get("out")

            bs = _best_shot(shots)
            rt = (bs or {}).get("runtime") or {}
            mca = (bs or {}).get("mca") or {}

            s_mr = rt.get("speedup_model_raw")
            s_mg = rt.get("speedup_model_golden")
            s_gr = rt.get("speedup_golden_raw")

            def mca_pair(kind: str) -> Tuple[Optional[float], Optional[float]]:
                mm = mca.get(kind) or {}
                if (mm.get("status") or "") != "ok":
                    return None, None
                return mm.get("cycles"), mm.get("ipc")

            raw_cycles, raw_ipc = mca_pair("raw")
            g_cycles, g_ipc = mca_pair("golden")
            v_cycles, v_ipc = mca_pair("variant")

            case_recs.append(
                {
                    "mode": mode,
                    "case_id": case_id,
                    "difficulty": difficulty,
                    "prefix": prefix,
                    "pass_at_1": bool(summ.get("pass_at_1")),
                    "pass_at_5": bool(summ.get("pass_at_5")),
                    "valid_at_1": bool(summ.get("valid_at_1")),
                    "valid_at_5": bool(summ.get("valid_at_5")),
                    "equiv_at_1": bool(summ.get("equiv_at_1")),
                    "equiv_at_5": bool(summ.get("equiv_at_5")),
                    "unparseable": unparseable,
                    "tokens_in": tin,
                    "tokens_out": tout,
                    "speedup_mr": s_mr,
                    "speedup_mg": s_mg,
                    "speedup_gr": s_gr,
                    "raw_cycles": raw_cycles,
                    "raw_ipc": raw_ipc,
                    "golden_cycles": g_cycles,
                    "golden_ipc": g_ipc,
                    "variant_cycles": v_cycles,
                    "variant_ipc": v_ipc,
                }
            )

        if not case_recs:
            continue

        by_mode: Dict[str, List[Dict[str, Any]]] = {}
        for r in case_recs:
            by_mode.setdefault(r["mode"], []).append(r)

        for mode_key, rs_mode in sorted(by_mode.items()):
            _acc_transform_buckets(
                task="transform",
                mode=mode_key,
                records=rs_mode,
                overall_rows=overall_rows,
                diff_rows=diff_rows,
                pre_rows=pre_rows,
                pre_diff_rows=pre_diff_rows,
            )
            # core speedup summary per mode
            n_core = len(rs_mode)
            mg_stats_core = _num_stats([x["speedup_mg"] for x in rs_mode])
            vg_stats_core = _num_stats(
                [
                    (x["golden_cycles"] / x["variant_cycles"])
                    for x in rs_mode
                    if isinstance(x["variant_cycles"], (int, float))
                    and isinstance(x["golden_cycles"], (int, float))
                    and x["variant_cycles"]
                ]
            )
            speed_core_rows.append(
                [
                    "transform",
                    mode_key,
                    n_core,
                    mg_stats_core["min"],
                    mg_stats_core["max"],
                    mg_stats_core["median"],
                    vg_stats_core["min"],
                    vg_stats_core["max"],
                    vg_stats_core["median"],
                ]
            )

    if overall_rows:
        _write_csv(
            tables / "agg.transform.overall.csv",
            [
                "task",
                "mode",
                "case_count",
                "pass_at_1_rate",
                "pass_at_5_rate",
                "valid_at_1_rate",
                "valid_at_5_rate",
                "equiv_at_1_rate",
                "equiv_at_5_rate",
                "unparseable_rate",
                "speedup_mr_min",
                "speedup_mr_max",
                "speedup_mr_avg",
                "speedup_mr_median",
                "speedup_mg_min",
                "speedup_mg_max",
                "speedup_mg_avg",
                "speedup_mg_median",
                "speedup_gr_min",
                "speedup_gr_max",
                "speedup_gr_avg",
                "speedup_gr_median",
                "mca_raw_cycles_min",
                "mca_raw_cycles_max",
                "mca_raw_cycles_avg",
                "mca_raw_cycles_median",
                "mca_golden_cycles_min",
                "mca_golden_cycles_max",
                "mca_golden_cycles_avg",
                "mca_golden_cycles_median",
                "mca_variant_cycles_min",
                "mca_variant_cycles_max",
                "mca_variant_cycles_avg",
                "mca_variant_cycles_median",
                "mca_raw_ipc_avg",
                "mca_raw_ipc_median",
                "mca_golden_ipc_avg",
                "mca_golden_ipc_median",
                "mca_variant_ipc_avg",
                "mca_variant_ipc_median",
                "tokens_in_sum",
                "tokens_in_avg",
                "tokens_in_median",
                "tokens_in_min",
                "tokens_in_max",
                "tokens_out_sum",
                "tokens_out_avg",
                "tokens_out_median",
                "tokens_out_min",
                "tokens_out_max",
            ],
            overall_rows,
        )

    if diff_rows:
        _write_csv(
            tables / "agg.transform.by_difficulty.csv",
            [
                "task",
                "mode",
                "difficulty",
                "case_count",
                "pass_at_1_rate",
                "pass_at_5_rate",
                "valid_at_1_rate",
                "valid_at_5_rate",
                "equiv_at_1_rate",
                "equiv_at_5_rate",
                "unparseable_rate",
                "speedup_mr_min",
                "speedup_mr_max",
                "speedup_mr_avg",
                "speedup_mr_median",
                "speedup_mg_min",
                "speedup_mg_max",
                "speedup_mg_avg",
                "speedup_mg_median",
                "speedup_gr_min",
                "speedup_gr_max",
                "speedup_gr_avg",
                "speedup_gr_median",
                "mca_raw_cycles_min",
                "mca_raw_cycles_max",
                "mca_raw_cycles_avg",
                "mca_raw_cycles_median",
                "mca_golden_cycles_min",
                "mca_golden_cycles_max",
                "mca_golden_cycles_avg",
                "mca_golden_cycles_median",
                "mca_variant_cycles_min",
                "mca_variant_cycles_max",
                "mca_variant_cycles_avg",
                "mca_variant_cycles_median",
                "mca_raw_ipc_avg",
                "mca_raw_ipc_median",
                "mca_golden_ipc_avg",
                "mca_golden_ipc_median",
                "mca_variant_ipc_avg",
                "mca_variant_ipc_median",
                "tokens_in_sum",
                "tokens_in_avg",
                "tokens_in_median",
                "tokens_in_min",
                "tokens_in_max",
                "tokens_out_sum",
                "tokens_out_avg",
                "tokens_out_median",
                "tokens_out_min",
                "tokens_out_max",
            ],
            diff_rows,
        )

    if pre_rows:
        _write_csv(
            tables / "agg.transform.by_prefix.csv",
            [
                "task",
                "mode",
                "prefix",
                "case_count",
                "pass_at_1_rate",
                "pass_at_5_rate",
                "valid_at_1_rate",
                "valid_at_5_rate",
                "equiv_at_1_rate",
                "equiv_at_5_rate",
                "unparseable_rate",
                "speedup_mr_min",
                "speedup_mr_max",
                "speedup_mr_avg",
                "speedup_mr_median",
                "speedup_mg_min",
                "speedup_mg_max",
                "speedup_mg_avg",
                "speedup_mg_median",
                "speedup_gr_min",
                "speedup_gr_max",
                "speedup_gr_avg",
                "speedup_gr_median",
                "mca_raw_cycles_min",
                "mca_raw_cycles_max",
                "mca_raw_cycles_avg",
                "mca_raw_cycles_median",
                "mca_golden_cycles_min",
                "mca_golden_cycles_max",
                "mca_golden_cycles_avg",
                "mca_golden_cycles_median",
                "mca_variant_cycles_min",
                "mca_variant_cycles_max",
                "mca_variant_cycles_avg",
                "mca_variant_cycles_median",
                "mca_raw_ipc_avg",
                "mca_raw_ipc_median",
                "mca_golden_ipc_avg",
                "mca_golden_ipc_median",
                "mca_variant_ipc_avg",
                "mca_variant_ipc_median",
                "tokens_in_sum",
                "tokens_in_avg",
                "tokens_in_median",
                "tokens_in_min",
                "tokens_in_max",
                "tokens_out_sum",
                "tokens_out_avg",
                "tokens_out_median",
                "tokens_out_min",
                "tokens_out_max",
            ],
            pre_rows,
        )

    if pre_diff_rows:
        _write_csv(
            tables / "agg.transform.by_prefix_difficulty.csv",
            [
                "task",
                "mode",
                "prefix",
                "difficulty",
                "case_count",
                "pass_at_1_rate",
                "pass_at_5_rate",
                "valid_at_1_rate",
                "valid_at_5_rate",
                "equiv_at_1_rate",
                "equiv_at_5_rate",
                "unparseable_rate",
                "speedup_mr_min",
                "speedup_mr_max",
                "speedup_mr_avg",
                "speedup_mr_median",
                "speedup_mg_min",
                "speedup_mg_max",
                "speedup_mg_avg",
                "speedup_mg_median",
                "speedup_gr_min",
                "speedup_gr_max",
                "speedup_gr_avg",
                "speedup_gr_median",
                "mca_raw_cycles_min",
                "mca_raw_cycles_max",
                "mca_raw_cycles_avg",
                "mca_raw_cycles_median",
                "mca_golden_cycles_min",
                "mca_golden_cycles_max",
                "mca_golden_cycles_avg",
                "mca_golden_cycles_median",
                "mca_variant_cycles_min",
                "mca_variant_cycles_max",
                "mca_variant_cycles_avg",
                "mca_variant_cycles_median",
                "mca_raw_ipc_avg",
                "mca_raw_ipc_median",
                "mca_golden_ipc_avg",
                "mca_golden_ipc_median",
                "mca_variant_ipc_avg",
                "mca_variant_ipc_median",
                "tokens_in_sum",
                "tokens_in_avg",
                "tokens_in_median",
                "tokens_in_min",
                "tokens_in_max",
                "tokens_out_sum",
                "tokens_out_avg",
                "tokens_out_median",
                "tokens_out_min",
                "tokens_out_max",
            ],
            pre_diff_rows,
        )
    if speed_core_rows:
        _write_csv(
            tables / "agg.transform.speedup_core.csv",
            [
                "task",
                "mode",
                "case_count",
                "speedup_mg_min",
                "speedup_mg_max",
                "speedup_mg_median",
                "mca_golden_over_variant_cycles_min",
                "mca_golden_over_variant_cycles_max",
                "mca_golden_over_variant_cycles_median",
            ],
            speed_core_rows,
        )


def _acc_transform_buckets(
    task: str,
    mode: str,
    records: List[Dict[str, Any]],
    overall_rows: List[List[Any]],
    diff_rows: List[List[Any]],
    pre_rows: List[List[Any]],
    pre_diff_rows: List[List[Any]],
) -> None:
    if not records:
        return

    def agg_one(rs: List[Dict[str, Any]]) -> List[Any]:
        n = len(rs)

        def rate(key):
            return _safe_rate(sum(1 for x in rs if x[key]), n)

        pass1 = rate("pass_at_1")
        pass5 = rate("pass_at_5")
        valid1 = rate("valid_at_1")
        valid5 = rate("valid_at_5")
        equiv1 = rate("equiv_at_1")
        equiv5 = rate("equiv_at_5")
        unp = _safe_rate(sum(1 for x in rs if x["unparseable"]), n)

        s_mr_stats = _num_stats([x["speedup_mr"] for x in rs])
        s_mg_stats = _num_stats([x["speedup_mg"] for x in rs])
        s_gr_stats = _num_stats([x["speedup_gr"] for x in rs])

        rc = _num_stats([x["raw_cycles"] for x in rs])
        gc = _num_stats([x["golden_cycles"] for x in rs])
        vc = _num_stats([x["variant_cycles"] for x in rs])

        ri = _num_stats([x["raw_ipc"] for x in rs])
        gi = _num_stats([x["golden_ipc"] for x in rs])
        vi = _num_stats([x["variant_ipc"] for x in rs])

        ti = _num_stats([x["tokens_in"] for x in rs])
        to = _num_stats([x["tokens_out"] for x in rs])

        return [
            task,
            mode,
            n,
            pass1,
            pass5,
            valid1,
            valid5,
            equiv1,
            equiv5,
            unp,
            s_mr_stats["min"],
            s_mr_stats["max"],
            s_mr_stats["avg"],
            s_mr_stats["median"],
            s_mg_stats["min"],
            s_mg_stats["max"],
            s_mg_stats["avg"],
            s_mg_stats["median"],
            s_gr_stats["min"],
            s_gr_stats["max"],
            s_gr_stats["avg"],
            s_gr_stats["median"],
            rc["min"],
            rc["max"],
            rc["avg"],
            rc["median"],
            gc["min"],
            gc["max"],
            gc["avg"],
            gc["median"],
            vc["min"],
            vc["max"],
            vc["avg"],
            vc["median"],
            ri["avg"],
            ri["median"],
            gi["avg"],
            gi["median"],
            vi["avg"],
            vi["median"],
            ti["sum"],
            ti["avg"],
            ti["median"],
            ti["min"],
            ti["max"],
            to["sum"],
            to["avg"],
            to["median"],
            to["min"],
            to["max"],
        ]

    # overall
    overall_rows.append(agg_one(records))

    # by difficulty
    by_diff: Dict[str, List[Dict[str, Any]]] = {}
    for r in records:
        key = r["difficulty"] or "UNKNOWN"
        by_diff.setdefault(key, []).append(r)
    for diff, rs in sorted(by_diff.items()):
        diff_rows.append(
            [task, mode, diff] + agg_one(rs)[2:]
        )

    # by prefix
    by_pref: Dict[str, List[Dict[str, Any]]] = {}
    for r in records:
        key = r["prefix"] or "UNKNOWN"
        by_pref.setdefault(key, []).append(r)
    for pref, rs in sorted(by_pref.items()):
        pre_rows.append(
            [task, mode, pref] + agg_one(rs)[2:]
        )

    # by (prefix, difficulty)
    by_pd: Dict[Tuple[str, str], List[Dict[str, Any]]] = {}
    for r in records:
        pref = r["prefix"] or "UNKNOWN"
        diff = r["difficulty"] or "UNKNOWN"
        by_pd.setdefault((pref, diff), []).append(r)
    for (pref, diff), rs in sorted(by_pd.items()):
        pre_diff_rows.append(
            [task, mode, pref, diff] + agg_one(rs)[2:]
        )



def write_all_aggregates(run_dir: Path) -> None:
    _ensure_tables_dir(run_dir)
    write_analysis_aggregates(run_dir)
    write_repair_aggregates(run_dir)
    write_refactor_aggregates(run_dir)
    write_transform_aggregates(run_dir)