
#!/usr/bin/env python3

"""

ICML Protocol v2 finalizer.



Outputs in the run directory:

  - flow_contraction_update.png   (x = t_update, main)

  - flow_contraction_total.png    (x = t_total, appendix)

  - flow_contraction_uproj.png    (x = t_update + t_proj, fair low-rank compute; excludes shared eval)

  - icml_v2_table.csv / icml_v2_table.png



Fixes:

  - Avoid giant whitespace for low-rank runs by only drawing tau lines/labels when tau is inside the visible y-range.

  - Avoid tight_layout warnings by not calling tight_layout (bbox_inches="tight" is sufficient).

"""



from __future__ import annotations



import argparse, csv, json, math, statistics

from collections import defaultdict

from pathlib import Path

from typing import Dict, List, Optional, Tuple



import numpy as np



import matplotlib

matplotlib.use("Agg")

import matplotlib.pyplot as plt



try:

    from mpl_toolkits.axes_grid1.inset_locator import inset_axes, mark_inset

except Exception:

    inset_axes = None

    mark_inset = None





PALETTE = {

    "itspace": "#1f77b4",

    "bw_geodesic": "#ff7f0e",

    "euclidean": "#2ca02c",

    "logeuclid": "#d62728",

    "airm": "#9467bd",

    "coral": "#8c564b",

    "bw_gd": "#7f7f7f",

    "sinkhorn": "#e377c2",

    "sinkhorn_gaus": "#17becf",

}

LABEL = {

    "itspace": "ITSPACE (ours)",

    "bw_geodesic": "BW geodesic",

    "euclidean": "Euclidean",

    "logeuclid": "Log-Euclidean",

    "airm": "AIRM",

    "coral": "CORAL",

    "bw_gd": "BW-GD",

    "sinkhorn": "Sinkhorn",

    "sinkhorn_gaus": "Entropic Gaussian OT",

}

ORDER = ["itspace","bw_geodesic","euclidean","logeuclid","airm","coral","bw_gd","sinkhorn","sinkhorn_gaus"]





def finite(x: float) -> bool:

    try:

        x = float(x)

        return x == x and abs(x) != float("inf")

    except Exception:

        return False





def ffloat(x):

    try:

        return float(x)

    except Exception:

        return float("nan")





def fpos(x: float) -> float:

    try:

        v = float(x)

    except Exception:

        return float("nan")

    if not finite(v):

        return float("nan")

    return 0.0 if v < 0.0 else v





def running_min(arr: List[float]) -> List[float]:

    out = []

    m = float("inf")

    for v in arr:

        if not finite(v):

            out.append(float("nan"))

            continue

        m = min(m, float(v))

        out.append(m)

    return out





def median(xs: List[Optional[float]]) -> float:

    vals = [float(x) for x in xs if x is not None and finite(float(x))]

    if not vals:

        return float("nan")

    return float(statistics.median(vals))





def time_to_tau(xs: List[float], ys: List[float], tau: float) -> float:

    if tau is None or not finite(tau) or tau <= 0:

        return float("nan")

    for x, y in zip(xs, ys):

        if finite(x) and finite(y) and y <= tau:

            return float(x)

    return float("nan")





def series(rows: List[Dict], xkey: str, ykey: str) -> Tuple[List[float], List[float]]:

    xs, ys = [], []

    for r in rows:

        if xkey == "t_uproj":

            tu = ffloat(r.get("t_update", float("nan")))

            tp = ffloat(r.get("t_proj", float("nan")))

            x = tu + tp

        else:

            x = ffloat(r.get(xkey, float("nan")))

        y = ffloat(r.get(ykey, float("nan")))

        if not finite(x) or not finite(y):

            continue

        xs.append(float(x))

        if ykey in ("bw2", "rel_bw2"):

            ys.append(fpos(y))

        else:

            ys.append(float(y))

    return xs, ys





def choose_xlim(method_to_seedxy, focus_tau: float, pad_mult: float) -> float:

    hits = []

    xmax = 0.0

    for seed_list in method_to_seedxy.values():

        for xs, ys in seed_list:

            if xs:

                xmax = max(xmax, max(xs))

            t = time_to_tau(xs, ys, focus_tau)

            if finite(t):

                hits.append(t)

    if hits:

        return min(xmax, max(hits) * pad_mult)

    allx = []

    for seed_list in method_to_seedxy.values():

        for xs, _ in seed_list:

            allx += xs

    if allx:

        return float(np.quantile(np.array(allx, dtype=np.float64), 0.90))

    return xmax if xmax > 0 else 1.0





