import argparse
import re
from pathlib import Path

import git
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from matplotlib.lines import Line2D
import scienceplots  # noqa: F401


repo_root = Path(
    git.Repo(search_parent_directories=True).git.rev_parse("--show-toplevel")
)


# Plot style to match state_tracking.py (Times New Roman, scienceplots light)
plt.style.use(["science", "light"])  # type: ignore[attr-defined]
plt.rcParams["figure.constrained_layout.use"] = True
plt.rcParams["font.family"] = "Times New Roman"
plt.rcParams["mathtext.fontset"] = "custom"
plt.rcParams["mathtext.rm"] = "Times New Roman"
plt.rcParams["mathtext.it"] = "Times New Roman:italic"
plt.rcParams["mathtext.bf"] = "Times New Roman:bold"
plt.rcParams["font.family"] = "serif"
plt.rcParams["axes.grid"] = True
plt.rcParams["grid.alpha"] = 0.3
plt.rcParams["axes.labelsize"] = 8
plt.rcParams["axes.titlesize"] = 8
plt.rcParams["xtick.labelsize"] = 8
plt.rcParams["ytick.labelsize"] = 8
plt.rcParams["legend.fontsize"] = 7
plt.rcParams["xtick.top"] = False
plt.rcParams["ytick.right"] = False


SELECTIVE_LABEL = "Selective RoPE"
ROPE_LABEL = "RoPE"
NOPE_LABEL = "NoPE"

ARCH_MARKERS = {
    "GLA": "o",
    "Transformer": "o",
    "DeltaNet": "o",
}


def _infer_setting_from_name(run_name: str) -> str | None:
    name = run_name.lower()
    # Try to extract token for S2 (gla/transformer)
    m = re.search(r"s2-(?:gla|transformer)-lr_[^-]+-([a-z_]+)-", name)
    if m is None:
        # Try to extract token for A3 DeltaNet
        m = re.search(r"a3-delta_net-lr_[^-]+-([a-z_]+)-", name)
    if m is not None:
        token = m.group(1)
        if token == "selective_rope":
            return SELECTIVE_LABEL
        if token == "nope":
            return NOPE_LABEL
        if token == "rope":
            return ROPE_LABEL
    # Fallback heuristics
    if "selective_rope" in name:
        return SELECTIVE_LABEL
    if "-nope-" in name:
        return NOPE_LABEL
    if "-rope-" in name:
        return ROPE_LABEL
    return None


def _compute_auc_for_run(df_run: pd.DataFrame) -> float:
    # Ensure numeric and sorted by sequence_length
    x = pd.to_numeric(df_run["sequence_length"], errors="coerce").to_numpy()
    y = pd.to_numeric(df_run["accuracy"], errors="coerce").to_numpy()
    mask = ~np.isnan(x) & ~np.isnan(y)
    x = x[mask]
    y = y[mask]
    if x.size == 0:
        return -np.inf
    order = np.argsort(x)
    x = x[order]
    y = y[order]
    return float(np.trapezoid(y, x))


def _select_best_runs(df: pd.DataFrame) -> dict[str, str]:
    # Returns mapping from setting label -> best run id
    df = df.copy()
    df["setting"] = df["name"].apply(_infer_setting_from_name)
    df = df[df["setting"].isin({SELECTIVE_LABEL, ROPE_LABEL, NOPE_LABEL})]
    if df.empty:
        return {}

    best_by_setting: dict[str, tuple[str, float]] = {}
    for (setting, run_id), df_run in df.groupby(["setting", "id"]):
        score = _compute_auc_for_run(df_run)
        prev = best_by_setting.get(setting)
        if prev is None or score > prev[1]:
            best_by_setting[setting] = (run_id, score)

    return {setting: rid for setting, (rid, _score) in best_by_setting.items()}


def _load_csv_s2_gla() -> pd.DataFrame:
    csv_path = repo_root / "plotting" / "state_tracking" / "S2-GLA.csv"
    return pd.read_csv(csv_path)


def _load_csv_s2_softmax() -> pd.DataFrame:
    csv_path = repo_root / "plotting" / "state_tracking" / "S2-Softmax.csv"
    return pd.read_csv(csv_path)


def _load_csv_a3_deltanet() -> pd.DataFrame:
    csv_path = repo_root / "plotting" / "state_tracking" / "A3-DeltaNet.csv"
    return pd.read_csv(csv_path)


