import os
import json
import argparse
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt

# modest visual polish
mpl.rcParams.update({
    "font.size": 12,
    "axes.labelsize": 14,
    "axes.titlesize": 14,
    "xtick.labelsize": 12,
    "ytick.labelsize": 12,
    "legend.fontsize": 12,
    "lines.linewidth": 2.3,
    "axes.grid": True,
    "grid.alpha": 0.25,
    "savefig.dpi": 200,
    "pdf.fonttype": 42,
    "ps.fonttype": 42,
})

def _load_runs(exp_dir):
    runs = np.load(os.path.join(exp_dir, "runs.npz"))
    with open(os.path.join(exp_dir, "algo_keymap.json"), "r") as f:
        keymap = json.load(f)
    curves = []
    for disp, key in keymap.items():
        if key in runs:
            arr = runs[key]
            curves.append((disp, arr))
    curves.sort(key=lambda x: x[0].lower())
    return curves

def _marker_indices(T: int, n_markers: int = 20):
    n_markers = max(2, min(n_markers, T))
    return np.unique(np.round(np.linspace(0, T - 1, n_markers)).astype(int))

def _plot_panel(ax, exp_dir, show_ylabel=False, n_markers=20):
    curves = _load_runs(exp_dir)
    if not curves:
        ax.text(0.5, 0.5, "No data", ha="center", va="center", transform=ax.transAxes)
        return

    T = curves[0][1].shape[1]
    x = np.arange(T)
    idx = _marker_indices(T, n_markers=n_markers)

    for disp, arr in curves:
        mean = arr.mean(axis=0)
        (line_obj,) = ax.plot(x, mean, label=disp)
        color = line_obj.get_color()
        if arr.shape[0] >= 2:
            sem = arr.std(axis=0, ddof=1) / np.sqrt(arr.shape[0])
            y = mean[idx]
            yerr = 1.96 * sem[idx]
            ax.errorbar(
                x[idx], y, yerr=yerr,
                fmt="o",
                color=color, ecolor=color,
                markeredgecolor=color, markerfacecolor=color,
                markersize=4, linewidth=0,
                capsize=3, capthick=1, elinewidth=1.5,
            )
        else:
            ax.plot(x[idx], mean[idx], "o", color=color, markersize=4, linewidth=0)

    ax.set_xlabel("Round")
    if show_ylabel:
        ax.set_ylabel("Cumulative regret")
    else:
        ax.set_ylabel("")

def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--linear",  nargs=3, metavar=("DIR_SIMPLE","DIR_MEDIUM","DIR_HARD"))
    ap.add_argument("--kernelA", nargs=3, metavar=("DIR_SIMPLE","DIR_MEDIUM","DIR_HARD"))
    ap.add_argument("--kernelB", nargs=3, metavar=("DIR_SIMPLE","DIR_MEDIUM","DIR_HARD"))
    ap.add_argument("--outdir", required=True)
    ap.add_argument("--sharey", action="store_true")
    ap.add_argument("--markers", type=int, default=20, help="~number of errorbar markers per curve")
    args = ap.parse_args()

    os.makedirs(args.outdir, exist_ok=True)

    rows = []
    if args.linear:  rows.append(args.linear)
    if args.kernelA: rows.append(args.kernelA)
    if args.kernelB: rows.append(args.kernelB)
    if not rows:
        raise ValueError("Provide at least one of --linear/--kernelA/--kernelB")

    nrows, ncols = len(rows), 3

    fig_w = 6.2 * ncols 
    fig_h = 3.8 * nrows
    fig, axes = plt.subplots(
        nrows=nrows, ncols=ncols, figsize=(fig_w, fig_h),
        squeeze=False, sharey=args.sharey
    )

    for r in range(nrows):
        for c in range(ncols):
            _plot_panel(
                axes[r, c], rows[r][c],
                show_ylabel=(c == 0),
                n_markers=args.markers
            )

    fig.subplots_adjust(bottom=0.26, wspace=0.22, hspace=0.25)

    # single, centered legend below subplots
    handles, labels = axes[-1, -1].get_legend_handles_labels()
    if handles:
        fig.legend(
            handles, labels, loc="lower center",
            ncol=min(8, len(labels)), frameon=False,
            bbox_to_anchor=(0.5, 0.00)  
        )

    out_png = os.path.join(args.outdir, "grid.png")
    fig.savefig(out_png, bbox_inches="tight")
    print(f"[INFO] Saved {out_png}")

if __name__ == "__main__":
    main()
