#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import argparse, json, os, re
from collections import defaultdict
from typing import List, Tuple, Optional, Dict

import numpy as np
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
from matplotlib.colors import TwoSlopeNorm, LinearSegmentedColormap
from contextlib import contextmanager
from transformers import RwkvConfig, RwkvForCausalLM, AutoTokenizer

# If you have these locally; otherwise comment them out.
from project_classes import ForgetfulT5, CustomDenseReluDense

# ---------------------------- Viz: cmap + guards ---------------------------- #

_CMAP_BBG = LinearSegmentedColormap.from_list("bbg", ["#FDE725", "#000000", "#440154"], N=256)


def _signed_power(x: np.ndarray, gamma: float) -> np.ndarray:
    gamma=1.5
    if gamma == 1.0: return x
    return np.sign(x) * (np.abs(x) ** gamma)

def _bounds_for_diverging(mat: np.ndarray) -> tuple[float, float]:
    m = float(np.nanmin(mat))
    M = float(np.nanmax(mat))
    if not (m < 0.0 < M):
        a = max(abs(m), abs(M))
        if a == 0.0:
            a = 1e-6
        return -a, a
    return m, M

def render_heatmap(mat: np.ndarray, out_png: str, title: str,
                   vclip: Optional[float]=99.5, gamma: float=1.8):
    plt.figure(figsize=(10, 4))
    m = np.nan_to_num(mat.astype(np.float32, copy=False), nan=0.0, posinf=0.0, neginf=0.0)
    if vclip is not None:
        vmax = np.percentile(np.abs(m), vclip)
        m = np.clip(m, -vmax, vmax)
    m = _signed_power(m, gamma)
    vmin, vmax = _bounds_for_diverging(m)
    norm = TwoSlopeNorm(vcenter=0.0, vmin=vmin, vmax=vmax)
    plt.imshow(m, aspect='auto', interpolation='nearest', cmap=_CMAP_BBG, norm=norm)
    plt.title(title); plt.xlabel("model dims / columns"); plt.ylabel("ff rows / rows")
    cbar = plt.colorbar(); cbar.set_label(f"power-normalized units (γ={gamma:.1f})")
    plt.tight_layout(); plt.savefig(out_png, dpi=220); plt.close()

def render_band_row(vec: np.ndarray, out_png: str, title: str,
                    repeat: int=6, vclip: Optional[float]=99.5, gamma: float=1.8):
    v = np.nan_to_num(vec.astype(np.float32, copy=False), nan=0.0, posinf=0.0, neginf=0.0)
    if vclip is not None:
        vmax = np.percentile(np.abs(v), vclip)
        v = np.clip(v, -vmax, vmax)
    v = _signed_power(v, gamma)
    mat = np.tile(v[None, :], (repeat, 1))
    vmin, vmax = _bounds_for_diverging(mat)
    plt.figure(figsize=(10, 2))
    norm = TwoSlopeNorm(vcenter=0.0, vmin=vmin, vmax=vmax)
    plt.imshow(mat, aspect='auto', interpolation='nearest', cmap=_CMAP_BBG, norm=norm)
    plt.title(title); plt.yticks([]); plt.xlabel("dimension")
    cbar = plt.colorbar(); cbar.set_label(f"power-normalized units (γ={gamma:.1f})")
    plt.tight_layout(); plt.savefig(out_png, dpi=220); plt.close()

# ---------------------------- Generic utils ---------------------------- #

def gelu(x): return F.gelu(x)
def silu(x): return F.silu(x)
def act_fn_for_proj(feed_forward_proj: str):
    proj = feed_forward_proj.lower()
    if "gelu" in proj: return gelu
    if "silu" in proj: return silu
    if "relu" in proj: return F.relu
    return gelu

def ensure_dir(path): os.makedirs(path, exist_ok=True)
def to_cpu_f32(x):    return x.detach().to(torch.float32).cpu()

def power_norm_1d(v: np.ndarray, eps: float = 1e-8) -> np.ndarray:
    scale = np.mean(np.abs(v)) + eps
    return (v / scale).astype(np.float32)

def power_norm_2d(m: np.ndarray, eps: float = 1e-8) -> np.ndarray:
    scale = np.mean(np.abs(m)) + eps
    return (m / scale).astype(np.float32)

def tokenize_single_piece(tok, s: str) -> List[int]:
    return tok.encode(s, add_special_tokens=False)

def strings_to_ids_list(tok, s: str) -> List[int]:
    ids = []
    for it in [x.strip() for x in s.split(",") if x.strip()] if isinstance(s, str) else []:
        id_list = tokenize_single_piece(tok, it)
        if len(id_list) == 0:
            print(f"[warn] '{it}' tokenized to nothing; skipping.")
            continue
        if len(id_list) > 1:
            print(f"[warn] '{it}' tokenized to {id_list}; using first id {id_list[0]}.")
        ids.append(id_list[0])
    return ids

def strings_to_ids_multi(tok, s_or_list) -> List[int]:
    if isinstance(s_or_list, str):
        s = s_or_list
    else:
        s = ",".join(s_or_list)
    out = []
    for it in [x.strip() for x in s.split(",") if s and x.strip()]:
        ids = tokenize_single_piece(tok, it)
        if len(ids) == 0: continue
        out.append(ids[0])
    return out

def read_pairs(path: Optional[str]) -> List[Tuple[str,str]]:
    if path is None or not os.path.exists(path):
        return [
            ("translate English to German: A small test.", "Ein kleiner Test."),
            ("Call me at 12:30 tomorrow.", "Ruf mich morgen um 12:30 an."),
            ("He said, \"Hello!\"", "Er sagte: „Hallo!“"),
            ("The price is $19.99.", "Der Preis beträgt 19,99 $."),
            ("Version 2.0 was released in 2024.", "Version 2.0 wurde 2024 veröffentlicht."),
        ]
    pairs = []
    with open(path, "r", encoding="utf-8") as f:
        for line in f:
            if "\t" in line:
                src, tgt = line.rstrip("\n").split("\t", 1)
                pairs.append((src, tgt))
    return pairs

def read_text_as_pairs(path: str, max_examples: int):
    pairs = []
    with open(path, "r", encoding="utf-8") as f:
        for line in f:
            line = line.strip()
            if not line: continue
            pairs.append((line, line))
            if len(pairs) >= max_examples: break
    return pairs

def _slugify_token_label(s: str) -> str:
    return (s.replace(" ", "_").replace("/", "_").replace("\\", "_")
              .replace("\t", "_").replace("\n", "_").replace("<", "")
              .replace(">", "").replace(":", "").replace("|", "")
              .replace("*", "").replace("?", "").replace('"', ""))[:64] or "token"

def compute_pockets(A_vec: np.ndarray, c_vec: np.ndarray, WoT: np.ndarray | None,
                    mode: str = "wo_times_eff") -> np.ndarray:
    if mode == "outer_eff_down":
        return (A_vec[:, None] * c_vec[None, :]).astype(np.float32)
    elif mode == "wo_times_eff":
        if WoT is None: raise ValueError("compute_pockets(mode='wo_times_eff') requires WoT")
        return (WoT * A_vec[:, None]).astype(np.float32)
    else:
        raise ValueError(f"Unknown pockets mode: {mode}")

def wo_locality_weights(Wo: torch.Tensor) -> tuple[np.ndarray, np.ndarray]:
    Wo_np = Wo.detach().float().cpu().numpy()
    per_ff  = np.linalg.norm(Wo_np, axis=0).astype(np.float32)   # len d_ff
    per_mod = np.linalg.norm(Wo_np, axis=1).astype(np.float32)   # len d_model
    return per_ff, per_mod

# ---------------------------- Local smoothing helpers ---------------------------- #

def _moving_stats_1d(x: np.ndarray, window: int) -> tuple[np.ndarray, np.ndarray]:
    w = int(max(1, window)); pad = w // 2
    k = np.ones(w, dtype=np.float32)
    xp = np.pad(x.astype(np.float32), (pad, pad), mode="reflect")
    mx = np.convolve(xp, k, mode="valid"); mx2 = np.convolve(xp * xp, k, mode="valid")
    mean = mx / w; var = np.maximum(mx2 / w - mean * mean, 0.0); std = np.sqrt(var + 1e-8)
    return mean, std

def local_zscore_1d(x: np.ndarray, window: int, robust: bool=True) -> np.ndarray:
    w = int(max(3, window))
    if w % 2 == 0: w += 1
    pad = w // 2; x32 = x.astype(np.float32, copy=False)
    xp = np.pad(x32, (pad, pad), mode="reflect")
    k = np.ones(w, dtype=np.float32) / w
    return np.convolve(xp, k, mode="valid")

# ---------------------------- Attention capture (MT5 only) ---------------------------- #

def _attn_module_path(side: str, layer_idx: int, kind: str) -> str:
    if side == "enc":
        assert kind == "self"; return f"encoder.block.{layer_idx}.layer.0.SelfAttention"
    else:
        if kind == "self": return f"decoder.block.{layer_idx}.layer.0.SelfAttention"
        elif kind == "cross": return f"decoder.block.{layer_idx}.layer.1.EncDecAttention"
        else: raise ValueError("kind must be 'self' or 'cross'")

@contextmanager
def hook_attn_outputs(model, names, store_dict):
    handles = []; named = dict(model.named_modules())
    def _to_cpu_tensor(x):
        if isinstance(x, torch.Tensor): return x.detach().float().cpu()
        if isinstance(x, (tuple, list)) and len(x)>0: return _to_cpu_tensor(x[0])
        if isinstance(x, dict):
            for k in ("last_hidden_state","hidden_states","output"):
                if k in x: return _to_cpu_tensor(x[k])
        raise TypeError(f"Unexpected attention hook output type: {type(x)}")
    try:
        for n in names:
            mod = named[n]
            def _mk_hook(nm):
                def fn(m, inp, out): store_dict[nm] = _to_cpu_tensor(out)
                return fn
            handles.append(mod.register_forward_hook(_mk_hook(n)))
        yield
    finally:
        for h in handles: h.remove()

def compute_attn_band_for_token(attn_tensor: torch.Tensor, labels: torch.Tensor, target_id: int) -> np.ndarray:
    dev = attn_tensor.device; lab = labels.detach().to(dev)
    B, T, D = attn_tensor.shape; flat = attn_tensor.view(B*T, D); lab_flat = lab.view(B*T)
    mask = (lab_flat == int(target_id))
    v = flat[mask].mean(dim=0) if mask.any() else torch.zeros(D, dtype=attn_tensor.dtype, device=dev)
    return v.cpu().numpy()

def export_attn_band_maps(out_dir: str, side: str, layer_idx: int, kind: str, vec_raw: np.ndarray, tag: str,
                          enable_local_z: bool=False, local_z_window: int=31):
    ensure_dir(out_dir); sub = os.path.join(out_dir, f"{side}_L{layer_idx}_{kind}_{tag}"); ensure_dir(sub)
    np.save(os.path.join(sub, "attn_out.npy"), vec_raw)
    vec_norm = power_norm_1d(vec_raw)
    render_band_row(vec_norm, os.path.join(sub, "attn_out_band.png"),
                    f"{side} L{layer_idx} {kind} {tag}: band (attn out, power-normed)", gamma=1.8)
    if enable_local_z:
        vec_lz = local_zscore_1d(vec_raw, local_z_window, robust=True)
        render_band_row(vec_lz, os.path.join(sub, "attn_out_band_localz.png"),
                        f"{side} L{layer_idx} {kind} {tag}: band (attn out) local-z", gamma=1.8)

# ---------------------------- Fast tail & hidden baseline ---------------------------- #

