#!/usr/bin/env python3
import argparse
import json
from pathlib import Path

import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np

from scipy.stats import spearmanr, rankdata
import re

mpl.rcParams["font.family"] = "sans-serif"
mpl.rcParams["font.sans-serif"] = [
    "Noto Sans CJK SC",
    "Source Han Sans CN",
    "Microsoft YaHei",
    "SimHei",
    "PingFang SC",
    "Heiti SC",
    "Arial Unicode MS",
    "DejaVu Sans",
]
mpl.rcParams["axes.unicode_minus"] = False

SERIES_MAP = {
    "alphamax": ("probs_alphamax", "token_strs_alphamax"),
    "alphamin": ("probs_alphamin", "token_strs_alphamin"),
}
PLOT_STYLE = {
    "font.size": 14,
    "axes.labelsize": 16,
    "axes.titlesize": 18,
    "xtick.labelsize": 12,
    "ytick.labelsize": 12,
    "legend.fontsize": 14,
}

def check_margin_prob_consistency(plot_npz_path: str, margin_npz_path: str, eps: float = 1e-3, prob_tol: float = 0.05):
    """
    For tokens present in both files, find pairs whose margins differ by <= eps.
    If their alpha_max probs differ by > prob_tol, record them.
    Returns summary plus the mismatching pairs.
    """
    p = np.load(plot_npz_path, allow_pickle=True)
    plot_ids = p["token_alphamax"]
    plot_strs = p["token_strs_alphamax"]
    plot_probs = p["probs_alphamax"][-1]  # last alpha row = alpha_max

    m = np.load(margin_npz_path, allow_pickle=True)
    margin_ids = m["token_ids"]
    margin_vals = m["margins"]
    margin_lookup = {int(t): float(v) for t, v in zip(margin_ids, margin_vals)}

    rows = []
    for tid, tstr, prob in zip(plot_ids, plot_strs, plot_probs):
        tid_int = int(tid)
        if tid_int in margin_lookup:
            rows.append((tid_int, str(tstr), float(margin_lookup[tid_int]), float(prob)))

    if len(rows) < 2:
        return {"total_pairs_within_eps": 0, "inconsistent_pairs": [], "kept_tokens": len(rows)}

    rows.sort(key=lambda x: x[2])  # by margin
    inconsistent = []
    pairs_within_eps = 0
    n = len(rows)
    for i in range(n):
        tid_i, tstr_i, m_i, p_i = rows[i]
        for j in range(i + 1, n):
            tid_j, tstr_j, m_j, p_j = rows[j]
            if m_j - m_i > eps:
                break
            pairs_within_eps += 1
            if abs(p_i - p_j) > prob_tol:
                inconsistent.append({
                    "token_i": {"id": tid_i, "str": tstr_i, "margin": m_i, "prob": p_i},
                    "token_j": {"id": tid_j, "str": tstr_j, "margin": m_j, "prob": p_j},
                    "margin_diff": m_j - m_i,
                    "prob_diff": abs(p_i - p_j),
                })

    return {
        "kept_tokens": n,
        "pairs_within_eps": pairs_within_eps,
        "inconsistent_pairs": inconsistent,
    }

def plot_margin_and_plot_tokens(margin_npz_path: str, plot_npz_path: str, top_n: int = 30):
    """
    Two-row bar plot:
      - Top: top-N margin tokens (from margins npz).
      - Bottom: top-N plot tokens (from steering_plot npz, using alpha_max probs).
    """
    mdata = np.load(margin_npz_path, allow_pickle=True)
    m_scores = mdata["margins"]
    m_tokens = mdata["token_strs"]
    n_m = min(top_n, len(m_scores))
    m_y = m_scores[:n_m]
    m_x = [str(t) for t in m_tokens[:n_m]]

    pdata = np.load(plot_npz_path, allow_pickle=True)
    p_tokens = pdata["token_strs_alphamax"]
    p_probs = pdata["probs_alphamax"]  # [A, K]
    alpha_max_probs = p_probs[-1]      # last alpha row corresponds to alpha_max
    n_p = min(top_n, len(alpha_max_probs))
    p_y = alpha_max_probs[:n_p]
    p_x = [str(t) for t in p_tokens[:n_p]]

    fig, axes = plt.subplots(2, 1, figsize=(12, 6), constrained_layout=True)

    axes[0].bar(range(n_m), m_y, color="#2563eb")
    axes[0].set_xticks(range(n_m))
    axes[0].set_xticklabels(m_x, rotation=75, ha="right", fontsize=8)
    axes[0].set_ylabel("Margin")
    axes[0].set_title(f"Top {n_m} margin tokens")

    axes[1].bar(range(n_p), p_y, color="#f97316")
    axes[1].set_xticks(range(n_p))
    axes[1].set_xticklabels(p_x, rotation=75, ha="right", fontsize=8)
    axes[1].set_ylabel("Prob (alpha_max)")
    axes[1].set_title(f"Top {n_p} plot tokens (alpha_max)")

    plt.show()


