import os
import json
import argparse
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt

# Match style from plot_results.py
mpl.rcParams.update({
    "figure.figsize": (7.0, 4.6),
    "font.size": 14,
    "axes.labelsize": 16,
    "axes.titlesize": 16,
    "xtick.labelsize": 13,
    "ytick.labelsize": 13,
    "legend.fontsize": 12,
    "lines.linewidth": 2.5,
    "axes.grid": True,
    "grid.alpha": 0.25,
    "savefig.dpi": 200,
    "pdf.fonttype": 42,
    "ps.fonttype": 42,
})

def _marker_indices(T: int, n_markers: int = 20):
    """Roughly n_markers evenly spaced indices in [0, T-1]."""
    n_markers = max(2, min(n_markers, T))
    return np.unique(np.round(np.linspace(0, T - 1, n_markers)).astype(int))

def _load_exp(exp_dir: str):
    """Load curves + labels from a single experiment directory."""
    runs_path = os.path.join(exp_dir, "runs.npz")
    keymap_path = os.path.join(exp_dir, "algo_keymap.json")
    meta_path = os.path.join(exp_dir, "metadata.json")
    config_path = os.path.join(exp_dir, "config.json")

    if not os.path.exists(runs_path):
        raise FileNotFoundError(f"Missing {runs_path}")
    if not os.path.exists(keymap_path):
        raise FileNotFoundError(f"Missing {keymap_path}")
    if not os.path.exists(config_path):
        raise FileNotFoundError(f"Missing {config_path}")

    data = np.load(runs_path)
    with open(keymap_path, "r") as f:
        keymap = json.load(f)

    meta = {}
    if os.path.exists(meta_path):
        with open(meta_path, "r") as f:
            meta = json.load(f)

    with open(config_path, "r") as f:
        cfg = json.load(f)

    # Panel title from graph_type, fallback to dir name
    graph_label = cfg.get("graph_type", os.path.basename(os.path.normpath(exp_dir)))

    # Algorithms list from metadata if present, else from keymap
    algos = meta.get("algos", list(keymap.keys()))

    curves = {}
    T = None
    for algo in algos:
        arr_key = keymap.get(algo)
        if arr_key is None or arr_key not in data:
            continue
        arr = data[arr_key]  # shape (R, T)
        if T is None:
            T = arr.shape[1]
        curves[algo] = arr

    if not curves:
        raise RuntimeError(f"No curves found in {exp_dir}")

    return graph_label, curves, T

def main():
    ap = argparse.ArgumentParser()
    ap.add_argument(
        "--exp_dirs",
        nargs="+",
        required=True,
        help=(
            "List of experiment dirs, e.g. "
            "out/medium_kernelA_graph_ER out/medium_kernelA_graph_RBF out/medium_kernelA_graph_SBM"
        ),
    )
    ap.add_argument(
        "--markers",
        type=int,
        default=20,
        help="Approx. number of errorbar markers along each curve",
    )
    ap.add_argument(
        "--out_path",
        type=str,
        default=None,
        help="Optional path for the combined plot "
             "(default: first_exp_dir/plots/cumreg_graph_envs.png)",
    )
    args = ap.parse_args()

    panel_results = []  # list of (graph_label, curves) in input order
    Ts = []
    algo_list = None

    # Load all experiments (one per graph model)
    for exp_dir in args.exp_dirs:
        graph_label, curves, T = _load_exp(exp_dir)
        panel_results.append((graph_label, curves))
        Ts.append(T)

        if algo_list is None:
            algo_list = list(curves.keys())
        else:
            # Make sure same algorithms exist in all experiments
            missing = set(algo_list) - set(curves.keys())
            extra = set(curves.keys()) - set(algo_list)
            if missing or extra:
                raise ValueError(
                    f"Algorithms mismatch between experiments. "
                    f"Expected {algo_list}, got {list(curves.keys())} in {exp_dir}"
                )

    if len(set(Ts)) != 1:
        raise ValueError(f"Inconsistent T across experiments: {Ts}")
    T = Ts[0]
    x_full = np.arange(1, T + 1)
    idx = _marker_indices(T, n_markers=args.markers)

    # Compute global y-limits so panels are comparable
    global_min, global_max = np.inf, -np.inf
    for _, curves in panel_results:
        for algo in algo_list:
            arr = curves[algo]
            mean = arr.mean(axis=0)
            global_min = min(global_min, mean.min())
            global_max = max(global_max, mean.max())

    # A bit of padding
    y_pad = 0.05 * (global_max - global_min) if global_max > global_min else 1.0
    y_min = global_min - y_pad
    y_max = global_max + y_pad

    # One subplot per graph model (environment)
    n_panels = len(panel_results)
    fig, axes = plt.subplots(1, n_panels, figsize=(7.0 * n_panels, 4.6), sharey=True)
    if n_panels == 1:
        axes = [axes]

    colors = ["C0", "C1", "C2", "C3", "C4", "C5"]

    for j, (graph_label, curves) in enumerate(panel_results):
        ax = axes[j]
        for k, algo in enumerate(algo_list):
            arr = curves[algo]  # shape (R, T)
            R = arr.shape[0]
            mean = arr.mean(axis=0)
            color = colors[k % len(colors)]

            # Main line
            ax.plot(x_full, mean, label=algo, color=color)

            # Errorbars at subset of points (95% CI)
            if R >= 2:
                sem = arr.std(axis=0, ddof=1) / np.sqrt(R)
                y = mean[idx]
                yerr = 1.96 * sem[idx]
                ax.errorbar(
                    x_full[idx],
                    y,
                    yerr=yerr,
                    fmt="o",
                    color=color,
                    markersize=3,
                    linewidth=0,
                    capsize=3,
                )

        ax.set_title(graph_label)
        ax.set_xlabel("Round")
        ax.set_ylim(y_min, y_max)

        if j == 0:
            ax.set_ylabel("Cumulative Regret")

    axes[0].legend(frameon=False, title="Algorithm")

    fig.tight_layout()

    # Save
    if args.out_path is None:
        first_dir = args.exp_dirs[0]
        plot_dir = os.path.join(first_dir, "plots")
        os.makedirs(plot_dir, exist_ok=True)
        out_path = os.path.join(plot_dir, "cumreg_graph_envs.png")
    else:
        out_path = args.out_path
        os.makedirs(os.path.dirname(out_path), exist_ok=True)

    fig.savefig(out_path)
    print(f"[INFO] Saved {out_path}")

if __name__ == "__main__":
    main()