class TailFilterFast:
    """
    Fast two-sided tail selector for the last dimension.
    """
    def __init__(
        self,
        pct: float = 5.0,
        mode: str = "signed",           # 'signed' or 'abs'
        *,
        min_k: int = 1,
        inplace: bool = True,
        eps: float = 1e-12
    ):
        assert mode in ("signed", "abs")
        self.pct = max(0.0, float(pct))
        self.mode = mode
        self.min_k = int(max(1, min_k))
        self.inplace = bool(inplace)
        self.eps = float(eps)

    @torch.inference_mode()
    def apply_torch(
        self,
        V: torch.Tensor,               # shape [..., D]
        *,
        return_mask: bool = False,
        center: torch.Tensor | None = None,
        scale: torch.Tensor | None = None
    ) -> torch.Tensor:
        if center is not None:
            V = V - self._broadcast_to_last(center, V)
        if scale is not None:
            V = V / (self._broadcast_to_last(scale, V) + self.eps)

        if self.pct <= 0.0:
            return V if not return_mask else torch.ones_like(V, dtype=torch.bool)

        V = V.contiguous()
        D = V.shape[-1]
        k = max(self.min_k, int(np.ceil(D * (self.pct / 100.0))))
        if k >= D:
            return V if not return_mask else torch.ones_like(V, dtype=torch.bool)

        if self.mode == "signed":
            tau_pos = torch.kthvalue(-V, k, dim=-1).values.neg().unsqueeze(-1)
            tau_neg = torch.kthvalue( V, k, dim=-1).values.unsqueeze(-1)
            M = (V >= tau_pos) | (V <= tau_neg)
        else:
            A = V.abs()
            tau_abs = torch.kthvalue(-A, k, dim=-1).values.neg().unsqueeze(-1)
            M = (A >= tau_abs)

        if return_mask:
            return M

        if self.inplace and V.is_floating_point() and V.requires_grad is False:
            V.mul_(M)
            return V
        else:
            return V * M

    @staticmethod
    def _broadcast_to_last(x: torch.Tensor, ref: torch.Tensor) -> torch.Tensor:
        if x.shape == ref.shape:
            return x
        if x.dim() == ref.dim() - 1 and x.shape == ref.shape[:-1]:
            return x.unsqueeze(-1)
        if x.numel() == ref.shape[-1] and x.dim() == 1:
            shape = (1,) * (ref.dim() - 1) + (ref.shape[-1],)
            return x.view(shape)
        return x


class HiddenBaseline:
    """
    Maintains / serves per-(side, layer) baseline vectors for hidden states.
    """
    def __init__(self, mode: str = "running", loaded: Optional[dict] = None):
        self.mode = mode
        self.loaded = loaded or {}
        self.sum = defaultdict(lambda: None)
        self.cnt = defaultdict(int)

    @staticmethod
    def _norm_occ(H: torch.Tensor, eps: float = 1e-8) -> torch.Tensor:
        scale = H.abs().mean(dim=-1, keepdim=True) + eps
        return H / scale

    @torch.no_grad()
    def accumulate(self, side: str, layer: int, Hn_btD: torch.Tensor, eligible_mask_bt: Optional[torch.Tensor] = None):
        if self.mode != "running":
            return
        key = (side, layer)
        if eligible_mask_bt is not None:
            sel = eligible_mask_bt.view(-1)
            if not sel.any():
                return
            Hsel = Hn_btD.view(-1, Hn_btD.shape[-1])[sel]
        else:
            Hsel = Hn_btD.view(-1, Hn_btD.shape[-1])
        if Hsel.numel() == 0:
            return
        m = Hsel.mean(dim=0)
        if self.sum[key] is None:
            self.sum[key] = m.detach()
        else:
            self.sum[key] = self.sum[key] + m.detach()
        self.cnt[key] += 1

    @torch.no_grad()
    def get_mu(self, side: str, layer: int, device: torch.device, D: int) -> torch.Tensor:
        if self.mode == "none":
            return torch.zeros(D, device=device)
        if self.mode == "load":
            name = f"{side}_{layer}"
            if name in self.loaded:
                v = torch.as_tensor(self.loaded[name], device=device, dtype=torch.float32)
                if v.shape[-1] != D:
                    raise ValueError(f"[hidden_baseline] Shape mismatch for {name}: file {v.shape[-1]} vs D {D}")
                return v
            return torch.zeros(D, device=device)
        key = (side, layer)
        if self.cnt[key] > 0 and self.sum[key] is not None:
            return (self.sum[key] / float(self.cnt[key])).to(device)
        return torch.zeros(D, device=device)

# ---------------------------- FFN Aggregators (MT5) ---------------------------- #