def spearman_margin_vs_plot(plot_npz_path: str, margin_npz_path: str, limit = -1) -> dict:
    """
    Compare the token ordering from steering_plot (token_alphamax) against the
    ordering induced by margins on those same tokens.
    - plot ordering: the order of token_alphamax (best first) → ranks 1..K
    - margin ordering: ranks derived from the margin scores for those tokens
    Only tokens present in both files are compared.
    """
    plot_npz = np.load(plot_npz_path, allow_pickle=True)
    plot_ids = plot_npz["token_alphamax"][:limit]

    margin_npz = np.load(margin_npz_path, allow_pickle=True)
    margin_ids = margin_npz["token_ids"]
    margin_vals = margin_npz["margins"]
    margin_lookup = {int(t): float(s) for t, s in zip(margin_ids, margin_vals)}

    aligned = np.array([margin_lookup.get(int(t), np.nan) for t in plot_ids], dtype=float)
    mask = ~np.isnan(aligned)
    if mask.sum() < 2:
        return {
            "rho": np.nan,
            "pval": np.nan,
            "overlap": int(mask.sum()),
            "n_plot_tokens": int(len(plot_ids)),
        }

    plot_rank = np.arange(1, mask.sum() + 1, dtype=int)

    margin_rank = rankdata(-aligned[mask], method="average")

    rho, pval = spearmanr(plot_rank, margin_rank)
    return {
        "rho": float(rho),
        "pval": float(pval),
        "overlap": int(mask.sum()),
        "n_plot_tokens": int(len(plot_ids)),
    }

def r2_logodds_vs_probs(
    plot_npz_path: str,
    margin_npz_path: str,
    which: str = "alphamax",
    limit: int | None = None,
    annotate: int = 10,
    save: str | None = None,
) -> dict:
    """
    Scatter token probabilities from steering_plot against margin log-odds and return R^2.

    Args:
      plot_npz_path: .npz produced by steering_plot_actor.py
      margin_npz_path: margins_topk.npz produced by generate_margin.py / MarginActor
      which: 'alphamax' or 'alphamin' token set to use (probabilities are taken at the corresponding extreme α)
      limit: optionally restrict to the first N tokens from the stored top-K (order preserved)
      annotate: number of points to label (tokens with largest |log-odds|)
      save: optional path to save the scatter; otherwise plt.show() is called
    """
    if which not in SERIES_MAP:
        raise ValueError(f"'which' must be one of {list(SERIES_MAP)}")

    plot_npz = np.load(plot_npz_path, allow_pickle=True)
    margin_npz = np.load(margin_npz_path, allow_pickle=True)

    prob_key, tok_str_key = SERIES_MAP[which]
    tok_id_key = f"token_{which}"
    if tok_id_key not in plot_npz or prob_key not in plot_npz or tok_str_key not in plot_npz:
        raise KeyError(f"Missing '{tok_id_key}', '{prob_key}', or '{tok_str_key}' in {plot_npz_path}")

    probs = np.asarray(plot_npz[prob_key], dtype=float)  # [A, K]
    token_ids = np.asarray(plot_npz[tok_id_key], dtype=int)  # [K]
    token_strs = np.asarray(plot_npz[tok_str_key]).astype(object)  # [K]

    if limit is not None:
        limit = int(limit)
        if limit <= 0:
            raise ValueError("limit must be positive when provided")
        limit = min(limit, token_ids.shape[0])
        probs = probs[:, :limit]
        token_ids = token_ids[:limit]
        token_strs = token_strs[:limit]

    margins = np.asarray(margin_npz["margins"], dtype=float)
    margin_ids = np.asarray(margin_npz["token_ids"], dtype=int)
    margin_lookup = {int(t): float(m) for t, m in zip(margin_ids, margins)}

    if which == "alphamax":
        x_probs = probs[-1, :]  # α = max
        alpha_desc = "α=max"
    else:
        x_probs = probs[0, :]  # α = min
        alpha_desc = "α=min"

    xs, ys, labels = [], [], []
    for tok_id, tok_str, p in zip(token_ids, token_strs, x_probs):
        m = margin_lookup.get(int(tok_id))
        if m is None:
            continue
        xs.append(float(p))
        ys.append(float(m))
        labels.append(str(tok_str))

    xs = np.asarray(xs, dtype=float)
    ys = np.asarray(ys, dtype=float)
    if xs.size < 2:
        return {
            "r2": np.nan,
            "slope": np.nan,
            "intercept": np.nan,
            "overlap": int(xs.size),
            "n_plot_tokens": int(len(token_ids)),
        }

    slope, intercept = np.polyfit(xs, ys, deg=1)
    y_pred = slope * xs + intercept
    ss_res = float(np.sum((ys - y_pred) ** 2))
    ss_tot = float(np.sum((ys - np.mean(ys)) ** 2))
    r2 = float(1.0 - ss_res / ss_tot) if ss_tot > 0 else np.nan

    fig, ax = plt.subplots(figsize=(8, 6))
    ax.scatter(xs, ys, alpha=0.7, s=36, label=f"overlap={xs.size}")
    line_x = np.linspace(xs.min(), xs.max(), 200)
    ax.set_xlabel(f"token probability @ {alpha_desc}")
    ax.set_ylabel("margin log-odds")
    ax.grid(True, alpha=0.4)
    ax.legend()
    ax.set_xscale('log')

    if annotate and annotate > 0:
        top_idx = np.argsort(np.abs(ys))[-int(annotate) :]
        for idx in top_idx:
            ax.annotate(
                labels[idx],
                (xs[idx], ys[idx]),
                textcoords="offset points",
                xytext=(4, 4),
                fontsize=8,
            )

    fig.tight_layout()
    if save:
        save_path = Path(save)
        save_path.parent.mkdir(parents=True, exist_ok=True)
        fig.savefig(save_path, dpi=300, bbox_inches="tight")
    else:
        plt.show()

    return {
        "r2": float(r2),
        "slope": float(slope),
        "intercept": float(intercept),
        "overlap": int(xs.size),
        "n_plot_tokens": int(len(token_ids)),
    }

