# lmkit/experiments/pointer_suite.py
from __future__ import annotations

import argparse
import json
import math
import os
from dataclasses import dataclass
from typing import Dict, List, Tuple, Optional

import jax
import jax.numpy as jnp
import numpy as np
import pandas as pd

from scipy.stats import spearmanr
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score

from lmkit.impl import config as config_lib
from lmkit.impl import transformer
from lmkit.impl.caching import TransformerCache, build_rope
from lmkit.impl.hooks import HookRequest, HookType
from lmkit.impl.hooks import capture as capture_hooks
from lmkit.impl.hooks import unpack_captured
from lmkit.tools import compat, train_utils
from lmkit.tools import data as data_tools

# We rely on your existing event extractor
from ..smiles_events import (
    extract_events,
)  # must provide .rings / .parens with .batch, .pred_idx, .open_idx

# ------------------------------
# utils & helpers
# ------------------------------


def _json_default(o):
    import numpy as _np

    if isinstance(o, (_np.integer,)):
        return int(o)
    if isinstance(o, (_np.floating,)):
        return float(o)
    if isinstance(o, (_np.ndarray,)):
        return o.tolist()
    if hasattr(o, "item"):
        try:
            return o.item()
        except Exception:
            pass
    return str(o)


def _clone_params(params):
    p = dict(params)
    p["layers"] = list(params["layers"])
    return p


def ablate_heads_in_Wo(
    params, heads: List[Tuple[int, int]], hidden_size: int, num_heads: int
):
    """Return a *copy* of params with multiple heads zeroed in W_o."""
    head_dim = hidden_size // num_heads
    new_params = _clone_params(params)
    for layer_id, head_idx in heads:
        lyr = dict(new_params["layers"][layer_id])
        attn = dict(lyr["attn"])
        W_o = np.array(attn["W_o"])  # host mutable
        rs, re = head_idx * head_dim, (head_idx + 1) * head_dim
        W_o[rs:re, :] = 0.0
        attn["W_o"] = jnp.asarray(W_o, dtype=attn["W_o"].dtype)
        lyr["attn"] = attn
        new_params["layers"][layer_id] = lyr
    return new_params


def run_with_hooks(inputs, positions, params, config, hook_pairs):
    hooks_to_return, _ = capture_hooks(*[HookRequest(l, k) for (l, k) in hook_pairs])
    cache = TransformerCache.create(
        positions, config, dtype=jnp.bfloat16, dynamic=False
    )
    logits, _cache, captured = transformer.run(
        inputs,
        cache,
        params,
        config,
        hooks_to_return=hooks_to_return,
        hooks_to_stream=frozenset(),
        editor=None,
    )
    return logits, unpack_captured(hooks_to_return, captured)


def _pad_time_np(a: np.ndarray, T_target: int, pad_value) -> np.ndarray:
    if a.shape[1] == T_target:
        return a
    pad_shape = (a.shape[0], T_target - a.shape[1]) + a.shape[2:]
    pad_block = np.full(pad_shape, pad_value, dtype=a.dtype)
    return np.concatenate([a, pad_block], axis=1)


# ------------------------------
# attention math & probes
# ------------------------------


def _rope_apply(x, sin, cos):
    # match lmkit.impl.transformer.rope() broadcasting
    if x.ndim == 4 and sin.ndim == 3:
        sin = sin[:, :, None, :]
        cos = cos[:, :, None, :]
    elif x.ndim > sin.ndim and x.shape[-1] == sin.shape[-1]:
        num_broadcast_dims = x.ndim - sin.ndim
        new_shape = list(sin.shape)
        for _ in range(num_broadcast_dims):
            new_shape.insert(-1, 1)
        sin = jnp.reshape(sin, new_shape)
        cos = jnp.reshape(cos, new_shape)
        if sin.shape[:-1] != x.shape[:-1] or cos.shape[:-1] != x.shape[:-1]:
            sin = sin[..., None, :]
            cos = cos[..., None, :]
    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]
    rotated = (x * cos) + (jnp.concatenate((-x2, x1), axis=-1) * sin)
    return rotated.astype(x.dtype)


