# lmkit/experiments/pointer_must_adds.py
from __future__ import annotations

import argparse
import math
import os
from dataclasses import dataclass
from typing import Dict, Iterable, List, Sequence, Tuple

import jax
import jax.numpy as jnp
import numpy as np
import pandas as pd

from lmkit.impl import config as config_lib
from lmkit.impl import transformer
from lmkit.impl.caching import TransformerCache, build_rope
from lmkit.impl.hooks import HookRequest, HookType
from lmkit.impl.hooks import capture as capture_hooks, unpack_captured
from lmkit.tools import compat, train_utils
from lmkit.tools import data as data_tools

# We rely on the same event extraction you used for pointer_suite/pointer_heads
from ..smiles_events import extract_events


# ------------------------------
# utilities: params + hooks
# ------------------------------
def _clone_params(params):
    """Make a shallow copy of the param Pytree at the edges we mutate."""
    p = dict(params)
    p["layers"] = list(params["layers"])
    return p


def ablate_heads_in_Wo(
    params,
    layer_id: int,
    head_indices: Sequence[int],
    hidden_size: int,
    num_heads: int,
):
    """
    Return a PARAM COPY with the chosen heads' rows zeroed in W_o
    (clean, pre-residual head ablation). Works with JAX arrays.
    """
    head_dim = hidden_size // num_heads
    new_params = _clone_params(params)
    lyr = dict(new_params["layers"][layer_id])
    attn = dict(lyr["attn"])
    W_o = jnp.asarray(attn["W_o"])
    mask = jnp.ones_like(W_o)

    for h in head_indices:
        row_start = int(h) * head_dim
        row_end = row_start + head_dim
        mask = mask.at[row_start:row_end, :].set(0.0)

    attn["W_o"] = W_o * mask
    lyr["attn"] = attn
    new_params["layers"][layer_id] = lyr
    return new_params


def run_with_hooks(inputs, positions, params, config, hook_pairs):
    """Single forward pass with custom hooks for RESID_PRE."""
    hooks_to_return, _ = capture_hooks(*[HookRequest(l, k) for (l, k) in hook_pairs])
    cache = TransformerCache.create(
        positions=positions, model_config=config, dtype=jnp.bfloat16, dynamic=False
    )
    logits, _cache, captured = transformer.run(
        inputs,
        cache,
        params,
        config,
        hooks_to_return=hooks_to_return,
        hooks_to_stream=frozenset(),
        editor=None,
    )
    return logits, unpack_captured(hooks_to_return, captured)


def _pad_time_np(a: np.ndarray, T_target: int, pad_value) -> np.ndarray:
    """Right-pad along time dimension to T_target."""
    if a.shape[1] == T_target:
        return a
    pad_shape = (a.shape[0], T_target - a.shape[1]) + a.shape[2:]
    pad_block = np.full(pad_shape, pad_value, dtype=a.dtype)
    return np.concatenate([a, pad_block], axis=1)