def plot_npz(
    npz_path: str,
    which: str = "alphamax",
    n_labels: int = 6,
    limit: int | None = None,
    alpha_range: tuple[float, float] | None = None,
    title_suffix: str = "",
    save: str | None = None,
    symlog: bool = False,
    logit: bool = False,
    mode: str = "probs",
    zoom_alpha_range: tuple[float, float] | None = None,  # NEW
    select_bumps: bool = False,
    rank_bumps_by_height: bool = False,
) -> None:
    """
    Plot probability curves for the tokens chosen at alpha-min and/or alpha-max.

    Args:
      npz_path: path to the .npz produced by steering_plot_actor.py
      which: 'alphamax', 'alphamin', or 'both' (plots side-by-side when 'both')
      n_labels: how many lines to label in each panel (rest are unlabeled)
      limit: plot only the first N tokens from the stored top-K
      alpha_range: tuple (min_alpha, max_alpha) to filter existing alpha values
      title_suffix: extra text appended to the title
      save: if set, save the figure to this path instead of showing
      symlog: use symmetric log scale on the x-axis
      logit: use logit scale on the y-axis
      mode: 'probs' to plot raw probabilities, 'delta' to plot probs - probs at α=0
      zoom_alpha_range: if not None, draw a zoomed inset over this α-range,
                        e.g. (0.0, 10.0).
      select_bumps: when True, split plotted tokens ~half from top prob at the
                    anchor α (max or min) and half with the largest detected bumps.
      rank_bumps_by_height: when True with select_bumps, rank bump candidates by
                            peak height; otherwise keep anchor top-K order and
                            just filter for tokens that exhibit a bump.
    """
    plt.rcParams.update(
        {
            "axes.labelsize": 40,
            "axes.titlesize": 18,
            "xtick.labelsize": 25,
            "ytick.labelsize": 25,
            "legend.fontsize": 14,
        }
    )

    data = np.load(str(npz_path), allow_pickle=True)
    alphas = data["alphas"].astype(np.float32)
    meta = json.loads(str(data["meta"]))
    mode = mode.lower()
    if mode not in {"probs", "delta"}:
        raise ValueError("mode must be either 'probs' or 'delta'")

    series_keys = ["alphamin", "alphamax"] if which == "both" else [which]
    for key in series_keys:
        if key not in SERIES_MAP:
            raise ValueError(f"'which' must be one of {list(SERIES_MAP)} or 'both'")
        prob_key, tok_key = SERIES_MAP[key]
        if prob_key not in data or tok_key not in data:
            raise KeyError(f"Missing '{prob_key}' or '{tok_key}' in {npz_path}")

    eps = 1e-5
    zero_indices = np.nonzero(np.isclose(alphas, 0.0, atol=eps))[0]
    zero_idx = int(zero_indices[0]) if zero_indices.size else None
    if mode == "delta" and zero_idx is None:
        raise ValueError("mode='delta' requires an alpha value equal to 0.")

    if alpha_range is not None:
        lo, hi = sorted(alpha_range)
        mask = (alphas >= lo) & (alphas <= hi)
        if not np.any(mask):
            raise ValueError(f"No alpha values in [{lo}, {hi}]")
        alpha_indices = np.nonzero(mask)[0]
    else:
        alpha_indices = np.arange(alphas.shape[0])

    alphas_sel = alphas[alpha_indices]

    n_panels = len(series_keys)
    if zoom_alpha_range is not None:
        fig, axes_grid = plt.subplots(
            2,
            n_panels,
            figsize=(12 if n_panels == 2 else 10, 7),
            sharey="row",
            gridspec_kw={"height_ratios": [1.1, 3.0], "hspace": 0.1},
            constrained_layout=True,
        )
        zoom_axes = np.atleast_1d(axes_grid[0])
        axes = np.atleast_1d(axes_grid[1])
    else:
        fig, axes = plt.subplots(
            1, n_panels, figsize=(12 if n_panels == 2 else 10, 5), sharey=True, constrained_layout=True
        )
        zoom_axes = [None] * n_panels
    axes = np.atleast_1d(axes)
    zoom_axes = np.atleast_1d(zoom_axes)
    y_label = "probability" if mode == "probs" else "Δ probability vs α=0"

    for idx, (ax, key) in enumerate(zip(axes, series_keys)):
        zax = zoom_axes[idx] if zoom_alpha_range is not None else None
        prob_key, tok_key = SERIES_MAP[key]
        probs = np.asarray(data[prob_key])
        tokens = np.asarray(data[tok_key]).astype(object)
        A, K = probs.shape
        plot_k = K if limit is None else max(0, min(int(limit), K))
        label_alpha = alphas_sel[-1] if key == "alphamax" else alphas_sel[0]

        if mode == "delta":
            baseline_probs = probs[zero_idx, :]
            probs_to_plot = probs - baseline_probs
        else:
            probs_to_plot = probs
        probs_to_plot = probs_to_plot[alpha_indices, :]

        if plot_k > 0 and select_bumps:
            top_count = plot_k // 2
            bump_count = plot_k - top_count
            anchor_idx = -1 if key == "alphamax" else 0

            anchor_order = np.argsort(probs_to_plot[anchor_idx])[::-1]
            top_tokens = list(anchor_order[:top_count])

            candidate_limit = min(K, top_count + bump_count + 5)
            candidate_pool = list(anchor_order[:candidate_limit])
            remaining = [idx for idx in candidate_pool if idx not in top_tokens]
            bump_tokens: list[int] = []
            if rank_bumps_by_height:
                bump_scores = []
                for idx in remaining:
                    _, _, bump_y = find_bumps(alphas_sel, probs_to_plot[:, idx])
                    if bump_y.size:
                        bump_scores.append((float(np.max(bump_y)), idx))
                bump_scores.sort(key=lambda t: t[0], reverse=True)
                bump_tokens = [idx for _, idx in bump_scores[:bump_count]]
            else:
                for idx in remaining:
                    _, _, bump_y = find_bumps(alphas_sel, probs_to_plot[:, idx])
                    if bump_y.size:
                        bump_tokens.append(idx)
                    if len(bump_tokens) >= bump_count:
                        break

            if len(bump_tokens) < bump_count:
                need = bump_count - len(bump_tokens)
                filler = [idx for idx in remaining if idx not in bump_tokens][:need]
                bump_tokens.extend(filler)

            token_indices = np.array(top_tokens + bump_tokens, dtype=int)
            plot_k = token_indices.size
        else:
            token_indices = np.arange(plot_k, dtype=int)

        probs_sel = probs_to_plot[:, token_indices]
        tokens_sel = tokens[token_indices]

        lines = []
        for j in range(plot_k):
            lbl = repr(tokens_sel[j]) if j < n_labels else None
            (line,) = ax.plot(
                alphas_sel,
                probs_sel[:, j],
                label=lbl,
                alpha=0.9 if lbl else 0.5,
            )
            lines.append(line)

        if symlog:
            ax.set_xscale("symlog")
        if logit:
            ax.set_yscale("logit")
        ax.grid(True, alpha=0.4)
        ax.set_xlabel("$\\alpha$")
        ax.set_ylabel(y_label)

        if n_labels > 0 and plot_k > 0:
            ax.legend(
                title=f"{min(n_labels, plot_k)} tokens",
                loc="center left",
                bbox_to_anchor=(1.02, 0.5),
                fontsize=8,
                title_fontsize=9,
                frameon=True,
            )

        if zoom_alpha_range is not None and plot_k > 0 and zax is not None:
            z_lo, z_hi = sorted(zoom_alpha_range)
            zoom_mask = (alphas_sel >= z_lo) & (alphas_sel <= z_hi)

            if np.any(zoom_mask):
                zoom_alphas = alphas_sel[zoom_mask]
                zoom_probs = probs_sel[zoom_mask, :plot_k]

                for j in range(plot_k):
                    base_line = lines[j]
                    zax.plot(
                        zoom_alphas,
                        zoom_probs[:, j],
                        color=base_line.get_color(),
                        alpha=base_line.get_alpha(),
                        linewidth=base_line.get_linewidth(),
                    )

                zax.set_xlim(zoom_alphas[0], zoom_alphas[-1])

                y_min = float(np.nanmin(zoom_probs))
                y_max = float(np.nanmax(zoom_probs))
                if np.isfinite(y_min) and np.isfinite(y_max) and y_max > y_min:
                    pad = 0.05 * (y_max - y_min)
                    zax.set_ylim(y_min - pad, y_max + pad)

                if symlog:
                    zax.set_xscale("symlog")
                if logit:
                    zax.set_yscale("logit")

                zax.grid(True, alpha=0.3)
                zax.tick_params(labelsize=8, labelbottom=True)
                zax.set_xlabel("$\\alpha$", fontsize=9)
                zax.set_ylabel("")
                zax.set_title(f"zoom α∈[{z_lo:.3g}, {z_hi:.3g}]", fontsize=10)

            ax.axvspan(z_lo, z_hi, alpha=0.1)

    npz_name = Path(npz_path).name
    ctx_match = re.search(r"ctx[_-]+(\d+)", npz_name, flags=re.IGNORECASE)
    if ctx_match is None:
        ctx_match = re.search(r"ctx[_-]+(\d+)", str(npz_path), flags=re.IGNORECASE)
    ctx_desc = f"ctx idx={ctx_match.group(1)}" if ctx_match else f"ctx len={meta['seq_len']}"

    title = f"{meta['model']} | {meta['concept']} | layer={meta['layer_idx']} | {ctx_desc}"
    if title_suffix:
        title += f" | {title_suffix}"
    fig.suptitle(title)
    if save:
        save_path = Path(save)
        save_path.parent.mkdir(parents=True, exist_ok=True)
        fig.savefig(save_path, dpi=300, bbox_inches="tight")
    else:
        plt.show()