def compute_pointer_mass_for_events(
    resid_pre_layer: jnp.ndarray,  # (B,T,HIDDEN)
    positions: jnp.ndarray,  # (B,T) -1 for pad
    params_layer_attn: Dict,  # W_q/W_k/W_v/W_o
    config,  # FrozenDict
    events: List[Tuple[int, int, int]],  # (b, i_pred, j_open)
) -> np.ndarray:
    """Return average pointer mass per head across events (H,)."""
    hidden = config["hidden_size"]
    H = int(config["num_heads"])
    Hkv = int(config.get("num_kv_heads", H))
    D = hidden // H

    B, T, _ = resid_pre_layer.shape
    sin, cos = build_rope(positions, D, config["rope_base"])

    y = resid_pre_layer
    Q = y @ params_layer_attn["W_q"]  # (B,T,H*D)
    K = y @ params_layer_attn["W_k"]  # (B,T,Hkv*D)
    Q = jnp.reshape(Q, (B, T, H, D))
    K = jnp.reshape(K, (B, T, Hkv, D))
    Qr = _rope_apply(Q, sin, cos)  # (B,T,H,D)
    Kr = _rope_apply(K, sin, cos)  # (B,T,Hkv,D)

    if Hkv != H:
        assert H % Hkv == 0
        repeat = H // Hkv
        Kr = jnp.repeat(Kr, repeats=repeat, axis=2)  # (B,T,H,D)

    inv_sqrt_d = 1.0 / math.sqrt(D)

    per_head_scores = []
    for b, i_pred, j_open in events:
        if i_pred < 0 or j_open < 0 or b < 0 or b >= B:
            continue
        seq_valid = int(jnp.sum(positions[b] >= 0))
        kv_len = min(seq_valid, i_pred + 1)
        if kv_len <= 0 or j_open >= kv_len:
            continue
        q = Qr[b, i_pred]  # (H,D)
        k = Kr[b, :kv_len]  # (kv_len,H,D)
        logits = jnp.einsum("hd,thd->ht", q, k) * inv_sqrt_d  # (H,kv_len)
        w = jax.nn.softmax(logits, axis=-1)  # (H,kv_len)
        per_head_scores.append(w[:, j_open])  # (H,)
    if not per_head_scores:
        return np.zeros((H,), dtype=np.float32)
    arr = jnp.stack(per_head_scores, axis=0)  # (N,H)
    return np.array(jnp.mean(arr, axis=0))  # (H,)


def collect_head_outputs_for_events(
    resid_pre_layer: jnp.ndarray,  # (B,T,HIDDEN)
    positions: jnp.ndarray,
    params_layer_attn: Dict,
    config,
    events: List[Tuple[int, int, int]],  # (b, i_pred, j_open)
    head_idx: int,
    post_ov: bool = False,
):
    """
    Return list of head outputs at the close step for a single head:
      pre-OV:   sum_t alpha_t * V_t      (shape D,)
      post-OV:  (sum_t alpha_t * V_t) @ W_o_block   (shape HIDDEN,)
    """
    hidden = config["hidden_size"]
    H = int(config["num_heads"])
    Hkv = int(config.get("num_kv_heads", H))
    D = hidden // H

    B, T, _ = resid_pre_layer.shape
    sin, cos = build_rope(positions, D, config["rope_base"])

    y = resid_pre_layer
    Q = y @ params_layer_attn["W_q"]
    K = y @ params_layer_attn["W_k"]
    V = y @ params_layer_attn["W_v"]
    Q = jnp.reshape(Q, (B, T, H, D))
    K = jnp.reshape(K, (B, T, Hkv, D))
    V = jnp.reshape(V, (B, T, Hkv, D))

    Qr = _rope_apply(Q, sin, cos)  # (B,T,H,D)
    Kr = _rope_apply(K, sin, cos)  # (B,T,Hkv,D)

    if Hkv != H:
        assert H % Hkv == 0
        repeat = H // Hkv
        Kr = jnp.repeat(Kr, repeats=repeat, axis=2)
        Vr = jnp.repeat(jnp.reshape(V, (B, T, Hkv, D)), repeats=repeat, axis=2)
    else:
        Vr = jnp.reshape(V, (B, T, H, D))

    inv_sqrt_d = 1.0 / math.sqrt(D)

    # Extract OV block for this head
    Wo = params_layer_attn["W_o"]  # (H*D, hidden)
    rs, re = head_idx * D, (head_idx + 1) * D
    Wo_block = Wo[rs:re, :]  # (D, hidden)

    outputs = []
    for b, i_pred, j_open in events:
        if b < 0 or b >= B or i_pred < 0:
            continue
        seq_valid = int(jnp.sum(positions[b] >= 0))
        kv_len = min(seq_valid, i_pred + 1)
        if kv_len <= 0:
            continue
        q = Qr[b, i_pred, head_idx]  # (D,)
        k = Kr[b, :kv_len, head_idx]  # (kv_len,D)
        v = Vr[b, :kv_len, head_idx]  # (kv_len,D)

        logits = jnp.einsum("d,td->t", q, k) * inv_sqrt_d  # (kv_len,)
        w = jax.nn.softmax(logits, axis=-1)  # (kv_len,)
        pre = jnp.sum(w[:, None] * v, axis=0)  # (D,)

        if post_ov:
            post = pre @ Wo_block  # (hidden,)
            outputs.append(np.asarray(post, dtype=np.float32))
        else:
            outputs.append(np.asarray(pre, dtype=np.float32))
    return outputs  # list of vectors


