import argparse
import re
from pathlib import Path

import git
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scienceplots  # noqa: F401


repo_root = Path(
    git.Repo(search_parent_directories=True).git.rev_parse("--show-toplevel")
)


# Plot style to match plot_zoology.ipynb (Times New Roman, scienceplots light)
plt.style.use(["science", "light"])  # type: ignore[attr-defined]
plt.rcParams["figure.constrained_layout.use"] = True
plt.rcParams["font.family"] = "Times New Roman"
plt.rcParams["mathtext.fontset"] = "custom"
plt.rcParams["mathtext.rm"] = "Times New Roman"
plt.rcParams["mathtext.it"] = "Times New Roman:italic"
plt.rcParams["mathtext.bf"] = "Times New Roman:bold"
plt.rcParams["font.family"] = "serif"
plt.rcParams["axes.grid"] = True
plt.rcParams["grid.alpha"] = 0.3
plt.rcParams["axes.labelsize"] = 8
plt.rcParams["axes.titlesize"] = 8
plt.rcParams["xtick.labelsize"] = 8
plt.rcParams["ytick.labelsize"] = 8
plt.rcParams["legend.fontsize"] = 7
plt.rcParams["xtick.top"] = False
plt.rcParams["ytick.right"] = False


SELECTIVE_LABEL = "Selective RoPE"
ROPE_LABEL = "RoPE"
NOPE_LABEL = "NoPE"


def _infer_setting_from_name(run_name: str) -> str | None:
    name = run_name.lower()
    # Prefer precise regex to extract the token following lr_...-
    m = re.search(r"s2-(?:gla|transformer)-lr_[^-]+-([a-z_]+)-", name)
    if m is not None:
        token = m.group(1)
        if token == "selective_rope":
            return SELECTIVE_LABEL
        if token == "nope":
            return NOPE_LABEL
        if token == "rope":
            return ROPE_LABEL
    # Fallback heuristics
    if "selective_rope" in name:
        return SELECTIVE_LABEL
    if "-nope-" in name:
        return NOPE_LABEL
    if "-rope-" in name:
        return ROPE_LABEL
    return None


def _compute_auc_for_run(df_run: pd.DataFrame) -> float:
    # Ensure numeric and sorted by sequence_length
    x = pd.to_numeric(df_run["sequence_length"], errors="coerce").to_numpy()
    y = pd.to_numeric(df_run["accuracy"], errors="coerce").to_numpy()
    mask = ~np.isnan(x) & ~np.isnan(y)
    x = x[mask]
    y = y[mask]
    if x.size == 0:
        return -np.inf
    order = np.argsort(x)
    x = x[order]
    y = y[order]
    return float(np.trapezoid(y, x))


def _load_csv(backend: str) -> pd.DataFrame:
    # backend in {gla, softmax}
    backend_norm = backend.strip().lower()
    if backend_norm not in {"gla", "softmax"}:
        raise ValueError("backend must be 'gla' or 'softmax'")
    fname = "S2-GLA.csv" if backend_norm == "gla" else "S2-Softmax.csv"
    csv_path = repo_root / "plotting" / "state_tracking" / fname
    return pd.read_csv(csv_path)


def _select_best_runs(df: pd.DataFrame) -> dict[str, str]:
    # Returns mapping from setting label -> best run id
    # First annotate each row with setting inferred from run name
    df = df.copy()
    df["setting"] = df["name"].apply(_infer_setting_from_name)
    df = df[df["setting"].isin({SELECTIVE_LABEL, ROPE_LABEL, NOPE_LABEL})]
    if df.empty:
        return {}

    best_by_setting: dict[str, tuple[str, float]] = {}
    for (setting, run_id), df_run in df.groupby(["setting", "id" ]):
        score = _compute_auc_for_run(df_run)
        prev = best_by_setting.get(setting)
        if prev is None or score > prev[1]:
            best_by_setting[setting] = (run_id, score)

    return {setting: rid for setting, (rid, _score) in best_by_setting.items()}


