import os
import json
import argparse
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt

# Match the style of plot_results.py
mpl.rcParams.update({
    "figure.figsize": (7.0, 4.6),
    "font.size": 14,
    "axes.labelsize": 16,
    "axes.titlesize": 16,
    "xtick.labelsize": 13,
    "ytick.labelsize": 13,
    "legend.fontsize": 12,
    "lines.linewidth": 2.5,
    "axes.grid": True,
    "grid.alpha": 0.25,
    "savefig.dpi": 200,
    "pdf.fonttype": 42,
    "ps.fonttype": 42,
})

def _marker_indices(T: int, n_markers: int = 20):
    """Roughly n_markers evenly spaced indices in [0, T-1]."""
    n_markers = max(2, min(n_markers, T))
    return np.unique(np.round(np.linspace(0, T - 1, n_markers)).astype(int))

def _load_exp(exp_dir: str):
    """Load curves + labels from a single experiment directory."""
    runs_path = os.path.join(exp_dir, "runs.npz")
    keymap_path = os.path.join(exp_dir, "algo_keymap.json")
    meta_path = os.path.join(exp_dir, "metadata.json")
    config_path = os.path.join(exp_dir, "config.json")

    if not os.path.exists(runs_path):
        raise FileNotFoundError(f"Missing {runs_path}")
    if not os.path.exists(keymap_path):
        raise FileNotFoundError(f"Missing {keymap_path}")
    if not os.path.exists(config_path):
        raise FileNotFoundError(f"Missing {config_path}")

    data = np.load(runs_path)
    with open(keymap_path, "r") as f:
        keymap = json.load(f)

    meta = {}
    if os.path.exists(meta_path):
        with open(meta_path, "r") as f:
            meta = json.load(f)

    with open(config_path, "r") as f:
        cfg = json.load(f)

    # Use graph_type from the config as label: "ER", "RBF", "SBM"
    graph_label = cfg.get("graph_type", os.path.basename(os.path.normpath(exp_dir)))
    algos = meta.get("algos", list(keymap.keys()))

    curves = {}
    T = None
    for algo in algos:
        arr_key = keymap.get(algo)
        if arr_key is None or arr_key not in data:
            continue
        arr = data[arr_key]  # shape (R, T)
        if T is None:
            T = arr.shape[1]
        curves[algo] = arr

    if not curves:
        raise RuntimeError(f"No curves found in {exp_dir}")

    return graph_label, curves, T

def main():
    ap = argparse.ArgumentParser()
    ap.add_argument(
        "--exp_dirs",
        nargs="+",
        required=True,
        help="List of experiment dirs, e.g. out/medium_kernelA_graph_ER out/medium_kernelA_graph_RBF out/medium_kernelA_graph_SBM",
    )
    ap.add_argument(
        "--markers",
        type=int,
        default=20,
        help="Approx. number of errorbar markers along each curve",
    )
    ap.add_argument(
        "--out_path",
        type=str,
        default=None,
        help="Optional path for the combined plot (default: first_exp_dir/plots/cumreg_graph_models.png)",
    )
    args = ap.parse_args()

    graph_results = {}
    Ts = []
    algo_list = None

    # Load each experiment
    for exp_dir in args.exp_dirs:
        graph_label, curves, T = _load_exp(exp_dir)
        graph_results[graph_label] = curves
        Ts.append(T)

        if algo_list is None:
            algo_list = list(curves.keys())
        else:
            # Ensure all experiments have the same algorithms
            missing = set(algo_list) - set(curves.keys())
            extra = set(curves.keys()) - set(algo_list)
            if missing or extra:
                raise ValueError(
                    f"Algorithms mismatch between experiments. "
                    f"Expected {algo_list}, got {list(curves.keys())} in {exp_dir}"
                )

    graph_labels = list(graph_results.keys())

    # Check T consistency
    if len(set(Ts)) != 1:
        raise ValueError(f"Inconsistent T across experiments: {Ts}")
    T = Ts[0]
    x_full = np.arange(1, T + 1)
    idx = _marker_indices(T, n_markers=args.markers)

    n_algos = len(algo_list)
    fig, axes = plt.subplots(1, n_algos, figsize=(7.0 * n_algos, 4.6), sharey=True)
    if n_algos == 1:
        axes = [axes]

    colors = ["C0", "C1", "C2", "C3", "C4", "C5"]

    # One subplot per algorithm (e.g., LK-GP-UCB, LK-GP-TS)
    for j, algo in enumerate(algo_list):
        ax = axes[j]
        for k, graph_label in enumerate(graph_labels):
            arr = graph_results[graph_label][algo]  # shape (R, T)
            R = arr.shape[0]
            mean = arr.mean(axis=0)
            color = colors[k % len(colors)]

            # Main line
            ax.plot(x_full, mean, label=graph_label, color=color)

            # Errorbars at a subset of points (95% CI)
            if R >= 2:
                sem = arr.std(axis=0, ddof=1) / np.sqrt(R)
                y = mean[idx]
                yerr = 1.96 * sem[idx]
                ax.errorbar(
                    x_full[idx],
                    y,
                    yerr=yerr,
                    fmt="o",
                    color=color,
                    markersize=3,
                    linewidth=0,
                    capsize=3,
                )

        ax.set_title(algo)
        ax.set_xlabel("Round")

    axes[0].set_ylabel("Cumulative Regret")
    axes[0].legend(frameon=False, title="Graph model")

    fig.tight_layout()

    # Where to save
    if args.out_path is None:
        first_dir = args.exp_dirs[0]
        plot_dir = os.path.join(first_dir, "plots")
        os.makedirs(plot_dir, exist_ok=True)
        out_path = os.path.join(plot_dir, "cumreg_graph_models.png")
    else:
        out_path = args.out_path
        os.makedirs(os.path.dirname(out_path), exist_ok=True)

    fig.savefig(out_path)
    print(f"[INFO] Saved {out_path}")

if __name__ == "__main__":
    main()