def choose_ylim(method_to_seedxy, xlim: float, max_decades: int, taus: List[float], ymin_floor: float) -> Tuple[float, float]:

    ymax = 1.05

    min_pos = None

    for seed_list in method_to_seedxy.values():

        for xs, ys in seed_list:

            for x, y in zip(xs, ys):

                if x <= xlim and finite(y) and y > 0:

                    min_pos = y if (min_pos is None or y < min_pos) else min_pos



    if min_pos is not None:

        exp = math.floor(math.log10(min_pos))

        ymin = 10.0 ** (exp - 1)

    else:

        ymin = ymax / (10.0 ** max_decades)



    cap = ymax / (10.0 ** max_decades)

    ymin = max(ymin, cap)



    # If taus are within range, allow showing a bit below them

    in_taus = [t for t in taus if finite(t) and t > 0]

    if in_taus:

        tmin = min(in_taus)

        ymin = min(ymin, tmin / 10.0)



    if min_pos is not None and finite(min_pos):

        ymin = max(ymin, float(min_pos) / 10.0)



    ymin = max(ymin, float(ymin_floor))

    return ymin, ymax





def short_title(run_dir: Path) -> str:

    cfg = run_dir / "config.json"

    ds = run_dir.name

    d = None

    K = None

    if cfg.exists():

        try:

            j = json.loads(cfg.read_text())

            d = j.get("d", None)

            K = j.get("K", None)

            ds = str(j.get("dataset", ds))

        except Exception:

            pass

    base = Path(ds).name

    if d is not None and K is not None:

        return f"{base} (d={d}, K={K})"

    return base





def plot_flow(

    run_dir: Path,

    by_ms: Dict[Tuple[str, int], List[Dict]],

    xkey: str,

    out_png: Path,

    taus: List[float],

    focus_tau: float,

    max_decades: int,

    xmult: float,

    inset_mode: str,

    ymin_floor: float,

) -> None:

    method_to_seedxy: Dict[str, List[Tuple[List[float], List[float]]]] = defaultdict(list)

    for (m, s), rs in by_ms.items():

        xs, ys = series(rs, xkey, "rel_bw2")

        if len(xs) >= 2:

            method_to_seedxy[m].append((xs, ys))



    methods = [m for m in ORDER if m in method_to_seedxy] + [m for m in method_to_seedxy if m not in ORDER]

    pad_mult = 1.25 * max(float(xmult), 0.5)

    xlim = choose_xlim(method_to_seedxy, focus_tau=focus_tau, pad_mult=pad_mult)

    ymin, ymax = choose_ylim(method_to_seedxy, xlim=xlim, max_decades=max_decades, taus=taus, ymin_floor=ymin_floor)



    plt.close("all")

    fig, ax = plt.subplots(figsize=(7.6, 4.4), dpi=260)



    for m in methods:

        seed_curves = method_to_seedxy[m]

        color = PALETTE.get(m, "#000000")



        # per-seed faint running-min

        for xs, ys in seed_curves:

            ys_rm = running_min([max(v, ymin * 0.5) for v in ys])

            ax.plot(xs, ys_rm, color=color, alpha=0.18, linewidth=0.9)



        # median running-min on a reference grid

        ref = max(seed_curves, key=lambda t: len(t[0]))[0]

        ys_interp = []

        for t in ref:

            vals = []

            for xs, ys in seed_curves:

                j = int(np.argmin(np.abs(np.array(xs) - t)))

                vals.append(ys[j])

            ys_interp.append(float(np.median(vals)))

        ys_med = running_min([max(v, ymin * 0.5) for v in ys_interp])



        lw = 3.0 if m == "itspace" else 2.2

        ax.plot(ref, ys_med, color=color, linewidth=lw, label=LABEL.get(m, m))



    # tau lines ONLY if tau is inside visible range (prevents giant whitespace on low-rank runs)

    for t in taus:

        if not (finite(t) and t > 0):

            continue

        if t < ymin or t > ymax:

            continue

        ax.axhline(t, linestyle=":", linewidth=1.1, color="#444444")

        ax.text(xlim * 0.98, t * 1.15, f"τ={t:g}", ha="right", va="bottom",

                fontsize=10, color="#444444", clip_on=True)



    ax.set_yscale("log")

    ax.set_xlim(0.0, xlim)

    ax.set_ylim(ymin, ymax)

    ax.grid(True, which="major", alpha=0.25)

    ax.grid(True, which="minor", alpha=0.12)



    if xkey == "t_update":

        ax.set_xlabel("Cumulative update time (s)")

    elif xkey == "t_total":

        ax.set_xlabel("Total end-to-end time (s)")

    else:

        ax.set_xlabel("Update + projection time (s)")



    ax.set_ylabel(r"Relative $W_2^2$  ($W_{2,k}^2/W_{2,0}^2$)")

    ax.set_title(short_title(run_dir), fontsize=14, pad=8)

    ax.legend(loc="center left", bbox_to_anchor=(1.02, 0.5), frameon=False)



    # inset only on update plot (optional)

    if inset_axes is not None and xkey == "t_update":

        hits = []

        for seed_list in method_to_seedxy.values():

            for xs, ys in seed_list:

                t = time_to_tau(xs, ys, focus_tau)

                if finite(t):

                    hits.append(t)

        min_hit = min(hits) if hits else float("inf")



        want_inset = False

        if inset_mode == "on":

            want_inset = True

        elif inset_mode == "auto":

            want_inset = finite(min_hit) and min_hit < 0.05 * xlim



        if want_inset:

            xin = min(xlim * 0.25, max(0.05 * xlim, min_hit * 8.0))

            iax = inset_axes(ax, width="42%", height="46%", loc="lower left", borderpad=1.2)



            for m in methods:

                seed_curves = method_to_seedxy[m]

                color = PALETTE.get(m, "#000000")

                ref = max(seed_curves, key=lambda t: len(t[0]))[0]

                ys_interp = []

                for t in ref:

                    vals = []

                    for xs, ys in seed_curves:

                        j = int(np.argmin(np.abs(np.array(xs) - t)))

                        vals.append(ys[j])

                    ys_interp.append(float(np.median(vals)))

                ys_med = running_min([max(v, ymin * 0.5) for v in ys_interp])



                xs_crop = [x for x in ref if x <= xin]

                ys_crop = ys_med[: len(xs_crop)]

                iax.plot(xs_crop, ys_crop, color=color, linewidth=1.6)



            for t in taus:

                if finite(t) and t > 0 and (t >= ymin and t <= ymax):

                    iax.axhline(t, linestyle=":", linewidth=0.8, color="#444444")



            iax.set_yscale("log")

            iax.set_xlim(0.0, xin)

            iax.set_ylim(ymin, ymax)

            iax.grid(True, which="major", alpha=0.18)

            iax.grid(True, which="minor", alpha=0.08)

            iax.tick_params(labelsize=7)



            if mark_inset is not None:

                try:

                    mark_inset(ax, iax, loc1=2, loc2=4, fc="none", ec="#888888", lw=0.8)

                except Exception:

                    pass



    fig.savefig(out_png, bbox_inches="tight", pad_inches=0.02)

    plt.close(fig)





