import re
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scienceplots  # noqa
import git

repo_root = git.Repo(search_parent_directories=True).git.rev_parse("--show-toplevel")


CSV_PATH = Path(repo_root) / "plotting/copying/copy_la.csv"


plt.style.use(["science", "light"])  # Match state_tracking style
plt.rcParams["figure.constrained_layout.use"] = True
plt.rcParams["font.family"] = "Times New Roman"
plt.rcParams["mathtext.fontset"] = "custom"
plt.rcParams["mathtext.rm"] = "Times New Roman"
plt.rcParams["mathtext.it"] = "Times New Roman:italic"
plt.rcParams["mathtext.bf"] = "Times New Roman:bold"
plt.rcParams["font.family"] = "serif"
plt.rcParams["axes.grid"] = True
plt.rcParams["grid.alpha"] = 0.3
plt.rcParams["axes.labelsize"] = 8
plt.rcParams["axes.titlesize"] = 8
plt.rcParams["xtick.labelsize"] = 8
plt.rcParams["ytick.labelsize"] = 8
plt.rcParams["legend.fontsize"] = 7
plt.rcParams["xtick.top"] = False
plt.rcParams["ytick.right"] = False
# plt.rcParams["legend.handlelength"] = 0.8


def _extract_base_and_seed(column: str):
    """
    Parse a column name of the form
    "<run-name>-<seed> - test/mean_char_acc" (optionally with __MIN/__MAX)
    and return (base_run_name_without_seed, seed_int). If it does not match,
    return (None, None).
    """
    m = re.match(r"(.+)-(\d+)\s+-\s+test/mean_char_acc(?:__.*)?$", column)
    if m is None:
        return None, None
    return m.group(1), int(m.group(2))


def _display_name(base_run: str) -> str:
    if base_run.startswith("linear_attn"):
        return "LA"
    if base_run.startswith("gla"):
        if "selective_rope" in base_run:
            return "Selective RoPE"
        if "-rope" in base_run:
            return "RoPE"
        return "NoPE"
    return base_run


def _collect_groups(df: pd.DataFrame):
    """ Group columns by model-variant label, keeping only the seed-specific
    mean columns (ignores __MIN/__MAX). Returns a mapping of label -> [cols]. """
    groups: dict[str, list[str]] = {}
    for col in df.columns:
        if col == "Step":
            continue
        if " - test/mean_char_acc" not in col:
            continue
        if "__" in col:  # skip MIN/MAX helper columns
            continue
        base, _seed = _extract_base_and_seed(col)
        if base is None:
            continue
        label = _display_name(base)
        groups.setdefault(label, []).append(col)
    return groups


def _compute_stats(df: pd.DataFrame, cols: list[str]):
    """Return y_mean, y_min, y_max across seeds for the given columns, and a
    boolean mask indicating rows where at least one seed is present."""
    if len(cols) == 0:
        n = len(df)
        nan_arr = np.full(n, np.nan)
        mask = np.zeros(n, dtype=bool)
        return nan_arr, nan_arr, nan_arr, mask
    arr = np.column_stack(
        [pd.to_numeric(df[c], errors="coerce").to_numpy() for c in cols]
    )
    # Compute row-wise means without raising warnings for all-NaN rows
    valid_counts = np.sum(~np.isnan(arr), axis=1)
    sums = np.nansum(arr, axis=1)
    y_mean = np.divide(
        sums,
        valid_counts,
        out=np.full(arr.shape[0], np.nan, dtype=float),
        where=valid_counts > 0,
    )
    # Use masked arrays for min/max to avoid all-NaN warnings
    masked = np.ma.masked_invalid(arr)
    y_min = masked.min(axis=1).filled(np.nan)
    y_max = masked.max(axis=1).filled(np.nan)
    has_any = valid_counts > 0
    return y_mean, y_min, y_max, has_any


def _derive_seq_len_for_group(
    series_steps: pd.Series, valid_mask: np.ndarray, start_len: int = 64
) -> np.ndarray:
    """Map wandb Steps to sequence lengths for a specific model group.

    We anchor the sequence length such that the first row where the group has
    any data corresponds to start_len (typically 64). This handles cases where
    different groups started logging at different absolute steps (e.g., ~5k vs
    ~15k).
    """
    steps = pd.to_numeric(series_steps, errors="coerce").to_numpy()
    if valid_mask.any():
        first_idx = np.argmax(valid_mask)
        base_step = steps[first_idx]
    else:
        base_step = steps[0]
    return (steps - base_step) + start_len


