
#!/usr/bin/env python3

"""

Protocol v2 threshold table finalizer.



Writes:

  - icml_threshold_table.csv

  - icml_threshold_table.png



Includes:

  - update_time_to_tau_{1e-06,1e-09}

  - time_to_tau_{1e-06,1e-09}

  - update_time_to_floor (computed from curves using floor_bw from run_note.json)

  - time_to_floor (from summary.csv)

"""



from __future__ import annotations



import argparse, csv, json, statistics

from collections import defaultdict

from pathlib import Path

from typing import Dict, List, Optional, Tuple



import matplotlib

matplotlib.use("Agg")

import matplotlib.pyplot as plt





LABEL = {

    "itspace": "ITSPACE (ours)",

    "bw_geodesic": "BW geodesic",

    "euclidean": "Euclidean",

    "logeuclid": "Log-Euclidean",

    "airm": "AIRM",

    "coral": "CORAL",

    "bw_gd": "BW-GD",

    "sinkhorn": "Sinkhorn",

    "sinkhorn_gaus": "Entropic Gaussian OT",

}

ORDER = ["itspace","bw_geodesic","euclidean","logeuclid","airm","coral","bw_gd","sinkhorn","sinkhorn_gaus"]





def ffloat(x):

    if x is None:

        return None

    s = str(x).strip()

    if s == "" or s.lower() in ("none", "nan"):

        return None

    try:

        v = float(s)

        if v != v or abs(v) == float("inf"):

            return None

        return v

    except Exception:

        return None





def finite(x) -> bool:

    try:

        x = float(x)

        return x == x and abs(x) != float("inf")

    except Exception:

        return False





def median(xs: List[Optional[float]]) -> float:

    vals = [float(x) for x in xs if x is not None and finite(x)]

    if not vals:

        return float("nan")

    return float(statistics.median(vals))





def time_to_abs_from_curves(curves_path: Path, floor_bw: float) -> Dict[Tuple[str, int], float]:

    """

    Returns (method, seed) -> update_time_to_floor using bw2 <= floor_bw.

    """

    out: Dict[Tuple[str, int], float] = {}

    if not curves_path.exists() or not finite(floor_bw):

        return out



    by = defaultdict(list)

    with open(curves_path, newline="") as f:

        for r in csv.DictReader(f):

            m = r.get("method", "UNKNOWN")

            try:

                s = int(float(r.get("seed", 0)))

            except Exception:

                s = 0

            by[(m, s)].append(r)



    for (m, s), rs in by.items():

        upd = None

        for r in rs:

            bw = ffloat(r.get("bw2"))

            tu = ffloat(r.get("t_update"))

            if bw is None or tu is None:

                continue

            bw = max(float(bw), 0.0)

            if bw <= floor_bw:

                upd = float(tu)

                break

        out[(m, s)] = float("nan") if upd is None else upd

    return out





