from pathlib import Path
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.ticker as mtick
import numpy as np

fmt = mtick.ScalarFormatter(useMathText=True)
fmt.set_scientific(True)
fmt.set_powerlimits((0, 0))
fmt.set_useOffset(False)

def sci_no_zpad(decimals=0, show_plus=True):
    sign = '+' if show_plus else ''
    def _fmt(x, pos=None):
        if not np.isfinite(x):
            return ''
        if x == 0:
            return '0'
        m, e = f"{x:.{decimals}e}".split('e')
        exp = int(e)
        return f"{m}e{exp:{sign}d}"
    return mtick.FuncFormatter(_fmt)

def fmt_x_string(x, digits=0, *, style="thousands"):
    print(x)
    if not np.isfinite(x):
        return ''
    if style == "si":
        for base, suf in ((1e9,'G'), (1e6,'M'), (1e3,'k')):
            if abs(x) >= base:
                v = round(x/base, digits)
                s = f"{v:.{digits}f}".rstrip('0').rstrip('.')
                return rf'$\mathrm{{{s}{suf}}}$'
        s = f"{round(x, digits):.{digits}f}".rstrip('0').rstrip('.')
        return rf'$\mathrm{{{s}}}$'
    s = f"{int(round(x)):,}".replace(",", r"\,")
    return str(s)

DATA_DIR   = Path("results_vllm_3008_rtx")
DATA_DIR_DNA = Path("results_dna")
DATA_DIR_PTB = Path("results_vllm_ptb_3008_rtx")


FIG_DIR    = Path("figures")
FIG_DIR.mkdir(exist_ok=True)

SETUPS = {
    "realpha": {
        r"$p_{\mathsf{gpt2}} \circ\textcolor[rgb]{0.384,0.451,0.074}{\texttt{f}}_{\alpha}$" : DATA_DIR / "gpt2_large_hf_realpha_.csv",
        r"$p_{\mathsf{llama1B}}\circ\textcolor[rgb]{0.384,0.451,0.074}{\texttt{f}}_{\alpha}$"  : DATA_DIR / "meta_llama_Llama_3.2_1B_hf_realpha_.csv",
        r"$p_{\mathsf{llama8B}} \circ\textcolor[rgb]{0.384,0.451,0.074}{\texttt{f}}_{\alpha}$"  : DATA_DIR / "meta_llama_Llama_3.1_8B_hf_realpha_.csv",
    },
    "ptb": {
        r"$p_{\mathsf{gpt2}} \circ\textcolor[rgb]{0.384,0.451,0.074}{\texttt{f}}_{\alpha}\circ\textcolor[rgb]{0.384,0.451,0.074}{\texttt{f}}_{\mathrm{ptb}}$" : DATA_DIR_PTB / "gpt2_large_ptb_.csv",
        r"$p_{\mathsf{llama1B}}\circ\textcolor[rgb]{0.384,0.451,0.074}{\texttt{f}}_{\alpha}\circ\textcolor[rgb]{0.384,0.451,0.074}{\texttt{f}}_{\mathrm{ptb}}$"  : DATA_DIR_PTB / "meta_llama_Llama_3.2_1B_ptb_.csv",
        r"$p_{\mathsf{llama8B}} \circ\textcolor[rgb]{0.384,0.451,0.074}{\texttt{f}}_{\alpha}\circ\textcolor[rgb]{0.384,0.451,0.074}{\texttt{f}}_{\mathrm{ptb}}$"  : DATA_DIR_PTB / "meta_llama_Llama_3.1_8B_ptb_.csv",
    },
    "dna": {
        r"5000 $p_{\mathsf{dna}}\circ\textcolor[rgb]{0.384,0.451,0.074}{\texttt{f}}_{\mathrm{dna}}$" : DATA_DIR_DNA / "gpt2_dna_hf_dna2aa__05000_rtx.csv",
        r"10000 $p_{\mathsf{dna}}\circ\textcolor[rgb]{0.384,0.451,0.074}{\texttt{f}}_{\mathrm{dna}}$" : DATA_DIR_DNA / "gpt2_dna_hf_dna2aa__10000_rtx.csv",
        r"15000 $p_{\mathsf{dna}}\circ\textcolor[rgb]{0.384,0.451,0.074}{\texttt{f}}_{\mathrm{dna}}$" : DATA_DIR_DNA / "gpt2_dna_hf_dna2aa__15000_rtx.csv",
        r"20000 $p_{\mathsf{dna}}\circ\textcolor[rgb]{0.384,0.451,0.074}{\texttt{f}}_{\mathrm{dna}}$" : DATA_DIR_DNA / "gpt2_dna_hf_dna2aa__20000_rtx.csv",
    },
}

PALETTE = plt.get_cmap("Set2").colors
BASE_FONTSIZE = 24
MARKER_KW = dict(marker="o", markersize=8, lw=3)

