from pathlib import Path
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.ticker as mtick
import numpy as np
import matplotlib as mpl

fmt = mtick.ScalarFormatter(useMathText=True)
fmt.set_scientific(True)
fmt.set_powerlimits((0, 0))
fmt.set_useOffset(False)

def sci_no_zpad(decimals=0, show_plus=True):
    sign = '+' if show_plus else ''
    def _fmt(x, pos=None):
        if not np.isfinite(x):
            return ''
        if x == 0:
            return '0'
        m, e = f"{x:.{decimals}e}".split('e')
        exp = int(e)
        return f"{m}e{exp:{sign}d}"
    return mtick.FuncFormatter(_fmt)

DATA_DIR   = Path("results_vllm_3008_rtx")
DATA_DIR_DNA = Path("results_dna")
DATA_DIR_PTB = Path("results_vllm_ptb_3008_rtx")

TAU_MARKERS = ['*','o','s','D','^','v','P','X','h','>','<','d','p']

FIG_DIR    = Path("figures")
FIG_DIR.mkdir(exist_ok=True)

SETUPS = {
    "realpha": {
        r"$p_{\mathsf{gpt2}} \circ\textcolor[rgb]{0.384,0.451,0.074}{\texttt{f}}_{\alpha}$" : DATA_DIR / "gpt2_large_hf_realpha_.csv",
        r"$p_{\mathsf{lla1B}}\circ\textcolor[rgb]{0.384,0.451,0.074}{\texttt{f}}_{\alpha}$"  : DATA_DIR / "meta_llama_Llama_3.2_1B_hf_realpha_.csv",
        r"$p_{\mathsf{lla8B}} \circ\textcolor[rgb]{0.384,0.451,0.074}{\texttt{f}}_{\alpha}$"  : DATA_DIR / "meta_llama_Llama_3.1_8B_hf_realpha_.csv",
    },
    "ptb": {
        r"$p_{\mathsf{gpt2}} \circ\textcolor[rgb]{0.384,0.451,0.074}{\texttt{f}}_{\alpha}\circ\textcolor[rgb]{0.384,0.451,0.074}{\texttt{f}}_{\mathrm{ptb}}$" : DATA_DIR_PTB / "gpt2_large_ptb_.csv",
        r"$p_{\mathsf{lla1B}}\circ\textcolor[rgb]{0.384,0.451,0.074}{\texttt{f}}_{\alpha}\circ\textcolor[rgb]{0.384,0.451,0.074}{\texttt{f}}_{\mathrm{ptb}}$"  : DATA_DIR_PTB / "meta_llama_Llama_3.2_1B_ptb_.csv",
        r"$p_{\mathsf{lla8B}} \circ\textcolor[rgb]{0.384,0.451,0.074}{\texttt{f}}_{\alpha}\circ\textcolor[rgb]{0.384,0.451,0.074}{\texttt{f}}_{\mathrm{ptb}}$"  : DATA_DIR_PTB / "meta_llama_Llama_3.1_8B_ptb_.csv",
    },
    "dna": {
        r"$p_{\mathsf{dna}}\circ\textcolor[rgb]{0.384,0.451,0.074}{\texttt{f}}_{\mathrm{dna2aa}}$ (max candidates=5000)" : DATA_DIR_DNA / "gpt2_dna_hf_dna2aa__05000_rtx.csv",
        r"$p_{\mathsf{dna}}\circ\textcolor[rgb]{0.384,0.451,0.074}{\texttt{f}}_{\mathrm{dna2aa}}$ (max candidates=10000)" : DATA_DIR_DNA / "gpt2_dna_hf_dna2aa__10000_rtx.csv",
        r"$p_{\mathsf{dna}}\circ\textcolor[rgb]{0.384,0.451,0.074}{\texttt{f}}_{\mathrm{dna2aa}}$ (max candidates=15000)" : DATA_DIR_DNA / "gpt2_dna_hf_dna2aa__15000_rtx.csv",
        r"$p_{\mathsf{dna}}\circ\textcolor[rgb]{0.384,0.451,0.074}{\texttt{f}}_{\mathrm{dna2aa}}$ (max candidates=20000)" : DATA_DIR_DNA / "gpt2_dna_hf_dna2aa__20000_rtx.csv",
    },
}