def main() -> None:

    ap = argparse.ArgumentParser()

    ap.add_argument("--run-dir", required=True)

    args = ap.parse_args()



    run_dir = Path(args.run_dir)

    summ = run_dir / "summary.csv"

    if not summ.exists():

        raise SystemExit(f"[icml_finalize] missing summary.csv in {run_dir}")



    floor_bw: Optional[float] = None

    note = run_dir / "run_note.json"

    if note.exists():

        try:

            j = json.loads(note.read_text())

            v = float(j.get("floor_bw", float("nan")))

            if finite(v):

                floor_bw = v

        except Exception:

            pass



    rows = []

    with open(summ, newline="") as f:

        for r in csv.DictReader(f):

            if r.get("status", "ok") != "ok":

                continue

            rows.append(r)



    by_m = defaultdict(list)

    for r in rows:

        by_m[r.get("method", "UNKNOWN")].append(r)



    upd_floor = time_to_abs_from_curves(run_dir / "curves.csv", floor_bw) if floor_bw is not None else {}



    cols = [

        "method",

        "update_time_to_tau_1e-06",

        "update_time_to_tau_1e-09",

        "time_to_tau_1e-06",

        "time_to_tau_1e-09",

        "update_time_to_floor",

        "time_to_floor",

        "t_update_last",

        "t_eval_last",

        "t_proj_last",

        "t_total_last",

        "rel_bw_last",

    ]



    out_rows = []

    methods = [m for m in ORDER if m in by_m] + [m for m in sorted(by_m.keys()) if m not in ORDER]

    for m in methods:

        rs = by_m[m]



        def medcol(c):

            return median([ffloat(x.get(c)) for x in rs])



        upd_floor_med = median([

            upd_floor.get((m, int(float(x.get("seed", 0)))), float("nan")) for x in rs

        ])



        rel_last = medcol("rel_bw_last")

        if finite(rel_last):

            rel_last = max(float(rel_last), 0.0)



        out_rows.append({

            "method": m,

            "update_time_to_tau_1e-06": medcol("update_time_to_tau_1e-06"),

            "update_time_to_tau_1e-09": medcol("update_time_to_tau_1e-09"),

            "time_to_tau_1e-06": medcol("time_to_tau_1e-06"),

            "time_to_tau_1e-09": medcol("time_to_tau_1e-09"),

            "update_time_to_floor": upd_floor_med,

            "time_to_floor": medcol("time_to_floor"),

            "t_update_last": medcol("t_update_last"),

            "t_eval_last": medcol("t_eval_last"),

            "t_proj_last": medcol("t_proj_last"),

            "t_total_last": medcol("t_total_last"),

            "rel_bw_last": rel_last,

        })



    csv_path = run_dir / "icml_threshold_table.csv"

    with open(csv_path, "w", newline="") as f:

        w = csv.DictWriter(f, fieldnames=cols)

        w.writeheader()

        w.writerows(out_rows)



    def fmt(x):

        if not finite(x):

            return "—"

        x = float(x)

        if x < 0:

            x = 0.0

        return f"{x:.3g}"



    plt.close("all")

    fig, ax = plt.subplots(figsize=(12, 0.7 + 0.35 * len(out_rows)), dpi=260)

    ax.axis("off")



    data = []

    for r in out_rows:

        row = []

        for c in cols:

            if c == "method":

                row.append(LABEL.get(r[c], r[c]))

            elif c == "rel_bw_last":

                v = r.get(c, float("nan"))

                if not finite(v):

                    row.append("—")

                else:

                    v = max(float(v), 0.0)

                    row.append(f"{v:.2e}" if (v > 0 and v < 1e-3) else f"{v:.3g}")

            else:

                row.append(fmt(r.get(c, float("nan"))))

        data.append(row)



    headers = []

    for c in cols:

        if c == "method":

            headers.append("method")

        elif c.startswith("update_time_to_tau_"):

            headers.append("upd@" + c.split("update_time_to_tau_", 1)[1].replace("_", ""))

        elif c.startswith("time_to_tau_"):

            headers.append("tot@" + c.split("time_to_tau_", 1)[1].replace("_", ""))

        elif c == "update_time_to_floor":

            headers.append("upd@floor2x")

        elif c == "time_to_floor":

            headers.append("tot@floor2x")

        elif c.endswith("_last"):

            headers.append(c.replace("_last",""))

        else:

            headers.append(c)



    tbl = ax.table(cellText=data, colLabels=headers, loc="center", cellLoc="center", colLoc="center")

    try:
        tbl.auto_set_column_width(col=[0])
    except Exception:
        pass

    tbl.auto_set_font_size(False)

    tbl.set_fontsize(9)

    tbl.scale(1, 1.25)



    title = "Threshold summary (median over seeds) — times in seconds"

    if floor_bw is not None:

        title += "  |  floor2x = 2×best observed BW²"

    ax.set_title(title, fontsize=13, pad=10)



    fig.tight_layout()

    fig.savefig(run_dir / "icml_threshold_table.png", bbox_inches="tight")

    plt.close(fig)



    print("[icml_finalize] wrote icml_threshold_table.csv and icml_threshold_table.png")





if __name__ == "__main__":

    main()