if __name__ == "__main__":
    # Style (Times New Roman; scienceplots light)

    # Option to hide LA (linear attention)
    SHOW_LA = False  # Set to False to hide LA
    TRAIN_LEN = 64
    X_LIM = 160
    MAKE_OURS_BOLD = True

    df = pd.read_csv(CSV_PATH)

    groups = _collect_groups(df)

    # Preferred plotting order
    order = [
        "LA",
        "NoPE",
        "RoPE",
        "Selective RoPE",
    ]

    # Filter out LA if not showing it
    if not SHOW_LA:
        order = [label for label in order if label != "LA"]

    # Precompute numeric steps once
    step_series = df["Step"]

    # Colors to match state_tracking theme
    COLORS = {
        "Selective RoPE": "#E41A1C",  # red
        "RoPE": "#FF7F00",           # orange
        "NoPE": "black",             # black
        # Optional LA color if enabled
        "LA": "#377eb8",
    }

    fig, ax = plt.subplots(figsize=(1.85, 1.8), constrained_layout=True)

    for label in order:
        if label not in groups:
            continue
        cols = groups[label]
        y, y_lo, y_hi, has_any = _compute_stats(df, cols)
        # Derive sequence length for this group using its own starting step
        x_seq_len = _derive_seq_len_for_group(step_series, has_any, start_len=0)
        mask = has_any & ~np.isnan(y)
        if mask.sum() == 0:
            continue
        x = x_seq_len[mask]
        y = y[mask]
        y_lo = y_lo[mask]
        y_hi = y_hi[mask]

        ax.plot(
            x,
            y,
            lw=2,
            color=COLORS.get(label, "C0"),
            label=(r"$\mathit{Selective\ RoPE}$" if label == "Selective RoPE" else label),
            alpha=0.7,
            zorder=3,
        )
        ax.fill_between(
            x,
            y_lo,
            y_hi,
            color=COLORS.get(label, "C0"),
            alpha=0.30,
            linewidth=0,
        )

    # Vertical dashed line at training length
    ax.axvline(TRAIN_LEN, color="k", linestyle="--", linewidth=1.6, dashes=(6, 6))

    ax.set_xlabel("Sequence Length")
    ax.set_ylabel("Accuracy")
    ax.set_title("GLA: String copying")
    ax.set_ylim(0.0, 1.02)
    ax.set_xlim(left=32, right=X_LIM)
    # Grid style to match state_tracking
    ax.grid(True, axis="y", linestyle="--", linewidth=1.1, alpha=0.35)
    # Helper gridlines at mid x and y=0.5
    try:
        ax.axvline(TRAIN_LEN * 2, color="0.5", linestyle="--", linewidth=1.1, alpha=0.35, zorder=0)
    except Exception:
        pass
    ax.axhline(0.5, color="0.5", linestyle="--", linewidth=1.1, alpha=0.35, zorder=0)

    # Reorder legend: Selective RoPE, RoPE, NoPE (others, e.g., LA, appended)
    all_handles, all_labels = ax.get_legend_handles_labels()
    pairs = list(zip(all_handles, all_labels))
    sel = [p for p in pairs if "Selective" in p[1]]
    rope = [p for p in pairs if p[1] == "RoPE"]
    nope = [p for p in pairs if p[1] == "NoPE"]
    others = [p for p in pairs if p not in sel + rope + nope]
    ordered = sel + rope + nope + others
    ordered_handles = [h for h, _ in ordered]
    ordered_labels = [l for _, l in ordered]

    leg = ax.legend(
        ordered_handles,
        ordered_labels,
        framealpha=0.95,
        bbox_to_anchor=(0.4, -0.3),
        loc="upper center",
        ncol=3,
        handlelength=0.5,
        # labelspacing=0.3,
        handletextpad=0.4,
        columnspacing=0.8,
        fancybox=True,
        edgecolor="black",
    )
    handles = getattr(leg, "legendHandles", None) or getattr(
        leg, "legend_handles", None
    )
    if handles is not None:
        for lh in handles:
            try:
                lh.set_linewidth(1.5)
            except Exception:
                pass

    # Legend selective entry italic (Times New Roman italic)
    try:
        for txt in getattr(leg, "texts", []) or leg.get_texts():
            if "Selective" in txt.get_text():
                txt.set_fontstyle("italic")
                txt.set_fontfamily("Times New Roman")
                txt.set_fontweight("normal")
    except Exception:
        pass

    # Spine styling to match state_tracking
    ax.tick_params(axis="both", which="both", width=1)
    for spine in ("top", "right"):
        ax.spines[spine].set_visible(False)
    for spine in ("bottom", "left"):
        ax.spines[spine].set_linewidth(1.5)

    out_path = Path(__file__).resolve().parent / "gla_copying.pdf"
    # plt.tight_layout()
    plt.savefig(out_path, dpi=500)