def _plot_best_runs(df: pd.DataFrame, train_len: int, out_path: Path, show: bool, dpi: int) -> None:
    order = [NOPE_LABEL, ROPE_LABEL, SELECTIVE_LABEL]

    best_ids = _select_best_runs(df)
    if not best_ids:
        raise RuntimeError("Could not identify any best runs for settings")

    # Colors to match Figure 1 theme
    # Selective RoPE = red, RoPE = orange, NoPE (GLA) = black
    colors = {
        SELECTIVE_LABEL: "#E41A1C",
        ROPE_LABEL: "#FF7F00",
        NOPE_LABEL: "black",
    }

    fig, ax = plt.subplots(figsize=(1.7, 1.6), constrained_layout=True)
    # Keep the plotting area square regardless of labels/legend
    # try:
    #     ax.set_box_aspect(1)
    # except Exception:
    #     ax.set_aspect('equal', adjustable='box')

    for label in order:
        run_id = best_ids.get(label)
        if run_id is None:
            continue
        df_run = df[df["id"] == run_id]
        x = pd.to_numeric(df_run["sequence_length"], errors="coerce").to_numpy()
        y = pd.to_numeric(df_run["accuracy"], errors="coerce").to_numpy()
        mask = ~np.isnan(x) & ~np.isnan(y)
        x = x[mask]
        y = y[mask]
        if x.size == 0:
            continue
        order_idx = np.argsort(x)
        x = x[order_idx]
        y = y[order_idx]

        # Subsample markers for Selective RoPE to avoid a "solid markers" look
        marker_every = None
        if label == SELECTIVE_LABEL:
            n_points = int(len(x))
            # Aim for ~20 markers across the curve
            step = max(1, n_points // 20)
            marker_every = slice(0, n_points, step)

        ax.plot(
            x,
            y,
            color=colors.get(label, "C0"),
            label=("Selective RoPE" if label == SELECTIVE_LABEL else label),
            linestyle="solid",
            marker="o",
            markersize=2.5,
            lw=1.2,
            alpha=0.7,
            zorder=3,
            markevery=marker_every,
        )

    # Vertical dashed line at training context length
    ax.axvline(train_len, color="k", linestyle="--", linewidth=1.6, dashes=(6, 6))

    ax.set_title(r"Group $S_2$")
    ax.set_xlabel("Sequence Length")
    ax.set_ylabel("Accuracy")
    ax.set_ylim(0.0, 1.02)
    # xlim tight to data range plus small margin
    all_x = pd.to_numeric(df["sequence_length"], errors="coerce").to_numpy()
    if np.isfinite(all_x).any():
        xmin = np.nanmin(all_x)
        xmax = np.nanmax(all_x)
        # Ensure the visible ticks at 128 and 512 are within bounds
        ax.set_xlim(left=max(0, xmin), right=max(512, xmax))
    # Show only ticks at 128 and 512, matching the requested appearance
    ax.set_xticks([128, 512])
    # Match grid style from Figure 1 (science/light): dashed grid with slight alpha
    # Only draw horizontal gridlines by default (avoid vertical gridline at x=128)
    ax.grid(True, axis="y", linestyle="--", linewidth=1.1, alpha=0.35)
    # Add extra helper gridlines at x=256 and y=0.5
    ax.axvline(256, color="0.5", linestyle="--", linewidth=1.1, alpha=0.35, zorder=0)
    ax.axhline(0.5, color="0.5", linestyle="--", linewidth=1.1, alpha=0.35, zorder=0)

    # Legends below the plot in two rows:
    #   Row 1: Selective RoPE (single entry)
    #   Row 2: NoPE and RoPE (two columns)
    handles_all, labels_all = ax.get_legend_handles_labels()
    handle_map = {l: h for h, l in zip(handles_all, labels_all)}

    top_handles = [handle_map.get("Selective RoPE")] if handle_map.get("Selective RoPE") is not None else []
    top_labels = ["Selective RoPE"] if top_handles else []

    bottom_order = ["NoPE", "RoPE"]
    bottom_pairs = [(k, handle_map.get(k)) for k in bottom_order if handle_map.get(k) is not None]
    bottom_labels = [k for k, _ in bottom_pairs]
    bottom_handles = [h for _, h in bottom_pairs]

    if top_handles:
        fig.legend(
            top_handles,
            top_labels,
            loc="lower center",
            bbox_to_anchor=(0.55, -0.14),
            ncol=1,
            frameon=False,
            handlelength=1.7,
            handletextpad=0.6,
            columnspacing=1.0,
        )

    if bottom_handles:
        fig.legend(
            bottom_handles,
            bottom_labels,
            loc="lower center",
            bbox_to_anchor=(0.55, -0.24),
            ncol=2,
            frameon=False,
            handlelength=1.7,
            handletextpad=0.6,
            columnspacing=1.0,
        )
    # Match spine visibility/weights from Figure 1
    ax.tick_params(axis="both", which="both", width=1)
    for spine in ("top", "right"):
        ax.spines[spine].set_visible(False)
    for spine in ("bottom", "left"):
        ax.spines[spine].set_linewidth(1.5)
    # Note: figure-level legends are created above; no per-legend handle tweaks here.

    out_path.parent.mkdir(parents=True, exist_ok=True)
    plt.savefig(out_path, dpi=dpi)
    if show:
        plt.show()
    plt.close(fig)


def main() -> None:
    parser = argparse.ArgumentParser(description="Plot S2 state tracking best-run curves")
    parser.add_argument(
        "--backend",
        type=str,
        default="gla",
        choices=["gla", "softmax", "GLA", "Softmax"],
        help="Which CSV to use: GLA or Softmax",
    )
    parser.add_argument("--train_len", type=int, default=128)
    parser.add_argument("--dpi", type=int, default=500)
    parser.add_argument("--ext", type=str, default="pdf", choices=["pdf", "png", "svg"])
    parser.add_argument("--out", type=str, default="")
    parser.add_argument("--show", action="store_true")
    args = parser.parse_args()

    df = _load_csv(args.backend)

    backend_norm = args.backend.lower()
    if args.out:
        out_path = Path(args.out)
    else:
        script_dir = Path(__file__).parent
        out_path = script_dir / f"state_tracking_{backend_norm}.{args.ext}"

    _plot_best_runs(df, train_len=args.train_len, out_path=out_path, show=bool(args.show), dpi=args.dpi)


if __name__ == "__main__":
    main()