def plot_behaviour_probs(
    npz_path: str,
    alpha_range: tuple[float, float] | None = None,
    title_suffix: str = "",
    save: str | None = None,
    show_contexts: bool = False,
    max_contexts: int = 200,
    context_alpha: float = 0.15,
    logit: bool = False,
) -> None:
    """
    Plot behaviour-judge probability curves from a behaviour .npz.

    Supports both formats:
      - new: p1_by_ctx + mean_all/mean_negative/mean_positive (mean_match ignored)
      - legacy: p_match_by_item (+ mean_match ignored)
    """
    plt.rcParams.update(
        {
            "axes.labelsize": 40,
            "axes.titlesize": 18,
            "xtick.labelsize": 25,
            "ytick.labelsize": 25,
            "legend.fontsize": 14,
        }
    )

    data = np.load(str(npz_path), allow_pickle=True)
    alphas = data["alphas"].astype(np.float32)
    meta = json.loads(str(data["meta"])) if "meta" in data else {}

    if "p1_by_ctx" in data:
        p_by_ctx = data["p1_by_ctx"].astype(np.float32)
        ctx_is_positive = data["ctx_is_positive"] if "ctx_is_positive" in data else None
        mean_all = data["mean_all"] if "mean_all" in data else None
        mean_neg = data["mean_negative"] if "mean_negative" in data else None
        mean_pos = data["mean_positive"] if "mean_positive" in data else None
        y_label = "P(judge=1)"
    elif "p_match_by_item" in data:
        p_by_ctx = data["p_match_by_item"].astype(np.float32)
        ctx_is_positive = None
        mean_all = None
        mean_neg = data["mean_negative"] if "mean_negative" in data else None
        mean_pos = data["mean_positive"] if "mean_positive" in data else None
        y_label = "P(match)"
    else:
        raise KeyError("Expected 'p1_by_ctx' or 'p_match_by_item' in behaviour .npz")

    if alpha_range is not None:
        lo, hi = sorted(alpha_range)
        mask = (alphas >= lo) & (alphas <= hi)
        if not np.any(mask):
            raise ValueError(f"No alpha values in [{lo}, {hi}]")
        alpha_idx = np.nonzero(mask)[0]
    else:
        alpha_idx = np.arange(alphas.shape[0])

    alphas_sel = alphas[alpha_idx]
    p_by_ctx = p_by_ctx[:, alpha_idx]

    def _slice_curve(arr: np.ndarray | None) -> np.ndarray | None:
        if arr is None:
            return None
        arr = np.asarray(arr, dtype=np.float32)
        if arr.ndim != 1 or arr.shape[0] != alphas.shape[0]:
            return None
        return arr[alpha_idx]

    mean_all = _slice_curve(mean_all)
    mean_neg = _slice_curve(mean_neg)
    mean_pos = _slice_curve(mean_pos)
    if mean_all is None:
        mean_all = p_by_ctx.mean(axis=0)

    if ctx_is_positive is not None:
        ctx_is_positive = np.asarray(ctx_is_positive, dtype=np.int8)
        neg_mask = ctx_is_positive == 0
        pos_mask = ctx_is_positive == 1
        if mean_neg is None and neg_mask.any():
            mean_neg = p_by_ctx[neg_mask].mean(axis=0)
        if mean_pos is None and pos_mask.any():
            mean_pos = p_by_ctx[pos_mask].mean(axis=0)
    fig, ax = plt.subplots(figsize=(9, 5))

    if show_contexts:
        n_ctx = p_by_ctx.shape[0]
        if max_contexts is None or max_contexts <= 0 or max_contexts >= n_ctx:
            ctx_idx = np.arange(n_ctx)
        else:
            ctx_idx = np.unique(
                np.linspace(0, n_ctx - 1, int(max_contexts)).round().astype(int)
            )
        colors = {-1: "#94a3b8", 0: "#60a5fa", 1: "#f59e0b"}
        for i in ctx_idx:
            if ctx_is_positive is None:
                color = colors[-1]
            else:
                flag = int(ctx_is_positive[i]) if i < ctx_is_positive.shape[0] else -1
                color = colors.get(flag, colors[-1])
            ax.plot(
                alphas_sel,
                p_by_ctx[i],
                color=color,
                alpha=float(context_alpha),
                linewidth=0.8,
            )
        if ctx_is_positive is not None:
            ax.plot([], [], color=colors[0], label="negative contexts")
            ax.plot([], [], color=colors[1], label="positive contexts")

    def _plot_curve(curve: np.ndarray | None, label: str, color: str, lw: float = 2.0) -> None:
        if curve is None:
            return
        if not np.isfinite(curve).any():
            return
        ax.plot(alphas_sel, curve, label=label, color=color, linewidth=4)

    _plot_curve(mean_neg, "mean_negative", "#2563eb")

    ax.grid(True, alpha=0.4)
    ax.set_xlabel("$\\alpha$")
    if logit:
        ax.set_yscale("logit")

    model = meta.get("model", "?")
    concept = meta.get("concept_label", meta.get("concept", meta.get("concept_slug", "?")))
    layer = meta.get("layer_idx", "?")
    title = f"{model} | {concept} | layer={layer}"
    if title_suffix:
        title += f" | {title_suffix}"

    handles, labels = ax.get_legend_handles_labels()

    fig.tight_layout()
    if save:
        save_path = Path(save)
        save_path.parent.mkdir(parents=True, exist_ok=True)
        fig.savefig(save_path, dpi=300, bbox_inches="tight")
    else:
        plt.show()


