from pathlib import Path
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.ticker as mtick
import numpy as np
import matplotlib as mpl

fmt = mtick.ScalarFormatter(useMathText=True)
fmt.set_scientific(True)
fmt.set_powerlimits((0, 0))
fmt.set_useOffset(False)

def sci_no_zpad(decimals=0, show_plus=True):
    sign = '+' if show_plus else ''
    def _fmt(x, pos=None):
        if not np.isfinite(x):
            return ''
        if x == 0:
            return '0'
        m, e = f"{x:.{decimals}e}".split('e')
        exp = int(e)
        return f"{m}e{exp:{sign}d}"
    return mtick.FuncFormatter(_fmt)

DATA_DIR   = Path("results_vllm_3008_rtx")
DATA_DIR_DNA = Path("results_dna")
DATA_DIR_PTB = Path("results_vllm_ptb_3008_rtx")

TAU_MARKERS = ['*','o','s','D','^','v','P','X','h','>','<','d','p']

FIG_DIR    = Path("figures")
FIG_DIR.mkdir(exist_ok=True)

SETUPS = {
    "realpha": {
        r"$p_{\mathsf{gpt2}} \circ\textcolor[rgb]{0.384,0.451,0.074}{\texttt{f}}_{\alpha}$" : DATA_DIR / "gpt2_large_hf_realpha_.csv",
        r"$p_{\mathsf{lla1B}}\circ\textcolor[rgb]{0.384,0.451,0.074}{\texttt{f}}_{\alpha}$"  : DATA_DIR / "meta_llama_Llama_3.2_1B_hf_realpha_.csv",
        r"$p_{\mathsf{lla8B}} \circ\textcolor[rgb]{0.384,0.451,0.074}{\texttt{f}}_{\alpha}$"  : DATA_DIR / "meta_llama_Llama_3.1_8B_hf_realpha_.csv",
    },
    "ptb": {
        r"$p_{\mathsf{gpt2}} \circ\textcolor[rgb]{0.384,0.451,0.074}{\texttt{f}}_{\alpha}\circ\textcolor[rgb]{0.384,0.451,0.074}{\texttt{f}}_{\mathrm{ptb}}$" : DATA_DIR_PTB / "gpt2_large_ptb_.csv",
        r"$p_{\mathsf{lla1B}}\circ\textcolor[rgb]{0.384,0.451,0.074}{\texttt{f}}_{\alpha}\circ\textcolor[rgb]{0.384,0.451,0.074}{\texttt{f}}_{\mathrm{ptb}}$"  : DATA_DIR_PTB / "meta_llama_Llama_3.2_1B_ptb_.csv",
        r"$p_{\mathsf{lla8B}} \circ\textcolor[rgb]{0.384,0.451,0.074}{\texttt{f}}_{\alpha}\circ\textcolor[rgb]{0.384,0.451,0.074}{\texttt{f}}_{\mathrm{ptb}}$"  : DATA_DIR_PTB / "meta_llama_Llama_3.1_8B_ptb_.csv",
    },
    "dna": {
        r"$p_{\mathsf{dna}}\circ\textcolor[rgb]{0.384,0.451,0.074}{\texttt{f}}_{\mathrm{dna2aa}}$ (max candidates=5000)" : DATA_DIR_DNA / "gpt2_dna_hf_dna2aa__05000_rtx.csv",
        r"$p_{\mathsf{dna}}\circ\textcolor[rgb]{0.384,0.451,0.074}{\texttt{f}}_{\mathrm{dna2aa}}$ (max candidates=10000)" : DATA_DIR_DNA / "gpt2_dna_hf_dna2aa__10000_rtx.csv",
        r"$p_{\mathsf{dna}}\circ\textcolor[rgb]{0.384,0.451,0.074}{\texttt{f}}_{\mathrm{dna2aa}}$ (max candidates=15000)" : DATA_DIR_DNA / "gpt2_dna_hf_dna2aa__15000_rtx.csv",
        r"$p_{\mathsf{dna}}\circ\textcolor[rgb]{0.384,0.451,0.074}{\texttt{f}}_{\mathrm{dna2aa}}$ (max candidates=20000)" : DATA_DIR_DNA / "gpt2_dna_hf_dna2aa__20000_rtx.csv",
    },
}

PALETTE = plt.get_cmap("Set2").colors
BASE_FONTSIZE = 30
MARKER_KW = dict(marker="o", markersize=8, lw=3)