PALETTE = plt.get_cmap("Set2").colors
BASE_FONTSIZE = 24
MARKER_KW = dict(marker="o", markersize=8, lw=3)

def fmt_e(x, digits=1, *, show_plus=False, trim_mantissa=True):
    if pd.isna(x):
        return x
    x = float(x)
    if x == 0.0:
        return f"{0:.{digits}f}" if digits else "0"
    m, e = f"{x:.{digits}e}".split('e')
    if trim_mantissa:
        m = m.rstrip('0').rstrip('.')
    exp = int(e)
    sign = '+' if (show_plus and exp >= 0) else ''
    return f"{m}e{sign}{exp}"

def fmt_x_string(x, digits=0, *, style="thousands"):
    print(x)
    if not np.isfinite(x):
        return ''
    if style == "si":
        for base, suf in ((1e9,'G'), (1e6,'M'), (1e3,'k')):
            if abs(x) >= base:
                v = round(x/base, digits)
                s = f"{v:.{digits}f}".rstrip('0').rstrip('.')
                return rf'$\mathrm{{{s}{suf}}}$'
        s = f"{round(x, digits):.{digits}f}".rstrip('0').rstrip('.')
        return rf'$\mathrm{{{s}}}$'
    s = f"{int(round(x)):,}".replace(",", r"\,")
    return str(s)

def build_tau_color_map(df, cmap_name="viridis"):
    taus = np.asarray(sorted(df["K"].dropna().unique()))
    n = len(taus)
    cmap = plt.get_cmap(cmap_name, n)
    colors = cmap(np.linspace(0, 1, n))
    if n == 1:
        edges = np.array([taus[0]/10, taus[0]*10])
    else:
        logt = np.log10(taus)
        mids_log = (logt[:-1] + logt[1:]) / 2
        first_edge = 2*logt[0] - mids_log[0]
        last_edge  = 2*logt[-1] - mids_log[-1]
        edges_log = np.r_[first_edge, mids_log, last_edge]
        edges = 10**edges_log
    norm = mpl.colors.BoundaryNorm(edges, n)

    k2color = {k: colors[i] for i, k in enumerate(taus)}
    k2label = {fmt_e(k, digits=0): colors[i]  for i, k in enumerate(taus)}
    return k2color, k2label, cmap, norm

def build_tau_marker_map(df):
    taus = sorted(df["K"].dropna().unique())
    labs = [fmt_e(t, digits=0) for t in taus]
    return {lab: TAU_MARKERS[i % len(TAU_MARKERS)] for i, lab in enumerate(labs)}


def load_setup(name: str, file_map: dict[str, Path]) -> pd.DataFrame:
    dfs = []
    for model, path in file_map.items():
        print("Model here:", model)
        d = pd.read_csv(path).assign(model=model)
        dfs.append(d)

    df = pd.concat(dfs, ignore_index=True).query("K < 0.003")
    df["mean_metric"] = pd.to_numeric(df["mean_metric"], errors="coerce")
    df["K"] = pd.to_numeric(df["K"], errors="coerce")
    df["K_label"] = df["K"].map(lambda v: fmt_e(v, digits=0))
    print("DF:\n", df)
    return df


def mpl_style():
    return plt.rc_context({
        "text.usetex": True,
        "font.family": "serif",
        "font.size":        BASE_FONTSIZE,
        "axes.spines.right": False,
        "axes.spines.top"  : False,
        "axes.grid"        : True,
        "xtick.labelsize"  : BASE_FONTSIZE,
        "ytick.labelsize"  : BASE_FONTSIZE,
        "grid.alpha"       : 0.25,
        "axes.prop_cycle"  : plt.cycler(color=PALETTE),
        "figure.figsize"   : (14, 3),
        "pdf.fonttype"     : 42,
        "ps.fonttype"      : 42,
        "pgf.texsystem": "pdflatex", 
        "text.latex.preamble": r"""
        \usepackage{amsmath,amssymb,xcolor}
        \usepackage{inconsolata}
        \renewcommand{\rmdefault}{ptm} 
        %\renewcommand{\sfdefault}{phv}
        %\renewcommand{\ttdefault}{pcr}
        """,
        "pgf.rcfonts":  False

    })