# ------------------------------
# metrics on logits (Δlogit, Δmargin, Δacc)
# ------------------------------


def _gather_at_positions(arr, pos_list):
    b = np.array([p[0] for p in pos_list], dtype=np.int32)
    t = np.array([p[1] for p in pos_list], dtype=np.int32)
    return arr[b, t, :]


def _gold_ids(targets, pos_list):
    b = np.array([p[0] for p in pos_list], dtype=np.int32)
    t = np.array([p[1] for p in pos_list], dtype=np.int32)
    return targets[b, t]


def delta_metrics_at_positions(logits_base, logits_abl, targets, pos_list):
    """Return Δgold_logit, Δmargin, Δacc for a set of token positions."""
    if not pos_list:
        return dict(delta_logit=0.0, delta_margin=0.0, delta_acc=0.0)

    base = _gather_at_positions(logits_base, pos_list)  # (N,V)
    abl = _gather_at_positions(logits_abl, pos_list)  # (N,V)
    gold = _gold_ids(targets, pos_list)  # (N,)

    rows = np.arange(gold.shape[0], dtype=np.int32)
    base_gold = base[rows, gold]
    abl_gold = abl[rows, gold]
    delta_logit = float(np.mean(base_gold - abl_gold))

    # margins (gold - best_other)
    def margin(mat):
        V = mat.shape[1]
        # mask gold index to -inf to extract best-other
        masked = mat.copy()
        masked[rows, gold] = -1e9
        best_other = masked.max(axis=1)
        return base_gold - best_other  # will replace base_gold below

    base_margin = margin(base)
    abl_margin = margin(abl)
    delta_margin = float(np.mean(base_margin - abl_margin))

    # delta accuracy
    base_pred = base.argmax(axis=1)
    abl_pred = abl.argmax(axis=1)
    base_acc = (base_pred == gold).astype(np.float32)
    abl_acc = (abl_pred == gold).astype(np.float32)
    delta_acc = float(
        np.mean(base_acc - abl_acc)
    )  # positive means accuracy dropped on ablation

    return dict(delta_logit=delta_logit, delta_margin=delta_margin, delta_acc=delta_acc)


# ------------------------------
# main evaluation
# ------------------------------


@dataclass
class HeadKey:
    layer: int
    head: int