def fmt_e(x, digits=1, *, show_plus=False, trim_mantissa=True):
    if pd.isna(x):
        return x
    x = float(x)
    if x == 0.0:
        return f"{0:.{digits}f}" if digits else "0"
    m, e = f"{x:.{digits}e}".split('e')
    if trim_mantissa:
        m = m.rstrip('0').rstrip('.')
    exp = int(e)
    sign = '+' if (show_plus and exp >= 0) else ''
    return f"{m}e{sign}{exp}"

def fmt_x_string(x, digits=0, *, style="thousands"):
    print(x)
    if not np.isfinite(x):
        return ''
    if style == "si":
        for base, suf in ((1e9,'G'), (1e6,'M'), (1e3,'k')):
            if abs(x) >= base:
                v = round(x/base, digits)
                s = f"{v:.{digits}f}".rstrip('0').rstrip('.')
                return rf'$\mathrm{{{s}{suf}}}$'
        s = f"{round(x, digits):.{digits}f}".rstrip('0').rstrip('.')
        return rf'$\mathrm{{{s}}}$'
    s = f"{int(round(x)):,}".replace(",", r"\,")
    return str(s)

def build_tau_color_map(df, cmap_name="viridis"):
    taus = np.asarray(sorted(df["K"].dropna().unique()))
    n = len(taus)
    cmap = plt.get_cmap(cmap_name, n)
    colors = cmap(np.linspace(0, 1, n))
    if n == 1:
        edges = np.array([taus[0]/10, taus[0]*10])
    else:
        logt = np.log10(taus)
        mids_log = (logt[:-1] + logt[1:]) / 2
        first_edge = 2*logt[0] - mids_log[0]
        last_edge  = 2*logt[-1] - mids_log[-1]
        edges_log = np.r_[first_edge, mids_log, last_edge]
        edges = 10**edges_log
    norm = mpl.colors.BoundaryNorm(edges, n)

    k2color = {k: colors[i] for i, k in enumerate(taus)}
    k2label = {fmt_e(k, digits=0): colors[i]  for i, k in enumerate(taus)}
    return k2color, k2label, cmap, norm

def build_tau_marker_map(df):
    taus = sorted(df["K"].dropna().unique())
    labs = [fmt_e(t, digits=0) for t in taus]
    return {lab: TAU_MARKERS[i % len(TAU_MARKERS)] for i, lab in enumerate(labs)}


def load_setup(name: str, file_map: dict[str, Path]) -> pd.DataFrame:
    dfs = []
    for model, path in file_map.items():
        print("Model here:", model)
        d = pd.read_csv(path).assign(model=model)
        dfs.append(d)

    df = pd.concat(dfs, ignore_index=True).query("K < 0.0003")
    df["mean_metric"] = pd.to_numeric(df["mean_metric"], errors="coerce")
    df["K"] = pd.to_numeric(df["K"], errors="coerce")
    df["K_label"] = df["K"].map(lambda v: fmt_e(v, digits=0))
    print("DF:\n", df)
    return df


def mpl_style():
    return plt.rc_context({
        #"mathtext.fontset": "cm",
        "text.usetex": True,
        "font.family": "serif",
        "font.size":        BASE_FONTSIZE,
        "axes.spines.right": False,
        "axes.spines.top"  : False,
        "axes.grid"        : True,
        "xtick.labelsize"  : BASE_FONTSIZE,
        "ytick.labelsize"  : BASE_FONTSIZE,
        "grid.alpha"       : 0.25,
        "axes.prop_cycle"  : plt.cycler(color=PALETTE),
        "figure.figsize"   : (18, 3),
        "pdf.fonttype"     : 42,
        "ps.fonttype"      : 42,
        "pgf.texsystem": "pdflatex", 
        "text.latex.preamble": r"""
        \usepackage{amsmath,amssymb,xcolor}
        \usepackage{inconsolata}
        \renewcommand{\rmdefault}{ptm} 
        %\renewcommand{\sfdefault}{phv}
        %\renewcommand{\ttdefault}{pcr}
        """,
        "pgf.rcfonts":  False

    })