class TokenAggregators:
    """
    Aggregates FFN streams (MT5) and hidden-state maps (MT5 & RWKV) and supports A/B comparison.
    Now also supports self-partitioning for self-consistency checks (part1 vs part2).
    """

    def __init__(self, n_layers_enc: int, n_layers_dec: int, d_ff: int, d_model: int,
                 target_single_ids: list[int], set_a_ids: list[int], set_b_ids: list[int],
                 special_ids: list[int], want_baseline: bool, baseline_exclude_ids: set):
        self.n_layers_enc = n_layers_enc
        self.n_layers_dec = n_layers_dec
        self.d_ff = d_ff
        self.d_model = d_model

        self.target_single_ids = set(target_single_ids or [])
        self.set_a_ids = set(set_a_ids or [])
        self.set_b_ids = set(set_b_ids or [])
        self.special_ids = set(special_ids or [])
        self.want_baseline = bool(want_baseline)
        self.baseline_exclude_ids = set(baseline_exclude_ids or [])

        # ---------------- Raw sums (FFN legacy CPU path) ----------------
        self.first_single_eff = defaultdict(lambda: None)
        self.first_single_down = defaultdict(lambda: None)
        self.sum_single_eff = defaultdict(lambda: np.zeros(d_ff, dtype=np.float32))
        self.sum_single_down = defaultdict(lambda: np.zeros(d_model, dtype=np.float32))
        self.cnt_single = defaultdict(int)

        self.sum_generic_eff = defaultdict(lambda: np.zeros(d_ff, dtype=np.float32))
        self.sum_generic_down = defaultdict(lambda: np.zeros(d_model, dtype=np.float32))
        self.cnt_generic = defaultdict(int)

        self.sum_setA_eff = defaultdict(lambda: np.zeros(d_ff, dtype=np.float32))
        self.sum_setA_down = defaultdict(lambda: np.zeros(d_model, dtype=np.float32))
        self.cnt_setA = defaultdict(int)

        self.sum_setB_eff = defaultdict(lambda: np.zeros(d_ff, dtype=np.float32))
        self.sum_setB_down = defaultdict(lambda: np.zeros(d_model, dtype=np.float32))
        self.cnt_setB = defaultdict(int)

        # ---------------- Baseline accumulators (FFN legacy) ----------------
        self.base_sum_dir_down = defaultdict(lambda: np.zeros(d_model, dtype=np.float32))
        self.base_cnt_dir_down = defaultdict(int)
        self.base_sum_rowmass  = defaultdict(lambda: np.zeros(d_ff, dtype=np.float32))
        self.base_cnt_rowmass  = defaultdict(int)
        self.baseline_total_tokens = 0

        self.single_sum_dir_down = defaultdict(lambda: np.zeros(d_model, dtype=np.float32))
        self.single_cnt_dir_down = defaultdict(int)
        self.single_sum_rowmass  = defaultdict(lambda: np.zeros(d_ff, dtype=np.float32))
        self.single_cnt_rowmass  = defaultdict(int)

        # ---------------- FFN token-vs-token compare (legacy CPU) ----------------
        self.compA_sum_S = defaultdict(lambda: None)   # np[d_ff, d_model]
        self.compA_cnt   = defaultdict(int)
        self.compB_sum_S = defaultdict(lambda: None)
        self.compB_cnt   = defaultdict(int)

        # ---------------- GPU device & cached Wo norms ----------------
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self._cached_wo_norms = {}  # (side,layer) -> (per_ff, per_mod)

        # ---------------- FFN token-vs-token compare (CUDA path) ----------------
        self.compA_sum_S_t = defaultdict(lambda: None)  # torch [F,M]
        self.compB_sum_S_t = defaultdict(lambda: None)
        self.compA_cnt_t   = defaultdict(int)
        self.compB_cnt_t   = defaultdict(int)

        # -------- NEW: FFN self partitions (CUDA path) --------
        self.selfA1_sum_S_t = defaultdict(lambda: None)
        self.selfA2_sum_S_t = defaultdict(lambda: None)
        self.selfA1_cnt_t   = defaultdict(int)
        self.selfA2_cnt_t   = defaultdict(int)
        self.selfB1_sum_S_t = defaultdict(lambda: None)
        self.selfB2_sum_S_t = defaultdict(lambda: None)
        self.selfB1_cnt_t   = defaultdict(int)
        self.selfB2_cnt_t   = defaultdict(int)

        # ---------------- Hidden-state compare (CUDA path) ----------------
        self.hidA_sum_S_t = defaultdict(lambda: None)  # torch [D,D]
        self.hidB_sum_S_t = defaultdict(lambda: None)
        self.hidA_cnt_t   = defaultdict(int)
        self.hidB_cnt_t   = defaultdict(int)

        # -------- NEW: Hidden self partitions (CUDA path) --------
        self.hidA1_sum_S_t = defaultdict(lambda: None)  # torch [D,D]
        self.hidA2_sum_S_t = defaultdict(lambda: None)
        self.hidA1_cnt_t   = defaultdict(int)
        self.hidA2_cnt_t   = defaultdict(int)
        self.hidB1_sum_S_t = defaultdict(lambda: None)
        self.hidB2_sum_S_t = defaultdict(lambda: None)
        self.hidB1_cnt_t   = defaultdict(int)
        self.hidB2_cnt_t   = defaultdict(int)

        # Set ids on device for fast masking
        self.set_a_ids_t = torch.tensor(sorted(self.set_a_ids), device=self.device, dtype=torch.long) \
                           if len(self.set_a_ids) else torch.empty(0, device=self.device, dtype=torch.long)
        self.set_b_ids_t = torch.tensor(sorted(self.set_b_ids), device=self.device, dtype=torch.long) \
                           if len(self.set_b_ids) else torch.empty(0, device=self.device, dtype=torch.long)

    # ----------------------------------------------------------------------- #
    # FFN: fast CUDA update (MT5 only; called from FFN hooks)
    # ----------------------------------------------------------------------- #
    @torch.inference_mode()
    def update_ffn_batch_cuda(self, side: str, layer: int,
                              A_btf: torch.Tensor, C_btm: torch.Tensor, token_ids_bt: torch.Tensor,
                              do_self_partitions: bool = False):
        """
        A: [B,T,F] (power-normed here), C: [B,T,M] (power-normed here), ids: [B,T]
        Accumulates sum over occurrences of (A^T @ C) separately for set A and set B.
        Also (optionally) splits each set into two partitions by time-step parity for self-consistency.
        """
        key = (side, layer)
        if A_btf.numel() == 0:
            return

        a = A_btf / (A_btf.abs().mean(dim=-1, keepdim=True).clamp_min_(1e-8))  # [B,T,F]
        c = C_btm / (C_btm.abs().mean(dim=-1, keepdim=True).clamp_min_(1e-8))  # [B,T,M]

        ids = token_ids_bt.to(A_btf.device)
        mA = torch.isin(ids, self.set_a_ids_t) if self.set_a_ids_t.numel() else torch.zeros_like(ids, dtype=torch.bool)
        mB = torch.isin(ids, self.set_b_ids_t) if self.set_b_ids_t.numel() else torch.zeros_like(ids, dtype=torch.bool)

        def _acc(mask, dst_sum, dst_cnt):
            if mask.any():
                A_occ = a[mask]                  # [K,F]
                C_occ = c[mask]                  # [K,M]
                S_sum = A_occ.transpose(0, 1) @ C_occ   # [F,M]
                if dst_sum[key] is None:
                    dst_sum[key] = S_sum
                else:
                    dst_sum[key].add_(S_sum)
                dst_cnt[key] += A_occ.shape[0]

        _acc(mA, self.compA_sum_S_t, self.compA_cnt_t)
        _acc(mB, self.compB_sum_S_t, self.compB_cnt_t)

        # Self partitions by even/odd timestep (deterministic)
        if do_self_partitions:
            T = A_btf.shape[1]
            pos = torch.arange(T, device=A_btf.device).view(1, T).expand_as(ids)
            even = (pos % 2 == 0)

            def _acc_self(mask_total, even_mask, dst_sum1, dst_cnt1, dst_sum2, dst_cnt2):
                m1 = mask_total & even_mask
                m2 = mask_total & (~even_mask)
                _acc(m1, dst_sum1, dst_cnt1)
                _acc(m2, dst_sum2, dst_cnt2)

            _acc_self(mA, even, self.selfA1_sum_S_t, self.selfA1_cnt_t,
                               self.selfA2_sum_S_t, self.selfA2_cnt_t)
            _acc_self(mB, even, self.selfB1_sum_S_t, self.selfB1_cnt_t,
                               self.selfB2_sum_S_t, self.selfB2_cnt_t)

    # ----------------------------------------------------------------------- #
    # FFN: slower CPU update (legacy / baseline building); kept as-is
    # ----------------------------------------------------------------------- #
    def _update_positions(self, side: str, layer: int,
                          A_btF: torch.Tensor, c_btD: torch.Tensor,
                          token_ids_bt: torch.Tensor, count_for_baseline: bool):
        A = A_btF.detach().cpu().numpy()
        C = c_btD.detach().cpu().numpy()
        TIDs = token_ids_bt.detach().cpu().numpy()
        B, T, _ = A.shape
        key = (side, layer)

        def _pn1d(v: np.ndarray, eps: float = 1e-8) -> np.ndarray:
            s = np.mean(np.abs(v)) + eps
            return (v / s).astype(np.float32)

        for b in range(B):
            for t in range(T):
                tid = int(TIDs[b, t])
                if tid in self.special_ids:
                    continue
                a = A[b, t]; c = C[b, t]

                if tid in self.target_single_ids:
                    if self.first_single_eff[key] is None:  self.first_single_eff[key]  = a.copy()
                    if self.first_single_down[key] is None: self.first_single_down[key] = c.copy()
                    self.sum_single_eff[key]  += a; self.sum_single_down[key] += c; self.cnt_single[key] += 1
                elif tid in self.set_a_ids:
                    self.sum_setA_eff[key]  += a; self.sum_setA_down[key] += c; self.cnt_setA[key] += 1
                elif tid in self.set_b_ids:
                    self.sum_setB_eff[key]  += a; self.sum_setB_down[key] += c; self.cnt_setB[key] += 1
                else:
                    self.sum_generic_eff[key]  += a; self.sum_generic_down[key] += c; self.cnt_generic[key] += 1

                cnorm = np.linalg.norm(c) + 1e-8
                c_hat = c / cnorm
                denom = np.abs(a).sum() + 1e-8
                p = np.abs(a) / denom
                if tid in self.target_single_ids:
                    self.single_sum_dir_down[key] += c_hat; self.single_cnt_dir_down[key] += 1
                    self.single_sum_rowmass[key] += p;      self.single_cnt_rowmass[key] += 1
                if count_for_baseline and (tid not in self.baseline_exclude_ids):
                    self.base_sum_dir_down[key] += c_hat; self.base_cnt_dir_down[key] += 1
                    self.base_sum_rowmass[key] += p;      self.base_cnt_rowmass[key] += 1
                    self.baseline_total_tokens += 1

                if (tid in self.set_a_ids) or (tid in self.set_b_ids):
                    a_n = _pn1d(a); c_n = _pn1d(c)
                    S_occ = (a_n[:, None] * c_n[None, :]).astype(np.float32)
                    if tid in self.set_a_ids:
                        if self.compA_sum_S[key] is None: self.compA_sum_S[key] = np.zeros_like(S_occ, dtype=np.float32)
                        self.compA_sum_S[key] += S_occ; self.compA_cnt[key] += 1
                    else:
                        if self.compB_sum_S[key] is None: self.compB_sum_S[key] = np.zeros_like(S_occ, dtype=np.float32)
                        self.compB_sum_S[key] += S_occ; self.compB_cnt[key] += 1

    # ----------------------------------------------------------------------- #
    # Hidden: fast CUDA update (MT5 & RWKV) with optional self partitions
    # ----------------------------------------------------------------------- #
    @torch.inference_mode()
    def update_hidden_side_gpu(self, side: str, layer: int,
                               H_btD: torch.Tensor, token_ids_bt: torch.Tensor,
                               do_self_partitions: bool = False):
        """
        H_btD: [B,T,D] float32 on device
        token_ids_bt: [B,T] long on same device
        S_occ = (|h|_pn ⊗ h_pn), accumulate A/B buckets and (optionally) A/B self partitions.
        """
        key = (side, layer)
        if H_btD.numel() == 0:
            return

        ids = token_ids_bt.to(H_btD.device)
        denom = H_btD.abs().mean(dim=-1, keepdim=True).clamp_min_(1e-8)
        Hn = H_btD / denom            # signed normalized
        An = H_btD.abs() / denom      # magnitude normalized

        mA = torch.isin(ids, self.set_a_ids_t) if self.set_a_ids_t.numel() else torch.zeros_like(ids, dtype=torch.bool)
        mB = torch.isin(ids, self.set_b_ids_t) if self.set_b_ids_t.numel() else torch.zeros_like(ids, dtype=torch.bool)

        def _acc(mask, dst_sum, dst_cnt):
            if not mask.any():
                return
            An_sel = An[mask]  # [K,D]
            Hn_sel = Hn[mask]  # [K,D]
            S_sum = An_sel.transpose(0, 1) @ Hn_sel
            if dst_sum[key] is None:
                dst_sum[key] = S_sum
            else:
                dst_sum[key].add_(S_sum)
            dst_cnt[key] += An_sel.shape[0]

        _acc(mA, self.hidA_sum_S_t, self.hidA_cnt_t)
        _acc(mB, self.hidB_sum_S_t, self.hidB_cnt_t)

        if do_self_partitions:
            T = H_btD.shape[1]
            pos = torch.arange(T, device=H_btD.device).view(1, T).expand_as(ids)
            even = (pos % 2 == 0)

            def _acc_self(mask_total, even_mask, dst_sum1, dst_cnt1, dst_sum2, dst_cnt2):
                _acc(mask_total & even_mask, dst_sum1, dst_cnt1)
                _acc(mask_total & (~even_mask), dst_sum2, dst_cnt2)

            _acc_self(mA, even, self.hidA1_sum_S_t, self.hidA1_cnt_t,
                              self.hidA2_sum_S_t, self.hidA2_cnt_t)
            _acc_self(mB, even, self.hidB1_sum_S_t, self.hidB1_cnt_t,
                              self.hidB2_sum_S_t, self.hidB2_cnt_t)

    # ----------------------------------------------------------------------- #
    # Baseline I/O for FFN (unchanged)
    # ----------------------------------------------------------------------- #
    def baseline_means(self, side: str, n_layers: int):
        mean_dir, mean_row = [], []
        for L in range(n_layers):
            k = (side, L)
            mean_dir.append(self.base_sum_dir_down[k] / max(1, self.base_cnt_dir_down[k]) if self.base_cnt_dir_down[k] > 0 else np.zeros(self.d_model, np.float32))
            mean_row.append(self.base_sum_rowmass[k] / max(1, self.base_cnt_rowmass[k]) if self.base_cnt_rowmass[k] > 0 else np.zeros(self.d_ff, np.float32))
        return np.stack(mean_dir, 0), np.stack(mean_row, 0)

    def save_baseline(self, path: str, nL_enc: int, nL_dec: int):
        enc_dir, enc_row = self.baseline_means("enc", nL_enc)
        dec_dir, dec_row = self.baseline_means("dec", nL_dec)
        np.savez(path, enc_dir_down_mean=enc_dir, enc_rowmass_mean=enc_row,
                 dec_dir_down_mean=dec_dir, dec_rowmass_mean=dec_row)

    @staticmethod
    def load_baseline(path: str) -> dict[str, np.ndarray]:
        arrs = np.load(path)
        return {k: arrs[k] for k in arrs.files}

    def word_baseline(self, side: str, layer: int, ids: list[int]):
        key = (side, layer)
        if self.base_cnt_dir_down[key] > 0 and self.base_cnt_rowmass[key] > 0:
            mu_dir = self.base_sum_dir_down[key] / self.base_cnt_dir_down[key]
            mu_row = self.base_sum_rowmass[key] / self.base_cnt_rowmass[key]
            return mu_dir, mu_row
        return np.zeros(self.d_model, np.float32), np.zeros(self.d_ff, np.float32)

    # ----------------------------------------------------------------------- #
    # FFN exporter (compatibility with both CPU & CUDA compare accumulators)
    # ----------------------------------------------------------------------- #
    def export_views(self, out_dir: str, side: str, layer: int, Wo_dF: torch.Tensor,
                     baseline_dir_down: np.ndarray | None = None,
                     baseline_rowmass: np.ndarray | None = None,
                     *, enable_local_z: bool = True, local_z_window: int = 35,
                     pockets_axis: str = "columns", token_single: str = "",
                     pockets_mode: str = "outer_eff_down", weight_bands_by_wo: bool = False,
                     compA_ids: list[int] | None = None, compB_ids: list[int] | None):
        import os
        ensure_dir(out_dir)
        key = (side, layer)

        if (side, layer) not in self._cached_wo_norms:
            self._cached_wo_norms[(side, layer)] = wo_locality_weights(Wo_dF)
        per_ff_w, per_mod_w = self._cached_wo_norms[(side, layer)]
        WoT_np = Wo_dF.detach().float().cpu().numpy().T

        def save_pack(tag: str, A_vec: np.ndarray | None, c_vec: np.ndarray | None):
            if A_vec is None or c_vec is None:
                return
            pack_dir = os.path.join(out_dir, tag)
            ensure_dir(pack_dir)
            A_vec = power_norm_1d(A_vec)
            c_vec = power_norm_1d(c_vec)
            if weight_bands_by_wo:
                A_vec = (A_vec * (per_ff_w / (per_ff_w.mean() + 1e-8))).astype(np.float32)
                c_vec = (c_vec * (per_mod_w / (per_mod_w.mean() + 1e-8))).astype(np.float32)
            np.save(os.path.join(pack_dir, "eff_A.npy"), A_vec)
            np.save(os.path.join(pack_dir, "down_c.npy"), c_vec)
            S = compute_pockets(A_vec, c_vec, WoT_np, mode=pockets_mode)
            S = power_norm_2d(S)
            np.save(os.path.join(pack_dir, "pockets_S.npy"), S)
            if not getattr(save_pack, "_skip_png", False):
                render_heatmap(S, os.path.join(pack_dir, "pockets_S.png"),
                               f"{side} L{layer} {tag}: pockets S", gamma=1.6)
                render_band_row(c_vec, os.path.join(pack_dir, "down_c.png"),
                                f"{side} L{layer} {tag}: band (down c, power-normed)", gamma=1.8)
                render_band_row(A_vec, os.path.join(pack_dir, "eff_A.png"),
                                f"{side} L{layer} {tag}: eff A (power-normed)", gamma=1.8)
                if enable_local_z:
                    tok_slug = _slugify_token_label(token_single)
                    c_lz = local_zscore_1d(c_vec, local_z_window)
                    A_lz = local_zscore_1d(A_vec, local_z_window)
                    np.save(os.path.join(pack_dir, f"down_c_localz__{tok_slug}.npy"), c_lz)
                    np.save(os.path.join(pack_dir, f"eff_A_localz__{tok_slug}.npy"), A_lz)
                    render_band_row(c_lz, os.path.join(pack_dir, f"down_c_localz__{tok_slug}.png"),
                                    f"{side} L{layer} {tag}: band (down c) local-z [{tok_slug}]", gamma=1.8)
                    render_band_row(A_lz, os.path.join(pack_dir, f"eff_A_localz__{tok_slug}.png"),
                                    f"{side} L{layer} {tag}: eff A local-z [{tok_slug}]", gamma=1.8)
                    S_lz = (A_lz[:, None] * c_lz[None, :]).astype(np.float32)
                    np.save(os.path.join(pack_dir, f"pockets_S_localz__{tok_slug}.npy"), S_lz)
                    render_heatmap(S_lz, os.path.join(pack_dir, f"pockets_S_localz__{tok_slug}.png"),
                                   f"{side} L{layer} {tag}: pockets S (A_lz ⊗ c_lz)", gamma=1.6)

        if self.first_single_eff[key] is not None and self.first_single_down[key] is not None:
            save_pack("single_token_first", self.first_single_eff[key], self.first_single_down[key])
        if self.cnt_single[key] > 0:
            A = self.sum_single_eff[key] / max(1, self.cnt_single[key])
            c = self.sum_single_down[key] / max(1, self.cnt_single[key])
            save_pack("mean_single_token", A, c)
        if self.cnt_generic[key] > 0:
            A_g = self.sum_generic_eff[key] / max(1, self.cnt_generic[key])
            c_g = self.sum_generic_down[key] / max(1, self.cnt_generic[key])
            if self.cnt_single[key] > 0:
                A_s = self.sum_single_eff[key] / max(1, self.cnt_single[key])
                c_s = self.sum_single_down[key] / max(1, self.cnt_single[key])
                save_pack("mean_single_minus_generic", A_s - A_g, c_s - c_g)
            if self.cnt_setA[key] > 0 and self.cnt_setB[key] > 0:
                A_a = self.sum_setA_eff[key] / max(1, self.cnt_setA[key])
                c_a = self.sum_setA_down[key] / max(1, self.cnt_setA[key])
                A_b = self.sum_setB_eff[key] / max(1, self.cnt_setB[key])
                c_b = self.sum_setB_down[key] / max(1, self.cnt_setB[key])
                save_pack("mean_setA_minus_setB", A_a - A_b, c_a - c_b)

        if baseline_dir_down is not None and baseline_rowmass is not None and self.single_cnt_dir_down[key] > 0:
            mu_dir = self.single_sum_dir_down[key] / max(1, self.single_cnt_dir_down[key])
            mu_p   = self.single_sum_rowmass[key] / max(1, self.single_cnt_rowmass[key])
            delta_dir = power_norm_1d(mu_dir - baseline_dir_down)
            delta_p   = power_norm_1d(mu_p   - baseline_rowmass)
            if weight_bands_by_wo:
                per_ff_w, per_mod_w = self._cached_wo_norms[(side, layer)]
                delta_dir = (delta_dir * (per_mod_w / (per_mod_w.mean() + 1e-8))).astype(np.float32)
                delta_p   = (delta_p   * (per_ff_w  / (per_ff_w.mean()  + 1e-8))).astype(np.float32)
            pack_dir = os.path.join(out_dir, "single_vs_baseline_dir")
            ensure_dir(pack_dir)
            np.save(os.path.join(pack_dir, "delta_dir_down.npy"), delta_dir)
            np.save(os.path.join(pack_dir, "delta_rowmass_eff.npy"), delta_p)
            S_dir = compute_pockets(delta_p, delta_dir, WoT_np, mode=pockets_mode)
            S_dir = power_norm_2d(S_dir)
            np.save(os.path.join(pack_dir, "pockets_S_dir.npy"), S_dir)
            render_heatmap(S_dir, os.path.join(pack_dir, "pockets_S_dir.png"),
                           f"{side} L{layer} single vs baseline: pockets", gamma=1.6)
            render_band_row(delta_dir, os.path.join(pack_dir, "delta_dir_down.png"),
                            f"{side} L{layer} single vs baseline: band Δĉ", gamma=1.8)

        layer_diff_sum, layer_overlap = None, None

        has_cuda_A = (self.compA_sum_S_t[key] is not None) and (self.compA_cnt_t[key] > 0)
        has_cuda_B = (self.compB_sum_S_t[key] is not None) and (self.compB_cnt_t[key] > 0)

        if has_cuda_A and has_cuda_B:
            SA = self.compA_sum_S_t[key] / float(self.compA_cnt_t[key])
            SB = self.compB_sum_S_t[key] / float(self.compB_cnt_t[key])
            D  = (SA - SB).abs()
            layer_diff_sum = float(D.sum().item())
            denom = (SA.abs().sum() + SB.abs().sum()).clamp_min(1e-8)
            layer_overlap = float((D.sum() / denom).item())

            pack_dir = os.path.join(out_dir, "compare_token_diff")
            ensure_dir(pack_dir)
            SA_np = power_norm_2d(SA.detach().cpu().numpy())
            SB_np = power_norm_2d(SB.detach().cpu().numpy())
            D_np  = D.detach().cpu().numpy()
            np.save(os.path.join(pack_dir, "mapA.npy"), SA_np)
            np.save(os.path.join(pack_dir, "mapB.npy"), SB_np)
            np.save(os.path.join(pack_dir, "diff_map.npy"), D_np)
            render_heatmap(SA_np, os.path.join(pack_dir, "mapA.png"),
                           f"{side} L{layer} token A map (avg per-occurrence norm)", gamma=1.6)
            render_heatmap(SB_np, os.path.join(pack_dir, "mapB.png"),
                           f"{side} L{layer} token B map (avg per-occurrence norm)", gamma=1.6)
            render_heatmap(D_np, os.path.join(pack_dir, "diff_map.png"),
                           f"{side} L{layer} |A - B| (difference) • overlap={layer_overlap:.3f}", gamma=1.6)

        else:
            if (self.compA_cnt[key] > 0) and (self.compB_cnt[key] > 0):
                SA = (self.compA_sum_S[key] / max(1, self.compA_cnt[key])).astype(np.float32)
                SB = (self.compB_sum_S[key] / max(1, self.compB_cnt[key])).astype(np.float32)
                D = np.abs(SA - SB).astype(np.float32)
                layer_diff_sum = float(D.sum())
                denom = (np.abs(SA) + np.abs(SB)).sum() + 1e-8
                layer_overlap = float(D.sum() / denom)

                pack_dir = os.path.join(out_dir, "compare_token_diff")
                ensure_dir(pack_dir)
                np.save(os.path.join(pack_dir, "mapA.npy"), power_norm_2d(SA))
                np.save(os.path.join(pack_dir, "mapB.npy"), power_norm_2d(SB))
                np.save(os.path.join(pack_dir, "diff_map.npy"), D)
                render_heatmap(power_norm_2d(SA), os.path.join(pack_dir, "mapA.png"),
                               f"{side} L{layer} token A map (avg per-occurrence norm)", gamma=1.6)
                render_heatmap(power_norm_2d(SB), os.path.join(pack_dir, "mapB.png"),
                               f"{side} L{layer} token B map (avg per-occurrence norm)", gamma=1.6)
                render_heatmap(D, os.path.join(pack_dir, "diff_map.png"),
                               f"{side} L{layer} |A - B| (difference) • overlap={layer_overlap:.3f}", gamma=1.6)

        return {"layer_diff_sum": layer_diff_sum, "layer_overlap": layer_overlap}

