from __future__ import annotations

import argparse
import csv
import json
from collections import defaultdict
from pathlib import Path
from typing import Any, Dict, List, Tuple

import numpy as np


def _safe_float(value: str | None) -> float | None:
    if value is None:
        return None
    value = value.strip()
    if not value:
        return None
    try:
        return float(value)
    except ValueError:
        return None


def _aggregate(values: List[float]) -> Dict[str, float]:
    arr = np.asarray(values, dtype=np.float64)
    if arr.size == 0:
        return {"mean": float("nan"), "std": float("nan")}
    if arr.size == 1:
        return {"mean": float(arr[0]), "std": 0.0}
    return {"mean": float(arr.mean()), "std": float(arr.std(ddof=1))}


def main() -> int:
    parser = argparse.ArgumentParser()
    parser.add_argument("--root", type=str, default="plots", help="Root directory to search for *_summary.csv")
    parser.add_argument("--pattern", type=str, default="*_summary.csv", help="Glob pattern for summary CSVs")
    parser.add_argument(
        "--out",
        type=str,
        default=str(Path("plots") / "aggregate" / "summary_agg.csv"),
        help="Output CSV path",
    )
    parser.add_argument(
        "--group-cols",
        type=str,
        default="task,method,cfg_task_id,cfg_hidden,cfg_lr,cfg_epochs,cfg_scan_epochs,cfg_batch_size,cfg_best_g,cfg_time_weighting,cfg_step_labels",
        help="Comma-separated columns used to define a run group",
    )
    args = parser.parse_args()

    root = Path(args.root)
    paths = sorted(root.rglob(args.pattern))
    if not paths:
        raise SystemExit(f"No files matched: {root} / {args.pattern}")

    group_cols = [c.strip() for c in str(args.group_cols).split(",") if c.strip()]
    metrics = [
        "metric",
        "val_metric",
        "val_loss",
        "lyap_pre",
        "lyap_post",
        "runtime_sec",
        "update_flops",
        "update_flops_per_step",
        "update_flops_per_update",
        "update_peak_cpu_alloc_delta_bytes",
        "update_peak_rss_bytes",
        "update_peak_rss_delta_bytes",
        "update_peak_cuda_bytes",
        "update_peak_cuda_delta_bytes",
        "update_peak_cuda_reserved_bytes",
        "update_peak_cuda_reserved_delta_bytes",
    ]

    groups: Dict[Tuple[str, ...], Dict[str, List[float]]] = defaultdict(lambda: defaultdict(list))
    group_meta: Dict[Tuple[str, ...], Dict[str, Any]] = {}

    for path in paths:
        with path.open("r", encoding="utf-8", newline="") as handle:
            reader = csv.DictReader(handle)
            for row in reader:
                key = tuple(row.get(col, "") for col in group_cols)
                if key not in group_meta:
                    group_meta[key] = {col: row.get(col, "") for col in group_cols}
                    group_meta[key]["source_count"] = 0
                group_meta[key]["source_count"] += 1
                for m in metrics:
                    val = _safe_float(row.get(m))
                    if val is None:
                        continue
                    groups[key][m].append(val)

    out_path = Path(args.out)
    out_path.parent.mkdir(parents=True, exist_ok=True)

    out_rows: List[Dict[str, Any]] = []
    for key, values in groups.items():
        meta = dict(group_meta[key])
        row: Dict[str, Any] = dict(meta)
        for m in metrics:
            agg = _aggregate(values.get(m, []))
            row[f"{m}_mean"] = agg["mean"]
            row[f"{m}_std"] = agg["std"]
            row[f"{m}_n"] = int(len(values.get(m, [])))
        out_rows.append(row)

    out_rows.sort(key=lambda r: (r.get("task", ""), r.get("method", "")))
    if out_rows:
        fieldnames = list(out_rows[0].keys())
    else:
        fieldnames = group_cols

    with out_path.open("w", encoding="utf-8", newline="") as handle:
        writer = csv.DictWriter(handle, fieldnames=fieldnames)
        writer.writeheader()
        writer.writerows(out_rows)

    json_path = out_path.with_suffix(".json")
    with json_path.open("w", encoding="utf-8") as handle:
        json.dump(
            {
                "root": str(root),
                "pattern": str(args.pattern),
                "group_cols": group_cols,
                "metrics": metrics,
                "groups": out_rows,
            },
            handle,
            indent=2,
            ensure_ascii=False,
        )

    print(f"[OK] Found {len(paths)} summary CSV files.")
    print(f"[OK] Wrote: {out_path}")
    print(f"[OK] Wrote: {json_path}")
    return 0


if __name__ == "__main__":
    raise SystemExit(main())