def _plot_panel_on_axes(df: pd.DataFrame, axes_list, *, klabel2color, tau2marker, title_prefix: str):
    """Plot one dataset across models on the provided single-row axes_list."""
    models = sorted(df["model"].unique())
    legend_seen = set()

    for ax, model in zip(axes_list, models):
        d = df[df["model"] == model].sort_values("chars_per_sec")

        # backbone + error bars
        ax.plot(d["chars_per_sec"], d["mean_metric"],
                color='0.65', lw=2.5, label='_nolegend_', zorder=1)
        ax.errorbar(
            d["chars_per_sec"], d["mean_metric"],
            yerr=[d["mean_metric"] - d["metric_ci_lower"],
                  d["metric_ci_upper"] - d["mean_metric"]],
            fmt="none", ecolor='0.4', elinewidth=3, capsize=4, zorder=0
        )

        # points per tau with consistent markers/colors
        for xi, yi, lab in zip(d["chars_per_sec"], d["mean_metric"], d["K_label"]):
            if np.isnan(xi) or np.isnan(yi):
                continue
            m = tau2marker[lab]
            ax.scatter([xi],[yi], s=80, marker=m,
                       facecolors=klabel2color[lab], edgecolors='black',
                       linewidths=1.8, zorder=3,
                       label=(rf'$\tau\mkern-3mu=$'+lab if lab not in legend_seen else '_nolegend_'))
            legend_seen.add(lab)

        # per-panel formatting
        ax.set_title(f"{title_prefix} · {model}", fontweight="bold", pad=2, fontsize=BASE_FONTSIZE-6)
        ax.xaxis.set_major_locator(mtick.MaxNLocator(4))
        ax.yaxis.set_major_locator(mtick.MaxNLocator(5))
        ax.yaxis.set_major_formatter(mtick.FuncFormatter(lambda x, pos: fmt_e(x, digits=1)))
        ax.xaxis.set_major_formatter(
            mtick.FuncFormatter(lambda x, pos: fmt_x_string(x, digits=2, style="thousands"))
        )


def _build_union_style_maps(df_list):
    """Unify tau colors/labels/markers across multiple dataframes."""
    df_all = pd.concat(df_list, ignore_index=True)
    df_all["K"] = pd.to_numeric(df_all["K"], errors="coerce")
    df_all["K_label"] = df_all["K"].map(lambda v: fmt_e(v, digits=0))

    taus = np.asarray(sorted(df_all["K"].dropna().unique()))
    n = len(taus)
    cmap = plt.get_cmap("viridis", n)
    colors = cmap(np.linspace(0, 1, n))

    # Consistent color per tau label and marker per tau label
    klabel2color = {fmt_e(k, digits=0): colors[i] for i, k in enumerate(taus)}
    tau2marker   = {fmt_e(k, digits=0): TAU_MARKERS[i % len(TAU_MARKERS)] for i, k in enumerate(taus)}
    return klabel2color, tau2marker


def _plot_row_group(df: pd.DataFrame, axes_list, *, klabel2color, tau2marker):
    """Plot one dataset across a contiguous list of axes (single row)."""
    models = sorted(df["model"].unique())
    legend_seen = set()

    for ax, model in zip(axes_list, models):
        d = df[df["model"] == model].sort_values("chars_per_sec")

        ax.plot(d["chars_per_sec"], d["mean_metric"],
                color='0.65', lw=2.5, label='_nolegend_', zorder=1)
        ax.errorbar(
            d["chars_per_sec"], d["mean_metric"],
            yerr=[d["mean_metric"] - d["metric_ci_lower"],
                  d["metric_ci_upper"] - d["mean_metric"]],
            fmt="none", ecolor='0.4', elinewidth=3, capsize=4, zorder=0
        )

        for xi, yi, lab in zip(d["chars_per_sec"], d["mean_metric"], d["K_label"]):
            if np.isnan(xi) or np.isnan(yi):
                continue
            m = tau2marker[lab]
            ax.scatter([xi],[yi], s=80, marker=m,
                       facecolors=klabel2color[lab], edgecolors='black',
                       linewidths=1.8, zorder=3,
                       label=(rf'$\tau\mkern-3mu=$'+lab if lab not in legend_seen else '_nolegend_'))
            legend_seen.add(lab)

        # Titles should already distinguish ReAlpha vs PTB in your model strings
        ax.set_title(model, fontweight="bold", pad=2,y=1.2, fontsize=BASE_FONTSIZE-2)
        ax.xaxis.set_major_locator(mtick.MaxNLocator(4))
        ax.yaxis.set_major_locator(mtick.MaxNLocator(4))
        ax.yaxis.set_major_formatter(mtick.FuncFormatter(lambda x, pos: fmt_e(x, digits=1)))
        ax.xaxis.set_major_formatter(
            mtick.FuncFormatter(lambda x, pos: fmt_x_string(x, digits=2, style="thousands"))
        )