def plot_npz_seed_mean_std(
    npz_paths: list[str],
    which: str = "alphamax",
    top_n: int = 20,
    mode: str = "probs",
    alpha_range: tuple[float, float] | None = None,
    title_suffix: str = "",
    save: str | None = None,
):
    """
    npz_paths: list of .npz paths for the same model/concept/layer/ctx across seeds.
    which: 'alphamax' or 'alphamin' token set to use.
    top_n: number of tokens (taken from the first npz) to plot.
    mode: 'probs' or 'delta' (delta subtracts the α=0 baseline).
    alpha_range: optional (lo, hi) to restrict the α domain.
    """
    if which not in SERIES_MAP:
        raise ValueError(f"'which' must be one of {list(SERIES_MAP)}")

    plt.rcParams.update(PLOT_STYLE)

    runs = [np.load(p, allow_pickle=True) for p in npz_paths]
    alphas_full = runs[0]["alphas"].astype(np.float32)
    for r in runs[1:]:
        if not np.allclose(alphas_full, r["alphas"]):
            raise ValueError("alpha grids differ across runs")

    eps = 1e-5
    zero_idx = int(np.nonzero(np.isclose(alphas_full, 0.0, atol=eps))[0][0]) if mode == "delta" else None

    if alpha_range is not None:
        lo, hi = sorted(alpha_range)
        mask = (alphas_full >= lo) & (alphas_full <= hi)
        if not np.any(mask):
            raise ValueError(f"No alpha values in [{lo}, {hi}]")
        alpha_idx = np.nonzero(mask)[0]
    else:
        alpha_idx = np.arange(alphas_full.shape[0])
    alphas = alphas_full[alpha_idx]

    prob_key, tok_str_key = SERIES_MAP[which]
    tok_id_key = f"token_{which}"

    anchor_idx = -1 if which == "alphamax" else 0
    probs0 = runs[0][prob_key]
    tokens0 = np.asarray(runs[0][tok_id_key], dtype=int)
    token_strs0 = np.asarray(runs[0][tok_str_key]).astype(object)
    order0 = np.argsort(probs0[anchor_idx])[::-1]
    sel_idx = order0[: min(top_n, tokens0.shape[0])]
    sel_token_ids = tokens0[sel_idx]
    sel_token_strs = token_strs0[sel_idx]

    col_maps = []
    for r in runs:
        ids = np.asarray(r[tok_id_key], dtype=int)
        col_maps.append({int(t): i for i, t in enumerate(ids)})

    fig, ax = plt.subplots(figsize=(12, 6))
    y_label = "probability" if mode == "probs" else "Δ probability vs α=0"

    for tid, tstr in zip(sel_token_ids, sel_token_strs):
        curves = []
        for r, cmap in zip(runs, col_maps):
            idx = cmap.get(int(tid))
            if idx is None:
                continue  # token not present in this run's top-K
            curve = np.asarray(r[prob_key][:, idx], dtype=float)
            if mode == "delta":
                curve = curve - curve[zero_idx]
            curves.append(curve[alpha_idx])
        if not curves:
            continue
        curves = np.stack(curves, axis=0)  # [num_runs, A]
        mean = curves.mean(axis=0)
        std = curves.std(axis=0)

        ax.plot(alphas, mean, label=str(tstr))
        ax.fill_between(alphas, mean - std, mean + std, alpha=0.2)

    meta = json.loads(str(runs[0]["meta"]))
    ax.set_title(f"{meta.get('model')} · {meta.get('concept')} · {meta.get('layer_idx')} {title_suffix}")
    ax.set_xlabel("alpha")
    ax.set_ylabel(y_label)
    ax.grid(True, alpha=0.4)
    ax.legend()

    fig.tight_layout()
    if save:
        save_path = Path(save)
        save_path.parent.mkdir(parents=True, exist_ok=True)
        fig.savefig(save_path, dpi=300, bbox_inches="tight")
    else:
        plt.show()