# ---------------------------- FFN Capture (MT5) ---------------------------- #

class FFNCapture:
    """
    Hooks all FFNs and streams, updates TokenAggregators.
    """
    def __init__(self, model, tokenizer, aggregators: TokenAggregators,
                 capture_encoder: bool, capture_decoder: bool,
                 do_self_partitions: bool = False):
        self.model = model; self.tok = tokenizer; self.agg = aggregators
        self.capture_encoder = capture_encoder; self.capture_decoder = capture_decoder
        self.act = act_fn_for_proj(model.config.feed_forward_proj)
        self.handles = []; self._enc_ids_bt = None; self._dec_ids_bt = None
        self.do_self_partitions = do_self_partitions
        if capture_encoder:
            for li, block in enumerate(model.encoder.block):
                self._register_ffn_hook(block.layer[1].DenseReluDense, "enc", li)
        if capture_decoder:
            for li, block in enumerate(model.decoder.block):
                self._register_ffn_hook(block.layer[2].DenseReluDense, "dec", li)

    def _register_ffn_hook(self, ffn, side: str, layer: int):
        has_gated = hasattr(ffn, "wi_0") and hasattr(ffn, "wi_1")

        def hook_fn(module, inputs, output):
            x = inputs[0]
            with torch.no_grad():
                if has_gated:
                    u = ffn.wi_0(x); v = ffn.wi_1(x)
                    A = self.act(u) * v  # [B,T,F]
                else:
                    u = ffn.wi(x)
                    A = self.act(u)  # [B,T,F]
                Wo = ffn.wo.weight  # [M,F]
                c = torch.einsum("btf,mf->btm", A, Wo)  # [B,T,M]

                ids = self._enc_ids_bt if side == "enc" else self._dec_ids_bt
                if ids is None: return
                self.agg.update_ffn_batch_cuda(side, layer, A, c, ids, do_self_partitions=self.do_self_partitions)
        self.handles.append(ffn.register_forward_hook(hook_fn))

    def set_batch_token_ids(self, enc_ids_bt: Optional[torch.Tensor], dec_ids_bt: Optional[torch.Tensor]):
        self._enc_ids_bt = enc_ids_bt; self._dec_ids_bt = dec_ids_bt

    def close(self):
        for h in self.handles: h.remove()
        self.handles = []

# ---------------------------- Hidden-state Capture (MT5 & RWKV) ---------------------------- #

class HiddenStateCapture:
    """
    Captures per-layer hidden states (post-block outputs).
    """
    def __init__(self, model, capture_encoder: bool, capture_decoder: bool):
        self.model = model
        self.capture_encoder = capture_encoder
        self.capture_decoder = capture_decoder
        self.handles = []
        self.enc = {}
        self.dec = {}

        def _hook_store(dst_dict, lid):
            def fn(m, inp, out):
                x = out[0] if isinstance(out, (tuple, list)) else out
                dst_dict[lid] = x.detach().float().cpu()
            return fn

        if hasattr(model, "encoder") and hasattr(model.encoder, "block") and capture_encoder:
            for i, blk in enumerate(model.encoder.block):
                self.handles.append(blk.register_forward_hook(_hook_store(self.enc, i)))
        if hasattr(model, "decoder") and hasattr(model.decoder, "block") and capture_decoder:
            for i, blk in enumerate(model.decoder.block):
                self.handles.append(blk.register_forward_hook(_hook_store(self.dec, i)))

        if hasattr(model, "rwkv") and hasattr(model.rwkv, "blocks"):
            if capture_decoder:
                for i, blk in enumerate(model.rwkv.blocks):
                    self.handles.append(blk.register_forward_hook(_hook_store(self.dec, i)))

    def close(self):
        for h in self.handles: h.remove()
        self.handles = []

# ---------------------------- RWKV helpers ---------------------------- #