# ------------------------------
# attention computation (per-event, per-head)
# ------------------------------
def compute_pointer_mass_for_events(
    resid_pre_layer: jnp.ndarray,  # (B,T,HIDDEN)
    positions: jnp.ndarray,  # (B,T)  -1 for pad
    params_layer_attn: Dict,  # dict with W_q/W_k/W_v/W_o
    config,  # FrozenDict
    events: List[Tuple[int, int, int]],  # list of (b, pred_idx, opener_idx)
) -> np.ndarray:
    """
    Return pointer mass per head: shape (num_heads,) averaged across provided events.
    Robust to GQA (num_kv_heads < num_heads): K is tiled across query heads.
    """
    hidden = int(config["hidden_size"])
    num_heads = int(config["num_heads"])
    num_kv_heads = int(config.get("num_kv_heads", num_heads))
    head_dim = hidden // num_heads

    B, T, _ = resid_pre_layer.shape

    # build RoPE
    sin, cos = build_rope(positions, head_dim, config["rope_base"])  # (B,T,head_dim)

    # project once (B,T, H*D) etc.
    y = resid_pre_layer
    Q = y @ params_layer_attn["W_q"]  # (B,T, H*D)
    K = y @ params_layer_attn["W_k"]  # (B,T, H_kv*D)

    # reshape to heads and apply RoPE
    Q = jnp.reshape(Q, (B, T, num_heads, head_dim))
    K = jnp.reshape(K, (B, T, num_kv_heads, head_dim))
    from lmkit.impl.transformer import rope as rope_rot

    Qr = rope_rot(Q, sin, cos)  # (B,T,H,D)
    Kr = rope_rot(K, sin, cos)  # (B,T,H_kv,D)

    # If GQA, tile K across query heads
    if num_kv_heads != num_heads:
        assert num_heads % num_kv_heads == 0, (
            "num_heads must be a multiple of num_kv_heads"
        )
        repeat = num_heads // num_kv_heads
        Kr = jnp.repeat(Kr, repeats=repeat, axis=2)  # (B,T,H,D)

    inv_sqrt_d = 1.0 / math.sqrt(head_dim)

    per_head_scores = []
    for b, i_pred, j_open in events:
        if i_pred < 0 or j_open < 0 or b < 0 or b >= B:
            continue
        seq_valid = int(jnp.sum(positions[b] >= 0))
        kv_len = min(seq_valid, i_pred + 1)
        if kv_len <= 0 or j_open >= kv_len:
            continue

        q = Qr[b, i_pred]  # (H,D)
        k = Kr[b, :kv_len]  # (kv_len,H,D)
        logits = jnp.einsum("hd,thd->ht", q, k) * inv_sqrt_d
        weights = jax.nn.softmax(logits, axis=-1)  # (H, kv_len)
        pointer = weights[:, j_open]  # (H,)
        per_head_scores.append(pointer)

    if not per_head_scores:
        return np.zeros((num_heads,), dtype=np.float32)
    arr = jnp.stack(per_head_scores, axis=0)  # (N_events, H)
    return np.array(jnp.mean(arr, axis=0), dtype=np.float32)  # (H,)


# ------------------------------
# metrics helpers (margin, acc)
# ------------------------------
def gold_and_margin(
    logits: np.ndarray, targets: np.ndarray
) -> Tuple[np.ndarray, np.ndarray]:
    """
    Safe computation of gold logit and margin that tolerates out-of-range targets
    at padded positions. We clip indices into [0, V-1] for the gather; those
    padded positions are not used later (we only aggregate over event/control positions).
    Returns:
      gold_logit: (B,T) float32
      margin:     (B,T) float32  where margin = gold - max_{v != gold} logit
    """
    assert logits.ndim == 3 and targets.ndim == 2, (
        "shapes: logits (B,T,V), targets (B,T)"
    )
    B, T, V = logits.shape

    # ----- gather gold logit safely -----
    t_clip = np.clip(targets, 0, V - 1)  # (B,T), safe for gather at padded steps
    # gather gold: take_along_axis expects the index array to have a trailing axis
    gold = np.take_along_axis(logits, t_clip[..., None], axis=-1).squeeze(-1)  # (B,T)

    # ----- compute max of others -----
    # argmax over vocab
    argmax = logits.argmax(axis=-1)  # (B,T)
    max_all = logits.max(axis=-1)  # (B,T)

    # positions where the model's argmax equals the (clipped) gold index
    same = argmax == t_clip

    # For those positions, we need the "second best": max over vocab with gold index removed.
    # Do a copy once (float32), zero-out the gold logit, and take max.
    hacked = logits.copy()
    # Advanced assignment (B,T) into (B,T,V) – safe due to clipping above:
    hacked[np.arange(B)[:, None], np.arange(T)[None, :], t_clip] = -np.inf
    second_best = hacked.max(axis=-1)  # (B,T)

    # If gold was the argmax, max_other = second_best; otherwise it's max_all
    max_other = np.where(same, second_best, max_all)

    margin = gold - max_other
    return gold.astype(np.float32, copy=False), margin.astype(np.float32, copy=False)


