# import pandas as pd
# import matplotlib.pyplot as plt
#
# df_cpo = pd.read_csv("./results/CPO_SafetyGoal.csv")
# df_cup = pd.read_csv("./results/CUP_SafetyGoal.csv")
# df_focops = pd.read_csv("./results/FOCOPS_SafetyGoal.csv")
# df_p3o = pd.read_csv("./results/P3O_SafetyGoal.csv")
# df_pcpo = pd.read_csv("./results/PCPO_SafetyGoal.csv")
# df_pcpo_kl = pd.read_csv("./results/PCPO_KL_SafetyGoal.csv")
# df_crpo = pd.read_csv("./results/CRPO_SafetyGoal.csv")
# df_pid = pd.read_csv("./results/CPPOPID.csv")
# df_ctrpo = pd.read_csv("./results/CTRPO_SafetyGoal.csv")
# df_c3po = pd.read_csv("./results/C3PO_SafetyGoal.csv")
# # Store all labeled dataframes in a dictionary
# dataframes = {
#     "CPO": df_cpo,
#     "CUP": df_cup,
#     "FOCOPS": df_focops,
#     "P3O": df_p3o,
#     "PCPO": df_pcpo,
#     "PCPO_kl": df_pcpo_kl,
#     "CRPO": df_crpo,
#     "CPPOPID": df_pid,
#     "CTRPO": df_ctrpo,
#     "C3PO": df_c3po,
# }
#
# # Create vertically stacked plots
# fig, axs = plt.subplots(2, 1, figsize=(14, 10), sharex=True)
#
# # Plot Episode Return
# for algo, df in dataframes.items():
#     axs[0].plot(df['Train/Epoch'], df['Metrics/EpRet'], label=algo)
# axs[0].set_title('Episode Return over Epochs')
# axs[0].set_ylabel('Episode Return')
# axs[0].legend()
# axs[0].grid(True)
#
# # Plot Episode Cost
# for algo, df in dataframes.items():
#     axs[1].plot(df['Train/Epoch'], df['Metrics/EpCost'], label=algo)
# axs[1].axhline(y=25, color='red', linestyle='--', label='Cost Limit (25)')
# axs[1].set_title('Episode Cost over Epochs')
# axs[1].set_xlabel('Epoch')
# axs[1].set_ylabel('Episode Cost')
# axs[1].legend()
# axs[1].grid(True)
#
# plt.tight_layout()
# plt.show()

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

SMOOTH_TAU = 30
DOWNSAMPLE = 3
COST_LIMIT = 25

def ema(x, tau=20):
    x = np.asarray(x, dtype=float)
    y = np.empty_like(x)
    a = np.exp(-1.0 / max(tau, 1))
    y[0] = x[0]
    for t in range(1, len(x)):
        y[t] = a * y[t-1] + (1 - a) * x[t]
    return y

def downsample(x, y, k=1):
    return x[::k], y[::k]