def _build_union_style_maps(df_list):
    """Unify tau colors/markers across both datasets so τ styles match."""
    df_all = pd.concat(df_list, ignore_index=True)
    df_all["K"] = pd.to_numeric(df_all["K"], errors="coerce")
    df_all["K_label"] = df_all["K"].map(lambda v: fmt_e(v, digits=0))

    taus = np.asarray(sorted(df_all["K"].dropna().unique()))
    n = len(taus)
    cmap = plt.get_cmap("viridis", n)
    colors = cmap(np.linspace(0, 1, n))

    klabel2color = {fmt_e(k, digits=0): colors[i] for i, k in enumerate(taus)}
    tau2marker   = {fmt_e(k, digits=0): TAU_MARKERS[i % len(TAU_MARKERS)] for i, k in enumerate(taus)}
    return klabel2color, tau2marker


def plot_combined_realpha_ptb_one_row_group_sharey(setups: dict, out_name: str = "speed_vs_jsd_realpha_ptb_1row_groupsharey.pdf"):
    """One row: 3 ReAlpha subplots + 3 PTB subplots.
    Share y within each group of three, but not across groups. No extra group labels."""
    # Load and prep
    df_realpha = load_setup("realpha", setups["realpha"])
    df_ptb     = load_setup("ptb",     setups["ptb"])
    for df in (df_realpha, df_ptb):
        df["mean_metric"] = pd.to_numeric(df["mean_metric"], errors="coerce")
        df["K"] = pd.to_numeric(df["K"], errors="coerce")
        df["K_label"] = df["K"].map(lambda v: fmt_e(v, digits=0))

    # Ensure the order is consistent across both (each has 3 models in your SETUPS)
    models_realpha = sorted(df_realpha["model"].unique())
    models_ptb     = sorted(df_ptb["model"].unique())

    # Unified τ styles across both datasets
    klabel2color, tau2marker = _build_union_style_maps([df_realpha, df_ptb])

    ncols = len(models_realpha) + len(models_ptb)  # expect 6
    with mpl_style():
        fig = plt.figure(constrained_layout=True)
        gs = fig.add_gridspec(1, ncols, wspace=0)

        # Create axes so that sharey links only within each group
        # ReAlpha group
        ax_r0 = fig.add_subplot(gs[0, 0])
        ax_r1 = fig.add_subplot(gs[0, 1], sharey=ax_r0)
        ax_r2 = fig.add_subplot(gs[0, 2], sharey=ax_r0)
        axes_realpha = [ax_r0, ax_r1, ax_r2]

        # PTB group (new y-scale anchor)
        ax_p0 = fig.add_subplot(gs[0, 3])
        ax_p1 = fig.add_subplot(gs[0, 4], sharey=ax_p0)
        ax_p2 = fig.add_subplot(gs[0, 5], sharey=ax_p0)
        axes_ptb = [ax_p0, ax_p1, ax_p2]

        # Plot each group
        _plot_row_group(
            df_realpha, axes_realpha, klabel2color=klabel2color, tau2marker=tau2marker
        )
        _plot_row_group(
            df_ptb, axes_ptb, klabel2color=klabel2color, tau2marker=tau2marker
        )

        # Hide redundant y tick labels within each group to reduce clutter
        for ax in [ax_r1, ax_r2, ax_p1, ax_p2]:
            ax.tick_params(labelleft=False)

        # Shared figure labels
        fig.supylabel("JSD", fontsize=BASE_FONTSIZE)
        fig.supxlabel("Speed (bytes per second)", fontsize=BASE_FONTSIZE)

        # Single, deduped legend
        H, L = [], []
        for ax in [*axes_realpha, *axes_ptb]:
            h, l = ax.get_legend_handles_labels()
            for hh, ll in zip(h, l):
                if ll != "_nolegend_" and ll not in L:
                    H.append(hh); L.append(ll)
        if H:
            fig.legend(H, L, ncols=len(L),
                       loc='upper center', bbox_to_anchor=(0.5, 1.22),
                       frameon=False, handletextpad=0.01, labelspacing=0.01, borderaxespad=0.0, scatterpoints=1)

        out_file = FIG_DIR / out_name
        fig.savefig(out_file, bbox_inches="tight")
        print(f"✓ Saved → {out_file}")