def _align_time_mix_names(sd: dict) -> dict:
    pat = re.compile(r'(\.time_mix_)([kr])(\b)')
    def repl(m: re.Match) -> str:
        prefix, letter, boundary = m.groups()
        return prefix + ('key' if letter == 'k' else 'receptance') + boundary
    out = {}; changed = 0
    for k, v in sd.items():
        k2, n = pat.subn(repl, k)
        if n: changed += n
        out[k2] = v
    print(f"[rwkv] time_mix remap: changed={changed}")
    return out

def _strip_prefix_rwkv(k: str) -> str:
    return re.sub(r'^(?:module\.|model\.)', '', k, count=1)

def _infer_rwkv_cfg(sd: dict, tokenizer, fallback_hidden=None, fallback_layers=None) -> RwkvConfig:
    keys = list(sd.keys())
    vocab_size = None
    for cand in ("rwkv.embeddings.weight", "rwkv.emb.weight", "embeddings.weight", "emb.weight"):
        if cand in sd:
            vocab_size = sd[cand].shape[0]; break
    if vocab_size is None:
        vocab_size = len(tokenizer)

    hidden_size = None
    for k in keys:
        t = sd[k]
        if k.endswith(("ln1.weight","pre_ln.weight","ln2.weight")) and t.ndim == 1:
            hidden_size = int(t.shape[0]); break
        if "time_mix_" in k and t.ndim >= 1:
            hidden_size = int(t.shape[-1]); break
    if hidden_size is None:
        hidden_size = fallback_hidden or 1024

    layer_idxs = []
    for k in keys:
        m = re.search(r'rwkv\.blocks\.(\d+)\.', k)
        if m: layer_idxs.append(int(m.group(1)))
    num_layers = (max(layer_idxs) + 1) if layer_idxs else (fallback_layers or 24)

    return RwkvConfig(
        vocab_size=vocab_size,
        hidden_size=hidden_size,
        num_hidden_layers=num_layers,
        layer_norm_epsilon=1e-5,
        bos_token_id=getattr(tokenizer, "bos_token_id", 0) or 0,
        eos_token_id=getattr(tokenizer, "eos_token_id", 0) or 0,
        pad_token_id=getattr(tokenizer, "pad_token_id", 0) or 0,
    )

def load_rwkv_from_sd(sd_path: str, tokenizer, *, fallback_hidden=None, fallback_layers=None, device="cpu"):
    sd_raw = torch.load(sd_path, map_location="cpu")
    if isinstance(sd_raw, dict) and "state_dict" in sd_raw:
        sd_raw = sd_raw["state_dict"]
    sd = {_strip_prefix_rwkv(k): v for k, v in sd_raw.items()}
    sd = _align_time_mix_names(sd)

    cfg = _infer_rwkv_cfg(sd, tokenizer, fallback_hidden=fallback_hidden, fallback_layers=fallback_layers)
    model = RwkvForCausalLM(cfg)
    ms = model.state_dict()

    keys_sd = set(sd.keys())
    keys_ms = set(ms.keys())
    inter   = keys_sd & keys_ms
    miss    = sorted(keys_ms - keys_sd)
    extra   = sorted(keys_sd - keys_ms)

    print(f"[rwkv] sd keys: {len(keys_sd)} | model keys: {len(keys_ms)} | intersect: {len(inter)} "
          f"({100.0*len(inter)/max(1,len(keys_ms)):.1f}% of model)")
    print(f"[rwkv] missing in sd: {len(miss)}")
    for k in miss[:20]: print("   -", k)
    print(f"[rwkv] sd-only (unused): {len(extra)}")
    for k in extra[:20]: print("   -", k)

    missing, unexpected = model.load_state_dict({k: sd[k] for k in inter}, strict=False)
    print(f"[rwkv] load_state_dict: missing_after_load={len(missing)} unexpected_after_load={len(unexpected)}")

    return model.to(device).eval()

def load_rwkv_tokenizer(tok_id: str | None):
    tried = []
    if tok_id:
        try:
            tok = AutoTokenizer.from_pretrained(tok_id, use_forced_bos_token=False, use_fast=True, trust_remote_code=True)
            print(f"[rwkv] tokenizer loaded from '{tok_id}' ({tok.__class__.__name__})")
            return tok
        except Exception as e:
            tried.append((tok_id, str(e)))
    try:
        tok = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b", use_fast=True, trust_remote_code=True)
        print(f"[rwkv] tokenizer loaded from fallback 'EleutherAI/gpt-neox-20b' ({tok.__class__.__name__})")
        return tok
    except Exception as e:
        tried.append(("EleutherAI/gpt-neox-20b", str(e)))

    msg = "[rwkv] Failed to load a tokenizer. Tried:\n" + "\n".join([f"  - {k}: {err}" for k, err in tried])
    msg += ("\nIf you’re offline, download '20B_tokenizer.json' and point --rwkv_tokenizer "
            "to its directory, or clone EleutherAI/gpt-neox-20b locally.")
    raise OSError(msg)

def _tok_clean(s: str) -> str:
    return re.sub(r"^[Ġ▁\s]+", "", s).lower()

def rwkv_ids_for_words(tok, csv: str, *, include_head=True, include_bare=True, scan_vocab=True):
    words = [w.strip().lower() for w in csv.split(",") if w.strip()]
    want = set(words)
    ids = set()

    for w in words:
        variants = []
        if include_bare:
            variants.append(w)
        if include_head and not w.startswith(" "):
            variants.append(" " + w)
        for v in variants:
            enc = tok.encode(v, add_special_tokens=False)
            toks = tok.convert_ids_to_tokens(enc)
            if len(enc) == 1:
                ids.add(enc[0])
            else:
                for i, s in enumerate(toks):
                    s2 = s if isinstance(s, str) else ""
                    if s2.startswith(("Ġ", "▁")) or s2.startswith(" "):
                        ids.add(enc[i])
                        break

    if scan_vocab and hasattr(tok, "get_vocab"):
        vocab = tok.get_vocab()
        for s, tid in list(vocab.items()):
            cs = _tok_clean(s)
            if cs in want:
                ids.add(int(tid))

    if ids:
        sample = list(ids)[:24]
        names = tok.convert_ids_to_tokens(sample)
        print(f"[tokens] RWKV resolved {len(ids)} ids for {words}: sample {list(zip(sample, names))[:8]}")
    else:
        print(f"[tokens] RWKV: no ids resolved for {words} (check tokenizer)")

    return sorted(ids)

def _clean_words_csv(s: str) -> str:
    words = [re.sub(r"^[^\w]+|[^\w]+$", "", w.lower()).strip() for w in s.split(",")]
    words = [w for w in words if w]
    return ",".join(words)

# ---------------------------- Main ---------------------------- #

