
import argparse
from pathlib import Path

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib as mpl
import seaborn as sns
from matplotlib.colors import TwoSlopeNorm
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
from matplotlib.gridspec import GridSpecFromSubplotSpec



EXPECTED_COLS = [
    "algorithm",
    "environment",
    "mean_TTS",
    "std_TTS",
    "mean_RP",
    "std_RP",
    "mean_num_violations",
    "std_num_violations",
]


def make_marker_map(algorithms):

    markers = ["o", "s", "^", "D", "v", "P", "X", "*", "<", ">", "h", "8"]
    return {a: markers[i % len(markers)] for i, a in enumerate(algorithms)}


def _safe_float(x, default=np.nan):
    try:
        if pd.isna(x):
            return default
        return float(x)
    except Exception:
        return default


def plot_icml_2x5(
    merged_csv: Path,
    out_path: Path,
    locomotion_envs=("Ant", "Humanoid", "Hopper", "HalfCheetah", "Swimmer"),
    navigation_envs=("PointGoal", "PointButton", "CarGoal", "CarButton"),
    algo_order=None,
    show_errorbars=True,
    clip_quantiles=(0.05, 0.95),
    manual_drppct=None,
    figsize=(7.0, 2.2),
    dpi=300,
):
    df = pd.read_csv(merged_csv)

    missing = [c for c in EXPECTED_COLS if c not in df.columns]
    if missing:
        raise ValueError(f"Missing required columns in merged CSV: {missing}")


    sns.set_theme(style="whitegrid", context="paper")
    mpl.rcParams.update({
        "font.size": 7,
        "axes.titlesize": 5,
        "axes.labelsize": 6,
        "xtick.labelsize": 4,
        "ytick.labelsize": 4,
        "legend.fontsize": 5,
        "axes.linewidth": 0.8,
        "grid.linewidth": 0.6,
    })
    plt.rcParams.update({
        "font.family": "serif",
        "mathtext.fontset": "cm",
    })

    rp_min, rp_max = 0.80, 1.00

    norm = mpl.colors.Normalize(vmin=rp_min, vmax=rp_max)
    cmap = plt.get_cmap("RdYlGn")
    sm = mpl.cm.ScalarMappable(norm=norm, cmap=cmap)
    sm.set_array([])

    # Algorithm order
    algorithms = sorted(df["algorithm"].unique().tolist())
    if algo_order is not None:

        algorithms = [a for a in algo_order if a in algorithms] + [a for a in algorithms if a not in algo_order]

    marker_map = make_marker_map(algorithms)


    fig, axes = plt.subplots(2, 5, figsize=figsize, dpi=dpi)


    fig.subplots_adjust(left=0.06, right=0.98, bottom=0.12, top=0.92, wspace=0.35, hspace=0.45)


    fig.patch.set_facecolor("white")

    def stylize_ax(ax):
        ax.set_facecolor("#fbfbfb")
        ax.grid(True, alpha=0.35)
        ax.tick_params(length=2)
        for spine in ax.spines.values():
            spine.set_alpha(0.9)

    def draw_panel(ax, env):
        stylize_ax(ax)
        sub = df[df["environment"] == env]
        ax.set_title(env, pad=2)

        if sub.empty:
            ax.text(0.5, 0.5, "missing", ha="center", va="center", transform=ax.transAxes)
            return

        for alg in algorithms:
            r = sub[sub["algorithm"] == alg]
            if r.empty:
                continue
            r = r.iloc[0]

            x = _safe_float(r["mean_num_violations"]) / 300
            y = _safe_float(r["mean_TTS"])
            c = _safe_float(r["mean_RP"])
            if np.isnan(x) or np.isnan(y) or np.isnan(c):
                continue

            if show_errorbars:
                xerr = _safe_float(r["std_num_violations"], default=0.0)
                yerr = _safe_float(r["std_TTS"], default=0.0)
                ax.errorbar(
                    x, y,
                    xerr=xerr, yerr=yerr,
                    fmt="none",
                    elinewidth=0.7,
                    capsize=1.6,
                    alpha=0.35,
                    zorder=1,
                )

            ax.scatter(
                x, y,
                s=12,  # bigger markers for readability
                marker=marker_map[alg],
                c=[c], cmap=cmap, norm=norm,
                edgecolors="black",
                linewidths=0.3,
                alpha=0.95,
                zorder=2,
            )

    # Row 1: locomotion (5)
    for j, env in enumerate(locomotion_envs):
        draw_panel(axes[0, j], env)

    # Row 2: navigation (4)
    for j, env in enumerate(navigation_envs):
        draw_panel(axes[1, j], env)


    for ax in axes.ravel():
        ax.set_xlabel("")
        ax.set_ylabel("")
        ax.set_xlim(0.0, 0.8)
        ax.set_xticks([0.0, 0.2, 0.4, 0.6, 0.8])

    # Global labels (cleaner)
    fig.supxlabel("VF (Violation Frequency)", y=0.03, fontsize=6)
    fig.supylabel("TTS (Mean Time To Safety)", x=0.01, fontsize=6)

    # Legend + colorbar in last cell
    ax_legcell = axes[1, 4]
    ax_legcell.set_axis_off()

    sub = GridSpecFromSubplotSpec(
        2, 1,
        subplot_spec=axes[1, 4].get_subplotspec(),
        height_ratios=[5.5, 0.45],  # legend tall, colorbar thin
        hspace=0.35,
    )

    ax_legend = fig.add_subplot(sub[0])
    ax_cbar = fig.add_subplot(sub[1])

    ax_legend.axis("off")
    ax_cbar.set_facecolor("white")

    handles = [
        mpl.lines.Line2D(
            [0], [0],
            marker=marker_map[a],
            linestyle="",
            markerfacecolor="white",
            markeredgecolor="black",
            markeredgewidth=0.4,
            markersize=3,
            label=a,
        )
        for a in algorithms
    ]

    ax_legend.legend(
        handles=handles,
        loc="upper left",
        frameon=True,
        framealpha=0.95,
        borderpad=0.5,
        handletextpad=0.6,
        labelspacing=0.5,
    )

    cbar = fig.colorbar(sm, cax=ax_cbar, orientation="horizontal")
    cbar.set_label("Reward Preservation (RP)", fontsize=5, labelpad=2)
    cbar.set_ticks([0.80, 0.90, 1.00])
    cbar.ax.tick_params(labelsize=4)


    out_path.parent.mkdir(parents=True, exist_ok=True)
    fig.savefig(out_path, bbox_inches="tight", pad_inches=0.02)
    plt.close(fig)
    print(f"Saved: {out_path}")