dfs = {
    #"CPO":      pd.read_csv("./results/CPO_SafetyGoal.csv"),
    #"CUP":      pd.read_csv("./results/CUP_SafetyGoal.csv"),
    #"FOCOPS":   pd.read_csv("./results/FOCOPS_SafetyGoal.csv"),
    #"P3O":      pd.read_csv("./results/P3O_SafetyGoal.csv"),
    #"PCPO":     pd.read_csv("./results/PCPO_SafetyGoal.csv"),
    #"PCPO_kl":  pd.read_csv("./results/PCPO_KL_SafetyGoal.csv"),
    #"CRPO":     pd.read_csv("./results/CRPO_SafetyGoal.csv"),
    #"CPPOPID":  pd.read_csv("./results/CPPOPID.csv"),
    #"CTRPO":    pd.read_csv("./results/CTRPO_SafetyGoal.csv"),
    #"C3PO":     pd.read_csv("./results/C3PO_SafetyGoal.csv"),
    #"APPO":     pd.read_csv("./results/APPO_SafetyGoal.csv"),
    #"APPO-unc": pd.read_csv("./results/APPO_uncapped_SafetyGoal.csv"),
    #"APPO-cap": pd.read_csv("./results/APPO_capped_SafetyGoal.csv"),
    #"GAPPO-0.1": pd.read_csv("./results/GAPPO_scaled_SafetyGoal.csv"),
    #"GAPPO": pd.read_csv("./results/GAPPO_raw_SafetyGoal.csv"),
    #"GAPPO-0.3": pd.read_csv("./results/GAPPO-0.3-SafetyGoal.csv"),
    #"GAPPO-0.5": pd.read_csv("./results/GAPPO-0.5-SafetyGoal.csv"),
    #"APPO-geo": pd.read_csv("./results/APPO-geo-SafetyGoal.csv"),
    #"APPO-0.3": pd.read_csv("./results/APPO-0.3-SafetyGoal.csv"),
    #"GAPPO-0.2": pd.read_csv("./results/GAPPO-0.2-SafetyGoal.csv"),
    #"GAPPO-schedule": pd.read_csv("./results/GAPPO-sched-SafetyGoal.csv"),
    #"GAPPO-schedule-0.1": pd.read_csv("./results/GAPPO-sched-0.1-SafetyGoal.csv"),
    #"GAPO-0.3": pd.read_csv("./results/GAPO-0.3-SafetyGoal.csv"),
    #"GAPPO-Adam-0.3": pd.read_csv("./results/GAPPO-adam-0.3-SafetyGoal.csv"),
    #"GAPPO-Adam-0.2": pd.read_csv("./results/GAPPO-adam-0.2-SafetyGoal.csv"),
    #"GAPPO-Adam-0.1": pd.read_csv("./results/GAPPO-adam-0.1-SafetyGoal.csv"),
    #"GAPPO-adam-0.25": pd.read_csv("./results/GAPPOA-adam-0.25.csv"),
    #"GAPPO-adam-linear": pd.read_csv("./results/GAPPO-adam-linear.csv"),
    #"GAPPO-adam-linear-0.1": pd.read_csv("./results/GAPPO-adam-linear-0.1.csv"),
    #"GSPO-0.3": pd.read_csv("./results/GSPO-0.3-SafetyGoal.csv"),
    #"GSPO-0.2": pd.read_csv("./results/GSPO-0.2-SafetyGoal.csv"),
    #"GSPO-0.25": pd.read_csv("./results/GSPO-0.25-SafetyGoal.csv"),
    #"GSPOA-0.3": pd.read_csv("./results/GSPOA-0.3-SafetyGoal.csv"),
    #"GSPO-0.3-L": pd.read_csv("./results/GSPO-0.3-L.csv"),
    #"PPO-Lag": pd.read_csv("./results/PPOLag.csv.csv"),
    #"GAPPO-0.3-L": pd.read_csv("./results/GAPPO-0.3-L.csv")
    #"alpha-0.1": pd.read_csv("GSPO-ablation/alpha-ablation/PointGoal-0.1.csv"),
    "alpha-0.3": pd.read_csv("GSPO-ablation/alpha-ablation/PointGoal-0.3.csv"),
    "alpha-0.5": pd.read_csv("GSPO-ablation/alpha-ablation/PointGoal-0.5.csv"),
    "alpha-0.7": pd.read_csv("GSPO-ablation/alpha-ablation/PointGoal-0.7.csv"),
    "alpha-1.0": pd.read_csv("GSPO-ablation/alpha-ablation/PointGoal-1.0.csv"),
    #"d-10": pd.read_csv("./GSPO-ablation/cost-ablation/PointGoal-d10.csv"),
    #"d-15": pd.read_csv("./GSPO-ablation/cost-ablation/PointGoal-d15-0.5.csv"),
    #"d-25": pd.read_csv("./GSPO-ablation/cost-ablation/PointGoal-d25-0.3.csv"),
    #"d-35": pd.read_csv("./GSPO-ablation/cost-ablation/PointGoal-d35-0.02.csv")
}

import matplotlib as mpl

curves = {}
for name, df in dfs.items():
    x = df["Train/Epoch"].values
    r = df["Metrics/EpRet"].values
    c = df["Metrics/EpCost"].values
    curves[name] = {"x": x, "ret": r, "cost": c}

#print(f"curves: {curves['d-35']['x']}")

order = sorted(curves.keys(),
               key=lambda k: ema(curves[k]["ret"], SMOOTH_TAU)[-1],
               reverse=True)

