
#!/usr/bin/env python3

from __future__ import annotations



import argparse, json

from pathlib import Path

import numpy as np

import pandas as pd

import matplotlib.pyplot as plt



# Canonical method keys and their aliases that may appear in results.csv

ALIASES = {

    "undrifted": ["undrifted", "Undrifted (reference)", "Undrifted (upper bound)", "Undrifted (ref)"],

    "no_adapt": ["no_adapt", "No adapt"],

    "oracle_undo": ["oracle_undo", "Oracle-Undo (sanity)"],

    "oracle_coral_true": ["oracle_coral_true", "Oracle-CORAL (true Ct)", "Oracle-CORAL(true Ct)"],

    "itspace": ["itspace", "ITSPACE"],

    "bw_geodesic": ["bw_geodesic", "BW-geodesic"],

    "bw_gd": ["bw_gd", "BW-GD"],

    "euclidean": ["euclidean", "Euclidean"],

    "logeuclid": ["logeuclid", "Log-Euclidean"],

    "airm": ["airm", "AIRM"],

    "coral": ["coral", "CORAL"],

    "sinkhorn": ["sinkhorn", "Sinkhorn"],

    "sinkhorn_gaus": ["sinkhorn_gaus", "sinkhorn_gaussian", "Sinkhorn-Gaussian", "Sinkhorn-Gauss", "SinkhornGaussian"],

}



PRETTY = {

    "undrifted": "Undrifted (reference)",

    "no_adapt": "No adapt",

    "oracle_undo": "Oracle-Undo (sanity)",

    "oracle_coral_true": "Oracle-CORAL (true Ct)",

    "itspace": "ITSPACE",

    "bw_geodesic": "BW-geodesic",

    "bw_gd": "BW-GD",

    "euclidean": "Euclidean",

    "logeuclid": "Log-Euclidean",

    "airm": "AIRM",

    "coral": "CORAL",

    "sinkhorn": "Sinkhorn",

    "sinkhorn_gaus": "Sinkhorn-Gaussian",

}



ORDER_ALL = [

    "undrifted","no_adapt","oracle_undo","oracle_coral_true",

    "itspace","bw_geodesic","bw_gd","euclidean","logeuclid","airm","coral","sinkhorn","sinkhorn_gaus"

]



ORDER_MAIN = ["no_adapt","coral","itspace","bw_gd","sinkhorn_gaus"]



def resolve_present_name(present: set[str], key: str) -> str | None:

    for a in ALIASES.get(key, [key]):

        if a in present:

            return a

    return None



def pick_metric(df: pd.DataFrame) -> tuple[str, str, pd.DataFrame]:

    """

    Returns (metric_name, ylabel, df_metric)

    Supports:

      - long format: columns ['metric','value',...]

      - wide format: columns include 'acc' or 'roc_auc'

    """

    if "metric" in df.columns and "value" in df.columns:

        mets = set(df["metric"].astype(str).unique())

        if "roc_auc" in mets:

            m = "roc_auc"; y = "AUROC (%)"

        elif "auroc" in mets:

            m = "auroc"; y = "AUROC (%)"

        elif "acc" in mets:

            m = "acc"; y = "Accuracy (%)"

        elif "accuracy" in mets:

            m = "accuracy"; y = "Accuracy (%)"

        else:

            m = sorted(mets)[0]

            y = f"{m} (%)"

        dff = df[df["metric"].astype(str) == m].copy()

        return m, y, dff



    # wide format fallback

    if "roc_auc" in df.columns:

        return "roc_auc", "AUROC (%)", df.copy()

    if "acc" in df.columns:

        return "acc", "Accuracy (%)", df.copy()



    raise SystemExit("Could not infer metric column (need either (metric,value) long format OR acc/roc_auc wide columns).")



def scale_if_needed(vals: np.ndarray) -> float:

    # If values look like fractions (<= 1.0-ish), scale to %

    med = float(np.nanmedian(vals)) if vals.size else 0.0

    return 100.0 if med <= 1.2 else 1.0



def fmt_pm(vals: np.ndarray, dec: int = 2) -> str:

    if vals.size == 0:

        return "NA"

    m = float(np.nanmean(vals))

    s = float(np.nanstd(vals, ddof=1)) if vals.size > 1 else 0.0

    return f"{m:.{dec}f} ± {s:.{dec}f}"



def fmt_time(vals: np.ndarray) -> str:

    if vals.size == 0:

        return "NA"

    m = float(np.nanmean(vals))

    s = float(np.nanstd(vals, ddof=1)) if vals.size > 1 else 0.0

    if m < 5e-4 and s < 5e-4:

        return "<0.001"

    return f"{m:.3f} ± {s:.3f}"



def agg(dfm: pd.DataFrame, metric_name: str) -> pd.DataFrame:

    # Determine value column depending on format

    if "value" in dfm.columns and "metric" in dfm.columns:

        v = dfm["value"].to_numpy()

        scale = scale_if_needed(v)

        dfm = dfm.copy()

        dfm["valp"] = dfm["value"] * scale

        val_col = "valp"

    else:

        # wide: metric_name is the column

        v = dfm[metric_name].to_numpy()

        scale = scale_if_needed(v)

        dfm = dfm.copy()

        dfm["valp"] = dfm[metric_name] * scale

        val_col = "valp"



    if "t_adapt_s" not in dfm.columns:

        dfm["t_adapt_s"] = 0.0

    if "t_shared_s" not in dfm.columns:

        dfm["t_shared_s"] = 0.0



    out = (

        dfm.groupby(["method","K"], as_index=False)

           .agg(

               metric_mean=(val_col, "mean"),

               metric_std=(val_col, "std"),

               t_mean=("t_adapt_s", "mean"),

               t_std=("t_adapt_s", "std"),

               t_shared_mean=("t_shared_s", "mean"),

           )

    )

    return out