def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--merged_csv", type=str, required=True)
    ap.add_argument("--out", type=str, required=True)
    ap.add_argument("--no_errorbars", action="store_true")
    ap.add_argument("--dpi", type=int, default=300)
    ap.add_argument("--fig_w", type=float, default=7.0)
    ap.add_argument("--fig_h", type=float, default=3.2)
    ap.add_argument("--q_lo", type=float, default=0.83)
    ap.add_argument("--q_hi", type=float, default=1.0)
    ap.add_argument(
        "--manual_drppct",
        type=float,
        nargs=2,
        default=None,
        help="Manually set symmetric color range in ΔRP%%: provide vmin vmax (will be symmetrized).",
    )
    args = ap.parse_args()

    plot_icml_2x5(
        merged_csv=Path(args.merged_csv),
        out_path=Path(args.out),
        show_errorbars=not args.no_errorbars,
        clip_quantiles=(args.q_lo, args.q_hi),
        manual_drppct=tuple(args.manual_drppct) if args.manual_drppct is not None else None,
        figsize=(args.fig_w, args.fig_h),
        dpi=args.dpi,
        algo_order=["APPO", "CPPOPID", "CUP", "FOCOPS", "CSPO", "PPO-Lag"],
    )


if __name__ == "__main__":
    main()
