#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import argparse, json, os, re
from collections import defaultdict
from typing import List, Tuple, Optional, Dict

import numpy as np
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
from matplotlib.colors import TwoSlopeNorm, LinearSegmentedColormap
from contextlib import contextmanager

from transformers import AutoTokenizer
import re, torch
from transformers import RwkvConfig, RwkvForCausalLM

# ---------------------------- Utils ---------------------------- #

def gelu(x): return F.gelu(x)
def silu(x): return F.silu(x)
def act_fn_for_proj(feed_forward_proj: str):
    proj = feed_forward_proj.lower()
    if "gelu" in proj: return gelu
    if "silu" in proj: return silu
    if "relu" in proj: return F.relu
    return gelu

def ensure_dir(path): os.makedirs(path, exist_ok=True)
def to_cpu_f32(x):    return x.detach().to(torch.float32).cpu()

# ----- Blue → Black → Green diverging cmap + gamma emphasis -----
_CMAP_BBG = LinearSegmentedColormap.from_list("bbg", ["#440154", "#000000", "#FDE725"], N=256)

def _signed_power(x: np.ndarray, gamma: float) -> np.ndarray:
    gamma = 1.5
    if gamma == 1.0: return x
    return np.sign(x) * (np.abs(x) ** gamma)

def _bounds_for_diverging(mat: np.ndarray) -> tuple[float, float]:
    # safe min/max with NaNs handled
    m = np.nanmin(mat)
    M = np.nanmax(mat)
    # if all values are the same, or there’s no sign change, make symmetric bounds around 0
    if not (m < 0.0 < M):
        a = float(max(abs(m), abs(M)))
        if a == 0.0:
            a = 1e-6  # tiny guard to satisfy TwoSlopeNorm
        return -a, a
    return float(m), float(M)

def render_heatmap(mat: np.ndarray, out_png: str, title: str, vclip: Optional[float]=99.5, gamma: float = 1.8):
    plt.figure(figsize=(10, 4))
    m = mat.astype(np.float32, copy=False)
    # replace NaNs/Infs
    m = np.nan_to_num(m, nan=0.0, posinf=0.0, neginf=0.0)
    if vclip is not None:
        vmax = np.percentile(np.abs(m), vclip)
        m = np.clip(m, -vmax, vmax)
    m = _signed_power(m, gamma)
    vmin, vmax = _bounds_for_diverging(m)
    norm = TwoSlopeNorm(vcenter=0.0, vmin=vmin, vmax=vmax)
    plt.imshow(m, aspect='auto', interpolation='nearest', cmap=_CMAP_BBG, norm=norm)
    plt.title(title); plt.xlabel("model dims / columns"); plt.ylabel("ff rows")
    cbar = plt.colorbar(); cbar.set_label(f"Activation Strength")
    plt.tight_layout(); plt.savefig(out_png, dpi=220); plt.close()

def render_band_row(vec: np.ndarray, out_png: str, title: str, repeat: int = 6, vclip: Optional[float]=99.5, gamma: float = 1.8):
    v = vec.astype(np.float32, copy=False)
    v = np.nan_to_num(v, nan=0.0, posinf=0.0, neginf=0.0)
    if vclip is not None:
        vmax = np.percentile(np.abs(v), vclip)
        v = np.clip(v, -vmax, vmax)
    v = _signed_power(v, gamma)
    mat = np.tile(v[None, :], (repeat, 1))
    vmin, vmax = _bounds_for_diverging(mat)
    plt.figure(figsize=(10, 2))
    norm = TwoSlopeNorm(vcenter=0.0, vmin=vmin, vmax=vmax)
    plt.imshow(mat, aspect='auto', interpolation='nearest', cmap=_CMAP_BBG, norm=norm)
    plt.title(title); plt.yticks([]); plt.xlabel("dimension")
    cbar = plt.colorbar(); cbar.set_label(f"Activation Strength")
    plt.tight_layout(); plt.savefig(out_png, dpi=220); plt.close()

def power_norm_1d(v: np.ndarray, eps: float = 1e-8) -> np.ndarray:
    scale = np.mean(np.abs(v)) + eps
    return (v / scale).astype(np.float32)

def power_norm_2d(m: np.ndarray, eps: float = 1e-8) -> np.ndarray:
    scale = np.mean(np.abs(m)) + eps
    return (m / scale).astype(np.float32)

def tokenize_single_piece(tok, s: str) -> List[int]:
    ids = tok.encode(s, add_special_tokens=False)
    return ids

def strings_to_ids_list(tok, s: str) -> List[int]:
    ids = []
    for it in [x.strip() for x in s.split(",") if x.strip() != ""]:
        id_list = tokenize_single_piece(tok, it)
        if len(id_list) == 0:
            print(f"[warn] '{it}' tokenized to nothing; skipping."); continue
        if len(id_list) > 1:
            print(f"[warn] '{it}' tokenized to {id_list}; using second id {id_list[1]}.")
            ids.append(id_list[1])
        else:
            ids.append(id_list[0])
    return ids

def strings_to_ids_multi(tok, s_or_list) -> List[int]:
    if isinstance(s_or_list, str): s = s_or_list
    else: s = ",".join(s_or_list)
    return [tid for it in [x.strip() for x in s.split(",") if x.strip()] for tid in tokenize_single_piece(tok, it)[:1]]

def read_pairs(path: Optional[str]) -> List[Tuple[str,str]]:
    if path is None or not os.path.exists(path):
        return [
            ("translate English to German: A small test.", "Ein kleiner Test."),
            ("Call me at 12:30 tomorrow.", "Ruf mich morgen um 12:30 an."),
            ("He said, \"Hello!\"", "Er sagte: „Hallo!“"),
            ("The price is $19.99.", "Der Preis beträgt 19,99 $."),
            ("Version 2.0 was released in 2024.", "Version 2.0 wurde 2024 veröffentlicht."),
        ]
    pairs = []
    with open(path, "r", encoding="utf-8") as f:
        for line in f:
            if "\t" in line:
                src, tgt = line.rstrip("\n").split("\t", 1)
                pairs.append((src, tgt))
    return pairs

def read_text_as_pairs(path: str, max_examples: int):
    pairs = []
    with open(path, "r", encoding="utf-8") as f:
        for line in f:
            line = line.strip()
            if not line: continue
            pairs.append((line, line))
            if len(pairs) >= max_examples: break
    return pairs

