import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl

EPOCH_COL = "Train/Epoch"
RET_COL   = "Metrics/EpRet"
COST_COL  = "Metrics/EpCost"

LAMBDA_COL_CANDIDATES = [
    "Loss/lambda", "Metrics/Lambda", "Train/Lambda", "Loss/Lambda",
    "Misc/LagrangeMultiplier", "Metrics/LagrangeMultiplier"
]

HIGHLIGHT = ["CSPO", "APPO"]  # highlight these curves

def pick_lambda_col(df):
    for c in LAMBDA_COL_CANDIDATES:
        if c in df.columns:
            return c
    return None

def downsample(x, y, k=1):
    return x[::k], y[::k]

def plot_twocol_oscillations_twoenv_raw(
    dfs_top,
    dfs_bottom,
    cost_limit_top=25,
    cost_limit_bottom=25,
    downsample_k=1,
    savepath_pdf=None,
    savepath_eps=None,
):

    mpl.rcParams.update({
        "font.size": 8,
        "axes.titlesize": 7,
        "axes.labelsize": 6,
        "xtick.labelsize": 6,
        "ytick.labelsize": 6,
        "lines.linewidth": 1.0,
    })
    plt.rcParams.update({
        "font.family": "serif",
        "mathtext.fontset": "cm",
    })


    fig, axes = plt.subplots(
        2, 3,
        figsize=(7.2, 2.0),
        sharex=False,
        gridspec_kw={"wspace": 0.28, "hspace": 0.22},
    )


    col_titles = ["Return", "Constraint", "Multiplier"]
    for j, t in enumerate(col_titles):
        axes[0, j].set_title(t, pad=2)


    def plot_env_row(row_axes, dfs_env, cost_limit):
        ax_ret, ax_cost, ax_lam = row_axes

        for name, df in dfs_env.items():
            x = df[EPOCH_COL].values
            r = df[RET_COL].values
            c = df[COST_COL].values

            lam_col = pick_lambda_col(df)
            lam = df[lam_col].values if lam_col is not None else None

            if downsample_k > 1:
                x_r, r = downsample(x, r, downsample_k)
                x_c, c = downsample(x, c, downsample_k)
                if lam is not None:
                    x_l, lam = downsample(x, lam, downsample_k)
                else:
                    x_l = None
            else:
                x_r, x_c = x, x
                x_l = x


            line = ax_ret.plot(x_r, r, label=name)[0]
            color = line.get_color()

            # highlight logic
            if name in HIGHLIGHT:
                lw = 1.2
                alpha = 1.0
                z = 1
            else:
                lw = 1.0
                alpha = 0.60
                z = 1


            ax_ret.lines[-1].set_linewidth(lw)
            ax_ret.lines[-1].set_alpha(alpha)
            ax_ret.lines[-1].set_zorder(z)

            ax_cost.plot(x_c, c, color=color, lw=lw, alpha=alpha, zorder=z)
            if lam is not None:
                ax_lam.plot(x_l, lam, color=color)

        ax_cost.axhline(cost_limit, ls="--", lw=1.0, alpha=0.9)

        # labels
        ax_ret.set_ylabel("Return")
        ax_cost.set_ylabel("Cost")
        ax_lam.set_ylabel(r"$\lambda$")

        ax_lam.set_xlabel("Epoch")
        ax_cost.set_xlabel("Epoch")
        ax_ret.set_xlabel("Epoch")

        # grid/spines
        for ax in (ax_ret, ax_cost, ax_lam):
            ax.grid(True, lw=0.4, alpha=0.30)
            ax.tick_params(length=2, pad=1)
            for spine in ax.spines.values():
                spine.set_linewidth(0.7)

    # plot both rows
    plot_env_row(axes[0, :], dfs_top,    cost_limit_top)
    plot_env_row(axes[1, :], dfs_bottom, cost_limit_bottom)

    for ax in axes[0, :]:
        ax.set_xlabel("")
        ax.tick_params(labelbottom=False)

    # row labels (a)/(b) on the left of first panel
    # axes[0, 0].text(
    #     -0.20, 1.08, "(a) PointGoal",
    #     transform=axes[0, 0].transAxes,
    #     fontsize=9, fontweight="bold", va="bottom"
    # )
    # axes[1, 0].text(
    #     -0.20, 1.08, "(b) Ant",
    #     transform=axes[1, 0].transAxes,
    #     fontsize=9, fontweight="bold", va="bottom"
    # )


    handles, labels = axes[0, 0].get_legend_handles_labels()
    fig.legend(
        handles, labels,
        loc="upper center",
        ncol=len(labels),
        frameon=False,
        fontsize=6,
        handlelength=1.6,
        columnspacing=1.0,
        bbox_to_anchor=(0.5, 1.02),
    )


    fig.tight_layout(rect=[0, 0, 1, 0.95])

    if savepath_pdf is not None:
        fig.savefig(savepath_pdf, bbox_inches="tight")
    if savepath_eps is not None:
        fig.savefig(savepath_eps, bbox_inches="tight")

    return fig, axes




dfs_pointgoal = {
    "PPO-Lag":  pd.read_csv("./GSPO-ablation/oscillations/PPOLag.csv"),
    "CUP":      pd.read_csv("./GSPO-ablation/oscillations/CUP.csv"),
    "FOCOPS":   pd.read_csv("./GSPO-ablation/oscillations/FOCOPS.csv"),
    "APPO":     pd.read_csv("./GSPO-ablation/oscillations/APPO.csv"),
    "CSPO":     pd.read_csv("./GSPO-ablation/oscillations/CSPO.csv"),
    "CPPOPID":  pd.read_csv("./GSPO-ablation/oscillations/CPPOPID.csv"),
}

dfs_ant = {
    "PPO-Lag":  pd.read_csv("./GSPO-ablation/oscillations/PPOLag_Ant.csv"),
    "CUP":      pd.read_csv("./GSPO-ablation/oscillations/CUP_Ant.csv"),
    "FOCOPS":   pd.read_csv("./GSPO-ablation/oscillations/FOCOPS_Ant.csv"),
    "APPO":     pd.read_csv("./GSPO-ablation/oscillations/APPO_Ant.csv"),
    "CSPO":     pd.read_csv("./GSPO-ablation/oscillations/CSPO_Ant.csv"),
    "CPPOPID":  pd.read_csv("./GSPO-ablation/oscillations/CPPOPID_Ant.csv"),
}

plot_twocol_oscillations_twoenv_raw(
    dfs_top=dfs_pointgoal,
    dfs_bottom=dfs_ant,
    cost_limit_top=25,
    cost_limit_bottom=25,
    downsample_k=1,
    savepath_pdf="oscillations_2x3_pointgoal_ant.pdf",
    savepath_eps="oscillations_2x3_pointgoal_ant.eps",
)

plt.show()