def main():
    ap = argparse.ArgumentParser()

    # Base / weights
    ap.add_argument("--ckpt", type=str, default="google/mt5-base", help="Base HF checkpoint for mt5")
    ap.add_argument("--state_dict_path", type=str,
                    default="mt5_base_pretuned.pt",
                    #default="mt5_base_standard_FaF_11x_noise.pt",
                    #default="mt5_base_standard_FaF_11x_noise_two_let_retrain1_batch_10000.pt",
                    #default="mt5_base_twolet_FaF_12x_noise0_batch_5000_o.pt",
                    #default="rwkv3_checkpoint_0_batch_14000.pt",
                    help="Path to .pt/.bin state_dict to load on top of --ckpt")

    # --- NEW: optional second model for cross-model self-overlap ---
    ap.add_argument("--ckpt_b", type=str, default=None, help="Optional base checkpoint for model B (defaults to --ckpt)")
    ap.add_argument("--state_dict_path_b", type=str, #default="mt5_base_twolet_FaF_12x_noise0_batch_5000_o.pt",
                    default=None,#"mt5_base_standard_FaF_11x_noise_two_let_retrain1_batch_10000.pt",
                    #default="mt5_base_standard_FaF_11x_noise.pt",
                    help="Optional second state_dict for model B")

    ap.add_argument("--out_dir", type=str, default="Interpretation_Records/final_maps/dog_base")


    ap.add_argument("--token_compare_a", type=str, default="dog", help="Token(s) for compare A")
    ap.add_argument("--token_compare_b", type=str, default="dog", help="Token(s) for compare B")
    ap.add_argument("--tail_pct", type=float, default=3.0,
                    help="Percent for two-sided tail on hidden diffs (0 disables tail masking).")

    # --- NEW: enable intra-model self consistency check (even/odd partitions) ---
    ap.add_argument("--self_compare", action="store_true", default=False,
                    help="Split same-token occurrences into two partitions and report overlap (consistency).")

    # Architecture switch
    ap.add_argument("--arch", type=str, choices=["mt5", "rwkv"],
                    default="mt5",
                    help="Backbone to load. 'mt5' or 'rwkv'.")
    ap.add_argument("--use_forgetfult5", default=False, action="store_true",
                    help="Use custom ForgetfulT5 forward (import must exist)")

    ap.add_argument("--rwkv_vocab_size", type=int, default=50277)
    ap.add_argument("--rwkv_hidden_size", type=int, default=1024)
    ap.add_argument("--rwkv_n_layers", type=int, default=24)
    ap.add_argument("--rwkv_tokenizer", type=str, default="BlinkDL/rwkv-4-pile-430m")

    # Compare mode
    ap.add_argument("--compare_mode", type=str, choices=["ffn", "hidden"], default="hidden",
                    help="Use 'ffn' (T5-only) or 'hidden' (layer outputs; MT5 & RWKV).")

    # Modeling / options
    ap.add_argument("--token_baseline_multi", type=str, default="", help="Comma list averaged as baseline 'word'")
    ap.add_argument("--weight_bands_by_wo", action="store_true", help="Multiply bands by local Wo norms (FFN path)")


    # z/vis
    ap.add_argument("--enable_local_z", default=True, action="store_true", help="Also save local averaged variants")
    ap.add_argument("--local_z_window", type=int, default=19, help="Odd window size for local avg")
    ap.add_argument("--local_z_axis_for_pockets", type=str, default="columns", choices=["columns","rows","both"], help="Axis for local z in pockets")

    # Data
    ap.add_argument("--pairs_file", type=str, default=None, help="TSV src<TAB>tgt")
    ap.add_argument("--text_file", type=str, default="Interpretation_Records/interpret_sub_texts.txt", help="Plaintext one-per-line")
    ap.add_argument("--max_examples", type=int, default=10000)
    ap.add_argument("--batch_size", type=int, default=32)
    ap.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu")

    # Capture sides
    ap.add_argument("--capture_encoder", action="store_true")
    ap.add_argument("--capture_decoder", default=True, action="store_true")

    # Baseline (FFN path)
    ap.add_argument("--baseline_build", action="store_true", help="stream and save baseline means (ĉ, p)")
    ap.add_argument("--baseline_use", action="store_true", help="load baseline and emit single-vs-baseline outputs")
    ap.add_argument("--baseline_file", type=str, default=None, help="path to .npz for baseline")
    ap.add_argument("--baseline_tokens_target", type=int, default=50000, help="approx tokens to include in baseline")
    ap.add_argument("--baseline_exclude_targets", action="store_true", help="exclude single/setA/setB from baseline")

    # Pockets (FFN path)
    ap.add_argument("--pockets_mode", type=str, default="outer_eff_down", choices=["outer_eff_down","wo_times_eff"], help="How to form pockets")

    # Tokenizer lengths
    ap.add_argument("--max_src_len", type=int, default=512)
    ap.add_argument("--max_tgt_len", type=int, default=128)

    # Injection helper
    ap.add_argument("--inject_compare_sentences", action="store_true",
                    help="Prepend simple pairs so token_compare_* appear in the stream")

    # Debug
    ap.add_argument("--debug_tokens", action="store_true", help="Print batch-seen ids vs. A/B id sets")

    # --- Hidden-state baseline & tail selection ---
    ap.add_argument("--tail_mode", type=str, default="signed", choices=["signed","abs"],
                    help="'signed' keeps top/bottom by value; 'abs' keeps largest by magnitude.")
    ap.add_argument("--hidden_baseline_mode", type=str, default="running",
                    choices=["none","running","load"],
                    help="none: no baseline subtraction; running: build mean of normalized H; load: use file.")
    ap.add_argument("--hidden_baseline_file", type=str, default=None,
                    help="Path to npz with hidden baselines (keys like 'dec_0'). Used when mode=load.")
    ap.add_argument("--hidden_baseline_exclude_targets", action="store_true",
                    help="Exclude target tokens (single/A/B) from the running hidden baseline.")

    args = ap.parse_args()
    if args.compare_mode is None:
        args.compare_mode = "ffn" if args.arch == "mt5" else "hidden"

    ensure_dir(args.out_dir)
    with open(os.path.join(args.out_dir, "meta.json"), "w") as f:
        json.dump({"args": vars(args)}, f, indent=2)

    # ---------------- tokenizer + model A (arch switch) ----------------
    if args.arch == "mt5":
        tok = AutoTokenizer.from_pretrained(args.ckpt)
        from transformers import MT5ForConditionalGeneration as _T5Cls
        model = _T5Cls.from_pretrained(args.ckpt)

        if args.state_dict_path and os.path.exists(args.state_dict_path):
            sd_raw = torch.load(args.state_dict_path, map_location="cpu")
            if isinstance(sd_raw, dict) and "state_dict" in sd_raw:
                sd_raw = sd_raw["state_dict"]

            def strip_prefix(k: str) -> str:
                return re.sub(
                    r'^(?:module\.|model\.|base_model\.|transformer\.|t5\.|mt5\.)',
                    '',
                    k,
                    count=1
                )

            sd = {strip_prefix(k): v for k, v in sd_raw.items()}
            ms = model.state_dict()

            keys_sd = set(sd.keys())
            keys_ms = set(ms.keys())
            inter = keys_sd & keys_ms
            only_sd = list(keys_sd - keys_ms)
            only_ms = list(keys_ms - keys_sd)

            print(
                f"[ckpt] sd keys (after strip): {len(keys_sd)} | model keys: {len(keys_ms)} | intersect: {len(inter)} "
                f"({100.0 * len(inter) / max(1, len(keys_ms)):.1f}% of model)")
            if only_sd:
                print(" [ckpt] example in sd but not model (dropped):")
                for k in only_sd[:10]:
                    print(f"   - {k}")
            if only_ms:
                print(" [ckpt] example expected by model but missing in sd:")
                for k in only_ms[:10]:
                    print(f"   - {k}")

            sd_f = {k: sd[k] for k in inter}
            missing, unexpected = model.load_state_dict(sd_f, strict=False)
            print(f"[mt5] load_state_dict: loaded={len(sd_f)}  missing={len(missing)}  unexpected={len(unexpected)}")
        else:
            print("[mt5] no state_dict_path found; using base weights.")

        model.eval().to(args.device)

    else:  # RWKV
        tok = load_rwkv_tokenizer(args.rwkv_tokenizer or None)
        if tok.pad_token is None and tok.eos_token is not None:
            tok.pad_token = tok.eos_token

        model = load_rwkv_from_sd(args.state_dict_path, tok)
        model.to(args.device).eval()

        if args.compare_mode != "hidden":
            print("[rwkv] Forcing compare_mode=hidden.")
            args.compare_mode = "hidden"
        if args.capture_encoder:
            print("[rwkv] Encoder capture not applicable; disabling.")
        args.capture_encoder = False

    # ---------------- Optional model B for cross-model self overlap ----------------
    model_b = None
    hidcap_b = None
    capturer_b = None
    if args.state_dict_path_b is not None:
        if args.arch == "mt5":
            ckpt_b = args.ckpt_b or args.ckpt
            tok_b = tok  # keep same tokenizer for alignment
            from transformers import MT5ForConditionalGeneration as _T5Cls
            model_b = _T5Cls.from_pretrained(ckpt_b)
            sd_raw = torch.load(args.state_dict_path_b, map_location="cpu")
            if isinstance(sd_raw, dict) and "state_dict" in sd_raw:
                sd_raw = sd_raw["state_dict"]
            def strip_prefix_b(k: str) -> str:
                return re.sub(r'^(?:module\.|model\.|base_model\.|transformer\.|t5\.|mt5\.)','',k,1)
            sd_b = {strip_prefix_b(k): v for k, v in sd_raw.items()}
            ms_b = model_b.state_dict()
            inter_b = set(sd_b.keys()) & set(ms_b.keys())
            model_b.load_state_dict({k: sd_b[k] for k in inter_b}, strict=False)
            model_b.eval().to(args.device)
        else:
            # RWKV
            tok_b = tok  # keep same tokenizer for alignment
            model_b = load_rwkv_from_sd(args.state_dict_path_b, tok_b)
            model_b.to(args.device).eval()

    # ---------------- parse tokens ----------------
    single_ids = strings_to_ids_list(tok, args.token_baseline_multi)  # optional baseline "word"

    if args.arch == "rwkv":
        a_csv = _clean_words_csv(args.token_compare_a)
        b_csv = _clean_words_csv(args.token_compare_b)
        setA_ids_for_agg = rwkv_ids_for_words(tok, a_csv, include_head=True, include_bare=True, scan_vocab=True)
        setB_ids_for_agg = rwkv_ids_for_words(tok, b_csv, include_head=True, include_bare=True, scan_vocab=True)
    else:
        setA_ids_for_agg = strings_to_ids_list(tok, args.token_compare_a)
        setB_ids_for_agg = strings_to_ids_list(tok, args.token_compare_b)

    specials = {tok.pad_token_id, getattr(tok, "eos_token_id", None), getattr(tok, "unk_token_id", None)}
    specials = {i for i in specials if i is not None}

    # Need d_model, d_ff
    if args.arch == "mt5":
        d_model = model.config.d_model
        probe_ffn = model.decoder.block[0].layer[2].DenseReluDense
        d_ff = (probe_ffn.wi_0.out_features if hasattr(probe_ffn, "wi_0") else probe_ffn.wi.out_features)
        nL_enc = len(model.encoder.block); nL_dec = len(model.decoder.block)
    else:
        d_model = model.config.hidden_size
        d_ff = d_model
        nL_enc = 0
        nL_dec = getattr(model, "config", None).num_hidden_layers or 0

    if args.use_forgetfult5 and args.arch == "mt5":
        model = ForgetfulT5(model).to(args.device)

    baseline_exclude = set()
    if getattr(args, "baseline_exclude_targets", False):
        baseline_exclude |= set(single_ids) | set(setA_ids_for_agg) | set(setB_ids_for_agg)

    want_baseline = getattr(args, "baseline_build", False) or getattr(args, "baseline_use", False)
    aggregators = TokenAggregators(
        n_layers_enc=nL_enc, n_layers_dec=nL_dec,
        d_ff=d_ff, d_model=d_model,
        target_single_ids=[],  # not used in hidden mode
        set_a_ids=setA_ids_for_agg, set_b_ids=setB_ids_for_agg,
        special_ids=list(specials),
        want_baseline=want_baseline, baseline_exclude_ids=baseline_exclude
    )
    aggregators_b = None
    if model_b is not None:
        aggregators_b = TokenAggregators(
            n_layers_enc=nL_enc, n_layers_dec=nL_dec,
            d_ff=d_ff, d_model=d_model,
            target_single_ids=[],
            set_a_ids=setA_ids_for_agg, set_b_ids=setB_ids_for_agg,
            special_ids=list(specials),
            want_baseline=want_baseline, baseline_exclude_ids=baseline_exclude
        )

    # ---- Tail filter config (hidden) ----
    aggregators.tail = TailFilterFast(pct=getattr(args, "tail_pct", 5.0),
                                      mode=getattr(args, "tail_mode", "signed"))

    loaded_mu = None
    if args.hidden_baseline_mode == "load":
        assert args.hidden_baseline_file and os.path.exists(args.hidden_baseline_file), \
            "--hidden_baseline_file must exist when --hidden_baseline_mode=load"
        arrs = np.load(args.hidden_baseline_file, allow_pickle=False)
        loaded_mu = {k: arrs[k] for k in arrs.files}

    aggregators.hidden_baseline = HiddenBaseline(
        mode=args.hidden_baseline_mode,
        loaded=loaded_mu
    )
    if aggregators_b is not None:
        aggregators_b.tail = aggregators.tail
        aggregators_b.hidden_baseline = HiddenBaseline(
            mode=args.hidden_baseline_mode,
            loaded=loaded_mu
        )

    exclude_ids = set()
    if getattr(args, "hidden_baseline_exclude_targets", False):
        exclude_ids |= set(single_ids) | set(setA_ids_for_agg) | set(setB_ids_for_agg)
    exclude_ids |= set(aggregators.special_ids)
    aggregators.exclude_for_hidden_baseline = exclude_ids
    if aggregators_b is not None:
        aggregators_b.exclude_for_hidden_baseline = exclude_ids

    # ---------------- data ----------------
    if args.pairs_file:
        pairs = read_pairs(args.pairs_file)[:args.max_examples]
        print(f"[data] using TSV pairs: {len(pairs)}")
    elif args.text_file and os.path.exists(args.text_file):
        pairs = read_text_as_pairs(args.text_file, args.max_examples)
        print(f"[data] using text file: {len(pairs)} from {args.text_file}")
    else:
        pairs = read_pairs(None)[:args.max_examples]
        print(f"[data] using built-in sample: {len(pairs)}")

    if args.inject_compare_sentences:
        A = [s.strip() for s in args.token_compare_a.split(",") if s.strip()]
        B = [s.strip() for s in args.token_compare_b.split(",") if s.strip()]
        probe_pairs = [(w, w) for w in (A + B)]
        pairs = probe_pairs + pairs
        print(f"[inject] added {len(probe_pairs)} compare probe pairs")

    # ---------------- capture setup ----------------
    if args.arch == "mt5" and args.compare_mode == "ffn":
        capturer = FFNCapture(model, tok, aggregators,
                              capture_encoder=args.capture_encoder,
                              capture_decoder=args.capture_decoder,
                              do_self_partitions=args.self_compare)
        if model_b is not None:
            capturer_b = FFNCapture(model_b, tok, aggregators_b,
                                    capture_encoder=args.capture_encoder,
                                    capture_decoder=args.capture_decoder,
                                    do_self_partitions=args.self_compare)
    else:
        capturer = None

    hidcap = HiddenStateCapture(model,
                                capture_encoder=args.capture_encoder,
                                capture_decoder=args.capture_decoder) if args.compare_mode == "hidden" else None
    if model_b is not None and args.compare_mode == "hidden":
        hidcap_b = HiddenStateCapture(model_b,
                                      capture_encoder=args.capture_encoder,
                                      capture_decoder=args.capture_decoder)

    # ---------------- main pass ----------------
    total_scanned = 0
    last_enc = None; last_labels = None; last_dec_in = None

    for i in range(0, len(pairs), args.batch_size):
        batch = pairs[i:i + args.batch_size]
        src_texts = [s for s, t in batch]
        tgt_texts = [t for s, t in batch]

        enc = tok(
            src_texts,
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=args.max_src_len
        ).to(args.device)

        tgt = tok(
            text_target=tgt_texts,
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=args.max_tgt_len
        ).to(args.device)

        labels = tgt["input_ids"].clone()
        pad_id = tok.pad_token_id if tok.pad_token_id is not None else -100
        labels[labels == pad_id] = -100

        decoder_input_ids = None
        if hasattr(model, "prepare_decoder_input_ids_from_labels"):
            decoder_input_ids = model.prepare_decoder_input_ids_from_labels(labels=labels)

        last_enc = enc; last_labels = labels; last_dec_in = decoder_input_ids

        if capturer is not None:
            capturer.set_batch_token_ids(
                enc_ids_bt=enc["input_ids"].detach().cpu() if args.capture_encoder else None,
                dec_ids_bt=decoder_input_ids.detach().cpu() if args.capture_decoder and decoder_input_ids is not None else None,
            )
        if capturer_b is not None:
            capturer_b.set_batch_token_ids(
                enc_ids_bt=enc["input_ids"].detach().cpu() if args.capture_encoder else None,
                dec_ids_bt=decoder_input_ids.detach().cpu() if args.capture_decoder and decoder_input_ids is not None else None,
            )

        with torch.no_grad():
            _ = model(
                input_ids=enc["input_ids"],
                attention_mask=enc.get("attention_mask", None),
                decoder_input_ids=decoder_input_ids if hasattr(model, "decoder") else None,
                use_cache=False,
            )
            if model_b is not None:
                _ = model_b(
                    input_ids=enc["input_ids"],
                    attention_mask=enc.get("attention_mask", None),
                    decoder_input_ids=decoder_input_ids if hasattr(model_b, "decoder") else None,
                    use_cache=False,
                )

        if hidcap is not None:
            if args.arch == "rwkv":
                dec_ids_bt = enc["input_ids"].to(args.device)

                if args.debug_tokens:
                    seen_sample = torch.unique(dec_ids_bt).detach().cpu().tolist()[:40]
                    print("[debug] RWKV batch seen ids (sample):", seen_sample)
                    print("[debug] A ids:", setA_ids_for_agg, "B ids:", setB_ids_for_agg)

                for L, H in hidcap.dec.items():
                    aggregators.update_hidden_side_gpu("dec", L, H.to(args.device), dec_ids_bt,
                                                       do_self_partitions=args.self_compare)
                hidcap.dec.clear()

                if hidcap_b is not None:
                    for L, H in hidcap_b.dec.items():
                        aggregators_b.update_hidden_side_gpu("dec", L, H.to(args.device), dec_ids_bt,
                                                             do_self_partitions=args.self_compare)
                    hidcap_b.dec.clear()

            else:
                if args.capture_encoder and hidcap.enc:
                    enc_ids_bt = enc["input_ids"].to(args.device)
                    for L, H in hidcap.enc.items():
                        aggregators.update_hidden_side_gpu("enc", L, H.to(args.device), enc_ids_bt,
                                                           do_self_partitions=args.self_compare)

                if args.capture_decoder and hidcap.dec and decoder_input_ids is not None:
                    dec_ids_bt = decoder_input_ids.to(args.device)
                    for L, H in hidcap.dec.items():
                        aggregators.update_hidden_side_gpu("dec", L, H.to(args.device), dec_ids_bt,
                                                           do_self_partitions=args.self_compare)

                hidcap.enc.clear(); hidcap.dec.clear()

                if hidcap_b is not None:
                    if args.capture_encoder and hidcap_b.enc:
                        enc_ids_bt = enc["input_ids"].to(args.device)
                        for L, H in hidcap_b.enc.items():
                            aggregators_b.update_hidden_side_gpu("enc", L, H.to(args.device), enc_ids_bt,
                                                                 do_self_partitions=args.self_compare)
                    if args.capture_decoder and hidcap_b.dec and decoder_input_ids is not None:
                        dec_ids_bt = decoder_input_ids.to(args.device)
                        for L, H in hidcap_b.dec.items():
                            aggregators_b.update_hidden_side_gpu("dec", L, H.to(args.device), dec_ids_bt,
                                                                 do_self_partitions=args.self_compare)
                    hidcap_b.enc.clear(); hidcap_b.dec.clear()

        total_scanned += sum(len(s) for s in src_texts)

        if getattr(args, "baseline_build", False) and \
                aggregators.baseline_total_tokens >= getattr(args, "baseline_tokens_target", 0):
            print(f"[baseline] token budget reached: {aggregators.baseline_total_tokens}")
            break

    if hidcap is not None:
        hidcap.close()
    if capturer is not None:
        capturer.close()
    if hidcap_b is not None:
        hidcap_b.close()
    if capturer_b is not None:
        capturer_b.close()

    # ---------------- Baseline I/O (FFN path) ----------------
    baseline = None
    if getattr(args, "baseline_build", False) and args.arch == "mt5":
        assert args.baseline_file is not None, "Provide --baseline_file to save baseline"
        aggregators.save_baseline(args.baseline_file, nL_enc, nL_dec)
        print(f"[baseline] saved to {args.baseline_file}")

    if getattr(args, "baseline_use", False) and args.arch == "mt5":
        assert args.baseline_file and os.path.exists(args.baseline_file), "Provide existing --baseline_file to load"
        baseline = TokenAggregators.load_baseline(args.baseline_file)
        print(f"[baseline] loaded from {args.baseline_file}")

    # ---------------- Optional MT5 attention band maps ----------------
    if args.arch == "mt5" and last_enc is not None and last_labels is not None and last_dec_in is not None:
        probe_layers = [9, 10, 11]; probe_kinds = ["self", "cross"]
        names = [_attn_module_path("dec", L, kind) for L in probe_layers for kind in probe_kinds]
        captures = {}
        with torch.no_grad():
            with hook_attn_outputs(model, names, captures):
                _ = model(
                    input_ids=last_enc["input_ids"],
                    attention_mask=last_enc["attention_mask"],
                    decoder_input_ids=last_dec_in,
                    labels=last_labels,
                    use_cache=False,
                    output_attentions=False,
                    output_hidden_states=False,
                )
        single_ids_vis = strings_to_ids_list(tok, args.token_compare_a or "the")
        if len(single_ids_vis) > 0:
            target_tok = single_ids_vis[0]
            for L in probe_layers:
                for kind in probe_kinds:
                    key = _attn_module_path("dec", L, kind)
                    if key not in captures: continue
                    vec = compute_attn_band_for_token(captures[key], last_labels, target_tok)
                    export_attn_band_maps(
                        out_dir=args.out_dir, side="dec", layer_idx=L, kind=kind,
                        vec_raw=vec, tag="single_token_first",
                        enable_local_z=getattr(args,"enable_local_z",False),
                        local_z_window=getattr(args,"local_z_window",31),
                    )

    # ---------------- Export & Numbers (A vs B) ----------------
    print("Exporting maps / numbers ...")

    layer_diffs = []
    total_diff = 0.0

    if args.arch == "mt5" and args.compare_mode == "ffn":
        enc_dir = baseline["enc_dir_down_mean"] if (baseline and "enc_dir_down_mean" in baseline) else None
        enc_row = baseline["enc_rowmass_mean"] if (baseline and "enc_rowmass_mean" in baseline) else None
        dec_dir = baseline["dec_dir_down_mean"] if (baseline and "dec_dir_down_mean" in baseline) else None
        dec_row = baseline["dec_rowmass_mean"] if (baseline and "dec_rowmass_mean" in baseline) else None

        for li, block in enumerate(model.encoder.block if args.capture_encoder else []):
            ffn = block.layer[1].DenseReluDense
            out_dir = os.path.join(args.out_dir, f"enc_L{li:02d}")
            bd, br = (None, None)
            if len(single_ids) > 0:
                bd, br = aggregators.word_baseline("enc", li, single_ids)
            elif enc_dir is not None and enc_row is not None:
                bd, br = enc_dir[li], enc_row[li]
            res = aggregators.export_views(
                out_dir, "enc", li, ffn.wo.weight, bd, br,
                enable_local_z=args.enable_local_z, local_z_window=args.local_z_window,
                pockets_axis=args.local_z_axis_for_pockets,
                token_single=",".join([args.token_compare_a, args.token_compare_b]),
                pockets_mode=args.pockets_mode, weight_bands_by_wo=args.weight_bands_by_wo,
                compA_ids=setA_ids_for_agg, compB_ids=setB_ids_for_agg
            )
            if res and res.get("layer_diff_sum") is not None:
                layer_diffs.append({
                    "side": "enc", "layer": li,
                    "diff_sum": float(res["layer_diff_sum"]),
                    "overlap": float(res["layer_overlap"]) if res["layer_overlap"] is not None else None
                })
                total_diff += float(res["layer_diff_sum"])

        for li, block in enumerate(model.decoder.block if args.capture_decoder else []):
            ffn = block.layer[2].DenseReluDense
            out_dir = os.path.join(args.out_dir, f"dec_L{li:02d}")
            bd, br = (None, None)
            if len(single_ids) > 0:
                bd, br = aggregators.word_baseline("dec", li, single_ids)
            elif dec_dir is not None and dec_row is not None:
                bd, br = dec_dir[li], dec_row[li]
            res = aggregators.export_views(
                out_dir, "dec", li, ffn.wo.weight, bd, br,
                enable_local_z=args.enable_local_z, local_z_window=args.local_z_window,
                pockets_axis=args.local_z_axis_for_pockets,
                token_single=",".join([args.token_compare_a, args.token_compare_b]),
                pockets_mode=args.pockets_mode, weight_bands_by_wo=args.weight_bands_by_wo,
                compA_ids=setA_ids_for_agg, compB_ids=setB_ids_for_agg
            )
            if res and res.get("layer_diff_sum") is not None:
                layer_diffs.append({
                    "side": "dec", "layer": li,
                    "diff_sum": float(res["layer_diff_sum"]),
                    "overlap": float(res["layer_overlap"]) if res["layer_overlap"] is not None else None
                })
                total_diff += float(res["layer_diff_sum"])

    else:
        # Hidden-state numbers (MT5 & RWKV), GPU path with tail-filter + baseline-centering
        dev = torch.device(args.device)

        def _get_map_and_count(aggr, side, L):
            A_sum = getattr(aggr, "hidA_sum_S_t", {}).get((side, L), None)
            B_sum = getattr(aggr, "hidB_sum_S_t", {}).get((side, L), None)
            A_cnt = getattr(aggr, "hidA_cnt_t", {}).get((side, L), 0)
            B_cnt = getattr(aggr, "hidB_cnt_t", {}).get((side, L), 0)
            if A_sum is None and hasattr(aggr, "hidA_sum_S"):
                A_sum = aggr.hidA_sum_S.get((side, L), None); A_cnt = aggr.hidA_cnt.get((side, L), 0)
            if B_sum is None and hasattr(aggr, "hidB_sum_S"):
                B_sum = aggr.hidB_sum_S.get((side, L), None); B_cnt = aggr.hidB_cnt.get((side, L), 0)
            return A_sum, A_cnt, B_sum, B_cnt

        def _to_torch(x):
            if x is None: return None
            if isinstance(x, torch.Tensor): return x.to(dev)
            return torch.as_tensor(x, device=dev, dtype=torch.float32)

        def _tail_mask(x, pct: float, mode: str):
            if pct <= 0 or x.numel() == 0:
                return torch.ones_like(x, dtype=torch.bool)
            k = max(1, int(x.numel() * (pct / 100.0)))
            v = x.view(-1)
            if mode == "sym":
                mag = v.abs()
                topk_vals, _ = torch.topk(mag, k, largest=True)
                thr = topk_vals.min()
                return x.abs() >= thr
            else:
                topk_vals, _ = torch.topk(v, k, largest=True)
                botk_vals, _ = torch.topk(v, k, largest=False)
                hi, lo = topk_vals.min(), botk_vals.max()
                return (x >= hi) | (x <= lo)

        for side, nL in (("enc", nL_enc), ("dec", nL_dec)):
            if (side == "enc" and not args.capture_encoder) or (side == "dec" and not args.capture_decoder):
                continue
            for L in range(nL):
                A_sum, A_cnt, B_sum, B_cnt = _get_map_and_count(aggregators, side, L)
                if (A_sum is None) or (B_sum is None) or (A_cnt <= 0) or (B_cnt <= 0):
                    continue
                SA = _to_torch(A_sum) / max(1, A_cnt)
                SB = _to_torch(B_sum) / max(1, B_cnt)

                if hasattr(aggregators, "hidden_baseline") and aggregators.hidden_baseline is not None:
                    hb = aggregators.hidden_baseline
                    if getattr(hb, "mode", None) in ("running", "fixed"):
                        mu = hb.get_mu(side, L, device=dev, D=SA.shape[-1])
                        SA = SA - mu; SB = SB - mu

                maskA = _tail_mask(SA, args.tail_pct, args.tail_mode)
                maskB = _tail_mask(SB, args.tail_pct, args.tail_mode)
                SA_m = torch.where(maskA, SA, torch.zeros(1, device=dev, dtype=SA.dtype))
                SB_m = torch.where(maskB, SB, torch.zeros(1, device=dev, dtype=SB.dtype))
                D = (SA_m - SB_m).abs()
                diff_sum_t = D.sum()
                denom_t = SA_m.abs().sum() + SB_m.abs().sum() + 1e-8
                diff_sum = float(diff_sum_t.detach().cpu().item())
                overlap = float((diff_sum_t / denom_t).detach().cpu().item())
                layer_diffs.append({"side": side, "layer": L, "diff_sum": diff_sum, "overlap": overlap})
                total_diff += diff_sum

                out_dir = os.path.join(args.out_dir, f"{side}_L{L:02d}", "compare_token_diff_hidden")
                ensure_dir(out_dir)
                D_np = D.detach().cpu().numpy().astype(np.float32)
                np.save(os.path.join(out_dir, "diff_map.npy"), D_np)
                render_heatmap(
                    D_np,
                    os.path.join(out_dir, "diff_map.png"),
                    f"{side} L{L} |A-B| (hidden; baseline-centered tail={args.tail_pct:.1f}%/{args.tail_mode})",
                    gamma=1.6
                )

    # Save the running hidden baseline (if used)
    if args.compare_mode == "hidden" and hasattr(aggregators, "hidden_baseline") \
            and aggregators.hidden_baseline is not None and aggregators.hidden_baseline.mode == "running":
        hb = aggregators.hidden_baseline
        out_npz = os.path.join(args.out_dir, "hidden_baseline_learned.npz")
        pack = {}
        for side, nL in (("enc", nL_enc), ("dec", nL_dec)):
            for L in range(nL):
                mu = hb.get_mu(side, L, device=torch.device("cpu"), D=d_model).cpu().numpy()
                pack[f"{side}_{L}"] = mu
        np.savez(out_npz, **pack)
        print(f"[hidden_baseline] exported running baseline → {out_npz}")

    # ---------------- NEW: Intra-model self consistency (A, B) ----------------
    self_summary = {"intra_model": [], "cross_model": []}
    if args.self_compare:
        def _self_pair_metrics_ffn(aggr, side, L, which: str):
            # which in {"A","B"}
            if which == "A":
                S1 = aggr.selfA1_sum_S_t.get((side, L), None); c1 = aggr.selfA1_cnt_t.get((side, L), 0)
                S2 = aggr.selfA2_sum_S_t.get((side, L), None); c2 = aggr.selfA2_cnt_t.get((side, L), 0)
            else:
                S1 = aggr.selfB1_sum_S_t.get((side, L), None); c1 = aggr.selfB1_cnt_t.get((side, L), 0)
                S2 = aggr.selfB2_sum_S_t.get((side, L), None); c2 = aggr.selfB2_cnt_t.get((side, L), 0)
            if S1 is None or S2 is None or c1 == 0 or c2 == 0:
                return None
            A1 = S1 / float(max(1, c1)); A2 = S2 / float(max(1, c2))
            D = (A1 - A2).abs()
            denom = (A1.abs().sum() + A2.abs().sum()).clamp_min(1e-8)
            return float((D.sum() / denom).item())

        def _self_pair_metrics_hidden(aggr, side, L, which: str):
            if which == "A":
                S1 = aggr.hidA1_sum_S_t.get((side, L), None); c1 = aggr.hidA1_cnt_t.get((side, L), 0)
                S2 = aggr.hidA2_sum_S_t.get((side, L), None); c2 = aggr.hidA2_cnt_t.get((side, L), 0)
            else:
                S1 = aggr.hidB1_sum_S_t.get((side, L), None); c1 = aggr.hidB1_cnt_t.get((side, L), 0)
                S2 = aggr.hidB2_sum_S_t.get((side, L), None); c2 = aggr.hidB2_cnt_t.get((side, L), 0)
            if S1 is None or S2 is None or c1 == 0 or c2 == 0:
                return None
            A1 = (S1 / float(max(1, c1)))
            A2 = (S2 / float(max(1, c2)))
            D = (A1 - A2).abs()
            denom = (A1.abs().sum() + A2.abs().sum()).clamp_min(1e-8)
            return float((D.sum() / denom).item())

        for side, nL in (("enc", nL_enc), ("dec", nL_dec)):
            if (side == "enc" and not args.capture_encoder) or (side == "dec" and not args.capture_decoder):
                continue
            for L in range(nL):
                if args.compare_mode == "ffn":
                    ovA = _self_pair_metrics_ffn(aggregators, side, L, "A")
                    ovB = _self_pair_metrics_ffn(aggregators, side, L, "B")
                else:
                    ovA = _self_pair_metrics_hidden(aggregators, side, L, "A")
                    ovB = _self_pair_metrics_hidden(aggregators, side, L, "B")
                entry = {"side": side, "layer": L, "self_overlap_A": ovA, "self_overlap_B": ovB}
                self_summary["intra_model"].append(entry)

    # ---------------- NEW: Cross-model same-token overlap (A and B) ----------------
    if aggregators_b is not None:
        def _avg_map(aggr, side, L, which: str):
            if args.compare_mode == "ffn":
                if which == "A":
                    S = aggr.compA_sum_S_t.get((side, L), None); c = aggr.compA_cnt_t.get((side, L), 0)
                else:
                    S = aggr.compB_sum_S_t.get((side, L), None); c = aggr.compB_cnt_t.get((side, L), 0)
                if S is None or c == 0: return None
                return S / float(c)
            else:
                if which == "A":
                    S = aggr.hidA_sum_S_t.get((side, L), None); c = aggr.hidA_cnt_t.get((side, L), 0)
                else:
                    S = aggr.hidB_sum_S_t.get((side, L), None); c = aggr.hidB_cnt_t.get((side, L), 0)
                if S is None or c == 0: return None
                return S / float(c)

        for side, nL in (("enc", nL_enc), ("dec", nL_dec)):
            if (side == "enc" and not args.capture_encoder) or (side == "dec" and not args.capture_decoder):
                continue
            for L in range(nL):
                for which in ("A","B"):
                    SA = _avg_map(aggregators, side, L, which)
                    SB = _avg_map(aggregators_b, side, L, which)
                    if SA is None or SB is None:
                        self_summary["cross_model"].append({"side": side, "layer": L, "set": which, "overlap": None})
                        continue
                    D = (SA - SB).abs()
                    denom = (SA.abs().sum() + SB.abs().sum()).clamp_min(1e-8)
                    ov = float((D.sum() / denom).item())
                    self_summary["cross_model"].append({"side": side, "layer": L, "set": which, "overlap": ov})

    # Summary (A/B)
    overlap_values = [x["overlap"] for x in layer_diffs if x.get("overlap") is not None]
    overall_overlap = float(np.mean(overlap_values)) if overlap_values else None

    if layer_diffs:
        report = {"layers": layer_diffs, "total_diff_sum": float(total_diff), "overall_overlap_mean": overall_overlap}
        with open(os.path.join(args.out_dir, "token_compare_summary.json"), "w") as f:
            json.dump(report, f, indent=2)
        print("[compare] layer-wise:")
        for d in layer_diffs:
            print(f"  {d['side']} L{d['layer']:02d}: diff_sum={d['diff_sum']:.3f}, overlap={d['overlap']:.3f}")
        print(f"[compare] TOTAL diff_sum: {total_diff:.3f}")
        if overall_overlap is not None:
            print(f"[compare] OVERALL overlap (mean of layers): {overall_overlap:.3f}")
    else:
        print("[compare] No A/B counts found; check tokens, capture side, or use --inject_compare_sentences.")

    # --- already computing intra-model means above ---
    valsA = [e["self_overlap_A"] for e in self_summary["intra_model"] if e["self_overlap_A"] is not None]
    valsB = [e["self_overlap_B"] for e in self_summary["intra_model"] if e["self_overlap_B"] is not None]
    meanA = float(np.mean(valsA)) if valsA else None
    meanB = float(np.mean(valsB)) if valsB else None
    self_summary["mean_self_overlap_A"] = meanA
    self_summary["mean_self_overlap_B"] = meanB

    # --- NEW: cross-model means per set (A/B) ---
    crossA = [e["overlap"] for e in self_summary["cross_model"] if e.get("set") == "A" and e.get("overlap") is not None]
    crossB = [e["overlap"] for e in self_summary["cross_model"] if e.get("set") == "B" and e.get("overlap") is not None]
    mean_cross_A = float(np.mean(crossA)) if crossA else None
    mean_cross_B = float(np.mean(crossB)) if crossB else None
    self_summary["mean_cross_model_overlap_A"] = mean_cross_A
    self_summary["mean_cross_model_overlap_B"] = mean_cross_B


    # NEW: dump self compare summaries if requested / available
    if args.self_compare or aggregators_b is not None:
        with open(os.path.join(args.out_dir, "self_compare_summary.json"), "w") as f:
            json.dump(self_summary, f, indent=2)
        # brief console print
        if args.self_compare and self_summary["intra_model"]:
            print("[self] Intra-model self-overlap (lower = more similar):")
            for e in self_summary["intra_model"]:
                sA = "None" if e["self_overlap_A"] is None else f"{e['self_overlap_A']:.3f}"
                sB = "None" if e["self_overlap_B"] is None else f"{e['self_overlap_B']:.3f}"
                print(f"  {e['side']} L{e['layer']:02d}: A={sA}  B={sB}")
        if aggregators_b is not None and self_summary["cross_model"]:
            print("[self] Cross-model same-token overlap (lower = more similar):")
            for e in self_summary["cross_model"]:
                s = "None" if e["overlap"] is None else f"{e['overlap']:.3f}"
                print(f"  {e['side']} L{e['layer']:02d} set {e['set']}: {s}")

                # NEW: print cross-model means
                if mean_cross_A is not None:
                    print(f"[self] MEAN cross-model overlap (A): {mean_cross_A:.3f}")
                if mean_cross_B is not None:
                    print(f"[self] MEAN cross-model overlap (B): {mean_cross_B:.3f}")

        # Compute simple means over layers (ignore None)
        valsA = [e["self_overlap_A"] for e in self_summary["intra_model"] if e["self_overlap_A"] is not None]
        valsB = [e["self_overlap_B"] for e in self_summary["intra_model"] if e["self_overlap_B"] is not None]

        meanA = float(np.mean(valsA)) if valsA else None
        meanB = float(np.mean(valsB)) if valsB else None

        if meanA is not None:
            print(f"[self] MEAN self-overlap (A): {meanA:.3f}")
        if meanB is not None:
            print(f"[self] MEAN self-overlap (B): {meanB:.3f}")

        # also drop into the JSON for convenience
        self_summary["mean_self_overlap_A"] = meanA
        self_summary["mean_self_overlap_B"] = meanB

    print(f"Done. Outputs in {args.out_dir}")

if __name__ == "__main__":
    main()