def make_table(run_dir: Path, by_ms: Dict[Tuple[str, int], List[Dict]], taus: List[float]):

    floor_bw: Optional[float] = None

    note = run_dir / "run_note.json"

    if note.exists():

        try:

            j = json.loads(note.read_text())

            floor_bw = float(j.get("floor_bw", float("nan")))

            if not finite(floor_bw):

                floor_bw = None

        except Exception:

            floor_bw = None



    per = []

    for (m, s), rs in by_ms.items():

        tu, relu = series(rs, "t_update", "rel_bw2")

        tt, relt = series(rs, "t_total", "rel_bw2")

        bu, bwabs_u = series(rs, "t_update", "bw2")

        bt, bwabs_t = series(rs, "t_total", "bw2")



        row = {"method": m, "seed": s}

        for tau in taus:

            row[f"t_update@{tau:g}"] = time_to_tau(tu, relu, tau)

            row[f"t_total@{tau:g}"] = time_to_tau(tt, relt, tau)

        if floor_bw is not None:

            row["t_update@floor2x"] = time_to_tau(bu, bwabs_u, floor_bw)

            row["t_total@floor2x"] = time_to_tau(bt, bwabs_t, floor_bw)



        row["t_update_last"] = tu[-1] if tu else float("nan")

        row["t_total_last"] = tt[-1] if tt else float("nan")

        row["rel_last"] = relu[-1] if relu else float("nan")

        per.append(row)



    by_m = defaultdict(list)

    for r in per:

        by_m[r["method"]].append(r)



    cols = ["method"]

    for tau in taus:

        cols += [f"t_update@{tau:g}", f"t_total@{tau:g}"]

    if floor_bw is not None:

        cols += ["t_update@floor2x", "t_total@floor2x"]

    cols += ["t_update_last", "t_total_last", "rel_last"]



    out = []

    for m in [m for m in ORDER if m in by_m] + [m for m in sorted(by_m.keys()) if m not in ORDER]:

        rs = by_m[m]

        o = {"method": m}

        for c in cols[1:]:

            o[c] = median([r.get(c, float("nan")) for r in rs])

        out.append(o)



    with open(run_dir / "icml_v2_table.csv", "w", newline="") as f:

        w = csv.DictWriter(f, fieldnames=cols)

        w.writeheader()

        w.writerows(out)



    return out, cols, floor_bw