def _slugify_token_label(s: str) -> str:
    return (s.replace(" ", "_").replace("/", "_").replace("\\", "_")
              .replace("\t", "_").replace("\n", "_").replace("<", "")
              .replace(">", "").replace(":", "").replace("|", "")
              .replace("*", "").replace("?", "").replace('"', ""))[:64] or "token"

def compute_pockets(A_vec: np.ndarray, c_vec: np.ndarray, WoT: np.ndarray | None, mode: str = "wo_times_eff") -> np.ndarray:
    if mode == "outer_eff_down":
        return (A_vec[:, None] * c_vec[None, :]).astype(np.float32)
    elif mode == "wo_times_eff":
        if WoT is None: raise ValueError("compute_pockets(mode='wo_times_eff') requires WoT")
        return (WoT * A_vec[:, None]).astype(np.float32)
    else:
        raise ValueError(f"Unknown pockets mode: {mode}")

def wo_locality_weights(Wo: torch.Tensor) -> tuple[np.ndarray, np.ndarray]:
    Wo_np = Wo.detach().float().cpu().numpy()
    per_ff  = np.linalg.norm(Wo_np, axis=0).astype(np.float32)   # len d_ff
    per_mod = np.linalg.norm(Wo_np, axis=1).astype(np.float32)   # len d_model
    return per_ff, per_mod

# ---------------------------- Local-Z helpers ---------------------------- #

def _moving_stats_1d(x: np.ndarray, window: int) -> tuple[np.ndarray, np.ndarray]:
    w = int(max(1, window)); pad = w // 2
    k = np.ones(w, dtype=np.float32)
    xp = np.pad(x.astype(np.float32), (pad, pad), mode="reflect")
    mx = np.convolve(xp, k, mode="valid"); mx2 = np.convolve(xp * xp, k, mode="valid")
    mean = mx / w; var = np.maximum(mx2 / w - mean * mean, 0.0); std = np.sqrt(var + 1e-8)
    return mean, std

def local_zscore_1d(x: np.ndarray, window: int, robust: bool=True) -> np.ndarray:
    w = int(max(3, window))
    if w % 2 == 0: w += 1
    pad = w // 2; x32 = x.astype(np.float32, copy=False)
    xp = np.pad(x32, (pad, pad), mode="reflect")
    k = np.ones(w, dtype=np.float32) / w
    return np.convolve(xp, k, mode="valid")

def _rolling_median_mad_2d(X: np.ndarray, w: int, axis: int, pad_mode: str = "reflect"):
    assert X.ndim == 2; w = int(max(3, w));
    if w % 2 == 0: w += 1
    pad = w // 2
    if axis == 1:  # columns
        xp = np.pad(X.astype(np.float32), ((0,0),(pad,pad)), mode=pad_mode)
        shape = (X.shape[0], X.shape[1], w); strides = (xp.strides[0], xp.strides[1], xp.strides[1])
        windows = np.lib.stride_tricks.as_strided(xp, shape=shape, strides=strides, writeable=False)
    elif axis == 0:  # rows
        xp = np.pad(X.astype(np.float32), ((pad,pad),(0,0)), mode=pad_mode)
        shape = (X.shape[0], w, X.shape[1]); strides = (xp.strides[0], xp.strides[0], xp.strides[1])
        windows = np.lib.stride_tricks.as_strided(xp, shape=shape, strides=strides, writeable=False)
        windows = np.swapaxes(windows, 1, 2)
    else:
        raise ValueError("axis must be 0 or 1")
    med = np.median(windows, axis=2); mad = np.median(np.abs(windows - med[..., None]), axis=2)
    scale = np.maximum(1e-6, 1.4826 * mad); return med, scale

def local_zscore_2d(mat: np.ndarray, window: int, axis: str = "columns", robust: bool = True) -> np.ndarray:
    X = mat.astype(np.float32, copy=False)
    def _z_along_axis(X, ax):
        if robust:
            med, scale = _rolling_median_mad_2d(X, window, axis=ax); return (X - med) / scale
        else:
            w = window if window % 2 == 1 else window + 1; pad = w // 2
            if ax == 1:
                xp = np.pad(X, ((0,0),(pad,pad)), mode="reflect"); k = np.ones(w, dtype=np.float32)
                m = np.apply_along_axis(lambda r: np.convolve(r, k, mode="valid"), 1, xp) / w
                m2 = np.apply_along_axis(lambda r: np.convolve(r**2, k, mode="valid"), 1, xp) / w
                s = np.sqrt(np.maximum(m2 - m*m, 0.0) + 1e-8); return (X - m) / s
            elif ax == 0: return _z_along_axis(X.T, 1).T
            else: raise ValueError
    if axis == "columns": return _z_along_axis(X, ax=1)
    elif axis == "rows":  return _z_along_axis(X, ax=0)
    elif axis == "both":  Zc = _z_along_axis(X, ax=1); return _z_along_axis(Zc, ax=0)
    else: raise ValueError(f"axis must be 'columns','rows','both'")

# ---------------------------- Hooks (Attn) ---------------------------- #

def _attn_module_path(side: str, layer_idx: int, kind: str) -> str:
    if side == "enc":
        assert kind == "self"; return f"encoder.block.{layer_idx}.layer.0.SelfAttention"
    else:
        if kind == "self": return f"decoder.block.{layer_idx}.layer.0.SelfAttention"
        elif kind == "cross": return f"decoder.block.{layer_idx}.layer.1.EncDecAttention"
        else: raise ValueError("kind must be 'self' or 'cross'")

@contextmanager
def hook_attn_outputs(model, names, store_dict):
    handles = []; named = dict(model.named_modules())
    def _to_cpu_tensor(x):
        if isinstance(x, torch.Tensor): return x.detach().float().cpu()
        if isinstance(x, (tuple, list)) and len(x)>0: return _to_cpu_tensor(x[0])
        if isinstance(x, dict):
            for k in ("last_hidden_state","hidden_states","output"):
                if k in x: return _to_cpu_tensor(x[k])
        raise TypeError(f"Unexpected attention hook output type: {type(x)}")
    try:
        for n in names:
            mod = named[n]
            def _mk_hook(nm):
                def fn(m, inp, out): store_dict[nm] = _to_cpu_tensor(out)
                return fn
            handles.append(mod.register_forward_hook(_mk_hook(n)))
        yield
    finally:
        for h in handles: h.remove()

def compute_attn_band_for_token(attn_tensor: torch.Tensor, labels: torch.Tensor, target_id: int) -> np.ndarray:
    dev = attn_tensor.device; lab = labels.detach().to(dev)
    B, T, D = attn_tensor.shape; flat = attn_tensor.view(B*T, D); lab_flat = lab.view(B*T)
    mask = (lab_flat == int(target_id))
    v = flat[mask].mean(dim=0) if mask.any() else torch.zeros(D, dtype=attn_tensor.dtype, device=dev)
    return v.cpu().numpy()