def plot_speed_vs_jsd(df: pd.DataFrame, dataset: str) -> None:
    models = sorted(df["model"].unique())
    k2color, k2label, cmap, norm = build_tau_color_map(df, cmap_name="viridis")

    tau2marker = build_tau_marker_map(df)
    with mpl_style():
        fig, axes = plt.subplots(1, len(models), sharey=True, constrained_layout=True)
        if len(models) == 1:
            axes = [axes]

        legend_seen = set()
        handles, labels = [], []


        for ax, model in zip(axes, models):
            d = df[df["model"] == model].sort_values("chars_per_sec")
    
            ax.plot(d["chars_per_sec"], d["mean_metric"],
                    color='0.65', lw=2.5, label='_nolegend_', zorder=1)

            ax.errorbar(
                d["chars_per_sec"], d["mean_metric"],
                yerr=[d["mean_metric"] - d["metric_ci_lower"],
                      d["metric_ci_upper"] - d["mean_metric"]],
                fmt="none", ecolor='0.4', elinewidth=3, capsize=4, zorder=0
            )
            for xi, yi, lab in zip(d["chars_per_sec"], d["mean_metric"], d["K_label"]):
                m = tau2marker[lab]
                print(xi, yi, lab)
                if np.isnan(xi) or np.isnan(yi):
                    continue
                artist = ax.scatter([xi],[yi], s=80, marker=m,
                                    facecolors=k2label[lab] , edgecolors='black',
                                    linewidths=1.8, zorder=3,
                                    label=(rf'$\tau\mkern-3mu=$'+lab if lab not in legend_seen else '_nolegend_'))
                if lab not in legend_seen:
                    legend_seen.add(lab)

            ax.set_title(model, fontweight="bold", pad=2, fontsize=BASE_FONTSIZE)
            ax.xaxis.set_major_locator(mtick.MaxNLocator(4))
            ax.yaxis.set_major_locator(mtick.MaxNLocator(5))
            ax.yaxis.set_major_formatter(mtick.FuncFormatter(lambda x, pos: fmt_e(x, digits=1)))
            ax.xaxis.set_major_formatter(
                mtick.FuncFormatter(lambda x, pos: fmt_x_string(x, digits=2, style="thousands"))
            )

        axes[0].set_ylabel("JSD")
        fig.supxlabel("Speed (bytes per second)", fontsize=BASE_FONTSIZE)
        has_valid_err = df["metric_ci_lower"].notna() & df["metric_ci_upper"].notna()
        taus_with_err = sorted(df.loc[has_valid_err, "K"].dropna().unique())

        H, L = [], []
        for ax in axes:
            h, l = ax.get_legend_handles_labels()
            for hh, ll in zip(h, l):
                if ll != '_nolegend_' and ll not in L:
                    H.append(hh); L.append(ll)

        fig.legend(H, L, ncols=len(H),
                   loc='upper center', bbox_to_anchor=(0.5, 1.17),
                   frameon=False, handletextpad=0.01, labelspacing=0.01, borderaxespad=0.0, scatterpoints=1)

        fig.legend(
            handles, [h.get_label() for h in handles],
            loc='upper center', bbox_to_anchor=(0.5, 1.17), ncols=len(handles),
            frameon=False, handletextpad=0.01, labelspacing=0.01, borderaxespad=0.0
        )

        out_file = FIG_DIR / f"speed_vs_jsd_{dataset}.pdf"
        fig.savefig(out_file, bbox_inches="tight")
        print(f"✓ Saved → {out_file}")


for setup, files in SETUPS.items():
    df_setup = load_setup(setup, files)
    plot_speed_vs_jsd(df_setup, setup)