def accuracy_at_positions(
    logits: np.ndarray, targets: np.ndarray, pos_list: List[Tuple[int, int]]
) -> float:
    if not pos_list:
        return 0.0
    b = np.array([p[0] for p in pos_list], dtype=np.int32)
    t = np.array([p[1] for p in pos_list], dtype=np.int32)
    gold = targets[b, t]
    pred = logits[b, t, :].argmax(axis=-1)
    return float((pred == gold).astype(np.float32).mean())


def mean_over_positions(mat: np.ndarray, pos_list: List[Tuple[int, int]]) -> float:
    if not pos_list:
        return 0.0
    b = np.array([p[0] for p in pos_list], dtype=np.int32)
    t = np.array([p[1] for p in pos_list], dtype=np.int32)
    return float(mat[b, t].mean())


# ------------------------------
# main driver
# ------------------------------
@dataclass
class HeadKey:
    layer: int
    head: int


def run_must_adds(
    model_dir: str,
    ckpt_id: int,
    dataset_dir: str,
    split: str,
    num_examples: int,
    batch_size: int,
    out_dir: str,
    max_events: int = 2000,
    topk_report: int = 5,
    seq_length: int = 256,
):
    os.makedirs(out_dir, exist_ok=True)

    # --- tokenizer & model config
    tokenizer = compat.load_tokenizer(
        tokenizer_path=f"{model_dir}/tokenizer.json",
        generation_config_file=f"{model_dir}/generation_config.json",
        mode="train",
        trunc_length=seq_length,
    )
    cfg = config_lib.load_from_dir(model_dir).copy(
        dict(
            bos_id=tokenizer.bos_token_id,
            eos_id=tokenizer.eos_token_id,
            pad_id=tokenizer.pad_token_id,
        )
    )
    params, *_ = train_utils.load_checkpoint(
        f"{model_dir}/checkpoints/checkpoint_{ckpt_id}.pkl"
    )

    # --- dataset (tokenized)
    ds = data_tools.load_and_tokenize(
        dataset_dir=dataset_dir,
        tokenizer=tokenizer,
        batch_size=batch_size,
        num_processes=8,
        seed=2002,
        target_column="smiles",
        caching=True,
        limit=num_examples,
    )

    # --- capture RESID_PRE and logits (teacher-forced)
    hook_pairs = [(i, HookType.RESID_PRE) for i in range(cfg["num_layers"])]
    batch_inputs, batch_targets, batch_positions = [], [], []
    resid_pre_per_batch: List[Dict[Tuple[int, HookType], np.ndarray]] = []
    logits_per_batch: List[np.ndarray] = []
    processed_examples = 0

    for batch in ds:
        inputs = jnp.asarray(batch["inputs"], dtype=jnp.int32)
        targets = jnp.asarray(batch["targets"], dtype=jnp.int32)
        positions = jnp.asarray(batch["positions"], dtype=jnp.int32)
        logits, captured = run_with_hooks(inputs, positions, params, cfg, hook_pairs)

        batch_inputs.append(np.asarray(inputs))
        batch_targets.append(np.asarray(targets))
        batch_positions.append(np.asarray(positions))

        resid_pre_per_batch.append({k: np.asarray(v) for k, v in captured.items()})
        logits_per_batch.append(np.asarray(logits, dtype=np.float32))

        processed_examples += inputs.shape[0]
        if processed_examples >= num_examples:
            break

    # --- normalize time dimension across batches to a common T
    Ts = [arr.shape[1] for arr in batch_inputs]
    T_max = max(Ts)
    pad_id = tokenizer.pad_token_id

    for i in range(len(batch_inputs)):
        if batch_inputs[i].shape[1] != T_max:
            batch_inputs[i] = _pad_time_np(batch_inputs[i], T_max, pad_id)
            batch_targets[i] = _pad_time_np(batch_targets[i], T_max, 0)
            batch_positions[i] = _pad_time_np(batch_positions[i], T_max, -1)
            logits_per_batch[i] = _pad_time_np(logits_per_batch[i], T_max, 0.0)
            cap = {}
            for key, arr in resid_pre_per_batch[i].items():
                cap[key] = _pad_time_np(arr, T_max, 0.0)
            resid_pre_per_batch[i] = cap

    # --- concatenate across batches (B, T, ...)
    all_inputs = np.concatenate(batch_inputs, axis=0).astype(np.int32, copy=False)
    all_targets = np.concatenate(batch_targets, axis=0).astype(np.int32, copy=False)
    all_positions = np.concatenate(batch_positions, axis=0).astype(np.int32, copy=False)
    logits_all = np.concatenate(logits_per_batch, axis=0).astype(np.float32, copy=False)

    # --- event extraction
    events_all = extract_events(tokenizer, all_inputs)

    def take_first_n(ev_list, n):
        return ev_list[:n] if n is not None else ev_list

    ring_events = take_first_n(events_all.rings, max_events)
    paren_events = take_first_n(events_all.parens, max_events)

    ring_pred_positions = [
        (ev.batch, ev.pred_idx) for ev in ring_events if ev.pred_idx >= 0
    ]
    paren_pred_positions = [
        (ev.batch, ev.pred_idx) for ev in paren_events if ev.pred_idx >= 0
    ]

    # --- control sampling (exclude all event positions)
    used = set(ring_pred_positions + paren_pred_positions)
    B, T = all_positions.shape
    valid_mask = all_positions >= 0
    ctrl_candidates = [
        (bb, tt)
        for bb in range(B)
        for tt in range(T)
        if valid_mask[bb, tt] and (bb, tt) not in used
    ]
    rng = np.random.RandomState(42)
    n_ctrl_ring = min(len(ring_pred_positions), len(ctrl_candidates))
    n_ctrl_par = min(len(paren_pred_positions), len(ctrl_candidates))
    ctrl_ring = rng.choice(len(ctrl_candidates), size=n_ctrl_ring, replace=False)
    ctrl_par = rng.choice(len(ctrl_candidates), size=n_ctrl_par, replace=False)
    ctrl_ring_positions = [ctrl_candidates[i] for i in ctrl_ring]
    ctrl_paren_positions = [ctrl_candidates[i] for i in ctrl_par]

    # --- compute gold & margin once for baseline
    gold_base, margin_base = gold_and_margin(logits_all, all_targets)

    # --- compute pointer mass per head (layer by layer)
    L = int(cfg["num_layers"])
    H = int(cfg["num_heads"])

    ring_pointer = []
    paren_pointer = []

    for layer in range(L):
        resid_batches = [
            cap[(layer, HookType.RESID_PRE)] for cap in resid_pre_per_batch
        ]
        resid_pre = jnp.asarray(np.concatenate(resid_batches, axis=0))  # (B,T,HIDDEN)
        attn_params = params["layers"][layer]["attn"]

        ring_tuples = [
            (ev.batch, ev.pred_idx, ev.open_idx)
            for ev in ring_events
            if ev.pred_idx >= 0
        ]
        paren_tuples = [
            (ev.batch, ev.pred_idx, ev.open_idx)
            for ev in paren_events
            if ev.pred_idx >= 0
        ]

        ring_pm = compute_pointer_mass_for_events(
            resid_pre, jnp.asarray(all_positions), attn_params, cfg, ring_tuples
        )
        paren_pm = compute_pointer_mass_for_events(
            resid_pre, jnp.asarray(all_positions), attn_params, cfg, paren_tuples
        )

        ring_pointer.append(ring_pm)
        paren_pointer.append(paren_pm)

    ring_pointer = (
        np.stack(ring_pointer, axis=0)
        if ring_pointer
        else np.zeros((0, H), dtype=np.float32)
    ).astype(np.float32)
    paren_pointer = (
        np.stack(paren_pointer, axis=0)
        if paren_pointer
        else np.zeros((0, H), dtype=np.float32)
    ).astype(np.float32)

    # --- rank heads and choose top-k per task
    entries = []
    for l in range(L):
        for h in range(H):
            pr = float(ring_pointer[l, h]) if ring_pointer.size else 0.0
            pp = float(paren_pointer[l, h]) if paren_pointer.size else 0.0
            entries.append(dict(layer=l, head=h, pointer_ring=pr, pointer_paren=pp))
    rank_df = pd.DataFrame(entries)

    top_ring = rank_df.sort_values("pointer_ring", ascending=False).head(topk_report)
    top_paren = rank_df.sort_values("pointer_paren", ascending=False).head(topk_report)
    unique_heads = pd.concat([top_ring, top_paren]).drop_duplicates(["layer", "head"])

    # --- helper: forward pass (no hooks) for a given param set
    def forward_logits_for_all(params_now) -> np.ndarray:
        out = []
        for inputs, positions in zip(batch_inputs, batch_positions):
            cache = TransformerCache.create(
                jnp.asarray(positions),
                model_config=cfg,
                dtype=jnp.bfloat16,
                dynamic=False,
            )
            l, *_ = transformer.run(jnp.asarray(inputs), cache, params_now, cfg)
            out.append(np.asarray(l, dtype=np.float32))
        return np.concatenate(out, axis=0).astype(np.float32, copy=False)

    # Precompute baseline acc to reuse
    ring_acc_base = accuracy_at_positions(logits_all, all_targets, ring_pred_positions)
    paren_acc_base = accuracy_at_positions(
        logits_all, all_targets, paren_pred_positions
    )

    # ===============================
    # 1) Event-specificity scatter
    # ===============================
    scatter_rows = []

    for _, r in unique_heads.iterrows():
        l, h = int(r["layer"]), int(r["head"])
        p_abl = ablate_heads_in_Wo(params, l, [h], cfg["hidden_size"], cfg["num_heads"])
        logits_abl = forward_logits_for_all(p_abl)
        gold_abl, margin_abl = gold_and_margin(logits_abl, all_targets)

        # RING metrics
        d_margin_ring_evt = mean_over_positions(
            margin_base - margin_abl, ring_pred_positions
        )
        d_margin_ring_ctr = mean_over_positions(
            margin_base - margin_abl, ctrl_ring_positions
        )
        d_acc_ring = ring_acc_base - accuracy_at_positions(
            logits_abl, all_targets, ring_pred_positions
        )
        d_acc_ring_ctr = ring_acc_base - accuracy_at_positions(
            logits_abl, all_targets, ctrl_ring_positions
        )
        d_gold_ring_evt = mean_over_positions(gold_base - gold_abl, ring_pred_positions)
        d_gold_ring_ctr = mean_over_positions(gold_base - gold_abl, ctrl_ring_positions)

        # PARen metrics
        d_margin_par_evt = mean_over_positions(
            margin_base - margin_abl, paren_pred_positions
        )
        d_margin_par_ctr = mean_over_positions(
            margin_base - margin_abl, ctrl_paren_positions
        )
        d_acc_par = paren_acc_base - accuracy_at_positions(
            logits_abl, all_targets, paren_pred_positions
        )
        d_acc_par_ctr = paren_acc_base - accuracy_at_positions(
            logits_abl, all_targets, ctrl_paren_positions
        )
        d_gold_par_evt = mean_over_positions(gold_base - gold_abl, paren_pred_positions)
        d_gold_par_ctr = mean_over_positions(gold_base - gold_abl, ctrl_paren_positions)

        scatter_rows.append(
            dict(
                task="ring",
                layer=l,
                head=h,
                pointer_mass=float(ring_pointer[l, h]),
                delta_margin_event=d_margin_ring_evt,
                delta_margin_control=d_margin_ring_ctr,
                delta_acc_event=d_acc_ring,
                delta_acc_control=d_acc_ring_ctr,
                delta_goldlogit_event=d_gold_ring_evt,
                delta_goldlogit_control=d_gold_ring_ctr,
                n_events=len(ring_pred_positions),
                n_controls=len(ctrl_ring_positions),
            )
        )
        scatter_rows.append(
            dict(
                task="paren",
                layer=l,
                head=h,
                pointer_mass=float(paren_pointer[l, h]),
                delta_margin_event=d_margin_par_evt,
                delta_margin_control=d_margin_par_ctr,
                delta_acc_event=d_acc_par,
                delta_acc_control=d_acc_par_ctr,
                delta_goldlogit_event=d_gold_par_evt,
                delta_goldlogit_control=d_gold_par_ctr,
                n_events=len(paren_pred_positions),
                n_controls=len(ctrl_paren_positions),
            )
        )

    scatter_df = pd.DataFrame(scatter_rows)
    scatter_csv = os.path.join(out_dir, "pointer_vs_global_scatter.csv")
    scatter_df.to_csv(scatter_csv, index=False)

    # ===============================
    # 2) Two-head additivity
    # ===============================
    def top2(df: pd.DataFrame, col: str) -> List[Tuple[int, int]]:
        d2 = df.sort_values(col, ascending=False).head(2)
        return [
            (int(d2.iloc[0]["layer"]), int(d2.iloc[0]["head"])),
            [(int(d2.iloc[1]["layer"]), int(d2.iloc[1]["head"]))][0],
        ]

    ring_h1, ring_h2 = top2(rank_df, "pointer_ring")
    par_h1, par_h2 = top2(rank_df, "pointer_paren")

    red_rows = []

    def collect_additivity(task: str, hpair: Tuple[Tuple[int, int], Tuple[int, int]]):
        (l1, h1), (l2, h2) = hpair

        # single ablations
        p1 = ablate_heads_in_Wo(params, l1, [h1], cfg["hidden_size"], cfg["num_heads"])
        p2 = ablate_heads_in_Wo(params, l2, [h2], cfg["hidden_size"], cfg["num_heads"])
        # joint ablation
        pj = _clone_params(params)
        pj = ablate_heads_in_Wo(pj, l1, [h1], cfg["hidden_size"], cfg["num_heads"])
        pj = ablate_heads_in_Wo(pj, l2, [h2], cfg["hidden_size"], cfg["num_heads"])

        logits1 = forward_logits_for_all(p1)
        logits2 = forward_logits_for_all(p2)
        logitsj = forward_logits_for_all(pj)

        _, m1 = gold_and_margin(logits1, all_targets)
        _, m2 = gold_and_margin(logits2, all_targets)
        _, mj = gold_and_margin(logitsj, all_targets)

        if task == "ring":
            ev, ctr = ring_pred_positions, ctrl_ring_positions
            acc_base = ring_acc_base
        else:
            ev, ctr = paren_pred_positions, ctrl_paren_positions
            acc_base = paren_acc_base

        d1_evt = mean_over_positions(margin_base - m1, ev)
        d2_evt = mean_over_positions(margin_base - m2, ev)
        dj_evt = mean_over_positions(margin_base - mj, ev)
        dsum_evt = d1_evt + d2_evt
        synergy_evt = dj_evt - dsum_evt

        d1_ctr = mean_over_positions(margin_base - m1, ctr)
        d2_ctr = mean_over_positions(margin_base - m2, ctr)
        dj_ctr = mean_over_positions(margin_base - mj, ctr)
        dsum_ctr = d1_ctr + d2_ctr
        synergy_ctr = dj_ctr - dsum_ctr

        acc1_evt = acc_base - accuracy_at_positions(logits1, all_targets, ev)
        acc2_evt = acc_base - accuracy_at_positions(logits2, all_targets, ev)
        accj_evt = acc_base - accuracy_at_positions(logitsj, all_targets, ev)
        accsum_evt = acc1_evt + acc2_evt

        acc1_ctr = acc_base - accuracy_at_positions(logits1, all_targets, ctr)
        acc2_ctr = acc_base - accuracy_at_positions(logits2, all_targets, ctr)
        accj_ctr = acc_base - accuracy_at_positions(logitsj, all_targets, ctr)
        accsum_ctr = acc1_ctr + acc2_ctr

        red_rows.append(
            dict(
                task=task,
                layer1=l1,
                head1=h1,
                layer2=l2,
                head2=h2,
                pointer_mass1=float(
                    ring_pointer[l1, h1] if task == "ring" else paren_pointer[l1, h1]
                ),
                pointer_mass2=float(
                    ring_pointer[l2, h2] if task == "ring" else paren_pointer[l2, h2]
                ),
                delta_margin_event_h1=d1_evt,
                delta_margin_event_h2=d2_evt,
                delta_margin_event_both=dj_evt,
                delta_margin_event_sum=dsum_evt,
                synergy_event=synergy_evt,
                delta_margin_control_h1=d1_ctr,
                delta_margin_control_h2=d2_ctr,
                delta_margin_control_both=dj_ctr,
                delta_margin_control_sum=dsum_ctr,
                synergy_control=synergy_ctr,
                delta_acc_event_h1=acc1_evt,
                delta_acc_event_h2=acc2_evt,
                delta_acc_event_both=accj_evt,
                delta_acc_event_sum=accsum_evt,
                delta_acc_control_h1=acc1_ctr,
                delta_acc_control_h2=acc2_ctr,
                delta_acc_control_both=accj_ctr,
                delta_acc_control_sum=accsum_ctr,
                n_events=len(ev),
                n_controls=len(ctr),
            )
        )

    collect_additivity("ring", (ring_h1, ring_h2))
    collect_additivity("paren", (par_h1, par_h2))

    red_df = pd.DataFrame(red_rows)
    red_csv = os.path.join(out_dir, "redundancy_table.csv")
    red_df.to_csv(red_csv, index=False)

    # quick console summary
    print(f"\nWrote: {scatter_csv}")
    print(f"Wrote: {red_csv}")
    print("\nTop-5 ring heads by pointer mass:")
    print(top_ring[["layer", "head", "pointer_ring"]].to_string(index=False))
    print("\nTop-5 paren heads by pointer mass:")
    print(top_paren[["layer", "head", "pointer_paren"]].to_string(index=False))


def main():
    ap = argparse.ArgumentParser(
        description="Event-specificity scatter + two-head additivity"
    )
    ap.add_argument("--model_dir", required=True)
    ap.add_argument("--ckpt_id", type=int, required=True)
    ap.add_argument("--dataset_dir", required=True)
    ap.add_argument("--split", default="train")
    ap.add_argument("--num_examples", type=int, default=4096)
    ap.add_argument("--batch_size", type=int, default=512)
    ap.add_argument("--out_dir", default="experiments/must_add_outputs")
    ap.add_argument("--max_events", type=int, default=2000)
    ap.add_argument("--topk_report", type=int, default=5)
    ap.add_argument("--seq_length", type=int, default=256)
    args = ap.parse_args()

    run_must_adds(
        model_dir=args.model_dir,
        ckpt_id=args.ckpt_id,
        dataset_dir=args.dataset_dir,
        split=args.split,
        num_examples=args.num_examples,
        batch_size=args.batch_size,
        out_dir=args.out_dir,
        max_events=args.max_events,
        topk_report=args.topk_report,
        seq_length=args.seq_length,
    )


if __name__ == "__main__":
    main()