def export_attn_band_maps(out_dir: str, side: str, layer_idx: int, kind: str, vec_raw: np.ndarray, tag: str, enable_local_z: bool=False, local_z_window: int=31):
    ensure_dir(out_dir); sub = os.path.join(out_dir, f"{side}_L{layer_idx}_{kind}_{tag}"); ensure_dir(sub)
    np.save(os.path.join(sub, "attn_out.npy"), vec_raw)
    vec_norm = power_norm_1d(vec_raw)
    render_band_row(vec_norm, os.path.join(sub, "attn_out_band.png"), f"{side} L{layer_idx} {kind} {tag}: band (attn out, power-normed)", gamma=1.8)
    if enable_local_z:
        try:
            vec_lz = local_zscore_1d(vec_raw, local_z_window, robust=True)
            render_band_row(vec_lz, os.path.join(sub, "attn_out_band_localz.png"), f"{side} L{layer_idx} {kind} {tag}: band (attn out) local-z", gamma=1.8)
        except NameError:
            pass

# ---------------------------- Aggregators ---------------------------- #

class TokenAggregators:
    """
    Per-layer/side aggregates:
      - Raw means for single/generic/SetA/SetB
      - Baseline means (ĉ direction & row-mass)
      - Per-occurrence normalized maps for token-vs-token comparison (outer product)
    Keys are (side, layer).
    """

    def __init__(self, n_layers_enc: int, n_layers_dec: int, d_ff: int, d_model: int,
                 target_single_ids: List[int], set_a_ids: List[int], set_b_ids: List[int],
                 special_ids: List[int], want_baseline: bool, baseline_exclude_ids: set):
        self.n_layers_enc = n_layers_enc; self.n_layers_dec = n_layers_dec
        self.d_ff = d_ff; self.d_model = d_model
        self.target_single_ids = set(target_single_ids)
        self.set_a_ids = set(set_a_ids); self.set_b_ids = set(set_b_ids)
        self.special_ids = set(special_ids)
        self.want_baseline = want_baseline; self.baseline_exclude_ids = set(baseline_exclude_ids)

        # Raw sums
        self.first_single_eff = defaultdict(lambda: None)  # np[d_ff]
        self.first_single_down = defaultdict(lambda: None) # np[d_model]
        self.sum_single_eff = defaultdict(lambda: np.zeros(d_ff, dtype=np.float32))
        self.sum_single_down = defaultdict(lambda: np.zeros(d_model, dtype=np.float32))
        self.cnt_single = defaultdict(int)
        self.sum_generic_eff = defaultdict(lambda: np.zeros(d_ff, dtype=np.float32))
        self.sum_generic_down = defaultdict(lambda: np.zeros(d_model, dtype=np.float32))
        self.cnt_generic = defaultdict(int)
        self.sum_setA_eff = defaultdict(lambda: np.zeros(d_ff, dtype=np.float32))
        self.sum_setA_down = defaultdict(lambda: np.zeros(d_model, dtype=np.float32))
        self.cnt_setA = defaultdict(int)
        self.sum_setB_eff = defaultdict(lambda: np.zeros(d_ff, dtype=np.float32))
        self.sum_setB_down = defaultdict(lambda: np.zeros(d_model, dtype=np.float32))
        self.cnt_setB = defaultdict(int)

        # Baseline / target-single (ĉ & rowmass)
        self.base_sum_dir_down = defaultdict(lambda: np.zeros(d_model, dtype=np.float32))
        self.base_cnt_dir_down = defaultdict(int)
        self.base_sum_rowmass = defaultdict(lambda: np.zeros(d_ff, dtype=np.float32))
        self.base_cnt_rowmass = defaultdict(int)
        self.baseline_total_tokens = 0
        self.single_sum_dir_down = defaultdict(lambda: np.zeros(d_model, dtype=np.float32))
        self.single_cnt_dir_down = defaultdict(int)
        self.single_sum_rowmass = defaultdict(lambda: np.zeros(d_ff, dtype=np.float32))
        self.single_cnt_rowmass = defaultdict(int)

        # Token-vs-token comparison accumulators (per-occurrence normalized maps)
        self.compA_sum_S = defaultdict(lambda: None)  # np[d_ff, d_model]
        self.compA_cnt   = defaultdict(int)
        self.compB_sum_S = defaultdict(lambda: None)
        self.compB_cnt   = defaultdict(int)

        # cache Wo norms
        self._cached_wo_norms: Dict[Tuple[str,int], Tuple[np.ndarray,np.ndarray]] = {}

    def _update_positions(self, side: str, layer: int, A_btF: torch.Tensor, c_btD: torch.Tensor, token_ids_bt: torch.Tensor, count_for_baseline: bool):
        A = A_btF.numpy(); C = c_btD.numpy(); TIDs = token_ids_bt.numpy()
        B, T, _ = A.shape; key = (side, layer)

        def _pn1d(v: np.ndarray, eps: float=1e-8) -> np.ndarray:
            s = np.mean(np.abs(v)) + eps; return (v / s).astype(np.float32)

        for b in range(B):
            for t in range(T):
                tid = int(TIDs[b, t])
                if tid in self.special_ids: continue
                a = A[b, t]; c = C[b, t]

                # Raw buckets
                if tid in self.target_single_ids:
                    if self.first_single_eff[key] is None:  self.first_single_eff[key]  = a.copy()
                    if self.first_single_down[key] is None: self.first_single_down[key] = c.copy()
                    self.sum_single_eff[key]  += a; self.sum_single_down[key] += c; self.cnt_single[key] += 1
                elif tid in self.set_a_ids:
                    self.sum_setA_eff[key]  += a; self.sum_setA_down[key] += c; self.cnt_setA[key] += 1
                elif tid in self.set_b_ids:
                    self.sum_setB_eff[key]  += a; self.sum_setB_down[key] += c; self.cnt_setB[key] += 1
                else:
                    self.sum_generic_eff[key]  += a; self.sum_generic_down[key] += c; self.cnt_generic[key] += 1

                # Baseline-style normalized views
                cnorm = np.linalg.norm(c) + 1e-8; c_hat = c / cnorm
                denom = np.abs(a).sum() + 1e-8; p = np.abs(a) / denom
                if tid in self.target_single_ids:
                    self.single_sum_dir_down[key] += c_hat; self.single_cnt_dir_down[key] += 1
                    self.single_sum_rowmass[key] += p;      self.single_cnt_rowmass[key] += 1
                if count_for_baseline and (tid not in self.baseline_exclude_ids):
                    self.base_sum_dir_down[key] += c_hat; self.base_cnt_dir_down[key] += 1
                    self.base_sum_rowmass[key] += p;      self.base_cnt_rowmass[key] += 1
                    self.baseline_total_tokens += 1

                # NEW: per-occurrence normalized token map (outer product) for comparison
                if tid in self.set_a_ids or tid in self.set_b_ids:
                    a_n = _pn1d(a); c_n = _pn1d(c)
                    S_occ = (a_n[:, None] * c_n[None, :]).astype(np.float32)
                    if tid in self.set_a_ids:
                        if self.compA_sum_S[key] is None: self.compA_sum_S[key] = np.zeros_like(S_occ, dtype=np.float32)
                        self.compA_sum_S[key] += S_occ; self.compA_cnt[key] += 1
                    else:
                        if self.compB_sum_S[key] is None: self.compB_sum_S[key] = np.zeros_like(S_occ, dtype=np.float32)
                        self.compB_sum_S[key] += S_occ; self.compB_cnt[key] += 1

    # ---- Baseline I/O ----
    def baseline_means(self, side: str, n_layers: int) -> Tuple[np.ndarray, np.ndarray]:
        mean_dir, mean_row = [], []
        for L in range(n_layers):
            k = (side, L)
            mean_dir.append(self.base_sum_dir_down[k] / max(1, self.base_cnt_dir_down[k]) if self.base_cnt_dir_down[k] > 0 else np.zeros(self.d_model, np.float32))
            mean_row.append(self.base_sum_rowmass[k] / max(1, self.base_cnt_rowmass[k]) if self.base_cnt_rowmass[k] > 0 else np.zeros(self.d_ff, np.float32))
        return np.stack(mean_dir, 0), np.stack(mean_row, 0)

    def save_baseline(self, path: str, nL_enc: int, nL_dec: int):
        enc_dir, enc_row = self.baseline_means("enc", nL_enc)
        dec_dir, dec_row = self.baseline_means("dec", nL_dec)
        np.savez(path, enc_dir_down_mean=enc_dir, enc_rowmass_mean=enc_row, dec_dir_down_mean=dec_dir, dec_rowmass_mean=dec_row)

    @staticmethod
    def load_baseline(path: str) -> Dict[str, np.ndarray]:
        arrs = np.load(path); return {k: arrs[k] for k in arrs.files}

    def word_baseline(self, side: str, layer: int, ids: List[int]) -> tuple[np.ndarray, np.ndarray]:
        key = (side, layer)
        if self.base_cnt_dir_down[key] > 0 and self.base_cnt_rowmass[key] > 0:
            mu_dir = self.base_sum_dir_down[key] / self.base_cnt_dir_down[key]
            mu_row = self.base_sum_rowmass[key] / self.base_cnt_rowmass[key]
            return mu_dir, mu_row
        return np.zeros(self.d_model, np.float32), np.zeros(self.d_ff, np.float32)

    # ---------- Exporter ---------- #
    def export_views(self, out_dir: str, side: str, layer: int, Wo_dF: torch.Tensor,
                     baseline_dir_down: Optional[np.ndarray] = None, baseline_rowmass: Optional[np.ndarray] = None,
                     *, enable_local_z: bool=True, local_z_window: int=35, pockets_axis: str="columns",
                     token_single: str="", pockets_mode: str="outer_eff_down", weight_bands_by_wo: bool=False,
                     compA_ids: List[int]=None, compB_ids: List[int]=None) -> dict | None:
        ensure_dir(out_dir); key = (side, layer)
        WoT_np = Wo_dF.detach().float().cpu().numpy().T
        if key not in self._cached_wo_norms: self._cached_wo_norms[key] = wo_locality_weights(Wo_dF)
        per_ff_w, per_mod_w = self._cached_wo_norms[key]

        def save_pack(tag: str, A_vec: Optional[np.ndarray], c_vec: Optional[np.ndarray]):
            if A_vec is None or c_vec is None: return
            pack_dir = os.path.join(out_dir, tag); ensure_dir(pack_dir)
            # 1D power-norm
            A_vec = power_norm_1d(A_vec); c_vec = power_norm_1d(c_vec)
            if weight_bands_by_wo:
                A_vec = (A_vec * (per_ff_w / (per_ff_w.mean() + 1e-8))).astype(np.float32)
                c_vec = (c_vec * (per_mod_w / (per_mod_w.mean() + 1e-8))).astype(np.float32)
            np.save(os.path.join(pack_dir, "eff_A.npy"), A_vec)
            np.save(os.path.join(pack_dir, "down_c.npy"), c_vec)
            S = compute_pockets(A_vec, c_vec, WoT_np, mode=pockets_mode); S = power_norm_2d(S)
            np.save(os.path.join(pack_dir, "pockets_S.npy"), S)
            render_heatmap(S, os.path.join(pack_dir, "pockets_S.png"), f"{side} L{layer} {tag}: pockets S", gamma=1.6)
            render_band_row(c_vec, os.path.join(pack_dir, "down_c.png"), f"Decoder L{layer} downprojection", gamma=1.8)
            render_band_row(A_vec, os.path.join(pack_dir, "eff_A.png"), f"{side} L{layer} {tag}: eff A (power-normed)", gamma=1.8)
            if enable_local_z:
                tok_slug = _slugify_token_label(token_single)
                c_lz = local_zscore_1d(c_vec, local_z_window); A_lz = local_zscore_1d(A_vec, local_z_window)
                np.save(os.path.join(pack_dir, f"down_c_localz__{tok_slug}.npy"), c_lz)
                np.save(os.path.join(pack_dir, f"eff_A_localz__{tok_slug}.npy"), A_lz)
                render_band_row(c_lz, os.path.join(pack_dir, f"down_c_localz__{tok_slug}.png"), f"{side} L{layer} {tag}: band (down c) local-z [{tok_slug}]", gamma=1.8)
                render_band_row(A_lz, os.path.join(pack_dir, f"eff_A_localz__{tok_slug}.png"), f"{side} L{layer} {tag}: eff A local-z [{tok_slug}]", gamma=1.8)
                S_lz = (A_lz[:, None] * c_lz[None, :]).astype(np.float32)
                np.save(os.path.join(pack_dir, f"pockets_S_localz__{tok_slug}.npy"), S_lz)
                render_heatmap(S_lz, os.path.join(pack_dir, f"pockets_S_localz__{tok_slug}.png"), f"{side} L{layer} {tag}: pockets S (A_lz ⊗ c_lz)", gamma=1.6)

        # Legacy packs
        if self.first_single_eff[key] is not None and self.first_single_down[key] is not None:
            save_pack("single_token_first", self.first_single_eff[key], self.first_single_down[key])
        if self.cnt_single[key] > 0:
            A = self.sum_single_eff[key] / max(1, self.cnt_single[key])
            c = self.sum_single_down[key] / max(1, self.cnt_single[key])
            save_pack("mean_single_token", A, c)
        if self.cnt_generic[key] > 0:
            A_g = self.sum_generic_eff[key] / max(1, self.cnt_generic[key])
            c_g = self.sum_generic_down[key] / max(1, self.cnt_generic[key])
            if self.cnt_single[key] > 0:
                A_s = self.sum_single_eff[key] / max(1, self.cnt_single[key])
                c_s = self.sum_single_down[key] / max(1, self.cnt_single[key])
                save_pack("mean_single_minus_generic", A_s - A_g, c_s - c_g)
            if self.cnt_setA[key] > 0 and self.cnt_setB[key] > 0:
                A_a = self.sum_setA_eff[key]  / max(1, self.cnt_setA[key])
                c_a = self.sum_setA_down[key] / max(1, self.cnt_setA[key])
                A_b = self.sum_setB_eff[key]  / max(1, self.cnt_setB[key])
                c_b = self.sum_setB_down[key] / max(1, self.cnt_setB[key])
                save_pack("mean_setA_minus_setB", A_a - A_b, c_a - c_b)

        # Single vs provided baseline contrasts
        if baseline_dir_down is not None and baseline_rowmass is not None and self.single_cnt_dir_down[key] > 0:
            mu_dir = self.single_sum_dir_down[key] / max(1, self.single_cnt_dir_down[key])
            mu_p   = self.single_sum_rowmass[key] / max(1, self.single_cnt_rowmass[key])
            delta_dir = power_norm_1d(mu_dir - baseline_dir_down)
            delta_p   = power_norm_1d(mu_p   - baseline_rowmass)
            if weight_bands_by_wo:
                per_ff_w, per_mod_w = self._cached_wo_norms[key]
                delta_dir = (delta_dir * (per_mod_w / (per_mod_w.mean() + 1e-8))).astype(np.float32)
                delta_p   = (delta_p   * (per_ff_w  / (per_ff_w.mean()  + 1e-8))).astype(np.float32)
            pack_dir = os.path.join(out_dir, "single_vs_baseline_dir"); ensure_dir(pack_dir)
            np.save(os.path.join(pack_dir, "delta_dir_down.npy"), delta_dir)
            np.save(os.path.join(pack_dir, "delta_rowmass_eff.npy"), delta_p)
            S_dir = compute_pockets(delta_p, delta_dir, WoT_np, mode=pockets_mode); S_dir = power_norm_2d(S_dir)
            np.save(os.path.join(pack_dir, "pockets_S_dir.npy"), S_dir)
            render_heatmap(S_dir, os.path.join(pack_dir, "pockets_S_dir.png"), f"{side} L{layer} single vs baseline: pockets", gamma=1.6)
            render_band_row(delta_dir, os.path.join(pack_dir, "delta_dir_down.png"), f"{side} L{layer} single vs baseline: band Δĉ", gamma=1.8)

        # NEW: Token-vs-token difference with normalized overlap
        layer_diff_sum, layer_overlap = None, None
        if compA_ids and compB_ids and (self.compA_cnt[key] > 0) and (self.compB_cnt[key] > 0):
            SA = (self.compA_sum_S[key] / max(1, self.compA_cnt[key])).astype(np.float32)  # averaged map for A
            SB = (self.compB_sum_S[key] / max(1, self.compB_cnt[key])).astype(np.float32)  # averaged map for B
            # Difference map (your spec)
            D = np.abs(SA - SB).astype(np.float32)
            layer_diff_sum = float(D.sum())
            denom = (np.abs(SA) + np.abs(SB)).sum() + 1e-8
            layer_overlap = float(D.sum() / denom)  # ∈ [0,1]

            pack_dir = os.path.join(out_dir, "compare_token_diff"); ensure_dir(pack_dir)
            np.save(os.path.join(pack_dir, "mapA.npy"), SA)
            np.save(os.path.join(pack_dir, "mapB.npy"), SB)
            np.save(os.path.join(pack_dir, "diff_map.npy"), D)
            render_heatmap(power_norm_2d(SA), os.path.join(pack_dir, "mapA.png"), f"{side} L{layer} token A map (avg per-occurrence norm)", gamma=1.6)
            render_heatmap(power_norm_2d(SB), os.path.join(pack_dir, "mapB.png"), f"{side} L{layer} token B map (avg per-occurrence norm)", gamma=1.6)
            render_heatmap(D, os.path.join(pack_dir, "diff_map.png"), f"{side} L{layer} |A - B| (difference) • overlap={layer_overlap:.3f}", gamma=1.6)

            # Histogram of difference values
            plt.figure(figsize=(6,4))
            plt.hist(D.flatten(), bins=60, density=False)
            plt.title(f"{side} L{layer} difference histogram")
            plt.xlabel("|A-B| value"); plt.ylabel("count"); plt.tight_layout()
            plt.savefig(os.path.join(pack_dir, "diff_hist.png"), dpi=200); plt.close()

        return {"layer_diff_sum": layer_diff_sum, "layer_overlap": layer_overlap}