def plot_onecol_threshold_ablation(
    curves_by_d,          # dict: d_label -> {"x":..., "ret":..., "cost":...}
    cost_limits_by_d,     # dict: d_label -> float
    smooth_tau,
    ema_fn,
    downsample_factor=1,
    downsample_fn=None,
    savepath_pdf=None,
):
    # ---- one-column paper styling ----
    mpl.rcParams.update({
        "font.size": 7,
        "axes.titlesize": 6,
        "axes.labelsize": 6,
        "xtick.labelsize": 7,
        "ytick.labelsize": 7,
        "lines.linewidth": 1.0,
    })

    # One-column width ~3.25in. Height tuned for 2 small panels.
    fig, (ax_ret, ax_cost) = plt.subplots(
        1, 2, figsize=(3.25, 1.25), sharex=True,
        gridspec_kw={"wspace": 0.25}
    )

    # Stable order (e.g., d-10, d-15, d-25, d-35)
    def sort_key(lbl):
        # expects labels like "d-25" or "25"
        s = str(lbl).replace("d", "").replace("=", "").replace("_", "").strip()
        try:
            return float(s.replace("-", ""))
        except:
            return 0.0

    order = sorted(curves_by_d.keys(), key=sort_key)

    for d_lbl in order:
        x = curves_by_d[d_lbl]["x"]
        r = ema_fn(curves_by_d[d_lbl]["ret"], smooth_tau)
        c = ema_fn(curves_by_d[d_lbl]["cost"], smooth_tau)

        if downsample_factor > 1:
            assert downsample_fn is not None
            x_r, r = downsample_fn(x, r, downsample_factor)
            x_c, c = downsample_fn(x, c, downsample_factor)
        else:
            x_r, x_c = x, x

        # solid curves
        ax_ret.plot(x_r, r, alpha=0.95, label=str(d_lbl))
        ax_cost.plot(x_c, c, alpha=0.95)

        # matching cost limit (same color as the curve)
        line_color = ax_cost.lines[-1].get_color()
        ax_cost.axhline(cost_limits_by_d[d_lbl], ls="--", lw=1.0, alpha=0.9, color=line_color)

    # Titles/labels like your screenshot
    ax_ret.set_title("Return", pad=2)
    ax_cost.set_title("Constraint", pad=2)

    ax_ret.set_xlabel("Epoch", labelpad=1)
    ax_cost.set_xlabel("Epoch", labelpad=1)
    ax_ret.set_ylabel("Return", labelpad=1)
    ax_cost.set_ylabel("Cost", labelpad=1)

    for ax in (ax_ret, ax_cost):
        ax.grid(True, lw=0.4, alpha=0.35)
        ax.tick_params(length=2, pad=1)
        for spine in ax.spines.values():
            spine.set_linewidth(0.6)

    # Small legend INSIDE left plot (only 4 items, clean)
    handles, labels = ax_ret.get_legend_handles_labels()
    fig.legend(
        handles, labels,
        loc="upper center",
        ncol=len(labels),
        frameon=True,
        framealpha=0.9,
        borderpad=0.25,
        handlelength=1.4,
        handletextpad=0.5,
        columnspacing=0.9,
        bbox_to_anchor=(0.5, 1.02),  # legend slightly above the figure
    )

    # Make room for legend (THIS is the important part)
    fig.subplots_adjust(top=0.78, wspace=0.25)

    if savepath_pdf is not None:
        fig.savefig(savepath_pdf, bbox_inches="tight")

    return fig, (ax_ret, ax_cost)

plt.rcParams.update({
    "font.family": "serif",
    "mathtext.fontset": "cm",
})

curves_by_d = {
    r"$\alpha = 1.0$": {"x": curves["alpha-1.0"]["x"], "ret": curves["alpha-1.0"]["ret"], "cost": curves["alpha-1.0"]["cost"]},
    r"$\alpha = 0.7$": {"x": curves["alpha-0.7"]["x"], "ret": curves["alpha-0.7"]["ret"], "cost": curves["alpha-0.7"]["cost"]},
    r"$\alpha = 0.5$": {"x": curves["alpha-0.5"]["x"], "ret": curves["alpha-0.5"]["ret"], "cost": curves["alpha-0.5"]["cost"]},
    r"$\alpha = 0.3$": {"x": curves["alpha-0.3"]["x"], "ret": curves["alpha-0.3"]["ret"], "cost": curves["alpha-0.3"]["cost"]},
}


cost_limits_by_d = {r"$\alpha = 1.0$": 25, r"$\alpha = 0.7$": 25, r"$\alpha = 0.5$": 25, r"$\alpha = 0.3$": 25}

plot_onecol_threshold_ablation(
    curves_by_d=curves_by_d,
    cost_limits_by_d=cost_limits_by_d,
    smooth_tau=SMOOTH_TAU,
    ema_fn=ema,
    downsample_factor=DOWNSAMPLE,
    downsample_fn=downsample,
    savepath_pdf="gspo_alpha_ablation_onecol2.pdf",
)
plt.show()
