import argparse
import json
from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.ticker as mticker



FIGSIZE = (6.0, 3.5)  
LINEWIDTH = 2.0
ALPHA_BAND = 0.22
DPI_DEFAULT = 300

def set_fontsizes(base=11):
    plt.rcParams.update({
        "font.size": base,          
        "axes.labelsize": base,     
        "xtick.labelsize": base-1,  
        "ytick.labelsize": base-1,  
        "legend.fontsize": base-1,  
        "figure.titlesize": base,   
    })

def shift_xlabel_left(ax, frac=0.46, y=-0.08):
    ax.set_xlabel(ax.get_xlabel())  
    ax.xaxis.set_label_coords(frac, y) 
    ax.xaxis.get_label().set_ha("center")  



def apply_y_offset_on_top(ax, fontsize=None):
    sf = mticker.ScalarFormatter(useOffset=True, useMathText=False)  
    sf.set_scientific(True)
    sf.set_powerlimits((0, 0))  
    ax.yaxis.set_major_formatter(sf)

    ax.figure.canvas.draw_idle()
    off = ax.yaxis.get_offset_text()
    if fontsize is not None:
        off.set_fontsize(fontsize)
    off.set_va('bottom')
    off.set_ha('left')
    off.set_x(0.0)
    off.set_y(1.02)




def mean_ci_2sigma(arr_2d):
    if arr_2d.ndim != 2:
        raise ValueError(f"Expected 2D array (R,T), got shape {arr_2d.shape}")
    R = arr_2d.shape[0]
    mean = np.nanmean(arr_2d, axis=0)
    std = np.nanstd(arr_2d, axis=0, ddof=1 if R > 1 else 0)
    band = 2.0 * std
    return mean, mean - band, mean + band


def _running_avg(arr):
    cumsum = np.cumsum(arr, dtype=float)
    t = np.arange(1, len(arr) + 1, dtype=float)
    return cumsum / t


def load_aggregate_npz(path_npz):
    data = np.load(path_npz)
    loss_runs = data["loss_runs"]       
    cost_runs = data["cost_runs"]        
    loss_opt = data["loss_opt_episode"] if "loss_opt_episode" in data else np.nan
    cost_opt = data["cost_opt_episode"] if "cost_opt_episode" in data else np.nan
    return loss_runs, cost_runs, loss_opt, cost_opt


def load_runs_from_dirs(dirs):
    loss_list, cost_list = [], []
    loss_opt, cost_opt = np.nan, np.nan
    T_min = None

    for d in dirs:
        d = Path(d)
        npz_path = d / "histories.npz"
        if not npz_path.exists():
            raise FileNotFoundError(f"{npz_path} not found.")

        hist = np.load(npz_path)
        if "loss_avg_history" in hist and "cost_avg_history" in hist:
            loss_avg = np.asarray(hist["loss_avg_history"], dtype=float)
            cost_avg = np.asarray(hist["cost_avg_history"], dtype=float)
        else:
            if "loss_history" not in hist or "cost_history" not in hist:
                raise KeyError(f"{npz_path} must contain loss_avg_history/cost_avg_history or loss_history/cost_history.")
            loss_avg = _running_avg(np.asarray(hist["loss_history"], dtype=float))
            cost_avg = _running_avg(np.asarray(hist["cost_history"], dtype=float))

        T = min(len(loss_avg), len(cost_avg))
        T_min = T if T_min is None else min(T_min, T)
        loss_list.append(loss_avg)
        cost_list.append(cost_avg)

        if (np.isnan(loss_opt) or np.isnan(cost_opt)):
            meta_path = d / "meta.json"
            if meta_path.exists():
                try:
                    with open(meta_path, "r", encoding="utf-8") as f:
                        meta = json.load(f)
                    lp = meta.get("lp_opt", {})
                    if "loss_opt_episode" in lp and lp["loss_opt_episode"] is not None:
                        loss_opt = float(lp["loss_opt_episode"])
                    if "cost_opt_episode" in lp and lp["cost_opt_episode"] is not None:
                        cost_opt = float(lp["cost_opt_episode"])
                except Exception:
                    pass

    loss_runs = np.vstack([x[:T_min] for x in loss_list])
    cost_runs = np.vstack([x[:T_min] for x in cost_list])
    return loss_runs, cost_runs, loss_opt, cost_opt


def _finalize_axes(ax, x, xlabel, ylabel):
    ax.set_xlabel(xlabel)
    ax.set_ylabel(ylabel)
    ax.set_xlim(x[0], x[-1])
    ax.margins(x=0)  



def plot_mean_ci(x, runs_2d, ylabel, out_path, dpi=DPI_DEFAULT):
    mean, lo, hi = mean_ci_2sigma(runs_2d)
    fig, ax = plt.subplots(figsize=FIGSIZE)
    ax.plot(x, mean, linewidth=LINEWIDTH)
    ax.fill_between(x, lo, hi, alpha=ALPHA_BAND)
    _finalize_axes(ax, x, xlabel="Episode", ylabel=ylabel)
    fig.tight_layout(pad=0.1)
    out_path = Path(out_path)
    out_path.parent.mkdir(parents=True, exist_ok=True)
    fig.savefig(out_path, dpi=dpi, bbox_inches="tight", pad_inches=0.02)
    plt.close(fig)