def render_table_png(run_dir: Path, out, cols, floor_bw: Optional[float], time_unit: str):

    def fmt(col: str, x) -> str:

        if not finite(x):

            return "—"

        x = float(x)

        if x < 0:

            x = 0.0

        if col.startswith("t_"):

            if time_unit == "ms":

                return f"{x*1e3:.3g}"

            return f"{x:.3g}"

        if abs(x) >= 1e3 or (abs(x) > 0 and abs(x) < 1e-3):

            return f"{x:.2e}"

        return f"{x:.3g}"



    plt.close("all")

    fig, ax = plt.subplots(figsize=(12, 0.7 + 0.35 * len(out)), dpi=260)

    ax.axis("off")



    data = []

    for r in out:

        row = []

        for c in cols:

            if c == "method":

                row.append(LABEL.get(r[c], r[c]))

            else:

                row.append(fmt(c, r.get(c, float("nan"))))

        data.append(row)



    headers = []

    for c in cols:

        if c == "method":

            headers.append("method")

        elif c.startswith("t_update@"):

            headers.append("upd@" + c.split("@", 1)[1])

        elif c.startswith("t_total@"):

            headers.append("tot@" + c.split("@", 1)[1])

        else:

            headers.append(c)



    unit_note = "(times in ms)" if time_unit == "ms" else "(times in seconds)"

    title = f"Time-to-threshold summary (median over seeds) {unit_note}"

    if floor_bw is not None:

        title += "  |  floor2x = 2×best observed BW²"



    tbl = ax.table(cellText=data, colLabels=headers, loc="center", cellLoc="center", colLoc="center")

    tbl.auto_set_font_size(False)

    tbl.set_fontsize(9)

    tbl.scale(1, 1.25)



    try:

        tbl.auto_set_column_width(col=[0])

    except Exception:

        pass



    ax.set_title(title, fontsize=13, pad=10)

    fig.savefig(run_dir / "icml_v2_table.png", bbox_inches="tight", pad_inches=0.02)

    plt.close(fig)





def main() -> None:

    ap = argparse.ArgumentParser()

    ap.add_argument("--run-dir", required=True)

    ap.add_argument("--seed-plot", type=int, default=0)

    ap.add_argument("--taus", type=str, default="1e-6,1e-9")

    ap.add_argument("--focus-tau", type=float, default=1e-6)

    ap.add_argument("--max-decades", type=int, default=10)

    ap.add_argument("--xmult-update", type=float, default=5.0)

    ap.add_argument("--xmult-total", type=float, default=8.0)

    ap.add_argument("--inset", type=str, default="auto", choices=["auto", "on", "off"])

    ap.add_argument("--ymin-floor", type=float, default=1e-12)

    ap.add_argument("--time-unit", type=str, default="s", choices=["s", "ms"])

    args = ap.parse_args()



    run_dir = Path(args.run_dir)

    taus = []

    for part in str(args.taus).split(","):

        part = part.strip()

        if not part:

            continue

        try:

            taus.append(float(part))

        except Exception:

            pass



    curves = run_dir / "curves.csv"

    if not curves.exists():

        raise SystemExit(f"[icml_finalize_v2] missing curves.csv in {run_dir}")



    rows = []

    with open(curves, newline="") as f:

        for row in csv.DictReader(f):

            rows.append(row)



    by_ms = defaultdict(list)

    for r in rows:

        m = r.get("method", "UNKNOWN")

        try:

            s = int(float(r.get("seed", 0)))

        except Exception:

            s = 0

        by_ms[(m, s)].append(r)



    # Canonical plots

    plot_flow(run_dir, by_ms, "t_update", run_dir / "flow_contraction_update.png",

              taus, args.focus_tau, args.max_decades, args.xmult_update, args.inset, args.ymin_floor)

    plot_flow(run_dir, by_ms, "t_total",  run_dir / "flow_contraction_total.png",

              taus, args.focus_tau, args.max_decades, args.xmult_total, "off", args.ymin_floor)



    # Fair low-rank compute plot (update+projection, excludes shared eval)

    plot_flow(run_dir, by_ms, "t_uproj",  run_dir / "flow_contraction_uproj.png",

              taus, args.focus_tau, args.max_decades, args.xmult_update, "off", args.ymin_floor)



    out, cols, floor_bw = make_table(run_dir, by_ms, taus)

    render_table_png(run_dir, out, cols, floor_bw, time_unit=args.time_unit)



    print("[icml_finalize_v2] done (PNG+CSV)")



if __name__ == "__main__":

    main()