# ---------------------------- FFN Capture ---------------------------- #

class FFNCapture:
    """
    Hooks all FFNs (encoder layer[1], decoder layer[2]) and streams:
      A = act(wi_0(x))*wi_1(x)   (or act(wi(x)))
      c = A @ W_o^T
    """
    def __init__(self, model, tokenizer, aggregators: TokenAggregators, capture_encoder: bool, capture_decoder: bool):
        self.model = model; self.tok = tokenizer; self.agg = aggregators
        self.capture_encoder = capture_encoder; self.capture_decoder = capture_decoder
        self.act = act_fn_for_proj(model.config.feed_forward_proj)
        self.handles = []; self._enc_ids_bt = None; self._dec_ids_bt = None
        if capture_encoder:
            for li, block in enumerate(model.encoder.block):
                self._register_ffn_hook(block.layer[1].DenseReluDense, "enc", li)
        if capture_decoder:
            for li, block in enumerate(model.decoder.block):
                self._register_ffn_hook(block.layer[2].DenseReluDense, "dec", li)

    def _register_ffn_hook(self, ffn, side: str, layer: int):
        has_gated = hasattr(ffn, "wi_0") and hasattr(ffn, "wi_1")
        def hook_fn(module, inputs, output):
            x = inputs[0]
            with torch.no_grad():
                if has_gated: u = ffn.wi_0(x); v = ffn.wi_1(x); A = self.act(u) * v
                else:         u = ffn.wi(x); A = self.act(u)
                Wo = ffn.wo.weight
                c = torch.einsum("btd,md->btm", A, Wo)  # [B,T,d_model]
                ids = self._enc_ids_bt if side=="enc" else self._dec_ids_bt
                if ids is None: return
                self.agg._update_positions(side, layer, to_cpu_f32(A), to_cpu_f32(c), ids, count_for_baseline=self.agg.want_baseline)
        self.handles.append(ffn.register_forward_hook(hook_fn))

    def set_batch_token_ids(self, enc_ids_bt: Optional[torch.Tensor], dec_ids_bt: Optional[torch.Tensor]):
        self._enc_ids_bt = enc_ids_bt; self._dec_ids_bt = dec_ids_bt

    def close(self):
        for h in self.handles: h.remove()
        self.handles = []