def plot_gap_ci_loss(x, K, runs_2d, baseline_scalar, ylabel, out_path, dpi=DPI_DEFAULT):
    if baseline_scalar is None or np.isnan(baseline_scalar):
        return
    gap_runs = runs_2d - float(baseline_scalar)
    gap_runs *= np.arange(1, K+1)
    mean, lo, hi = mean_ci_2sigma(gap_runs)
    fig, ax = plt.subplots(figsize=FIGSIZE)
    ax.plot(x, mean, linewidth=LINEWIDTH)
    ax.fill_between(x, lo, hi, alpha=ALPHA_BAND)
    _finalize_axes(ax, x, xlabel="Episode", ylabel=ylabel)
    apply_y_offset_on_top(ax, fontsize=None) 
    fig.tight_layout(pad=0.1)
    out_path = Path(out_path)
    out_path.parent.mkdir(parents=True, exist_ok=True)
    fig.savefig(out_path, dpi=dpi, bbox_inches="tight", pad_inches=0.02)
    plt.close(fig)

def plot_gap_ci_cost(x, K, runs_2d, baseline_scalar, ylabel, out_path, dpi=DPI_DEFAULT):
    if baseline_scalar is None or np.isnan(baseline_scalar):
        return
    b = baseline_scalar
    gap_runs = np.maximum(runs_2d - b, np.zeros_like(runs_2d))
    gap_runs *= np.arange(1, K+1)



    mean, lo, hi = mean_ci_2sigma(gap_runs)
    fig, ax = plt.subplots(figsize=FIGSIZE)
    ax.plot(x, mean, linewidth=LINEWIDTH)
    ax.fill_between(x, lo, hi, alpha=ALPHA_BAND)
    _finalize_axes(ax, x, xlabel="Episode", ylabel=ylabel)
    apply_y_offset_on_top(ax, fontsize=None) 
    fig.tight_layout(pad=0.1)
    out_path = Path(out_path)
    out_path.parent.mkdir(parents=True, exist_ok=True)
    fig.savefig(out_path, dpi=dpi, bbox_inches="tight", pad_inches=0.02)
    plt.close(fig)


def main():
    parser = argparse.ArgumentParser(description="Paper-style plots with 2-sigma CI from saved NPZ.")
    src = parser.add_mutually_exclusive_group(required=True)
    src.add_argument("--aggregate", type=str, help="Path to aggregate npz (loss_runs, cost_runs, ...).")
    src.add_argument("--runs", nargs="+", help="Run directories containing histories.npz (and optional meta.json).")
    parser.add_argument("--out", type=str, default="results", help="Output directory.")
    parser.add_argument("--prefix", type=str, default="agg", help="Filename prefix (e.g., 'agg').")
    parser.add_argument("--dpi", type=int, default=DPI_DEFAULT, help="Figure DPI.")
    parser.add_argument("--K", type=int, default=100000, help="Episode number")
    parser.add_argument("--b", type=float, default=5.6, help="Budget")
    args = parser.parse_args()

    out_dir = Path(args.out)
    out_dir.mkdir(parents=True, exist_ok=True)

    set_fontsizes(base=15)

    # Load data
    if args.aggregate:
        loss_runs, cost_runs, loss_opt, cost_opt = load_aggregate_npz(args.aggregate)
    else:
        loss_runs, cost_runs, loss_opt, cost_opt = load_runs_from_dirs(args.runs)

    R, T = loss_runs.shape
    x = np.arange(1, T + 1)
    K = args.K
    b = args.b

    # Average Loss 
    plot_mean_ci(
        x=x, runs_2d=loss_runs,
        ylabel="Running Avg. Loss",
        out_path=out_dir / f"{args.prefix}_avg_loss_ci.png",
        dpi=args.dpi,
    )

    # verage Cost 
    plot_mean_ci(
        x=x, runs_2d=cost_runs,
        ylabel="Running Avg. Cost",
        out_path=out_dir / f"{args.prefix}_avg_cost_ci.png",
        dpi=args.dpi,
    )

    # Loss Gap 
    plot_gap_ci_loss(
        x=x, K=K, runs_2d=loss_runs, baseline_scalar=loss_opt,
        ylabel="Regret",
        out_path=out_dir / f"{args.prefix}_loss_gap_ci.png",
        dpi=args.dpi,
    )

    # Cost Gap 
    plot_gap_ci_cost(
        x=x, K=K, runs_2d=cost_runs, baseline_scalar=b,
        ylabel="Constraint Violation",
        out_path=out_dir / f"{args.prefix}_cost_gap_ci.png",
        dpi=args.dpi,
    )

    print(f"[Done] Saved figures under: {out_dir.resolve()}")


if __name__ == "__main__":
    main()