def run_pointer_suite(
    model_dir: str,
    ckpt_id: int,
    dataset_dir: str,
    split: str,
    num_examples: int,
    batch_size: int,
    out_dir: str,
    max_events: int = 2000,
    topk_report: int = 5,
    seq_length: int = 256,
    compare_ckpt: Optional[int] = None,
):
    os.makedirs(out_dir, exist_ok=True)

    # ---- load tokenizer & model
    tokenizer = compat.load_tokenizer(
        tokenizer_path=f"{model_dir}/tokenizer.json",
        generation_config_file=f"{model_dir}/generation_config.json",
        mode="train",
        trunc_length=seq_length,
    )
    cfg = config_lib.load_from_dir(model_dir).copy(
        dict(
            bos_id=tokenizer.bos_token_id,
            eos_id=tokenizer.eos_token_id,
            pad_id=tokenizer.pad_token_id,
        )
    )
    params, *_ = train_utils.load_checkpoint(
        f"{model_dir}/checkpoints/checkpoint_{ckpt_id}.pkl"
    )

    # ---- dataset slice
    ds = data_tools.load_and_tokenize(
        dataset_dir=dataset_dir,
        tokenizer=tokenizer,
        batch_size=batch_size,
        num_processes=8,
        seed=2002,
        target_column="smiles",
        caching=True,
        limit=num_examples,
    )

    # ---- one pass: logits + RESID_PRE at all layers
    hook_pairs = [(i, HookType.RESID_PRE) for i in range(cfg["num_layers"])]
    batch_inputs, batch_targets, batch_positions = [], [], []
    resid_pre_per_batch: List[Dict[Tuple[int, HookType], np.ndarray]] = []
    logits_per_batch: List[np.ndarray] = []

    processed = 0
    for batch in ds:
        inputs = jnp.asarray(batch["inputs"])
        targets = jnp.asarray(batch["targets"])
        positions = jnp.asarray(batch["positions"])
        logits, captured = run_with_hooks(inputs, positions, params, cfg, hook_pairs)

        batch_inputs.append(np.asarray(inputs))
        batch_targets.append(np.asarray(targets))
        batch_positions.append(np.asarray(positions))
        resid_pre_per_batch.append({k: np.asarray(v) for k, v in captured.items()})
        logits_per_batch.append(np.asarray(logits))
        processed += inputs.shape[0]
        if processed >= num_examples:
            break

    # ---- normalize time dim
    Ts = [arr.shape[1] for arr in batch_inputs]
    if len(set(Ts)) > 1:
        T_max = max(Ts)
        pad_id = tokenizer.pad_token_id
        for i in range(len(batch_inputs)):
            batch_inputs[i] = _pad_time_np(batch_inputs[i], T_max, pad_id)
            batch_targets[i] = _pad_time_np(batch_targets[i], T_max, 0)
            batch_positions[i] = _pad_time_np(batch_positions[i], T_max, -1)
            logits_per_batch[i] = _pad_time_np(logits_per_batch[i], T_max, 0.0)
            cap = {}
            for key, arr in resid_pre_per_batch[i].items():
                cap[key] = _pad_time_np(arr, T_max, 0.0)
            resid_pre_per_batch[i] = cap

    # ---- concat
    all_inputs = np.concatenate(batch_inputs, axis=0)
    all_targets = np.concatenate(batch_targets, axis=0)
    all_positions = np.concatenate(batch_positions, axis=0)
    logits_all = np.concatenate(logits_per_batch, axis=0)

    # ---- events
    events_all = extract_events(tokenizer, all_inputs)

    def take_first_n(ev_list, n):
        return ev_list[:n] if (n is not None and n > 0) else ev_list

    ring_events = take_first_n(events_all.rings, max_events)
    paren_events = take_first_n(events_all.parens, max_events)

    ring_tuples = [
        (ev.batch, ev.pred_idx, ev.open_idx) for ev in ring_events if ev.pred_idx >= 0
    ]
    paren_tuples = [
        (ev.batch, ev.pred_idx, ev.open_idx) for ev in paren_events if ev.pred_idx >= 0
    ]

    ring_pred_positions = [
        (ev.batch, ev.pred_idx) for ev in ring_events if ev.pred_idx >= 0
    ]
    paren_pred_positions = [
        (ev.batch, ev.pred_idx) for ev in paren_events if ev.pred_idx >= 0
    ]

    # ---- pointer mass per head per layer
    L = int(cfg["num_layers"])
    H = int(cfg["num_heads"])

    ring_pointer = []
    paren_pointer = []
    for layer in range(L):
        resid_batches = [
            cap[(layer, HookType.RESID_PRE)] for cap in resid_pre_per_batch
        ]
        resid_pre = jnp.asarray(np.concatenate(resid_batches, axis=0))
        attn = params["layers"][layer]["attn"]
        ring_pm = compute_pointer_mass_for_events(
            resid_pre, jnp.asarray(all_positions), attn, cfg, ring_tuples
        )
        paren_pm = compute_pointer_mass_for_events(
            resid_pre, jnp.asarray(all_positions), attn, cfg, paren_tuples
        )
        ring_pointer.append(ring_pm)
        paren_pointer.append(paren_pm)
    ring_pointer = (
        np.stack(ring_pointer, axis=0) if ring_pointer else np.zeros((0, H))
    ).astype(np.float32, copy=False)
    paren_pointer = (
        np.stack(paren_pointer, axis=0) if paren_pointer else np.zeros((0, H))
    ).astype(np.float32, copy=False)

    # ---- ranking & selection
    rows = []
    for l in range(L):
        for h in range(H):
            rows.append(
                dict(
                    layer=l,
                    head=h,
                    pointer_ring=float(ring_pointer[l, h]),
                    pointer_paren=float(paren_pointer[l, h]),
                )
            )
    rank_df = pd.DataFrame(rows)
    rank_df["rank_ring"] = rank_df["pointer_ring"].rank(ascending=False, method="min")
    rank_df["rank_paren"] = rank_df["pointer_paren"].rank(ascending=False, method="min")

    top_ring = rank_df.sort_values("pointer_ring", ascending=False).head(topk_report)
    top_paren = rank_df.sort_values("pointer_paren", ascending=False).head(topk_report)

    # ---- ablations: Δlogit, Δmargin, Δacc (event vs control)
    def run_abl(layer: int, head: int):
        p_abl = ablate_heads_in_Wo(
            params, [(layer, head)], cfg["hidden_size"], cfg["num_heads"]
        )
        # forward twice (baseline already computed as logits_all)
        logits_abl_batches = []
        for inputs, positions in zip(batch_inputs, batch_positions):
            cache = TransformerCache.create(
                jnp.asarray(positions), cfg, dtype=jnp.bfloat16, dynamic=False
            )
            l, *_ = transformer.run(jnp.asarray(inputs), cache, p_abl, cfg)
            logits_abl_batches.append(np.asarray(l))
        logits_abl = np.concatenate(logits_abl_batches, axis=0)

        # controls: sample up to N matched non-events
        used = set(ring_pred_positions + paren_pred_positions)
        B, T, V = logits_all.shape
        ctrl = []
        for bb in range(B):
            for tt in range(T):
                if all_positions[bb, tt] >= 0 and (bb, tt) not in used:
                    ctrl.append((bb, tt))
                    if len(ctrl) >= min(
                        len(ring_pred_positions) + len(paren_pred_positions), 5000
                    ):
                        break
            if len(ctrl) >= min(
                len(ring_pred_positions) + len(paren_pred_positions), 5000
            ):
                break

        ring_ev = delta_metrics_at_positions(
            logits_all, logits_abl, all_targets, ring_pred_positions
        )
        parn_ev = delta_metrics_at_positions(
            logits_all, logits_abl, all_targets, paren_pred_positions
        )
        ctrl_ev = (
            delta_metrics_at_positions(logits_all, logits_abl, all_targets, ctrl)
            if ctrl
            else dict(delta_logit=0.0, delta_margin=0.0, delta_acc=0.0)
        )

        return dict(
            layer=layer,
            head=head,
            delta_logit_ring=ring_ev["delta_logit"],
            delta_margin_ring=ring_ev["delta_margin"],
            delta_acc_ring=ring_ev["delta_acc"],
            delta_logit_paren=parn_ev["delta_logit"],
            delta_margin_paren=parn_ev["delta_margin"],
            delta_acc_paren=parn_ev["delta_acc"],
            delta_logit_controls=ctrl_ev["delta_logit"],
            delta_margin_controls=ctrl_ev["delta_margin"],
            delta_acc_controls=ctrl_ev["delta_acc"],
        )

    ablation_rows = []
    for _, r in (
        pd.concat([top_ring, top_paren]).drop_duplicates(["layer", "head"]).iterrows()
    ):
        ablation_rows.append(run_abl(int(r["layer"]), int(r["head"])))

    ablate_df = pd.DataFrame(ablation_rows)
    # merge & save extended summary
    summary = rank_df.merge(ablate_df, on=["layer", "head"], how="left")
    out_csv = os.path.join(out_dir, "pointer_suite_summary.csv")
    summary.to_csv(out_csv, index=False)

    # ---- redundancy: top-2 joint ablations per task
    def joint_redundancy(task: str, topdf: pd.DataFrame, k=2):
        top2 = topdf.sort_values(f"pointer_{task}", ascending=False).head(k)
        heads = [(int(r["layer"]), int(r["head"])) for _, r in top2.iterrows()]
        if len(heads) < 2:
            return None
        p_abl = ablate_heads_in_Wo(params, heads, cfg["hidden_size"], cfg["num_heads"])
        logits_abl_batches = []
        for inputs, positions in zip(batch_inputs, batch_positions):
            cache = TransformerCache.create(
                jnp.asarray(positions), cfg, dtype=jnp.bfloat16, dynamic=False
            )
            l, *_ = transformer.run(jnp.asarray(inputs), cache, p_abl, cfg)
            logits_abl_batches.append(np.asarray(l))
        logits_abl = np.concatenate(logits_abl_batches, axis=0)

        pos_list = ring_pred_positions if task == "ring" else paren_pred_positions
        # individual deltas (from ablate_df)
        l1, h1 = heads[0]
        l2, h2 = heads[1]
        r1 = ablate_df[(ablate_df["layer"] == l1) & (ablate_df["head"] == h1)].iloc[0]
        r2 = ablate_df[(ablate_df["layer"] == l2) & (ablate_df["head"] == h2)].iloc[0]
        if task == "ring":
            d1 = r1["delta_acc_ring"]
            d2 = r2["delta_acc_ring"]
        else:
            d1 = r1["delta_acc_paren"]
            d2 = r2["delta_acc_paren"]

        # joint delta
        base = _gather_at_positions(logits_all, pos_list)
        abl = _gather_at_positions(logits_abl, pos_list)
        gold = _gold_ids(all_targets, pos_list)
        joint = float(
            (
                (base.argmax(axis=1) == gold).astype(np.float32)
                - (abl.argmax(axis=1) == gold).astype(np.float32)
            ).mean()
        )
        return dict(
            task=task,
            head1=f"L{l1}H{h1}",
            head2=f"L{l2}H{h2}",
            delta_acc_1=float(d1),
            delta_acc_2=float(d2),
            delta_acc_joint=joint,
            redundancy_index=joint / ((d1 + d2) + 1e-9),
        )

    red_rows = []
    jr = joint_redundancy("ring", top_ring, k=2)
    if jr:
        red_rows.append(jr)
    jp = joint_redundancy("paren", top_paren, k=2)
    if jp:
        red_rows.append(jp)
    red_df = pd.DataFrame(red_rows)
    red_csv = os.path.join(out_dir, "redundancy_table.csv")
    red_df.to_csv(red_csv, index=False)

    # ---- distance / depth robustness
    # ring span = pred_idx - open_idx
    ring_spans = np.array(
        [
            ev.pred_idx - ev.open_idx
            for ev in ring_events
            if ev.pred_idx >= 0 and ev.open_idx >= 0
        ]
    )
    # paren depth: compute from tokens ( '(' increments, ')' decrements )
    tok_lparen = tokenizer.token_to_id("(")
    tok_rparen = tokenizer.token_to_id(")")
    depths = []
    for ev in paren_events:
        if ev.pred_idx < 0:
            continue
        seq = all_inputs[ev.batch]
        depth = 0
        for t in range(ev.pred_idx + 1):
            tok = seq[t]
            if tok == tok_lparen:
                depth += 1
            elif tok == tok_rparen:
                depth = max(depth - 1, 0)
        depths.append(depth)
    paren_depths = np.array(depths, dtype=np.int32)

    def bin_edges_quantiles(vals, q=(0, 0.2, 0.4, 0.6, 0.8, 1.0)):
        vs = np.asarray(vals)
        if vs.size == 0:
            return [0, 1]
        qs = np.quantile(vs, q)
        edges = [int(qs[0])]
        for v in qs[1:]:
            if int(v) > edges[-1]:
                edges.append(int(v))
        if edges[-1] < vs.max():
            edges[-1] = int(vs.max())
        return edges

    ring_bins = bin_edges_quantiles(ring_spans)  # monotone
    paren_bins = bin_edges_quantiles(paren_depths)

    def bin_indices(vals, edges):
        idx = []
        for i in range(len(edges) - 1):
            lo, hi = edges[i], edges[i + 1]
            mask = (vals >= lo) & (vals <= hi if i == len(edges) - 2 else vals < hi)
            idx.append(np.where(mask)[0])
        return idx

    ring_idx_bins = bin_indices(ring_spans, ring_bins)
    paren_idx_bins = bin_indices(paren_depths, paren_bins)

    # choose the prime heads to profile
    ring_head = (int(top_ring.iloc[0]["layer"]), int(top_ring.iloc[0]["head"]))
    parn_head = (int(top_paren.iloc[0]["layer"]), int(top_paren.iloc[0]["head"]))

    def pointer_mass_for_head_in_bin(task: str, head: Tuple[int, int], bin_indices_arr):
        l, h = head
        vals = []
        for bin_ids in bin_indices_arr:
            if task == "ring":
                evs = [ring_tuples[i] for i in bin_ids.tolist()]
            else:
                evs = [paren_tuples[i] for i in bin_ids.tolist()]
            resid = jnp.asarray(
                np.concatenate(
                    [cap[(l, HookType.RESID_PRE)] for cap in resid_pre_per_batch],
                    axis=0,
                )
            )
            pm = compute_pointer_mass_for_events(
                resid, jnp.asarray(all_positions), params["layers"][l]["attn"], cfg, evs
            )
            vals.append(float(pm[h]) if pm.size > 0 else 0.0)
        return vals

    # Δmargin vs span/depth for ablation of the prime head
    def delta_margin_curve(task: str, head: Tuple[int, int], bin_indices_arr):
        l, h = head
        p_abl = ablate_heads_in_Wo(params, [head], cfg["hidden_size"], cfg["num_heads"])
        logits_abl_batches = []
        for inputs, positions in zip(batch_inputs, batch_positions):
            cache = TransformerCache.create(
                jnp.asarray(positions), cfg, dtype=jnp.bfloat16, dynamic=False
            )
            lgt, *_ = transformer.run(jnp.asarray(inputs), cache, p_abl, cfg)
            logits_abl_batches.append(np.asarray(lgt))
        logits_abl = np.concatenate(logits_abl_batches, axis=0)
        out = []
        for bin_ids in bin_indices_arr:
            pos_list = []
            for i in bin_ids.tolist():
                if task == "ring":
                    ev = ring_events[i]
                    pos_list.append((ev.batch, ev.pred_idx))
                else:
                    ev = paren_events[i]
                    pos_list.append((ev.batch, ev.pred_idx))
            if not pos_list:
                out.append(0.0)
                continue
            dm = delta_metrics_at_positions(
                logits_all, logits_abl, all_targets, pos_list
            )["delta_margin"]
            out.append(float(dm))
        return out

    ring_pm_curve = pointer_mass_for_head_in_bin("ring", ring_head, ring_idx_bins)
    ring_dm_curve = delta_margin_curve("ring", ring_head, ring_idx_bins)
    parn_pm_curve = pointer_mass_for_head_in_bin("paren", parn_head, paren_idx_bins)
    parn_dm_curve = delta_margin_curve("paren", parn_head, paren_idx_bins)

    curves = dict(
        ring_bins=ring_bins,
        ring_pointer_mass=ring_pm_curve,
        ring_delta_margin=ring_dm_curve,
        paren_bins=paren_bins,
        paren_pointer_mass=parn_pm_curve,
        paren_delta_margin=parn_dm_curve,
        ring_head=f"L{ring_head[0]}H{ring_head[1]}",
        paren_head=f"L{parn_head[0]}H{parn_head[1]}",
    )
    with open(os.path.join(out_dir, "robustness_curves.json"), "w") as f:
        json.dump(curves, f, indent=2, default=_json_default)

    # ---- value-stream probe (part (b))
    def value_probe(
        task: str, head: Tuple[int, int], use_post_ov: bool, random_state=0
    ):
        l, h = head
        resid = jnp.asarray(
            np.concatenate(
                [cap[(l, HookType.RESID_PRE)] for cap in resid_pre_per_batch], axis=0
            )
        )
        attn = params["layers"][l]["attn"]
        if task == "ring":
            evs = ring_events
            tuples = ring_tuples
            # class = opener digit (token)
            # opener token id at open_idx
            labels = []
            for ev in evs:
                if ev.pred_idx < 0 or ev.open_idx < 0:
                    continue
                lab_tok = all_inputs[ev.batch, ev.open_idx]
                labels.append(int(lab_tok))
        else:
            evs = paren_events
            tuples = paren_tuples
            # class = paren depth bucket at close step
            labels = []
            for depth in paren_depths:
                labels.append(int(depth))
            # bucketize later

        outputs = collect_head_outputs_for_events(
            resid, jnp.asarray(all_positions), attn, cfg, tuples, h, post_ov=use_post_ov
        )
        X = (
            np.stack(outputs, axis=0).astype(np.float32, copy=False)
            if outputs
            else np.zeros((0, 1), dtype=np.float32)
        )

        if task == "paren":
            # bucketize depth: bins [1-2, 3-4, >=5]
            raw = np.array(labels[: len(outputs)], dtype=np.int32)
            y = np.where(raw <= 2, 0, np.where(raw <= 4, 1, 2))
        else:
            # ring: map digit tokens to 0..9 if digits, else collapse others to 10+ (filter)
            tok2id = {str(d): tokenizer.token_to_id(str(d)) for d in range(10)}
            valid = np.array([lab in tok2id.values() for lab in labels[: len(outputs)]])
            X = X[valid]
            lab = np.array(labels[: len(outputs)])[valid]
            # map ids to 0..9
            id2class = {v: k for k, v in tok2id.items()}
            y = np.array([id2class[int(z)] for z in lab], dtype=np.int32)

        if X.shape[0] < 100:
            return dict(n=X.shape[0], acc=0.0, post_ov=use_post_ov)

        X_train, X_test, y_train, y_test = train_test_split(
            X, y, test_size=0.3, random_state=random_state, stratify=y
        )
        clf = LogisticRegression(max_iter=1000)
        clf.fit(X_train, y_train)
        acc = accuracy_score(y_test, clf.predict(X_test))
        return dict(n=int(X.shape[0]), acc=float(acc), post_ov=use_post_ov)

    probe_ring_pre = value_probe("ring", ring_head, use_post_ov=False)
    probe_ring_post = value_probe("ring", ring_head, use_post_ov=True)
    probe_parn_pre = value_probe("paren", parn_head, use_post_ov=False)
    probe_parn_post = value_probe("paren", parn_head, use_post_ov=True)

    with open(os.path.join(out_dir, "value_probe.json"), "w") as f:
        json.dump(
            dict(
                ring_head=f"L{ring_head[0]}H{ring_head[1]}",
                paren_head=f"L{parn_head[0]}H{parn_head[1]}",
                ring_pre=probe_ring_pre,
                ring_post=probe_ring_post,
                paren_pre=probe_parn_pre,
                paren_post=probe_parn_post,
            ),
            f,
            indent=2,
            default=_json_default,
        )

    # ---- pointer vs. global scatter (Δmargin_event vs Δmargin_control; size=pointer mass)
    def mk_scatter_csv():
        rows = []
        for _, r in summary.iterrows():
            rows.append(
                dict(
                    layer=int(r["layer"]),
                    head=int(r["head"]),
                    pointer_ring=float(r["pointer_ring"]),
                    pointer_paren=float(r["pointer_paren"]),
                    dmargin_ring=float(r.get("delta_margin_ring", 0.0)),
                    dmargin_paren=float(r.get("delta_margin_paren", 0.0)),
                    dmargin_ctrl=float(r.get("delta_margin_controls", 0.0)),
                )
            )
        df = pd.DataFrame(rows)
        df.to_csv(os.path.join(out_dir, "pointer_vs_global_scatter.csv"), index=False)

    mk_scatter_csv()

    # ---- cross-checkpoint stability (optional)
    if compare_ckpt is not None:
        params2, *_ = train_utils.load_checkpoint(
            f"{model_dir}/checkpoints/checkpoint_{compare_ckpt}.pkl"
        )
        # recompute pointer mass in ckpt2 (reuse same batches & events)
        ring_pointer2 = []
        paren_pointer2 = []
        for layer in range(L):
            resid = jnp.asarray(
                np.concatenate(
                    [cap[(layer, HookType.RESID_PRE)] for cap in resid_pre_per_batch],
                    axis=0,
                )
            )
            attn2 = params2["layers"][layer]["attn"]
            ring_pm = compute_pointer_mass_for_events(
                resid, jnp.asarray(all_positions), attn2, cfg, ring_tuples
            )
            paren_pm = compute_pointer_mass_for_events(
                resid, jnp.asarray(all_positions), attn2, cfg, paren_tuples
            )
            ring_pointer2.append(ring_pm)
            paren_pointer2.append(paren_pm)
        ring_pointer2 = (np.stack(ring_pointer2, axis=0) if ring_pointer2 else np.zeros((0, H))).astype(np.float32, copy=False)
        paren_pointer2 = (np.stack(paren_pointer2, axis=0) if paren_pointer2 else np.zeros((0, H))).astype(np.float32, copy=False)

        # Spearman across all L*H entries
        rp1 = np.asarray(ring_pointer,  dtype=np.float32).ravel()
        rp2 = np.asarray(ring_pointer2, dtype=np.float32).ravel()
        pp1 = np.asarray(paren_pointer,  dtype=np.float32).ravel()
        pp2 = np.asarray(paren_pointer2, dtype=np.float32).ravel()
        r_rho,_ = spearmanr(rp1, rp2)
        p_rho,_ = spearmanr(pp1, pp2)

        # Jaccard@k (5,10)
        def jaccard_at_k(pm1, pm2, k):
            idx1 = set(np.argsort(pm1.flatten())[::-1][:k])
            idx2 = set(np.argsort(pm2.flatten())[::-1][:k])
            inter = len(idx1 & idx2)
            union = len(idx1 | idx2)
            return inter / max(union, 1)

        stab = dict(
            ckpt_a=int(ckpt_id),
            ckpt_b=int(compare_ckpt),
            spearman_ring=float(r_rho),
            spearman_paren=float(p_rho),
            jaccard5_ring=float(jaccard_at_k(ring_pointer, ring_pointer2, 5)),
            jaccard10_ring=float(jaccard_at_k(ring_pointer, ring_pointer2, 10)),
            jaccard5_paren=float(jaccard_at_k(paren_pointer, paren_pointer2, 5)),
            jaccard10_paren=float(jaccard_at_k(paren_pointer, paren_pointer2, 10)),
        )
        with open(os.path.join(out_dir, "stability_summary.json"), "w") as f:
            json.dump(stab, f, indent=2, default=_json_default)

    print(
        f"[pointer_suite] Wrote:\n  - {out_csv}\n  - {red_csv}\n  - robustness_curves.json\n  - value_probe.json\n  - pointer_vs_global_scatter.csv"
    )
    if compare_ckpt is not None:
        print("  - stability_summary.json")