# ---------------------------- Main ---------------------------- #

def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--ckpt", type=str, default="google/mt5-base", help="Base HF checkpoint")
    ap.add_argument("--state_dict_path", type=str,
                    #default="mt5_base_pretuned.pt",
                    #default="mt5_base_standard_FaF_11x_noise_two_let_retrain1_batch_10000.pt", # end with 2let no FZR
                    default="mt5_base_standard_FaF_11x_noise.pt",
                    #default="mt5_base_twolet_FaF_12x_noise0_batch_5000_o.pt",  # end 2let with FZR
                    #default="mt5_base_forgive_and_forget_whole_stream6.pt",
                    help="Path to .pt/.bin state_dict to load on top of --ckpt")

    # --- architecture switch ---
    ap.add_argument("--arch", type=str, choices=["mt5", "rwkv"], default="mt5",
                    help="Backbone to load. 'mt5' (default) or 'rwkv'.")

    # Tokens / maps
    ap.add_argument("--token_single", type=str, default="dog", help="single token string (ideally 1 id)")
    ap.add_argument("--out_dir", type=str, default="Interpretation_Records/final_maps/dog_full")

    # Modeling / options
    ap.add_argument("--use_forgetfult5", action="store_true", help="Use custom ForgetfulT5 forward (import must exist)")
    ap.add_argument("--token_baseline_multi", type=str, default="", help="Comma list averaged as baseline 'word'")
    ap.add_argument("--token_compare_a", type=str, default="dog", help="Token(s) for compare A") # cat
    ap.add_argument("--token_compare_b", type=str, default="dock", help="Token(s) for compare B") # call
    ap.add_argument("--weight_bands_by_wo", action="store_true", help="Multiply bands by local Wo norms")



    # RWKV specifics (only used when --arch rwkv)
    ap.add_argument("--rwkv_vocab_size", type=int, default=50277)
    ap.add_argument("--rwkv_hidden_size", type=int, default=1024)  # 430M commonly ~1024
    ap.add_argument("--rwkv_n_layers", type=int, default=24)
    ap.add_argument("--rwkv_tokenizer", type=str, default="BlinkDL/rwkv-4-pile-430m",
                    help="Tokenizer to use for RWKV runs")

    # z/vis
    ap.add_argument("--enable_local_z", default=False, action="store_true", help="Also save local z-scored variants")
    ap.add_argument("--local_z_window", type=int, default=13, help="Odd window size for local z-score")
    ap.add_argument("--local_z_axis_for_pockets", type=str, default="columns", choices=["columns","rows","both"], help="Axis for local z in pockets")

    # Data
    ap.add_argument("--pairs_file", type=str, default=None, help="TSV src<TAB>tgt")
    ap.add_argument("--text_file", type=str, default="Interpretation_Records/interpret_sub_texts.txt", help="Plaintext one-per-line")
    ap.add_argument("--max_examples", type=int, default=2000)
    ap.add_argument("--batch_size", type=int, default=32)
    ap.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu")

    ap.add_argument("--max_src_len", type=int, default=512)
    ap.add_argument("--max_tgt_len", type=int, default=128)

    # Sets
    ap.add_argument("--set_a", type=str, default="hill, hive, him, hit, hint, hip")
    ap.add_argument("--set_b", type=str, default="hi, hiss, high, hind, hinge") # his
    ap.add_argument("--capture_encoder", action="store_true")
    ap.add_argument("--capture_decoder", default=True, action="store_true")

    # Baseline
    ap.add_argument("--baseline_build", action="store_true", help="stream and save baseline means (ĉ, p)")
    ap.add_argument("--baseline_use", action="store_true", help="load baseline and emit single-vs-baseline outputs")
    ap.add_argument("--baseline_file", type=str, default=None, help="path to .npz for baseline")
    ap.add_argument("--baseline_tokens_target", type=int, default=10000, help="approx tokens to include in baseline")
    ap.add_argument("--baseline_exclude_targets", action="store_true", help="exclude single/setA/setB from baseline")

    # Pockets
    ap.add_argument("--pockets_mode", type=str, default="outer_eff_down", choices=["outer_eff_down","wo_times_eff"], help="How to form pockets")

    # possible sets of three words
    # app apple orange
    # the theme a
    # one three thick

    args = ap.parse_args()
    ensure_dir(args.out_dir)
    with open(os.path.join(args.out_dir, "meta.json"), "w") as f: json.dump({"args": vars(args)}, f, indent=2)

    if args.arch == "mt5":
        # tokenizer
        tok = AutoTokenizer.from_pretrained(args.ckpt)

        # model (ForgetfulT5 if you pass --use_forgetfult5)
        if args.use_forgetfult5:
            from ForgetfulT5 import ForgetfulT5ForConditionalGeneration as _T5Cls
        else:
            from transformers import MT5ForConditionalGeneration as _T5Cls
        model = _T5Cls.from_pretrained(args.ckpt)

        # load state dict if provided
        if args.state_dict_path and os.path.exists(args.state_dict_path):
            sd_raw = torch.load(args.state_dict_path, map_location="cpu")
            if isinstance(sd_raw, dict) and "state_dict" in sd_raw:
                sd_raw = sd_raw["state_dict"]

            def strip_prefix(k: str) -> str:
                return re.sub(
                    r'^(?:module\.|model\.|base_model\.|transformer\.|t5\.|mt5\.)',
                    '',
                    k,
                    count=1
                )

            sd = {strip_prefix(k): v for k, v in sd_raw.items()}
            ms = model.state_dict()

            keys_sd = set(sd.keys())
            keys_ms = set(ms.keys())
            inter = keys_sd & keys_ms
            only_sd = list(keys_sd - keys_ms)
            only_ms = list(keys_ms - keys_sd)

            print(
                f"[ckpt] sd keys (after strip): {len(keys_sd)} | model keys: {len(keys_ms)} | intersect: {len(inter)} "
                f"({100.0 * len(inter) / max(1, len(keys_ms)):.1f}% of model)")
            if only_sd:
                print(" [ckpt] example in sd but not model (dropped):")
                for k in only_sd[:10]:
                    print(f"   - {k}")
            if only_ms:
                print(" [ckpt] example expected by model but missing in sd:")
                for k in only_ms[:10]:
                    print(f"   - {k}")

            sd_f = {k: sd[k] for k in inter}
            missing, unexpected = model.load_state_dict(sd_f, strict=False)
            print(f"[mt5] load_state_dict: loaded={len(sd_f)}  missing={len(missing)}  unexpected={len(unexpected)}")
        else:
            print("[mt5] no state_dict_path found; using base weights.")

        model.eval().to(args.device)

    else:  # --- RWKV ---
        from transformers import RwkvConfig, RwkvForCausalLM

        # tokenizer
        tok = AutoTokenizer.from_pretrained(args.rwkv_tokenizer)

        # model skeleton from config (match your training hyperparams!)
        cfg = RwkvConfig(
            vocab_size=args.rwkv_vocab_size,
            hidden_size=args.rwkv_hidden_size,
            num_hidden_layers=args.rwkv_n_layers,
            layer_norm_epsilon=1e-5,
            bos_token_id=tok.bos_token_id or 0,
            eos_token_id=tok.eos_token_id or 0,
            pad_token_id=tok.pad_token_id or 0,
        )
        model = RwkvForCausalLM(cfg)

        # load state dict (handles DDP prefixes)
        if args.state_dict_path and os.path.exists(args.state_dict_path):
            sd = torch.load(args.state_dict_path, map_location="cpu")

            def strip_prefix(k: str) -> str:
                return re.sub(r'^(?:module\.|model\.)', '', k, count=1)

            sd = {strip_prefix(k): v for k, v in sd.items()}
            missing, unexpected = model.load_state_dict(sd, strict=False)
            print(f"[rwkv] loaded sd: missing={len(missing)} unexpected={len(unexpected)}")
        else:
            print("[rwkv] state_dict_path missing; built fresh RWKV model from config.")

        model.eval().to(args.device)

        # ⚠️ pipeline note: the rest of this script hooks T5 FFNs (DenseReluDense).
        # RWKV has no encoder/decoder blocks or T5 FFN modules, so disable captures.
        if args.capture_encoder or args.capture_decoder:
            print("[rwkv] Disabling FFN capture hooks (T5-only).")
        args.capture_encoder = False
        args.capture_decoder = False

    d_model = model.config.d_model
    probe_ffn = model.decoder.block[0].layer[2].DenseReluDense
    d_ff = (probe_ffn.wi_0.out_features if hasattr(probe_ffn, "wi_0") else probe_ffn.wi.out_features)
    nL_enc = len(model.encoder.block); nL_dec = len(model.decoder.block)

    # Parse tokens
    single_ids = strings_to_ids_list(tok, args.token_single)
    setA_ids_cli = strings_to_ids_list(tok, args.set_a)
    setB_ids_cli = strings_to_ids_list(tok, args.set_b)
    baseline_multi_ids = strings_to_ids_multi(tok, args.token_baseline_multi)
    compA_ids = strings_to_ids_multi(tok, args.token_compare_a)
    compB_ids = strings_to_ids_multi(tok, args.token_compare_b)
    setA_ids_for_agg = compA_ids if len(compA_ids)>0 else setA_ids_cli
    setB_ids_for_agg = compB_ids if len(compB_ids)>0 else setB_ids_cli

    specials = {tok.pad_token_id, tok.eos_token_id, tok.unk_token_id}; specials = {i for i in specials if i is not None}
    baseline_exclude = set()
    if args.baseline_exclude_targets:
        baseline_exclude |= set(single_ids) | set(setA_ids_for_agg) | set(setB_ids_for_agg)

    want_baseline = args.baseline_build or args.baseline_use
    aggregators = TokenAggregators(nL_enc, nL_dec, d_ff, d_model,
                                   target_single_ids=single_ids,
                                   set_a_ids=setA_ids_for_agg,
                                   set_b_ids=setB_ids_for_agg,
                                   special_ids=list(specials),
                                   want_baseline=want_baseline,
                                   baseline_exclude_ids=baseline_exclude)
    capturer = FFNCapture(model, tok, aggregators, capture_encoder=args.capture_encoder, capture_decoder=args.capture_decoder)

    # Data
    if args.pairs_file:
        pairs = read_pairs(args.pairs_file)[:args.max_examples]; print(f"[data] using TSV pairs: {len(pairs)}")
    elif args.text_file and os.path.exists(args.text_file):
        pairs = read_text_as_pairs(args.text_file, args.max_examples); print(f"[data] using text file: {len(pairs)} from {args.text_file}")
    else:
        pairs = read_pairs(None)[:args.max_examples]; print(f"[data] using built-in sample: {len(pairs)}")

    print(f"Scanning {len(pairs)} pairs ...")
    last_enc = None; last_labels = None; last_dec_in = None
    for i in range(0, len(pairs), args.batch_size):
        batch = pairs[i:i+args.batch_size]
        src_texts = [s for s,t in batch]; tgt_texts = [t for s,t in batch]
        enc = tok(src_texts, return_tensors="pt", padding=True, truncation=True).to(args.device)
        with tok.as_target_tokenizer():
            tgt = tok(tgt_texts, return_tensors="pt", padding=True, truncation=True).to(args.device)
        labels = tgt["input_ids"].clone(); labels[labels == tok.pad_token_id] = -100
        decoder_input_ids = model.prepare_decoder_input_ids_from_labels(labels=labels)
        last_enc = enc; last_labels = labels; last_dec_in = decoder_input_ids
        capturer.set_batch_token_ids(enc_ids_bt=enc["input_ids"].detach().cpu() if args.capture_encoder else None,
                                     dec_ids_bt=decoder_input_ids.detach().cpu() if args.capture_decoder else None)
        with torch.no_grad():
            _ = model(input_ids=enc["input_ids"], attention_mask=enc["attention_mask"],
                      decoder_input_ids=decoder_input_ids, use_cache=False)
        if args.baseline_build and aggregators.baseline_total_tokens >= args.baseline_tokens_target:
            print(f"[baseline] token budget reached: {aggregators.baseline_total_tokens}"); break

    # Baseline I/O
    baseline = None
    if args.baseline_build:
        assert args.baseline_file is not None, "Provide --baseline_file to save baseline"
        aggregators.save_baseline(args.baseline_file, nL_enc, nL_dec)
        print(f"[baseline] saved to {args.baseline_file}")
    if args.baseline_use:
        assert args.baseline_file and os.path.exists(args.baseline_file), "Provide existing --baseline_file to load"
        baseline = TokenAggregators.load_baseline(args.baseline_file)
        print(f"[baseline] loaded from {args.baseline_file}")

    # Attention bands (vis only)
    probe_layers = [9,10,11]; probe_kinds = ["self","cross"]
    names = [_attn_module_path("dec", L, kind) for L in probe_layers for kind in probe_kinds]
    if last_enc is None or last_labels is None or last_dec_in is None:
        print("[warn] No batches processed; skipping attention visualization.")
    else:
        captures = {}
        with torch.no_grad():
            with hook_attn_outputs(model, names, captures):
                _ = model(input_ids=last_enc["input_ids"], attention_mask=last_enc["attention_mask"],
                          decoder_input_ids=last_dec_in, labels=last_labels, use_cache=False,
                          output_attentions=False, output_hidden_states=False)
        single_ids_for_attn = strings_to_ids_list(tok, args.token_single)
        assert len(single_ids_for_attn)>=1, "token_single yielded no ids"
        target_tok = single_ids_for_attn[0]
        for L in probe_layers:
            for kind in probe_kinds:
                key = _attn_module_path("dec", L, kind)
                if key not in captures: continue
                vec = compute_attn_band_for_token(captures[key], last_labels, target_tok)
                export_attn_band_maps(out_dir=args.out_dir, side="dec", layer_idx=L, kind=kind,
                                      vec_raw=vec, tag="single_token_first",
                                      enable_local_z=getattr(args,"enable_local_z",False),
                                      local_z_window=getattr(args,"local_z_window",31))

    # Export FFN maps & collect compare metrics
    print("Exporting FFN maps and figures ...")
    total_diff = 0.0; total_overlap_num = 0.0; total_overlap_den = 0.0
    layer_diffs = []

    def run_side(side: str, blocks, nL, enc_or_dec_dir, enc_or_dec_row):
        nonlocal total_diff, total_overlap_num, total_overlap_den, layer_diffs
        for li, block in enumerate(blocks):
            ffn = block.layer[2].DenseReluDense if side=="dec" else block.layer[1].DenseReluDense
            out_dir = os.path.join(args.out_dir, f"{side}_L{li:02d}")
            # per-word baseline override if provided
            bd, br = (None, None)
            if len(baseline_multi_ids)>0:
                bd, br = aggregators.word_baseline(side, li, baseline_multi_ids)
            elif enc_or_dec_dir is not None and enc_or_dec_row is not None:
                bd, br = enc_or_dec_dir[li], enc_or_dec_row[li]
            res = aggregators.export_views(out_dir, side, li, ffn.wo.weight, bd, br,
                                           enable_local_z=args.enable_local_z, local_z_window=args.local_z_window,
                                           pockets_axis=args.local_z_axis_for_pockets, token_single=args.token_single,
                                           pockets_mode=args.pockets_mode, weight_bands_by_wo=args.weight_bands_by_wo,
                                           compA_ids=compA_ids, compB_ids=compB_ids)
            if res and res.get("layer_diff_sum") is not None:
                layer_diffs.append({"side": side, "layer": li,
                                    "diff_sum": float(res["layer_diff_sum"]),
                                    "overlap": float(res["layer_overlap"]) if res["layer_overlap"] is not None else None})
                total_diff += float(res["layer_diff_sum"])
                # For total normalized overlap, accumulate sum(|A-B|) and sum(|A|+|B|)
                if res["layer_overlap"] is not None:
                    # We don't have per-layer denom directly; recompute from overlap * denom = numer.
                    # Instead, store numer = overlap (since overlap = numer/denom) requires denom; we can't reconstruct.
                    # As an approximation for overall, average overlaps weighted by layer size; here we just mean across layers.
                    pass

    enc_dir = baseline["enc_dir_down_mean"] if (baseline and "enc_dir_down_mean" in baseline) else None
    enc_row = baseline["enc_rowmass_mean"] if (baseline and "enc_rowmass_mean" in baseline) else None
    dec_dir = baseline["dec_dir_down_mean"] if (baseline and "dec_dir_down_mean" in baseline) else None
    dec_row = baseline["dec_rowmass_mean"] if (baseline and "dec_rowmass_mean" in baseline) else None

    if args.capture_encoder:
        run_side("enc", model.encoder.block, nL_enc, enc_dir, enc_row)
    if args.capture_decoder:
        run_side("dec", model.decoder.block, nL_dec, dec_dir, dec_row)

    # Write summary (layer-wise + totals). For overall overlap, report the unweighted mean of per-layer overlaps where available.
    overlap_values = [x["overlap"] for x in layer_diffs if x.get("overlap") is not None]
    overall_overlap = float(np.mean(overlap_values)) if overlap_values else None

    if layer_diffs:
        report = {"layers": layer_diffs, "total_diff_sum": float(total_diff), "overall_overlap_mean": overall_overlap}
        with open(os.path.join(args.out_dir, "token_compare_summary.json"), "w") as f: json.dump(report, f, indent=2)
        print("[compare] layer-wise sums/overlaps:")
        for d in layer_diffs:
            print(f"  {d['side']} L{d['layer']:02d}: diff_sum={d['diff_sum']:.3f}" + (f", overlap={d['overlap']:.3f}" if d['overlap'] is not None else ""))
        print(f"[compare] TOTAL diff_sum: {total_diff:.3f}")
        if overall_overlap is not None:
            print(f"[compare] OVERALL overlap (mean of layers): {overall_overlap:.3f}")

    print(f"Done. Outputs in {args.out_dir}")

if __name__ == "__main__":
    main()