def _plot_dataset(
    ax: plt.Axes,
    df: pd.DataFrame,
    train_len: int,
    title: str,
    show_ylabel: bool = True,
    df_overlay: pd.DataFrame | None = None,
    main_arch: str = "GLA",
    overlay_arch: str | None = None,
) -> None:
    order = [NOPE_LABEL, ROPE_LABEL, SELECTIVE_LABEL]

    best_ids = _select_best_runs(df)
    if not best_ids:
        raise RuntimeError("Could not identify any best runs for settings")

    # Colors to match Figure 1 theme from state_tracking.py
    colors = {
        SELECTIVE_LABEL: "#E41A1C",
        ROPE_LABEL: "#FF7F00",
        NOPE_LABEL: "black",
    }

    for label in order:
        run_id = best_ids.get(label)
        if run_id is None:
            continue
        df_run = df[df["id"] == run_id]
        x = pd.to_numeric(df_run["sequence_length"], errors="coerce").to_numpy()
        y = pd.to_numeric(df_run["accuracy"], errors="coerce").to_numpy()
        mask = ~np.isnan(x) & ~np.isnan(y)
        x = x[mask]
        y = y[mask]
        if x.size == 0:
            continue
        order_idx = np.argsort(x)
        x = x[order_idx]
        y = y[order_idx]

        ax.plot(
            x,
            y,
            color=colors.get(label, "C0"),
            label=("Selective RoPE" if label == SELECTIVE_LABEL else label),
            linestyle="solid",
        lw=2,
            alpha=0.7,
            zorder=3,
        )

    # Optional overlay (e.g., Transformer softmax): dashed lines without duplicating labels
    if df_overlay is not None:
        overlay_best = _select_best_runs(df_overlay)
        for label in order:
            run_id = overlay_best.get(label)
            if run_id is None:
                continue
            df_run = df_overlay[df_overlay["id"] == run_id]
            x = pd.to_numeric(df_run["sequence_length"], errors="coerce").to_numpy()
            y = pd.to_numeric(df_run["accuracy"], errors="coerce").to_numpy()
            mask = ~np.isnan(x) & ~np.isnan(y)
            x = x[mask]
            y = y[mask]
            if x.size == 0:
                continue
            order_idx = np.argsort(x)
            x = x[order_idx]
            y = y[order_idx]

            ax.plot(
                x,
                y,
                color=colors.get(label, "C0"),
                label="_nolegend_",
                linestyle="solid",
            lw=2,
                alpha=0.7,
                zorder=2,
            )

    # Vertical dashed line at training context length
    ax.axvline(train_len, color="k", linestyle="--", linewidth=1.6, dashes=(6, 6))

    ax.set_title(title)
    ax.set_xlabel("Sequence Length")
    if show_ylabel:
        ax.set_ylabel("Accuracy")
    else:
        ax.set_ylabel("")
    ax.set_ylim(0.0, 1.02)
    # xlim tight to data range plus ensure ticks at 128, 512
    all_x = pd.to_numeric(df["sequence_length"], errors="coerce").to_numpy()
    if np.isfinite(all_x).any():
        xmin = np.nanmin(all_x)
        xmax = np.nanmax(all_x)
        ax.set_xlim(left=max(0, xmin), right=max(512, xmax))
    ax.set_xticks([128, 512])
    ax.grid(True, axis="y", linestyle="--", linewidth=1.1, alpha=0.35)
    ax.axvline(256, color="0.5", linestyle="--", linewidth=1.1, alpha=0.35, zorder=0)
    ax.axhline(0.5, color="0.5", linestyle="--", linewidth=1.1, alpha=0.35, zorder=0)

    ax.tick_params(axis="both", which="both", width=1)
    for spine in ("top", "right"):
        ax.spines[spine].set_visible(False)
    for spine in ("bottom", "left"):
        ax.spines[spine].set_linewidth(1.5)


def main() -> None:
    parser = argparse.ArgumentParser(description="Combined plot: S2 GLA and A3 DeltaNet")
    parser.add_argument("--train_len", type=int, default=128)
    parser.add_argument("--dpi", type=int, default=500)
    parser.add_argument("--ext", type=str, default="pdf", choices=["pdf", "png", "svg"])
    parser.add_argument("--out", type=str, default="")
    parser.add_argument("--show", action="store_true")
    args = parser.parse_args()

    df_s2 = _load_csv_s2_gla()
    df_s2_softmax = _load_csv_s2_softmax()
    df_a3 = _load_csv_a3_deltanet()

    # Figure with three subplots: GLA (S2), Transformer (S2), DeltaNet (A3)
    fig, axes = plt.subplots(1, 3, figsize=(4.2, 1.6), constrained_layout=True, sharey=True)
    ax_left, ax_mid, ax_right = axes

    _plot_dataset(
        ax_left,
        df_s2,
        train_len=args.train_len,
        title=r"GLA: Group $S_2$",
        show_ylabel=True,
        main_arch="GLA",
    )
    _plot_dataset(
        ax_mid,
        df_s2_softmax,
        train_len=args.train_len,
        title=r"Transformer: Group $S_2$",
        show_ylabel=False,
        main_arch="Transformer",
    )
    _plot_dataset(
        ax_right,
        df_a3,
        train_len=args.train_len,
        title=r"DeltaNet: Group $A_3$",
        show_ylabel=False,
        main_arch="DeltaNet",
    )

    # Shared legend across figure (single row)
    shared_colors = {
        "Selective RoPE": "#E41A1C",
        "RoPE": "#FF7F00",
        "NoPE": "black",
    }
    handles = [
        Line2D([0], [0], color=shared_colors["Selective RoPE"], marker=None, linestyle="solid", lw=2, label=r"$\mathit{Selective\ RoPE}$"),
        Line2D([0], [0], color=shared_colors["RoPE"], marker=None, linestyle="solid", lw=2, label="RoPE"),
        Line2D([0], [0], color=shared_colors["NoPE"], marker=None, linestyle="solid", lw=2, label="NoPE"),
    ]
    legend_settings = fig.legend(
        handles,
        [h.get_label() for h in handles],
        loc="lower center",
        bbox_to_anchor=(0.5, -0.14),
        ncol=3,
        frameon=False,
        handlelength=0.8,
        handletextpad=0.6,
        columnspacing=1.4,
    )

    # Make the "Selective RoPE" legend entry italic (Times New Roman italic)
    try:
        for txt in getattr(legend_settings, "texts", []) or legend_settings.get_texts():
            if "Selective" in txt.get_text():
                txt.set_fontstyle("italic")
                txt.set_fontfamily("Times New Roman")
                txt.set_fontweight("normal")
    except Exception:
        pass

    # Remove architecture legend (one panel per architecture)

    if args.out:
        out_path = Path(args.out)
    else:
        out_path = Path(__file__).parent / f"combined_state_tracking_deltanet.{args.ext}"
    out_path.parent.mkdir(parents=True, exist_ok=True)
    plt.savefig(out_path, dpi=args.dpi)
    if args.show:
        plt.show()
    plt.close(fig)


if __name__ == "__main__":
    main()