def plot_xent_delta(
    npz_path: str,
    alpha_range: tuple[float, float] | None = None,
    title_suffix: str = "",
    save: str | None = None,
    symlog: bool = False,
) -> None:
    """
    Plot Δ cross-entropy vs α from a .npz that contains:

      alphas                   : [A]
      xent or cross_entropy    : [A]          (mean cross-entropy per token)
      delta_xent or delta_cross_entropy : [A] (optional; used if present)
      meta                     : JSON-encoded dict

    If no delta is present, it is computed from the available xent/cross_entropy using α=0 as baseline.
    """
    plt.rcParams.update(
        {
            "font.size": 14,
            "axes.labelsize": 16,
            "axes.titlesize": 18,
            "xtick.labelsize": 12,
            "ytick.labelsize": 12,
            "legend.fontsize": 14,
        }
    )

    data = np.load(str(npz_path), allow_pickle=True)
    alphas = data["alphas"].astype(np.float32)
    meta = json.loads(str(data["meta"]))

    if "delta_cross_entropy" in data:
        delta = data["delta_cross_entropy"].astype(np.float32)
    elif "delta_xent" in data:
        delta = data["delta_xent"].astype(np.float32)
    else:
        xent_arr = None
        if "cross_entropy" in data:
            xent_arr = data["cross_entropy"].astype(np.float32)
        elif "xent" in data:
            xent_arr = data["xent"].astype(np.float32)
        if xent_arr is None:
            raise KeyError(
                f"{npz_path} must contain cross_entropy/xent or delta_cross_entropy/delta_xent to plot cross-entropy deltas."
            )
        eps = 1e-5
        zero_indices = np.nonzero(np.isclose(alphas, 0.0, atol=eps))[0]
        if zero_indices.size == 0:
            raise ValueError("Cannot compute Δ cross-entropy: no α=0 value found in 'alphas'.")
        zero_idx = int(zero_indices[0])
        xent0 = float(xent_arr[zero_idx])
        delta = xent_arr - xent0

    if alpha_range is not None:
        lo, hi = sorted(alpha_range)
        mask = (alphas >= lo) & (alphas <= hi)
        if not np.any(mask):
            raise ValueError(f"No alpha values in [{lo}, {hi}]")
        alphas_sel = alphas[mask]
        delta_sel = delta[mask]
    else:
        alphas_sel = alphas
        delta_sel = delta

    fig, ax = plt.subplots(figsize=(8, 5))
    ax.plot(alphas_sel, delta_sel)

    if symlog:
        ax.set_xscale("symlog")

    ax.grid(True, alpha=0.4)
    ax.set_xlabel("$\\alpha$")
    ax.set_ylabel("Δ cross-entropy vs α=0 (nats)")

    model = meta.get("model", "?")
    concept = meta.get("concept", meta.get("concept_slug", "?"))
    layer = meta.get("layer_idx", "?")
    seq_len = meta.get("seq_len", None)
    eval_blocks = meta.get("eval_blocks", None)

    title = f"{model} | {concept} | layer={layer}"
    if eval_blocks is not None and seq_len is not None:
        title += f" | {eval_blocks} blocks × {seq_len} tokens"
    if title_suffix:
        title += f" | {title_suffix}"

    ax.set_title(title)
    fig.tight_layout()

    if save:
        save_path = Path(save)
        save_path.parent.mkdir(parents=True, exist_ok=True)
        fig.savefig(save_path, dpi=300, bbox_inches="tight")
    else:
        plt.show()