def make_table(aggdf: pd.DataFrame, keys: list[str], Ks: list[int]) -> pd.DataFrame:

    present = set(aggdf["method"].unique())

    rows = []

    for key in keys:

        mname = resolve_present_name(present, key)

        if mname is None:

            continue

        row = {"method": PRETTY.get(key, mname)}

        for K in Ks:

            sub = aggdf[(aggdf["method"] == mname) & (aggdf["K"] == K)]

            if sub.empty:

                row[f"k{K}"] = "NA"

            else:

                row[f"k{K}"] = f"{float(sub['metric_mean'].iloc[0]):.2f} ± {float(sub['metric_std'].fillna(0.0).iloc[0]):.2f}"

        # t5: use K=5 if present else smallest K

        k_t = 5 if 5 in Ks else Ks[0]

        sub = aggdf[(aggdf["method"] == mname) & (aggdf["K"] == k_t)]

        if sub.empty:

            row["t5"] = "NA"

        else:

            m = float(sub["t_mean"].iloc[0]); s = float(sub["t_std"].fillna(0.0).iloc[0])

            row["t5"] = "<0.001" if (m < 5e-4 and s < 5e-4) else f"{m:.3f} ± {s:.3f}"

        rows.append(row)

    return pd.DataFrame(rows)



def plot(aggdf: pd.DataFrame, keys: list[str], Ks: list[int], title: str, ylabel: str, out_png: Path, show_ref: bool = True) -> None:

    present = set(aggdf["method"].unique())

    eps = 1e-6

    plt.figure(figsize=(7.6, 4.9))



    # reference horizontal line from undrifted @ smallest K

    if show_ref:

        und = resolve_present_name(present, "undrifted")

        if und is not None:

            sub = aggdf[(aggdf["method"] == und) & (aggdf["K"] == Ks[0])]

            if not sub.empty:

                yref = float(sub["metric_mean"].iloc[0])

                plt.axhline(yref, linestyle="--", linewidth=1.6, label="Undrifted (reference)")



    for key in keys:

        mname = resolve_present_name(present, key)

        if mname is None:

            continue

        sub = aggdf[aggdf["method"] == mname].sort_values("K")

        xs = np.maximum(sub["t_mean"].to_numpy(), eps)

        ys = sub["metric_mean"].to_numpy()

        yerr = sub["metric_std"].fillna(0.0).to_numpy()

        plt.plot(xs, ys, marker="o", linewidth=2.0, label=PRETTY.get(key, mname))

        plt.errorbar(xs, ys, yerr=yerr, fmt="none", capsize=2, alpha=0.6)



    plt.xscale("log")

    plt.xlabel("Adaptation time $t_{adapt}$ (s, log)  [shared prep excluded]")

    plt.ylabel(ylabel)

    plt.title(title)

    plt.grid(True, which="both", linestyle="--", alpha=0.35)

    plt.legend(loc="lower left", ncol=2, framealpha=0.95, fontsize=8)

    plt.tight_layout()

    plt.savefig(out_png, dpi=220)

    plt.close()



def main():

    ap = argparse.ArgumentParser()

    ap.add_argument("--run-dir", required=True)

    ap.add_argument("--title", default=None)

    ap.add_argument("--hide-oracles", action="store_true")

    args = ap.parse_args()



    run_dir = Path(args.run_dir)

    title = args.title or run_dir.name



    df = pd.read_csv(run_dir / "results.csv")

    # keep only ok rows if status exists

    if "status" in df.columns:

        df = df[df["status"].astype(str) == "ok"].copy()



    metric_name, ylabel, dfm = pick_metric(df)

    Ks = sorted(int(k) for k in dfm["K"].unique())

    Ks_tab = [k for k in [1,2,5,20] if k in Ks] or Ks



    aggdf = agg(dfm, metric_name)



    keys_all = ORDER_ALL.copy()

    if args.hide_oracles:

        keys_all = [k for k in keys_all if not k.startswith("oracle")]

    keys_main = [k for k in ORDER_MAIN if k in keys_all]



    art = run_dir / "paper_artifacts"

    art.mkdir(parents=True, exist_ok=True)



    tab_all = make_table(aggdf, keys_all, Ks_tab)

    tab_main = make_table(aggdf, keys_main, Ks_tab)



    tab_all.to_csv(art / "covdrift_mr_table_all.csv", index=False)

    tab_main.to_csv(art / "covdrift_mr_table_main.csv", index=False)

    (art / "covdrift_mr_table_main.tex").write_text(tab_main.to_latex(index=False, escape=False))



    plot(aggdf, keys_all, Ks, title, ylabel, art / "covdrift_mr_all.png", show_ref=True)

    plot(aggdf, keys_main, Ks, title, ylabel, art / "covdrift_mr_main.png", show_ref=True)



    manifest = {

        "run_dir": str(run_dir),

        "title": title,

        "metric": metric_name,

        "Ks": Ks,

        "Ks_table": Ks_tab,

        "main_methods": keys_main,

        "all_methods": keys_all,

    }

    (art / "covdrift_mr_manifest.json").write_text(json.dumps(manifest, indent=2))



    print("WROTE", art)



if __name__ == "__main__":

    main()