def main():
    ap = argparse.ArgumentParser(
        description="Pointer head suite: Δmargin, probes, robustness, redundancy, stability."
    )
    ap.add_argument("--model_dir", required=True)
    ap.add_argument("--ckpt_id", type=int, required=True)
    ap.add_argument("--dataset_dir", required=True)
    ap.add_argument("--split", default="train")
    ap.add_argument("--num_examples", type=int, default=4096)
    ap.add_argument("--batch_size", type=int, default=512)
    ap.add_argument("--out_dir", default="experiments/pointer_suite_outputs")
    ap.add_argument("--max_events", type=int, default=2000)
    ap.add_argument("--topk_report", type=int, default=5)
    ap.add_argument("--seq_length", type=int, default=256)
    ap.add_argument(
        "--compare_ckpt",
        type=int,
        default=None,
        help="Optional: second checkpoint id for stability tests",
    )
    args = ap.parse_args()

    run_pointer_suite(
        model_dir=args.model_dir,
        ckpt_id=args.ckpt_id,
        dataset_dir=args.dataset_dir,
        split=args.split,
        num_examples=args.num_examples,
        batch_size=args.batch_size,
        out_dir=args.out_dir,
        max_events=args.max_events,
        topk_report=args.topk_report,
        seq_length=args.seq_length,
        compare_ckpt=args.compare_ckpt,
    )


if __name__ == "__main__":
    main()