def plot_mmlu_delta(
    json_path: str,
    alpha_range: tuple[float, float] | None = None,
    title_suffix: str = "",
    save: str | None = None,
    symlog: bool = False,
    include_tasks: bool = False,
) -> None:
    """Plot Δ MMLU accuracy vs α=0 (optionally per-task) from layer_*_mmlu.json."""
    plt.rcParams.update(PLOT_STYLE)

    payload = json.loads(Path(json_path).read_text())
    alphas = np.asarray(payload["alphas"], dtype=np.float32)
    overall = np.asarray(
        [np.nan if v is None else float(v) for v in payload.get("overall_scores", [])],
        dtype=np.float32,
    )
    if overall.shape[0] != alphas.shape[0]:
        raise ValueError("overall_scores length does not match alphas length")

    eps = 1e-5
    zero_idx = np.nonzero(np.isclose(alphas, 0.0, atol=eps) & np.isfinite(overall))[0]
    if zero_idx.size:
        base_idx = int(zero_idx[0])
    else:
        finite_idx = np.nonzero(np.isfinite(overall))[0]
        if finite_idx.size == 0:
            raise ValueError("No finite overall scores to use as baseline.")
        base_idx = int(finite_idx[np.argmin(np.abs(alphas[finite_idx]))])

    base_alpha = alphas[base_idx]
    base_overall = overall[base_idx]
    delta_overall = overall

    if alpha_range is not None:
        lo, hi = sorted(alpha_range)
        mask = (alphas >= lo) & (alphas <= hi)
        if not np.any(mask):
            raise ValueError(f"No alpha values in [{lo}, {hi}]")
    else:
        mask = np.ones_like(alphas, dtype=bool)

    alphas_sel = alphas[mask]
    delta_sel = delta_overall[mask]

    fig, ax = plt.subplots(figsize=(8, 5))
    ax.plot(alphas_sel, delta_sel, label="overall Δacc")
    ax.axhline(0.25, color="#999", linestyle="--", linewidth=1, alpha=0.6)

    if include_tasks and "task_scores" in payload:
        for task_id, scores in payload["task_scores"].items():
            arr = np.asarray([np.nan if v is None else float(v) for v in scores], dtype=np.float32)
            if arr.shape[0] != alphas.shape[0] or not np.isfinite(arr[base_idx]):
                continue
            delta_task = arr - arr[base_idx]
            ax.plot(alphas_sel, delta_task[mask], alpha=0.35, linewidth=1, label=f"{task_id} Δ")

    if symlog:
        ax.set_xscale("symlog")
    ax.grid(True, alpha=0.4)
    ax.set_xlabel("$\\alpha$")
    ax.set_ylabel("MMLU accuracy")

    model = payload.get("model", "?")
    concept = payload.get("concept", payload.get("concept_slug", "?"))
    layer = payload.get("layer_idx", "?")
    title = f"{model} | {concept} | layer={layer} | α0={base_alpha:g}"
    if title_suffix:
        title += f" | {title_suffix}"
    ax.set_title(title)

    fig.tight_layout()

    if save:
        save_path = Path(save)
        save_path.parent.mkdir(parents=True, exist_ok=True)
        fig.savefig(save_path, dpi=300, bbox_inches="tight")
    else:
        plt.show()