def plot_speed_vs_jsd(df: pd.DataFrame, dataset: str) -> None:
    models = sorted(df["model"].unique())
    k2color, k2label, cmap, norm = build_tau_color_map(df, cmap_name="viridis")

    tau2marker = build_tau_marker_map(df)
    with mpl_style():
        fig, axes = plt.subplots(1, len(models), sharey=True, constrained_layout=True)
        if len(models) == 1:
            axes = [axes]

        legend_seen = set()
        handles, labels = [], []


        for ax, model in zip(axes, models):
            d = df[df["model"] == model].sort_values("chars_per_sec")
    
            ax.plot(d["chars_per_sec"], d["mean_metric"],
                    color='0.65', lw=2.5, label='_nolegend_', zorder=1)

            ax.errorbar(
                d["chars_per_sec"], d["mean_metric"],
                yerr=[d["mean_metric"] - d["metric_ci_lower"],
                      d["metric_ci_upper"] - d["mean_metric"]],
                fmt="none", ecolor='0.4', elinewidth=3, capsize=4, zorder=0
            )
            for xi, yi, lab in zip(d["chars_per_sec"], d["mean_metric"], d["K_label"]):
                m = tau2marker[lab]
                print(xi, yi, lab)
                if np.isnan(xi) or np.isnan(yi):
                    continue
                # high-contrast dots: white fill, black edge
                artist = ax.scatter([xi],[yi], s=80, marker=m,
                                    facecolors=k2label[lab] , edgecolors='black',
                                    linewidths=1.8, zorder=3,
                                    label=(rf'$\tau\mkern-3mu=$'+lab if lab not in legend_seen else '_nolegend_'))
                if lab not in legend_seen:
                    legend_seen.add(lab)
            # point_colors = [k2color[k] for k in d["K"].values]
            # ax.scatter(
            #     d["chars_per_sec"], d["mean_metric"],
            #     c=point_colors, s=110, marker="o",
            #     edgecolors="black", linewidths=1.2, zorder=3, label='_nolegend_'
            # )

            ax.set_title(model, fontweight="bold", pad=2, fontsize=BASE_FONTSIZE)
            ax.xaxis.set_major_locator(mtick.MaxNLocator(4))
            ax.yaxis.set_major_locator(mtick.MaxNLocator(5))
            ax.yaxis.set_major_formatter(mtick.FuncFormatter(lambda x, pos: fmt_e(x, digits=1)))
            ax.xaxis.set_major_formatter(
                mtick.FuncFormatter(lambda x, pos: fmt_x_string(x, digits=2, style="thousands"))
            )
            # if "DNA GPT" in model:
            #     ax.set_ylim(0.0, 0.30)

        axes[0].set_ylabel("JSD")
        fig.supxlabel("Speed (bytes per second)", fontsize=BASE_FONTSIZE)
        has_valid_err = df["metric_ci_lower"].notna() & df["metric_ci_upper"].notna()
        taus_with_err = sorted(df.loc[has_valid_err, "K"].dropna().unique())
        # handles = [
        #     Line2D([0],[0], marker='o', linestyle='None', markersize=10,
        #            markerfacecolor=k2color[t], markeredgecolor='black', linewidth=0,
        #            label=r'$\tau\mkern-3mu=$'+k2label[t])
        #     for t in taus_with_err
        # ]

        H, L = [], []
        for ax in axes:
            h, l = ax.get_legend_handles_labels()
            for hh, ll in zip(h, l):
                if ll != '_nolegend_' and ll not in L:
                    H.append(hh); L.append(ll)

        fig.legend(H, L, ncols=len(H),
                   loc='upper center', bbox_to_anchor=(0.5, 1.17),
                   frameon=False, handletextpad=0.01, labelspacing=0.01, borderaxespad=0.0, scatterpoints=1)

        # fig.legend(
        #     handles, [h.get_label() for h in handles],
        #     loc='center left', bbox_to_anchor=(0, 1.05), ncols=len(handles),
        #     frameon=False, handletextpad=0.01, labelspacing=0.01, borderaxespad=0.0
        # )

        fig.legend(
            handles, [h.get_label() for h in handles],
            loc='upper center', bbox_to_anchor=(0.5, 1.17), ncols=len(handles),
            frameon=False, handletextpad=0.01, labelspacing=0.01, borderaxespad=0.0
        )

        out_file = FIG_DIR / f"speed_vs_jsd_{dataset}_squash.pdf"
        fig.savefig(out_file, bbox_inches="tight")
        print(f"✓ Saved → {out_file}")


for setup, files in SETUPS.items():
    df_setup = load_setup(setup, files)
    #plot_speed_vs_jsd(df_setup, setup)
    plot_combined_realpha_ptb_one_row_group_sharey(SETUPS)