def fmt_e(x, digits=1, *, show_plus=False, trim_mantissa=True):
    if pd.isna(x):
        return x
    x = float(x)
    if x == 0.0:
        return f"{0:.{digits}f}" if digits else "0"
    m, e = f"{x:.{digits}e}".split('e')
    if trim_mantissa:
        m = m.rstrip('0').rstrip('.')
    exp = int(e)
    sign = '+' if (show_plus and exp >= 0) else ''
    return f"{m}e{sign}{exp}"

def load_setup(name: str, file_map: dict[str, Path]) -> pd.DataFrame:
    dfs = []
    for model, path in file_map.items():
        print("Model here:", model)
        d = pd.read_csv(path).assign(model=model)
        dfs.append(d)

    df = pd.concat(dfs, ignore_index=True).query("K < 0.003")
    df["mean_metric"] = pd.to_numeric(df["mean_metric"], errors="coerce")
    df["K"] = pd.to_numeric(df["K"], errors="coerce")
    df["K_label"] = df["K"].map(lambda v: fmt_e(v, digits=0))
    print("DF:\n", df)
    return df


def mpl_style():
    return plt.rc_context({
        #"mathtext.fontset": "cm",
        "text.usetex": True,
        "font.family": "serif",
        "font.size":        BASE_FONTSIZE,
        "axes.spines.right": False,
        "axes.spines.top"  : False,
        "axes.grid"        : True,
        "xtick.labelsize"  : BASE_FONTSIZE,
        "ytick.labelsize"  : BASE_FONTSIZE,
        "grid.alpha"       : 0.25,
        "axes.prop_cycle"  : plt.cycler(color=PALETTE),
        "figure.figsize"   : (14, 3),
        "pdf.fonttype"     : 42,
        "ps.fonttype"      : 42,
        "pgf.texsystem": "pdflatex", 
        "text.latex.preamble": r"""
        \usepackage{amsmath,amssymb,xcolor}
        \usepackage{inconsolata}
        \renewcommand{\rmdefault}{ptm} 
        %\renewcommand{\sfdefault}{phv}
        %\renewcommand{\ttdefault}{pcr}
        """,
        "pgf.rcfonts":  False

    })


def plot_speed_vs_jsd(df: pd.DataFrame, dataset: str) -> None:
    models = sorted(df["model"].unique())
    n_models = len(models)

    with mpl_style():
        print("num models: ", n_models)
        fig, axes = plt.subplots(
            1, n_models, sharey=True,
            constrained_layout=True,
        )
        if n_models == 1:
            axes = [axes]

        for ax, model in zip(axes, models):
            d = df[df["model"] == model].sort_values("chars_per_sec")
            ax.plot(d["chars_per_sec"], d["mean_metric"], **MARKER_KW, zorder=3)

            ax.errorbar(
                d["chars_per_sec"], d["mean_metric"],
                yerr=[d["mean_metric"] - d["metric_ci_lower"],
                      d["metric_ci_upper"] - d["mean_metric"]],
                fmt="none", capsize=5, lw=3, zorder=2
            )

            for i, (xi, yi, lab) in enumerate(zip(d["chars_per_sec"], d["mean_metric"], d["K_label"])):
                ha = "center" if "DNA GPT" in model else "left"
                x = -30 if lab=="1e-4" else -12
                y = 0 if lab=="1e-4" else 14
                rotation = 65
                xytext = (x, y)
                fontsize = BASE_FONTSIZE-8
                ax.annotate(
                    rf'$\tau\mkern-3mu=$'+lab, (xi, yi), xytext=xytext,
                    textcoords="offset points", rotation=rotation,
                    fontsize=fontsize, color="dimgray",
                    ha=ha, va="bottom"
                )

            ax.set_title(model, fontweight="bold", pad=2, fontsize=BASE_FONTSIZE)
            ax.xaxis.set_major_locator(mtick.MaxNLocator(4))
            ax.yaxis.set_major_locator(mtick.MaxNLocator(5))
            ax.yaxis.set_major_formatter(sci_no_zpad(decimals=0, show_plus=True))
            ax.xaxis.set_major_formatter(
                mtick.FuncFormatter(lambda x, pos: fmt_x_string(x, digits=2, style="thousands"))
            )
            
            if "DNA GPT" in model:
                ax.set_ylim(0.0, 0.30)

        axes[0].set_ylabel("JSD")
        fig.supxlabel("Speed (bytes per second)", fontsize=BASE_FONTSIZE)
        out_file = FIG_DIR / f"speed_vs_jsd_{dataset}_no_legend.pdf"
        fig.savefig(out_file, bbox_inches="tight")
        print(f"✓ Saved → {out_file}")


for setup, files in SETUPS.items():
    df_setup = load_setup(setup, files)
    plot_speed_vs_jsd(df_setup, setup)