def parse_args():
    p = argparse.ArgumentParser(description="Plot steering probability curves from .npz")
    p.add_argument(
        "npz_path",
        help="Path to a single .npz (e.g., layer_5_ctx_0.npz, layer_5__xent.npz, or layer_5_cross_entropy.npz)",
    )
    p.add_argument(
        "--plot",
        choices=["probs", "xent"],
        default="probs",
        help="Plot token probabilities ('probs') or cross-entropy deltas ('xent').",
    )
    p.add_argument(
        "--which",
        choices=["alphamax", "alphamin", "both"],
        default="alphamax",
        help="Which token set to plot (alphamax/alphamin/both)",
    )
    p.add_argument(
        "--mode",
        choices=["probs", "delta"],
        default="probs",
        help="Plot raw probabilities or deltas vs α=0",
    )
    p.add_argument("--limit", type=int, default=None, help="Plot only the first N tokens from top-K")
    p.add_argument("--n_labels", type=int, default=6, help="How many lines to label in each panel")
    p.add_argument(
        "--alpha_range",
        nargs=2,
        type=float,
        metavar=("ALPHA_MIN", "ALPHA_MAX"),
        help="Filter to alphas within [min,max]",
    )
    p.add_argument("--title_suffix", default="", help="Extra text appended to the title")
    p.add_argument("--save", default=None, help="If set, save figure to this path instead of showing")
    p.add_argument("--symlog", action="store_true", help="Use symmetric log scale on the x-axis")
    p.add_argument("--logit", action="store_true", help="Use logit scale on the y-axis")
    return p.parse_args()


if __name__ == "__main__":
    args = parse_args()
    alpha_range = tuple(args.alpha_range) if args.alpha_range else None

    if args.plot == "probs":
        plot_npz(
            args.npz_path,
            which=args.which,
            n_labels=args.n_labels,
            limit=args.limit,
            alpha_range=alpha_range,
            title_suffix=args.title_suffix,
            save=args.save,
            symlog=args.symlog,
            logit=args.logit,
            mode=args.mode,
        )
    else:  # args.plot == "xent"
        plot_xent_delta(
            args.npz_path,
            alpha_range=alpha_range,
            title_suffix=args.title_suffix,
            save=args.save,
            symlog=args.symlog,
        )
