import os
import json
import argparse
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt

# ---- Global style tweaks (bigger, cleaner) ----
mpl.rcParams.update({
    "figure.figsize": (7.0, 4.6),
    "font.size": 14,
    "axes.labelsize": 16,
    "axes.titlesize": 16,
    "xtick.labelsize": 13,
    "ytick.labelsize": 13,
    "legend.fontsize": 12,
    "lines.linewidth": 2.5,
    "axes.grid": True,
    "grid.alpha": 0.25,
    "savefig.dpi": 200,
    "pdf.fonttype": 42,
    "ps.fonttype": 42,
})

def _is_coop_kz_ablation(meta_algos):
    """Return True iff this plot contains only Coop-KernelUCB variants (agent-kernel ablation)."""
    if not isinstance(meta_algos, list) or len(meta_algos) < 2:
        return False
    al = [a.lower() for a in meta_algos]
    return all(a.startswith("coop-kernelucb") for a in al)

def _marker_indices(T: int, n_markers: int = 20):
    """Roughly n_markers evenly spaced indices in [0, T-1]."""
    n_markers = max(2, min(n_markers, T))
    return np.unique(np.round(np.linspace(0, T - 1, n_markers)).astype(int))

def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--exp_dir", required=True, help="Experiment output directory (contains runs.npz)")
    ap.add_argument("--markers", type=int, default=20, help="Approx. number of errorbar markers along the curve")
    args = ap.parse_args()

    exp_dir = args.exp_dir
    plot_dir = os.path.join(exp_dir, "plots")
    os.makedirs(plot_dir, exist_ok=True)

    # Load runs + labels
    runs_path = os.path.join(exp_dir, "runs.npz")
    keymap_path = os.path.join(exp_dir, "algo_keymap.json")
    meta_path = os.path.join(exp_dir, "metadata.json")

    if not os.path.exists(runs_path):
        raise FileNotFoundError(f"Missing {runs_path}")
    if not os.path.exists(keymap_path):
        raise FileNotFoundError(f"Missing {keymap_path}")

    data = np.load(runs_path)
    with open(keymap_path, "r") as f:
        keymap = json.load(f)
    meta_algos = None
    if os.path.exists(meta_path):
        with open(meta_path, "r") as f:
            meta = json.load(f)
        meta_algos = meta.get("algos")

    curves = []
    T = None
    for disp_name, arr_key in keymap.items():
        if arr_key not in data:
            continue
        arr = data[arr_key]
        if T is None:
            T = arr.shape[1]
        curves.append((disp_name, arr))

    if not curves:
        raise ValueError("No curves found in runs.npz that match algo_keymap.json")

    curves.sort(key=lambda x: x[0].lower())
    x_full = np.arange(T)
    idx = _marker_indices(T, n_markers=args.markers)

    fig, ax = plt.subplots()
    for disp_name, arr in curves:
        mean = arr.mean(axis=0)

        # Plot main line; capture its color for matching markers/error bars
        (line_obj,) = ax.plot(x_full, mean, label=disp_name)
        line_color = line_obj.get_color()

        if arr.shape[0] >= 2:
            sem = arr.std(axis=0, ddof=1) / np.sqrt(arr.shape[0])
            y = mean[idx]
            yerr = 1.96 * sem[idx]
            ax.errorbar(
                x_full[idx], y, yerr=yerr,
                fmt="o",
                color=line_color,
                ecolor=line_color,
                markeredgecolor=line_color,
                markerfacecolor=line_color,
                markersize=4,
                linewidth=0,
                capsize=3,
                capthick=1,
                elinewidth=1.5,
            )
        else:
            ax.plot(
                x_full[idx], mean[idx],
                "o",
                color=line_color,
                markeredgecolor=line_color,
                markerfacecolor=line_color,
                markersize=4,
                linewidth=0,
            )

    ax.set_xlabel("Round")
    ax.set_ylabel("Cumulative Regret")

    if not _is_coop_kz_ablation(meta_algos):
        ax.set_title(os.path.basename(os.path.normpath(exp_dir)))

    ax.legend(frameon=False)
    fig.tight_layout()
    out_path = os.path.join(plot_dir, "cumreg_multi.png")
    fig.savefig(out_path)
    print(f"[INFO] Saved {out_path}")

if __name__ == "__main__":
    main()
