import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D

SMOOTH_TAU = 30
DOWNSAMPLE = 2
COST_LIMITS = {
    "PointGoal": 25,
    "Ant": 25,   # change if needed
}

def ema(x, tau=20):
    x = np.asarray(x, dtype=float)
    y = np.empty_like(x)
    a = np.exp(-1.0 / max(tau, 1))
    y[0] = x[0]
    for t in range(1, len(x)):
        y[t] = a * y[t-1] + (1 - a) * x[t]
    return y

def downsample_xy(x, y, k=1):
    return x[::k], y[::k]

# ============
# CONFIG
# ============
LR_VALS = [0.1, 0.01, 0.05, 0.001]   # for lr ablation
METHODS = ["GSPO", "PPOLag"]
ENVS = ["PointGoal", "Ant"]          # used only for labels and cost limits

# Adapt this to your file names
def load_csv(method, env, lr):
    # Example:
    # GSPO-ablation/lambda_lr-ablation/GSPO/PointGoal-0.2.csv
    return pd.read_csv(f"GSPO-ablation/lambda-ablation/{method}/{env}-{lr}.csv")

def load_curves():
    curves = {}  # (env, method, lr) -> dict
    for env in ENVS:
        for m in METHODS:
            for lr in LR_VALS:
                df = load_csv(m, env, lr)
                curves[(env, m, lr)] = {
                    "x": df["Train/Epoch"].values,
                    "ret": df["Metrics/EpRet"].values,
                    "cost": df["Metrics/EpCost"].values,
                }
    return curves

def plot_appendix_ablation(curves, savepath="ablation_lr_appendix.pdf"):
    # Appendix style: bigger fonts + larger canvas
    plt.rcParams.update({
        "font.size": 10,
        "axes.titlesize": 12,
        "axes.labelsize": 11,
        "xtick.labelsize": 10,
        "ytick.labelsize": 10,
        "legend.fontsize": 10,
        "lines.linewidth": 1.2,
    })

    # Grid: rows = 2 (Return/Cost), cols = #envs
    fig, axes = plt.subplots(2, len(ENVS), figsize=(12, 6), sharex=False)
    if len(ENVS) == 1:
        axes = np.array([[axes[0]], [axes[1]]])  # keep 2D indexing

    # Colors = learning rates
    cmap = plt.get_cmap("tab10")
    lr_to_color = {lr: cmap(i) for i, lr in enumerate(LR_VALS)}

    # Styles = method
    style = {
        "GSPO": dict(ls="-",  lw=1.8, alpha=0.95),
        "PPOLag": dict(ls="--", lw=1.6, alpha=0.95),
    }

    for j, env in enumerate(ENVS):
        axR = axes[0, j]
        axC = axes[1, j]
        b = COST_LIMITS.get(env, 25)

        for lr in LR_VALS:
            col = lr_to_color[lr]
            for m in METHODS:
                x = curves[(env, m, lr)]["x"]
                R = ema(curves[(env, m, lr)]["ret"], SMOOTH_TAU)
                C = ema(curves[(env, m, lr)]["cost"], SMOOTH_TAU)

                xR, R = downsample_xy(x, R, DOWNSAMPLE)
                xC, C = downsample_xy(x, C, DOWNSAMPLE)

                st = style[m]
                axR.plot(xR, R, color=col, **st)
                axC.plot(xC, C, color=col, **st)

        axR.set_title(env)
        axR.set_ylabel("Return")
        axC.set_ylabel("Cost")
        axC.set_xlabel("Epoch")

        axC.axhline(b, color="k", ls="--", lw=1.6, alpha=0.7)

        axR.grid(True, alpha=0.3)
        axC.grid(True, alpha=0.3)

    # --------
    # Legends
    # --------
    # Legend A: colors = lr
    handles_lr = [
        Line2D([0], [0], color=lr_to_color[lr], lw=3.0, label=rf"$\lambda_{0}={lr}$")
        for lr in LR_VALS
    ]
    # Legend B: styles = methods
    handles_m = [
        Line2D([0], [0], color="k", lw=2.6, ls="-",  label="CSPO"),
        Line2D([0], [0], color="k", lw=2.6, ls="--", label="PPO-Lag"),
    ]

    # Place a single combined legend at the bottom center
    leg = fig.legend(
        handles=handles_lr + handles_m,
        loc="lower center",
        ncol=6,
        frameon=True,
        framealpha=0.95,
        bbox_to_anchor=(0.5, -0.02),
    )
    leg.get_frame().set_facecolor("white")
    leg.get_frame().set_edgecolor("0.85")

    fig.tight_layout(rect=[0, 0.05, 1, 1])
    fig.savefig(savepath, bbox_inches="tight")
    print(f"Saved: {savepath}")

if __name__ == "__main__":
    curves = load_curves()
    plot_appendix_ablation(curves, savepath="ppolag_vs_gspo_lambda_ablation_appendix.pdf")
