#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import argparse, json, os, re
from collections import defaultdict
from typing import List, Tuple, Optional, Dict

import numpy as np
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
from matplotlib.colors import TwoSlopeNorm, LinearSegmentedColormap
from contextlib import contextmanager
from transformers import RwkvConfig, RwkvForCausalLM, AutoTokenizer

from project_classes import ForgetfulT5, CustomDenseReluDense

# ---------------------------- Viz: cmap + guards ---------------------------- #

_CMAP_BBG = LinearSegmentedColormap.from_list(
    "bbg", ["#2b6cb0", "#000000", "#2ca02c"], N=256
)

def _signed_power(x: np.ndarray, gamma: float) -> np.ndarray:
    if gamma == 1.0: return x
    return np.sign(x) * (np.abs(x) ** gamma)

def _bounds_for_diverging(mat: np.ndarray) -> tuple[float, float]:
    m = float(np.nanmin(mat))
    M = float(np.nanmax(mat))
    # ensure vmin < 0 < vmax
    if not (m < 0.0 < M):
        a = max(abs(m), abs(M))
        if a == 0.0:
            a = 1e-6
        return -a, a
    return m, M

def render_heatmap(mat: np.ndarray, out_png: str, title: str,
                   vclip: Optional[float]=99.5, gamma: float=1.8):
    plt.figure(figsize=(10, 4))
    m = np.nan_to_num(mat.astype(np.float32, copy=False), nan=0.0, posinf=0.0, neginf=0.0)
    if vclip is not None:
        vmax = np.percentile(np.abs(m), vclip)
        m = np.clip(m, -vmax, vmax)
    m = _signed_power(m, gamma)
    vmin, vmax = _bounds_for_diverging(m)
    norm = TwoSlopeNorm(vcenter=0.0, vmin=vmin, vmax=vmax)
    plt.imshow(m, aspect='auto', interpolation='nearest', cmap=_CMAP_BBG, norm=norm)
    plt.title(title); plt.xlabel("model dims / columns"); plt.ylabel("ff rows / rows")
    cbar = plt.colorbar(); cbar.set_label(f"power-normalized units (γ={gamma:.1f})")
    plt.tight_layout(); plt.savefig(out_png, dpi=220); plt.close()

def render_band_row(vec: np.ndarray, out_png: str, title: str,
                    repeat: int=6, vclip: Optional[float]=99.5, gamma: float=1.8):
    v = np.nan_to_num(vec.astype(np.float32, copy=False), nan=0.0, posinf=0.0, neginf=0.0)
    if vclip is not None:
        vmax = np.percentile(np.abs(v), vclip)
        v = np.clip(v, -vmax, vmax)
    v = _signed_power(v, gamma)
    mat = np.tile(v[None, :], (repeat, 1))
    vmin, vmax = _bounds_for_diverging(mat)
    plt.figure(figsize=(10, 2))
    norm = TwoSlopeNorm(vcenter=0.0, vmin=vmin, vmax=vmax)
    plt.imshow(mat, aspect='auto', interpolation='nearest', cmap=_CMAP_BBG, norm=norm)
    plt.title(title); plt.yticks([]); plt.xlabel("dimension")
    cbar = plt.colorbar(); cbar.set_label(f"power-normalized units (γ={gamma:.1f})")
    plt.tight_layout(); plt.savefig(out_png, dpi=220); plt.close()

# ---------------------------- Generic utils ---------------------------- #

def gelu(x): return F.gelu(x)
def silu(x): return F.silu(x)
def act_fn_for_proj(feed_forward_proj: str):
    proj = feed_forward_proj.lower()
    if "gelu" in proj: return gelu
    if "silu" in proj: return silu
    if "relu" in proj: return F.relu
    return gelu

def ensure_dir(path): os.makedirs(path, exist_ok=True)
def to_cpu_f32(x):    return x.detach().to(torch.float32).cpu()

def power_norm_1d(v: np.ndarray, eps: float = 1e-8) -> np.ndarray:
    scale = np.mean(np.abs(v)) + eps
    return (v / scale).astype(np.float32)

def power_norm_2d(m: np.ndarray, eps: float = 1e-8) -> np.ndarray:
    scale = np.mean(np.abs(m)) + eps
    return (m / scale).astype(np.float32)

def tokenize_single_piece(tok, s: str) -> List[int]:
    return tok.encode(s, add_special_tokens=False)

def strings_to_ids_list(tok, s: str) -> List[int]:
    ids = []
    for it in [x.strip() for x in s.split(",") if x.strip()] if isinstance(s, str) else []:
        id_list = tokenize_single_piece(tok, it)
        if len(id_list) == 0:
            print(f"[warn] '{it}' tokenized to nothing; skipping.")
            continue
        if len(id_list) > 1:
            print(f"[warn] '{it}' tokenized to {id_list}; using first id {id_list[0]}.")
        ids.append(id_list[0])
    return ids

def strings_to_ids_multi(tok, s_or_list) -> List[int]:
    if isinstance(s_or_list, str):
        s = s_or_list
    else:
        s = ",".join(s_or_list)
    out = []
    for it in [x.strip() for x in s.split(",") if x.strip()]:
        ids = tokenize_single_piece(tok, it)
        if len(ids) == 0: continue
        out.append(ids[0])
    return out

def read_pairs(path: Optional[str]) -> List[Tuple[str,str]]:
    if path is None or not os.path.exists(path):
        return [
            ("translate English to German: A small test.", "Ein kleiner Test."),
            ("Call me at 12:30 tomorrow.", "Ruf mich morgen um 12:30 an."),
            ("He said, \"Hello!\"", "Er sagte: „Hallo!“"),
            ("The price is $19.99.", "Der Preis beträgt 19,99 $."),
            ("Version 2.0 was released in 2024.", "Version 2.0 wurde 2024 veröffentlicht."),
        ]
    pairs = []
    with open(path, "r", encoding="utf-8") as f:
        for line in f:
            if "\t" in line:
                src, tgt = line.rstrip("\n").split("\t", 1)
                pairs.append((src, tgt))
    return pairs

def read_text_as_pairs(path: str, max_examples: int):
    pairs = []
    with open(path, "r", encoding="utf-8") as f:
        for line in f:
            line = line.strip()
            if not line: continue
            pairs.append((line, line))
            if len(pairs) >= max_examples: break
    return pairs

def _slugify_token_label(s: str) -> str:
    return (s.replace(" ", "_").replace("/", "_").replace("\\", "_")
              .replace("\t", "_").replace("\n", "_").replace("<", "")
              .replace(">", "").replace(":", "").replace("|", "")
              .replace("*", "").replace("?", "").replace('"', ""))[:64] or "token"

def compute_pockets(A_vec: np.ndarray, c_vec: np.ndarray, WoT: np.ndarray | None,
                    mode: str = "wo_times_eff") -> np.ndarray:
    if mode == "outer_eff_down":
        return (A_vec[:, None] * c_vec[None, :]).astype(np.float32)
    elif mode == "wo_times_eff":
        if WoT is None: raise ValueError("compute_pockets(mode='wo_times_eff') requires WoT")
        return (WoT * A_vec[:, None]).astype(np.float32)
    else:
        raise ValueError(f"Unknown pockets mode: {mode}")

def wo_locality_weights(Wo: torch.Tensor) -> tuple[np.ndarray, np.ndarray]:
    Wo_np = Wo.detach().float().cpu().numpy()
    per_ff  = np.linalg.norm(Wo_np, axis=0).astype(np.float32)   # len d_ff
    per_mod = np.linalg.norm(Wo_np, axis=1).astype(np.float32)   # len d_model
    return per_ff, per_mod

# ---------------------------- Local smoothing helpers ---------------------------- #

def _moving_stats_1d(x: np.ndarray, window: int) -> tuple[np.ndarray, np.ndarray]:
    w = int(max(1, window)); pad = w // 2
    k = np.ones(w, dtype=np.float32)
    xp = np.pad(x.astype(np.float32), (pad, pad), mode="reflect")
    mx = np.convolve(xp, k, mode="valid"); mx2 = np.convolve(xp * xp, k, mode="valid")
    mean = mx / w; var = np.maximum(mx2 / w - mean * mean, 0.0); std = np.sqrt(var + 1e-8)
    return mean, std

def local_zscore_1d(x: np.ndarray, window: int, robust: bool=True) -> np.ndarray:
    w = int(max(3, window))
    if w % 2 == 0: w += 1
    pad = w // 2; x32 = x.astype(np.float32, copy=False)
    xp = np.pad(x32, (pad, pad), mode="reflect")
    k = np.ones(w, dtype=np.float32) / w
    return np.convolve(xp, k, mode="valid")

# ---------------------------- Attention capture (MT5 only) ---------------------------- #

def _attn_module_path(side: str, layer_idx: int, kind: str) -> str:
    if side == "enc":
        assert kind == "self"; return f"encoder.block.{layer_idx}.layer.0.SelfAttention"
    else:
        if kind == "self": return f"decoder.block.{layer_idx}.layer.0.SelfAttention"
        elif kind == "cross": return f"decoder.block.{layer_idx}.layer.1.EncDecAttention"
        else: raise ValueError("kind must be 'self' or 'cross'")

@contextmanager
def hook_attn_outputs(model, names, store_dict):
    handles = []; named = dict(model.named_modules())
    def _to_cpu_tensor(x):
        if isinstance(x, torch.Tensor): return x.detach().float().cpu()
        if isinstance(x, (tuple, list)) and len(x)>0: return _to_cpu_tensor(x[0])
        if isinstance(x, dict):
            for k in ("last_hidden_state","hidden_states","output"):
                if k in x: return _to_cpu_tensor(x[k])
        raise TypeError(f"Unexpected attention hook output type: {type(x)}")
    try:
        for n in names:
            mod = named[n]
            def _mk_hook(nm):
                def fn(m, inp, out): store_dict[nm] = _to_cpu_tensor(out)
                return fn
            handles.append(mod.register_forward_hook(_mk_hook(n)))
        yield
    finally:
        for h in handles: h.remove()

def compute_attn_band_for_token(attn_tensor: torch.Tensor, labels: torch.Tensor, target_id: int) -> np.ndarray:
    dev = attn_tensor.device; lab = labels.detach().to(dev)
    B, T, D = attn_tensor.shape; flat = attn_tensor.view(B*T, D); lab_flat = lab.view(B*T)
    mask = (lab_flat == int(target_id))
    v = flat[mask].mean(dim=0) if mask.any() else torch.zeros(D, dtype=attn_tensor.dtype, device=dev)
    return v.cpu().numpy()

def export_attn_band_maps(out_dir: str, side: str, layer_idx: int, kind: str, vec_raw: np.ndarray, tag: str,
                          enable_local_z: bool=False, local_z_window: int=31):
    ensure_dir(out_dir); sub = os.path.join(out_dir, f"{side}_L{layer_idx}_{kind}_{tag}"); ensure_dir(sub)
    np.save(os.path.join(sub, "attn_out.npy"), vec_raw)
    vec_norm = power_norm_1d(vec_raw)
    render_band_row(vec_norm, os.path.join(sub, "attn_out_band.png"),
                    f"{side} L{layer_idx} {kind} {tag}: band (attn out, power-normed)", gamma=1.8)
    if enable_local_z:
        vec_lz = local_zscore_1d(vec_raw, local_z_window, robust=True)
        render_band_row(vec_lz, os.path.join(sub, "attn_out_band_localz.png"),
                        f"{side} L{layer_idx} {kind} {tag}: band (attn out) local-z", gamma=1.8)

# ---------------------------- FFN Aggregators (MT5) ---------------------------- #

class TokenAggregators:
    """
    Aggregates FFN streams (MT5) and hidden-state maps (MT5 & RWKV) and supports A/B comparison.
    Keys are (side, layer) where side in {"enc","dec"}.
    """

    def __init__(self, n_layers_enc: int, n_layers_dec: int, d_ff: int, d_model: int,
                 target_single_ids: List[int], set_a_ids: List[int], set_b_ids: List[int],
                 special_ids: List[int], want_baseline: bool, baseline_exclude_ids: set):
        self.n_layers_enc = n_layers_enc; self.n_layers_dec = n_layers_dec
        self.d_ff = d_ff; self.d_model = d_model
        self.target_single_ids = set(target_single_ids)
        self.set_a_ids = set(set_a_ids); self.set_b_ids = set(set_b_ids)
        self.special_ids = set(special_ids)
        self.want_baseline = want_baseline; self.baseline_exclude_ids = set(baseline_exclude_ids)

        # Raw sums (FFN path)
        self.first_single_eff = defaultdict(lambda: None)  # np[d_ff]
        self.first_single_down = defaultdict(lambda: None) # np[d_model]
        self.sum_single_eff = defaultdict(lambda: np.zeros(d_ff, dtype=np.float32))
        self.sum_single_down = defaultdict(lambda: np.zeros(d_model, dtype=np.float32))
        self.cnt_single = defaultdict(int)
        self.sum_generic_eff = defaultdict(lambda: np.zeros(d_ff, dtype=np.float32))
        self.sum_generic_down = defaultdict(lambda: np.zeros(d_model, dtype=np.float32))
        self.cnt_generic = defaultdict(int)
        self.sum_setA_eff = defaultdict(lambda: np.zeros(d_ff, dtype=np.float32))
        self.sum_setA_down = defaultdict(lambda: np.zeros(d_model, dtype=np.float32))
        self.cnt_setA = defaultdict(int)
        self.sum_setB_eff = defaultdict(lambda: np.zeros(d_ff, dtype=np.float32))
        self.sum_setB_down = defaultdict(lambda: np.zeros(d_model, dtype=np.float32))
        self.cnt_setB = defaultdict(int)

        # Baseline / target-single (ĉ & rowmass) for FFN path
        self.base_sum_dir_down = defaultdict(lambda: np.zeros(d_model, dtype=np.float32))
        self.base_cnt_dir_down = defaultdict(int)
        self.base_sum_rowmass = defaultdict(lambda: np.zeros(d_ff, dtype=np.float32))
        self.base_cnt_rowmass = defaultdict(int)
        self.baseline_total_tokens = 0
        self.single_sum_dir_down = defaultdict(lambda: np.zeros(d_model, dtype=np.float32))
        self.single_cnt_dir_down = defaultdict(int)
        self.single_sum_rowmass = defaultdict(lambda: np.zeros(d_ff, dtype=np.float32))
        self.single_cnt_rowmass = defaultdict(int)

        # FFN token-vs-token comparison accumulators (per-occurrence normalized maps)
        self.compA_sum_S = defaultdict(lambda: None)  # np[d_ff, d_model]
        self.compA_cnt   = defaultdict(int)
        self.compB_sum_S = defaultdict(lambda: None)
        self.compB_cnt   = defaultdict(int)

        # Hidden-state comparison accumulators (per-occurrence maps)
        self.hidA_sum_S = defaultdict(lambda: None)
        self.hidA_cnt   = defaultdict(int)
        self.hidB_sum_S = defaultdict(lambda: None)
        self.hidB_cnt   = defaultdict(int)

        self._cached_wo_norms: Dict[Tuple[str,int], Tuple[np.ndarray,np.ndarray]] = {}

        # Device and CUDA-side id tensors for fast masking (used in FFN CUDA path now; OK to keep)
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.set_a_ids_t = torch.as_tensor(sorted(self.set_a_ids), device=self.device, dtype=torch.long)
        self.set_b_ids_t = torch.as_tensor(sorted(self.set_b_ids), device=self.device, dtype=torch.long)

        # running sums / counts ON DEVICE for fast accumulation (FFN CUDA path)
        self.compA_sum_S_t = defaultdict(lambda: None)  # torch.Tensor [F,M] on device
        self.compB_sum_S_t = defaultdict(lambda: None)
        self.compA_cnt_t = defaultdict(int)
        self.compB_cnt_t = defaultdict(int)

    @torch.no_grad()
    def update_ffn_batch_cuda(self, side, layer, A_btf, C_btm, token_ids_bt):
        # A: [B,T,F], C: [B,T,M], ids: [B,T]  (all on same CUDA device)
        key = (side, layer)
        B, T, F = A_btf.shape
        M = C_btm.shape[-1]
        ids = token_ids_bt.to(A_btf.device)

        # power-normalize per occurrence (vectorized)
        a = A_btf / (A_btf.abs().mean(dim=-1, keepdim=True) + 1e-8)  # [B,T,F]
        c = C_btm / (C_btm.abs().mean(dim=-1, keepdim=True) + 1e-8)  # [B,T,M]

        # build masks for A/B (vectorized isin)
        maskA = torch.isin(ids, self.set_a_ids_t)
        maskB = torch.isin(ids, self.set_b_ids_t)

        def _acc(mask, tgt_sum, tgt_cnt):
            if mask.any():
                A_occ = a[mask]  # [K,F]
                C_occ = c[mask]  # [K,M]
                S_sum = A_occ.t() @ C_occ  # [F,M]
                if tgt_sum[key] is None:
                    tgt_sum[key] = S_sum
                else:
                    tgt_sum[key] = tgt_sum[key] + S_sum
                tgt_cnt[key] += A_occ.shape[0]

        _acc(maskA, self.compA_sum_S_t, self.compA_cnt_t)
        _acc(maskB, self.compB_sum_S_t, self.compB_cnt_t)

    # ---------- FFN updates (MT5 only) ----------
    def _update_positions(self, side: str, layer: int, A_btF: torch.Tensor, c_btD: torch.Tensor,
                          token_ids_bt: torch.Tensor, count_for_baseline: bool):
        A = A_btF.numpy(); C = c_btD.numpy(); TIDs = token_ids_bt.numpy()
        B, T, _ = A.shape; key = (side, layer)

        def _pn1d(v: np.ndarray, eps: float=1e-8) -> np.ndarray:
            s = np.mean(np.abs(v)) + eps; return (v / s).astype(np.float32)

        for b in range(B):
            for t in range(T):
                tid = int(TIDs[b, t])
                if tid in self.special_ids: continue
                a = A[b, t]; c = C[b, t]

                # Raw groupings
                if tid in self.target_single_ids:
                    if self.first_single_eff[key] is None:  self.first_single_eff[key]  = a.copy()
                    if self.first_single_down[key] is None: self.first_single_down[key] = c.copy()
                    self.sum_single_eff[key]  += a; self.sum_single_down[key] += c; self.cnt_single[key] += 1
                elif tid in self.set_a_ids:
                    self.sum_setA_eff[key]  += a; self.sum_setA_down[key] += c; self.cnt_setA[key] += 1
                elif tid in self.set_b_ids:
                    self.sum_setB_eff[key]  += a; self.sum_setB_down[key] += c; self.cnt_setB[key] += 1
                else:
                    self.sum_generic_eff[key]  += a; self.sum_generic_down[key] += c; self.cnt_generic[key] += 1

                # Baseline-style normalized views
                cnorm = np.linalg.norm(c) + 1e-8; c_hat = c / cnorm
                denom = np.abs(a).sum() + 1e-8; p = np.abs(a) / denom
                if tid in self.target_single_ids:
                    self.single_sum_dir_down[key] += c_hat; self.single_cnt_dir_down[key] += 1
                    self.single_sum_rowmass[key] += p;      self.single_cnt_rowmass[key] += 1
                if count_for_baseline and (tid not in self.baseline_exclude_ids):
                    self.base_sum_dir_down[key] += c_hat; self.base_cnt_dir_down[key] += 1
                    self.base_sum_rowmass[key] += p;      self.base_cnt_rowmass[key] += 1
                    self.baseline_total_tokens += 1

                # Per-occurrence normalized token map (outer product) for FFN compare
                if tid in self.set_a_ids or tid in self.set_b_ids:
                    a_n = _pn1d(a); c_n = _pn1d(c)
                    S_occ = (a_n[:, None] * c_n[None, :]).astype(np.float32)
                    if tid in self.set_a_ids:
                        if self.compA_sum_S[key] is None: self.compA_sum_S[key] = np.zeros_like(S_occ, dtype=np.float32)
                        self.compA_sum_S[key] += S_occ; self.compA_cnt[key] += 1
                    else:
                        if self.compB_sum_S[key] is None: self.compB_sum_S[key] = np.zeros_like(S_occ, dtype=np.float32)
                        self.compB_sum_S[key] += S_occ; self.compB_cnt[key] += 1

    # ---------- Hidden-state compare (MT5 & RWKV) ----------
    def update_hidden_side(self, side: str, layer: int, H_btD: torch.Tensor, token_ids_bt: torch.Tensor,
                           set_a_ids: set, set_b_ids: set):
        key = (side, layer)
        H = H_btD.numpy(); TIDs = token_ids_bt.numpy()
        B, T, D = H.shape

        def _pn1d(v: np.ndarray, eps: float=1e-8) -> np.ndarray:
            return (v / (np.mean(np.abs(v)) + eps)).astype(np.float32)

        for b in range(B):
            for t in range(T):
                tid = int(TIDs[b, t])
                if tid in self.special_ids:
                    continue
                h = H[b, t]
                a = _pn1d(np.abs(h))
                c = _pn1d(h)
                S_occ = (a[:, None] * c[None, :]).astype(np.float32)
                if tid in set_a_ids:
                    if self.hidA_sum_S[key] is None: self.hidA_sum_S[key] = np.zeros_like(S_occ, dtype=np.float32)
                    self.hidA_sum_S[key] += S_occ; self.hidA_cnt[key] += 1
                elif tid in set_b_ids:
                    if self.hidB_sum_S[key] is None: self.hidB_sum_S[key] = np.zeros_like(S_occ, dtype=np.float32)
                    self.hidB_sum_S[key] += S_occ; self.hidB_cnt[key] += 1

    # ---------- Baseline I/O for FFN ----------
    def baseline_means(self, side: str, n_layers: int) -> Tuple[np.ndarray, np.ndarray]:
        mean_dir, mean_row = [], []
        for L in range(n_layers):
            k = (side, L)
            mean_dir.append(self.base_sum_dir_down[k] / max(1, self.base_cnt_dir_down[k]) if self.base_cnt_dir_down[k] > 0 else np.zeros(self.d_model, np.float32))
            mean_row.append(self.base_sum_rowmass[k] / max(1, self.base_cnt_rowmass[k]) if self.base_cnt_rowmass[k] > 0 else np.zeros(self.d_ff, np.float32))
        return np.stack(mean_dir, 0), np.stack(mean_row, 0)

    def save_baseline(self, path: str, nL_enc: int, nL_dec: int):
        enc_dir, enc_row = self.baseline_means("enc", nL_enc)
        dec_dir, dec_row = self.baseline_means("dec", nL_dec)
        np.savez(path, enc_dir_down_mean=enc_dir, enc_rowmass_mean=enc_row, dec_dir_down_mean=dec_dir, dec_rowmass_mean=dec_row)

    @staticmethod
    def load_baseline(path: str) -> Dict[str, np.ndarray]:
        arrs = np.load(path); return {k: arrs[k] for k in arrs.files}

    def word_baseline(self, side: str, layer: int, ids: List[int]) -> tuple[np.ndarray, np.ndarray]:
        key = (side, layer)
        if self.base_cnt_dir_down[key] > 0 and self.base_cnt_rowmass[key] > 0:
            mu_dir = self.base_sum_dir_down[key] / self.base_cnt_dir_down[key]
            mu_row = self.base_sum_rowmass[key] / self.base_cnt_rowmass[key]
            return mu_dir, mu_row
        return np.zeros(self.d_model, np.float32), np.zeros(self.d_ff, np.float32)

    # ---------- FFN exporter with A/B numbers ----------
    def export_views(self, out_dir: str, side: str, layer: int, Wo_dF: torch.Tensor,
                     baseline_dir_down: Optional[np.ndarray] = None, baseline_rowmass: Optional[np.ndarray] = None,
                     *, enable_local_z: bool=True, local_z_window: int=35, pockets_axis: str="columns",
                     token_single: str="", pockets_mode: str="outer_eff_down", weight_bands_by_wo: bool=False,
                     compA_ids: List[int]=None, compB_ids: List[int]=None) -> dict | None:
        ensure_dir(out_dir); key = (side, layer)
        WoT_np = Wo_dF.detach().float().cpu().numpy().T
        if (side, layer) not in self._cached_wo_norms:
            self._cached_wo_norms[(side, layer)] = wo_locality_weights(Wo_dF)
        per_ff_w, per_mod_w = self._cached_wo_norms[(side, layer)]

        def save_pack(tag: str, A_vec: Optional[np.ndarray], c_vec: Optional[np.ndarray]):
            if A_vec is None or c_vec is None: return
            pack_dir = os.path.join(out_dir, tag); ensure_dir(pack_dir)
            A_vec = power_norm_1d(A_vec); c_vec = power_norm_1d(c_vec)
            if weight_bands_by_wo:
                A_vec = (A_vec * (per_ff_w / (per_ff_w.mean() + 1e-8))).astype(np.float32)
                c_vec = (c_vec * (per_mod_w / (per_mod_w.mean() + 1e-8))).astype(np.float32)
            np.save(os.path.join(pack_dir, "eff_A.npy"), A_vec)
            np.save(os.path.join(pack_dir, "down_c.npy"), c_vec)
            S = compute_pockets(A_vec, c_vec, WoT_np, mode=pockets_mode); S = power_norm_2d(S)
            np.save(os.path.join(pack_dir, "pockets_S.npy"), S)
            render_heatmap(S, os.path.join(pack_dir, "pockets_S.png"),
                           f"{side} L{layer} {tag}: pockets S", gamma=1.6)
            render_band_row(c_vec, os.path.join(pack_dir, "down_c.png"),
                            f"{side} L{layer} {tag}: band (down c, power-normed)", gamma=1.8)
            render_band_row(A_vec, os.path.join(pack_dir, "eff_A.png"),
                            f"{side} L{layer} {tag}: eff A (power-normed)", gamma=1.8)
            if enable_local_z:
                tok_slug = _slugify_token_label(token_single)
                c_lz = local_zscore_1d(c_vec, local_z_window); A_lz = local_zscore_1d(A_vec, local_z_window)
                np.save(os.path.join(pack_dir, f"down_c_localz__{tok_slug}.npy"), c_lz)
                np.save(os.path.join(pack_dir, f"eff_A_localz__{tok_slug}.npy"), A_lz)
                render_band_row(c_lz, os.path.join(pack_dir, f"down_c_localz__{tok_slug}.png"),
                                f"{side} L{layer} {tag}: band (down c) local-z [{tok_slug}]", gamma=1.8)
                render_band_row(A_lz, os.path.join(pack_dir, f"eff_A_localz__{tok_slug}.png"),
                                f"{side} L{layer} {tag}: eff A local-z [{tok_slug}]", gamma=1.8)
                S_lz = (A_lz[:, None] * c_lz[None, :]).astype(np.float32)
                np.save(os.path.join(pack_dir, f"pockets_S_localz__{tok_slug}.npy"), S_lz)
                render_heatmap(S_lz, os.path.join(pack_dir, f"pockets_S_localz__{tok_slug}.png"),
                               f"{side} L{layer} {tag}: pockets S (A_lz ⊗ c_lz)", gamma=1.6)

        # Legacy packs
        if self.first_single_eff[key] is not None and self.first_single_down[key] is not None:
            save_pack("single_token_first", self.first_single_eff[key], self.first_single_down[key])
        if self.cnt_single[key] > 0:
            A = self.sum_single_eff[key] / max(1, self.cnt_single[key])
            c = self.sum_single_down[key] / max(1, self.cnt_single[key])
            save_pack("mean_single_token", A, c)
        if self.cnt_generic[key] > 0:
            A_g = self.sum_generic_eff[key] / max(1, self.cnt_generic[key])
            c_g = self.sum_generic_down[key] / max(1, self.cnt_generic[key])
            if self.cnt_single[key] > 0:
                A_s = self.sum_single_eff[key] / max(1, self.cnt_single[key])
                c_s = self.sum_single_down[key] / max(1, self.cnt_single[key])
                save_pack("mean_single_minus_generic", A_s - A_g, c_s - c_g)
            if self.cnt_setA[key] > 0 and self.cnt_setB[key] > 0:
                A_a = self.sum_setA_eff[key]  / max(1, self.cnt_setA[key])
                c_a = self.sum_setA_down[key] / max(1, self.cnt_setA[key])
                A_b = self.sum_setB_eff[key]  / max(1, self.cnt_setB[key])
                c_b = self.sum_setB_down[key] / max(1, self.cnt_setB[key])
                save_pack("mean_setA_minus_setB", A_a - A_b, c_a - c_b)

        # Baseline contrasts
        if baseline_dir_down is not None and baseline_rowmass is not None and self.single_cnt_dir_down[key] > 0:
            mu_dir = self.single_sum_dir_down[key] / max(1, self.single_cnt_dir_down[key])
            mu_p   = self.single_sum_rowmass[key] / max(1, self.single_cnt_rowmass[key])
            delta_dir = power_norm_1d(mu_dir - baseline_dir_down)
            delta_p   = power_norm_1d(mu_p   - baseline_rowmass)
            if weight_bands_by_wo:
                per_ff_w, per_mod_w = self._cached_wo_norms[(side, layer)]
                delta_dir = (delta_dir * (per_mod_w / (per_mod_w.mean() + 1e-8))).astype(np.float32)
                delta_p   = (delta_p   * (per_ff_w  / (per_ff_w.mean()  + 1e-8))).astype(np.float32)
            pack_dir = os.path.join(out_dir, "single_vs_baseline_dir"); ensure_dir(pack_dir)
            np.save(os.path.join(pack_dir, "delta_dir_down.npy"), delta_dir)
            np.save(os.path.join(pack_dir, "delta_rowmass_eff.npy"), delta_p)
            S_dir = compute_pockets(delta_p, delta_dir, WoT_np, mode=pockets_mode); S_dir = power_norm_2d(S_dir)
            np.save(os.path.join(pack_dir, "pockets_S_dir.npy"), S_dir)
            render_heatmap(S_dir, os.path.join(pack_dir, "pockets_S_dir.png"),
                           f"{side} L{layer} single vs baseline: pockets", gamma=1.6)
            render_band_row(delta_dir, os.path.join(pack_dir, "delta_dir_down.png"),
                            f"{side} L{layer} single vs baseline: band Δĉ", gamma=1.8)

        # A/B numbers from FFN maps
        layer_diff_sum, layer_overlap = None, None
        if (self.compA_cnt[key] > 0) and (self.compB_cnt[key] > 0):
            SA = (self.compA_sum_S[key] / max(1, self.compA_cnt[key])).astype(np.float32)
            SB = (self.compB_sum_S[key] / max(1, self.compB_cnt[key])).astype(np.float32)
            D = np.abs(SA - SB).astype(np.float32)
            layer_diff_sum = float(D.sum())
            denom = (np.abs(SA) + np.abs(SB)).sum() + 1e-8
            layer_overlap = float(D.sum() / denom)
            pack_dir = os.path.join(out_dir, "compare_token_diff"); ensure_dir(pack_dir)
            np.save(os.path.join(pack_dir, "mapA.npy"), SA)
            np.save(os.path.join(pack_dir, "mapB.npy"), SB)
            np.save(os.path.join(pack_dir, "diff_map.npy"), D)
            render_heatmap(power_norm_2d(SA), os.path.join(pack_dir, "mapA.png"),
                           f"{side} L{layer} token A map (avg per-occurrence norm)", gamma=1.6)
            render_heatmap(power_norm_2d(SB), os.path.join(pack_dir, "mapB.png"),
                           f"{side} L{layer} token B map (avg per-occurrence norm)", gamma=1.6)
            render_heatmap(D, os.path.join(pack_dir, "diff_map.png"),
                           f"{side} L{layer} |A - B| (difference) • overlap={layer_overlap:.3f}", gamma=1.6)
            plt.figure(figsize=(6,4))
            plt.hist(D.flatten(), bins=60, density=False)
            plt.title(f"{side} L{layer} difference histogram")
            plt.xlabel("|A-B| value"); plt.ylabel("count"); plt.tight_layout()
            plt.savefig(os.path.join(pack_dir, "diff_hist.png"), dpi=200); plt.close()

        return {"layer_diff_sum": layer_diff_sum, "layer_overlap": layer_overlap}

# ---------------------------- FFN Capture (MT5) ---------------------------- #

class FFNCapture:
    """
    Hooks all FFNs (encoder layer[1], decoder layer[2]) and streams:
      A = act(wi_0(x))*wi_1(x)   (or act(wi(x)))
      c = A @ W_o^T
    """
    def __init__(self, model, tokenizer, aggregators: TokenAggregators,
                 capture_encoder: bool, capture_decoder: bool):
        self.model = model; self.tok = tokenizer; self.agg = aggregators
        self.capture_encoder = capture_encoder; self.capture_decoder = capture_decoder
        self.act = act_fn_for_proj(model.config.feed_forward_proj)
        self.handles = []; self._enc_ids_bt = None; self._dec_ids_bt = None
        if capture_encoder:
            for li, block in enumerate(model.encoder.block):
                self._register_ffn_hook(block.layer[1].DenseReluDense, "enc", li)
        if capture_decoder:
            for li, block in enumerate(model.decoder.block):
                self._register_ffn_hook(block.layer[2].DenseReluDense, "dec", li)

    def _register_ffn_hook(self, ffn, side: str, layer: int):
        has_gated = hasattr(ffn, "wi_0") and hasattr(ffn, "wi_1")

        def hook_fn(module, inputs, output):
            x = inputs[0]
            with torch.no_grad():
                if has_gated:
                    u = ffn.wi_0(x)
                    v = ffn.wi_1(x)
                    A = self.act(u) * v  # [B,T,F]
                else:
                    u = ffn.wi(x)
                    A = self.act(u)  # [B,T,F]
                Wo = ffn.wo.weight  # [M,F]
                c = torch.einsum("btf,mf->btm", A, Wo)  # [B,T,M]

                ids = self._enc_ids_bt if side == "enc" else self._dec_ids_bt
                if ids is None: return
                self.agg.update_ffn_batch_cuda(side, layer, A, c, ids)
        self.handles.append(ffn.register_forward_hook(hook_fn))

    def set_batch_token_ids(self, enc_ids_bt: Optional[torch.Tensor], dec_ids_bt: Optional[torch.Tensor]):
        self._enc_ids_bt = enc_ids_bt; self._dec_ids_bt = dec_ids_bt

    def close(self):
        for h in self.handles: h.remove()
        self.handles = []

# ---------------------------- Hidden-state Capture (MT5 & RWKV) ---------------------------- #

class HiddenStateCapture:
    """
    Captures per-layer hidden states (post-block outputs). Works for MT5 and RWKV (HF).
    Stores (B,T,D) CPU float32 tensors in .enc / .dec dicts keyed by layer index.
    """
    def __init__(self, model, capture_encoder: bool, capture_decoder: bool):
        self.model = model
        self.capture_encoder = capture_encoder
        self.capture_decoder = capture_decoder
        self.handles = []
        self.enc = {}
        self.dec = {}

        named = dict(model.named_modules())

        def _hook_store(dst_dict, lid):
            def fn(m, inp, out):
                x = out[0] if isinstance(out, (tuple, list)) else out
                dst_dict[lid] = x.detach().float().cpu()
            return fn

        # MT5
        if hasattr(model, "encoder") and hasattr(model.encoder, "block") and capture_encoder:
            for i, blk in enumerate(model.encoder.block):
                self.handles.append(blk.register_forward_hook(_hook_store(self.enc, i)))
        if hasattr(model, "decoder") and hasattr(model.decoder, "block") and capture_decoder:
            for i, blk in enumerate(model.decoder.block):
                self.handles.append(blk.register_forward_hook(_hook_store(self.dec, i)))

        # RWKV (HF: model.rwkv.blocks)
        if hasattr(model, "rwkv") and hasattr(model.rwkv, "blocks"):
            if capture_decoder:  # treat RWKV as "decoder" side
                for i, blk in enumerate(model.rwkv.blocks):
                    self.handles.append(blk.register_forward_hook(_hook_store(self.dec, i)))

    def close(self):
        for h in self.handles: h.remove()
        self.handles = []

# ---------------------------- RWKV helpers ---------------------------- #

def _align_time_mix_names(sd: dict) -> dict:
    """
    Map '.time_mix_k' -> '.time_mix_key' and '.time_mix_r' -> '.time_mix_receptance'
    exactly once, using a single regex pass to avoid double-replacements like 'keyey'.
    """
    pat = re.compile(r'(\.time_mix_)([kr])(\b)')

    def repl(m: re.Match) -> str:
        prefix, letter, boundary = m.groups()
        return prefix + ('key' if letter == 'k' else 'receptance') + boundary

    out = {}
    changed = 0
    for k, v in sd.items():
        k2, n = pat.subn(repl, k)
        if n:
            changed += n
        out[k2] = v
    print(f"[rwkv] time_mix remap: changed={changed}")
    return out

def _strip_prefix_rwkv(k: str) -> str:
    """Keep 'rwkv.' (HF expects it); drop only DDP wrappers)."""
    return re.sub(r'^(?:module\.|model\.)', '', k, count=1)

def _infer_rwkv_cfg(sd: dict, tokenizer, fallback_hidden=None, fallback_layers=None) -> RwkvConfig:
    keys = list(sd.keys())

    # vocab size
    vocab_size = None
    for cand in ("rwkv.embeddings.weight", "rwkv.emb.weight", "embeddings.weight", "emb.weight"):
        if cand in sd:
            vocab_size = sd[cand].shape[0]; break
    if vocab_size is None:
        vocab_size = len(tokenizer)

    # hidden size
    hidden_size = None
    for k in keys:
        t = sd[k]
        if k.endswith(("ln1.weight","pre_ln.weight","ln2.weight")) and t.ndim == 1:
            hidden_size = int(t.shape[0]); break
        if "time_mix_" in k and t.ndim >= 1:
            hidden_size = int(t.shape[-1]); break
    if hidden_size is None:
        hidden_size = fallback_hidden or 1024  # typical 430M

    # layers
    layer_idxs = []
    for k in keys:
        m = re.search(r'rwkv\.blocks\.(\d+)\.', k)
        if m: layer_idxs.append(int(m.group(1)))
    num_layers = (max(layer_idxs) + 1) if layer_idxs else (fallback_layers or 24)

    return RwkvConfig(
        vocab_size=vocab_size,
        hidden_size=hidden_size,
        num_hidden_layers=num_layers,
        layer_norm_epsilon=1e-5,
        bos_token_id=getattr(tokenizer, "bos_token_id", 0) or 0,
        eos_token_id=getattr(tokenizer, "eos_token_id", 0) or 0,
        pad_token_id=getattr(tokenizer, "pad_token_id", 0) or 0,
    )

def load_rwkv_from_sd(sd_path: str, tokenizer, *, fallback_hidden=None, fallback_layers=None, device="cpu"):
    sd_raw = torch.load(sd_path, map_location="cpu")
    if isinstance(sd_raw, dict) and "state_dict" in sd_raw:
        sd_raw = sd_raw["state_dict"]
    sd = {_strip_prefix_rwkv(k): v for k, v in sd_raw.items()}
    sd = _align_time_mix_names(sd)

    cfg = _infer_rwkv_cfg(sd, tokenizer, fallback_hidden=fallback_hidden, fallback_layers=fallback_layers)
    model = RwkvForCausalLM(cfg)
    ms = model.state_dict()

    keys_sd = set(sd.keys())
    keys_ms = set(ms.keys())
    inter   = keys_sd & keys_ms
    miss    = sorted(keys_ms - keys_sd)
    extra   = sorted(keys_sd - keys_ms)

    print(f"[rwkv] sd keys: {len(keys_sd)} | model keys: {len(keys_ms)} | intersect: {len(inter)} "
          f"({100.0*len(inter)/max(1,len(keys_ms)):.1f}% of model)")
    print(f"[rwkv] missing in sd: {len(miss)}")
    for k in miss[:20]: print("   -", k)
    print(f"[rwkv] sd-only (unused): {len(extra)}")
    for k in extra[:20]: print("   -", k)

    missing, unexpected = model.load_state_dict({k: sd[k] for k in inter}, strict=False)
    print(f"[rwkv] load_state_dict: missing_after_load={len(missing)} unexpected_after_load={len(unexpected)}")

    return model.to(device).eval()

def load_rwkv_tokenizer(tok_id: str | None):
    tried = []
    if tok_id:
        try:
            tok = AutoTokenizer.from_pretrained(tok_id, use_forced_bos_token=False, use_fast=True, trust_remote_code=True)
            print(f"[rwkv] tokenizer loaded from '{tok_id}' ({tok.__class__.__name__})")
            return tok
        except Exception as e:
            tried.append((tok_id, str(e)))
    try:
        tok = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b", use_fast=True, trust_remote_code=True)
        print(f"[rwkv] tokenizer loaded from fallback 'EleutherAI/gpt-neox-20b' ({tok.__class__.__name__})")
        return tok
    except Exception as e:
        tried.append(("EleutherAI/gpt-neox-20b", str(e)))

    msg = "[rwkv] Failed to load a tokenizer. Tried:\n" + "\n".join([f"  - {k}: {err}" for k, err in tried])
    msg += ("\nIf you’re offline, download '20B_tokenizer.json' and point --rwkv_tokenizer "
            "to its directory, or clone EleutherAI/gpt-neox-20b locally.")
    raise OSError(msg)

def neox_ids_for_words(tok, csv: str, *, allow_multi_head=True, auto_space_prefix=True):
    """
    Returns a sorted list of *single-token* ids for words in `csv`.
    For GPT-NeoX tokenizers, tries both 'word' and ' word' (leading space).
    If a word is multi-piece and allow_multi_head=True, picks the head piece
    (the one whose token string starts with 'Ġ' or '▁').
    """
    out = set()
    words = [w.strip() for w in csv.split(",") if w.strip()]
    for w in words:
        variants = [w]
        if auto_space_prefix and not w.startswith(" "):
            variants = [w, " " + w]
        picked = False
        for v in variants:
            ids = tok.encode(v, add_special_tokens=False)
            toks = tok.convert_ids_to_tokens(ids)
            if len(ids) == 1:
                out.add(ids[0]); picked = True; break
            if allow_multi_head:
                for i, s in enumerate(toks):
                    if isinstance(s, str) and (s.startswith("Ġ") or s.startswith("▁")):
                        out.add(ids[i]); picked = True; break
                if picked: break
        if not picked:
            print(f"[tokens] WARN: '{w}' did not yield a single/head token id (variants tried: {variants})")
    out = sorted(out)
    if out:
        names = [tok.convert_ids_to_tokens([i])[0] for i in out]
        print(f"[tokens] mapped {len(words)} words -> {len(out)} ids: {list(zip(out, names))[:12]}")
    else:
        print(f"[tokens] ERROR: no ids resolved from: {words}")
    return out

def _clean_words_csv(s: str) -> str:
    words = [re.sub(r"^[^\w]+|[^\w]+$", "", w.lower()).strip() for w in s.split(",")]
    words = [w for w in words if w]
    return ",".join(words)

# ---------------------------- Main ---------------------------- #

def main():
    ap = argparse.ArgumentParser()

    # Base / weights
    ap.add_argument("--ckpt", type=str, default="google/mt5-base", help="Base HF checkpoint for mt5")
    ap.add_argument("--state_dict_path", type=str,
                    # default="mt5_base_pretuned.pt",
                    default="mt5_base_standard_FaF_11x_noise.pt",
                    #default="mt5_base_standard_FaF_11x_noise_two_let_retrain1_batch_10000.pt",  # end with 2let no FZR
                    #default="mt5_base_twolet_FaF_12x_noise0_batch_5000_o.pt",  # end 2let with FZR
                    #default="rwkv3_checkpoint_0_batch_14000.pt",  # RWKV 2let
                    help="Path to .pt/.bin state_dict to load on top of --ckpt")

    # Architecture switch
    ap.add_argument("--arch", type=str, choices=["mt5", "rwkv"], default="mt5",
                    help="Backbone to load. 'mt5' or 'rwkv'.")
    ap.add_argument("--use_forgetfult5", default=False, action="store_true", help="Use custom ForgetfulT5 forward (import must exist)")

    ap.add_argument("--rwkv_vocab_size", type=int, default=50277)
    ap.add_argument("--rwkv_hidden_size", type=int, default=1024)
    ap.add_argument("--rwkv_n_layers", type=int, default=24)
    ap.add_argument("--rwkv_tokenizer", type=str, default="BlinkDL/rwkv-4-pile-430m")

    # Compare mode
    ap.add_argument("--compare_mode", type=str, choices=["ffn", "hidden"], default="hidden",
                    help="Use 'ffn' (T5-only) or 'hidden' (layer outputs; MT5 & RWKV).")

    # Modeling / options
    ap.add_argument("--token_baseline_multi", type=str, default="", help="Comma list averaged as baseline 'word'")
    ap.add_argument("--token_compare_a", type=str, default="the", help="Token(s) for compare A")
    ap.add_argument("--token_compare_b", type=str, default="them", help="Token(s) for compare B")
    ap.add_argument("--weight_bands_by_wo", action="store_true", help="Multiply bands by local Wo norms (FFN path)")

    ap.add_argument("--out_dir", type=str, default="Interpretation_Records/map_comparison/cat_call_2let")

    # z/vis
    ap.add_argument("--enable_local_z", default=True, action="store_true", help="Also save local averaged variants")
    ap.add_argument("--local_z_window", type=int, default=19, help="Odd window size for local avg")
    ap.add_argument("--local_z_axis_for_pockets", type=str, default="columns", choices=["columns","rows","both"], help="Axis for local z in pockets")

    # Data
    ap.add_argument("--pairs_file", type=str, default=None, help="TSV src<TAB>tgt")
    ap.add_argument("--text_file", type=str, default="Interpretation_Records/interpret_sub_texts.txt", help="Plaintext one-per-line")
    ap.add_argument("--max_examples", type=int, default=5000)
    ap.add_argument("--batch_size", type=int, default=32)
    ap.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu")

    # Capture sides
    ap.add_argument("--capture_encoder", action="store_true")
    ap.add_argument("--capture_decoder", default=True, action="store_true")

    # Baseline (FFN path)
    ap.add_argument("--baseline_build", action="store_true", help="stream and save baseline means (ĉ, p)")
    ap.add_argument("--baseline_use", action="store_true", help="load baseline and emit single-vs-baseline outputs")
    ap.add_argument("--baseline_file", type=str, default=None, help="path to .npz for baseline")
    ap.add_argument("--baseline_tokens_target", type=int, default=50000, help="approx tokens to include in baseline")
    ap.add_argument("--baseline_exclude_targets", action="store_true", help="exclude single/setA/setB from baseline")

    # Pockets (FFN path)
    ap.add_argument("--pockets_mode", type=str, default="outer_eff_down", choices=["outer_eff_down","wo_times_eff"], help="How to form pockets")

    # Tokenizer lengths (fix truncation warnings)
    ap.add_argument("--max_src_len", type=int, default=512)
    ap.add_argument("--max_tgt_len", type=int, default=128)

    # Injection helper to guarantee A/B occurrences
    ap.add_argument("--inject_compare_sentences", action="store_true",
                    help="Prepend simple pairs so token_compare_* appear in the stream")

    # Debug
    ap.add_argument("--debug_tokens", action="store_true", help="Print batch-seen ids vs. A/B id sets")

    args = ap.parse_args()
    if args.compare_mode is None:
        args.compare_mode = "ffn" if args.arch == "mt5" else "hidden"

    ensure_dir(args.out_dir)
    with open(os.path.join(args.out_dir, "meta.json"), "w") as f:
        json.dump({"args": vars(args)}, f, indent=2)

    # ---------------- tokenizer + model (arch switch) ----------------
    if args.arch == "mt5":
        tok = AutoTokenizer.from_pretrained(args.ckpt)
        from transformers import MT5ForConditionalGeneration as _T5Cls
        model = _T5Cls.from_pretrained(args.ckpt)



        if args.state_dict_path and os.path.exists(args.state_dict_path):
            sd_raw = torch.load(args.state_dict_path, map_location="cpu")
            if isinstance(sd_raw, dict) and "state_dict" in sd_raw:
                sd_raw = sd_raw["state_dict"]

            def strip_prefix(k: str) -> str:
                return re.sub(
                    r'^(?:module\.|model\.|base_model\.|transformer\.|t5\.|mt5\.)',
                    '',
                    k,
                    count=1
                )

            sd = {strip_prefix(k): v for k, v in sd_raw.items()}
            ms = model.state_dict()

            keys_sd = set(sd.keys())
            keys_ms = set(ms.keys())
            inter = keys_sd & keys_ms
            only_sd = list(keys_sd - keys_ms)
            only_ms = list(keys_ms - keys_sd)

            print(
                f"[ckpt] sd keys (after strip): {len(keys_sd)} | model keys: {len(keys_ms)} | intersect: {len(inter)} "
                f"({100.0 * len(inter) / max(1, len(keys_ms)):.1f}% of model)")
            if only_sd:
                print(" [ckpt] example in sd but not model (dropped):")
                for k in only_sd[:10]:
                    print(f"   - {k}")
            if only_ms:
                print(" [ckpt] example expected by model but missing in sd:")
                for k in only_ms[:10]:
                    print(f"   - {k}")

            sd_f = {k: sd[k] for k in inter}
            missing, unexpected = model.load_state_dict(sd_f, strict=False)
            print(f"[mt5] load_state_dict: loaded={len(sd_f)}  missing={len(missing)}  unexpected={len(unexpected)}")
        else:
            print("[mt5] no state_dict_path found; using base weights.")

        model.eval().to(args.device)

    else:  # RWKV
        tok = load_rwkv_tokenizer(args.rwkv_tokenizer or None)
        if tok.pad_token is None and tok.eos_token is not None:
            tok.pad_token = tok.eos_token

        model = load_rwkv_from_sd(args.state_dict_path, tok)
        model.to(args.device).eval()

        if args.compare_mode != "hidden":
            print("[rwkv] Forcing compare_mode=hidden.")
            args.compare_mode = "hidden"
        if args.capture_encoder:
            print("[rwkv] Encoder capture not applicable; disabling.")
        args.capture_encoder = False

    # ---------------- parse tokens ----------------
    single_ids = strings_to_ids_list(tok, args.token_baseline_multi)  # optional baseline "word"

    if args.arch == "rwkv":
        a_csv = _clean_words_csv(args.token_compare_a)
        b_csv = _clean_words_csv(args.token_compare_b)
        setA_ids_for_agg = neox_ids_for_words(tok, a_csv)
        setB_ids_for_agg = neox_ids_for_words(tok, b_csv)
    else:
        setA_ids_for_agg = strings_to_ids_list(tok, args.token_compare_a)
        setB_ids_for_agg = strings_to_ids_list(tok, args.token_compare_b)

    specials = {tok.pad_token_id, getattr(tok, "eos_token_id", None), getattr(tok, "unk_token_id", None)}
    specials = {i for i in specials if i is not None}

    # Need d_model, d_ff (FFN path only)
    if args.arch == "mt5":
        d_model = model.config.d_model
        probe_ffn = model.decoder.block[0].layer[2].DenseReluDense
        d_ff = (probe_ffn.wi_0.out_features if hasattr(probe_ffn, "wi_0") else probe_ffn.wi.out_features)
        nL_enc = len(model.encoder.block); nL_dec = len(model.decoder.block)
    else:
        d_model = model.config.hidden_size
        d_ff = d_model  # placeholder; FFN path disabled
        nL_enc = 0
        nL_dec = getattr(model, "config", None).num_hidden_layers or 0

    if args.use_forgetfult5:
        model = ForgetfulT5(model).to('cuda')
    else:
        model = model.to('cuda')

    baseline_exclude = set()
    if getattr(args, "baseline_exclude_targets", False):
        baseline_exclude |= set(single_ids) | set(setA_ids_for_agg) | set(setB_ids_for_agg)

    want_baseline = getattr(args, "baseline_build", False) or getattr(args, "baseline_use", False)
    aggregators = TokenAggregators(
        n_layers_enc=nL_enc, n_layers_dec=nL_dec,
        d_ff=d_ff, d_model=d_model,
        target_single_ids=[],  # not used in hidden mode
        set_a_ids=setA_ids_for_agg, set_b_ids=setB_ids_for_agg,
        special_ids=list(specials),
        want_baseline=want_baseline, baseline_exclude_ids=baseline_exclude
    )

    # ---------------- data ----------------
    if args.pairs_file:
        pairs = read_pairs(args.pairs_file)[:args.max_examples]
        print(f"[data] using TSV pairs: {len(pairs)}")
    elif args.text_file and os.path.exists(args.text_file):
        pairs = read_text_as_pairs(args.text_file, args.max_examples)
        print(f"[data] using text file: {len(pairs)} from {args.text_file}")
    else:
        pairs = read_pairs(None)[:args.max_examples]
        print(f"[data] using built-in sample: {len(pairs)}")

    # Optional injection to guarantee occurrences
    if args.inject_compare_sentences:
        A = [s.strip() for s in args.token_compare_a.split(",") if s.strip()]
        B = [s.strip() for s in args.token_compare_b.split(",") if s.strip()]
        probe_pairs = [(w, w) for w in (A + B)]
        pairs = probe_pairs + pairs
        print(f"[inject] added {len(probe_pairs)} compare probe pairs")

    # ---------------- capture setup ----------------
    if args.arch == "mt5" and args.compare_mode == "ffn":
        capturer = FFNCapture(model, tok, aggregators,
                              capture_encoder=args.capture_encoder,
                              capture_decoder=args.capture_decoder)
    else:
        capturer = None

    hidcap = HiddenStateCapture(model,
                                capture_encoder=args.capture_encoder,
                                capture_decoder=args.capture_decoder) if args.compare_mode == "hidden" else None

    # ---------------- main pass ----------------
    total_scanned = 0
    last_enc = None; last_labels = None; last_dec_in = None

    for i in range(0, len(pairs), args.batch_size):
        batch = pairs[i:i + args.batch_size]
        src_texts = [s for s, t in batch]
        tgt_texts = [t for s, t in batch]

        # Inputs (explicit max lengths)
        enc = tok(
            src_texts,
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=args.max_src_len
        ).to(args.device)

        # Targets (labels) with text_target
        tgt = tok(
            text_target=tgt_texts,
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=args.max_tgt_len
        ).to(args.device)

        labels = tgt["input_ids"].clone()
        pad_id = tok.pad_token_id if tok.pad_token_id is not None else -100
        labels[labels == pad_id] = -100

        decoder_input_ids = None
        if hasattr(model, "prepare_decoder_input_ids_from_labels"):
            decoder_input_ids = model.prepare_decoder_input_ids_from_labels(labels=labels)

        last_enc = enc; last_labels = labels; last_dec_in = decoder_input_ids

        if capturer is not None:
            capturer.set_batch_token_ids(
                enc_ids_bt=enc["input_ids"].detach().cpu() if args.capture_encoder else None,
                dec_ids_bt=decoder_input_ids.detach().cpu() if args.capture_decoder and decoder_input_ids is not None else None,
            )

        with torch.no_grad():
            _ = model(
                input_ids=enc["input_ids"],
                attention_mask=enc.get("attention_mask", None),
                decoder_input_ids=decoder_input_ids if hasattr(model, "decoder") else None,
                use_cache=False,
            )

        # Hidden-state path: feed aggregators
        if hidcap is not None:
            if args.arch == "rwkv":
                # RWKV: use the input_ids for the "decoder" side
                dec_ids_bt_cpu = enc["input_ids"].detach().cpu()
                if args.debug_tokens:
                    seen = set(dec_ids_bt_cpu.view(-1).tolist())
                    print("[debug] RWKV batch seen ids (sample):", sorted(list(seen))[:40])
                    print("[debug] A ids:", setA_ids_for_agg, "B ids:", setB_ids_for_agg)
                    print("[debug] A∩seen:", set(seen)&set(setA_ids_for_agg), "B∩seen:", set(seen)&set(setB_ids_for_agg))
                for L, H in hidcap.dec.items():
                    aggregators.update_hidden_side(
                        "dec", L, to_cpu_f32(H), dec_ids_bt_cpu,
                        set_a_ids=set(setA_ids_for_agg), set_b_ids=set(setB_ids_for_agg)
                    )
                hidcap.dec.clear()
            else:
                if args.capture_encoder and hidcap.enc:
                    enc_ids_bt_cpu = enc["input_ids"].detach().cpu()
                    for L, H in hidcap.enc.items():
                        aggregators.update_hidden_side("enc", L, to_cpu_f32(H), enc_ids_bt_cpu,
                                                       set_a_ids=set(setA_ids_for_agg), set_b_ids=set(setB_ids_for_agg))
                if args.capture_decoder and hidcap.dec and decoder_input_ids is not None:
                    dec_ids_bt_cpu = decoder_input_ids.detach().cpu()
                    for L, H in hidcap.dec.items():
                        aggregators.update_hidden_side("dec", L, to_cpu_f32(H), dec_ids_bt_cpu,
                                                       set_a_ids=set(setA_ids_for_agg), set_b_ids=set(setB_ids_for_agg))
                hidcap.enc.clear(); hidcap.dec.clear()

        total_scanned += sum(len(s) for s in src_texts)

        if getattr(args, "baseline_build", False) and aggregators.baseline_total_tokens >= getattr(args, "baseline_tokens_target", 0):
            print(f"[baseline] token budget reached: {aggregators.baseline_total_tokens}")
            break

    if hidcap is not None:
        hidcap.close()
    if capturer is not None:
        capturer.close()

    # ---------------- Baseline I/O (FFN path) ----------------
    baseline = None
    if getattr(args, "baseline_build", False) and args.arch == "mt5":
        assert args.baseline_file is not None, "Provide --baseline_file to save baseline"
        aggregators.save_baseline(args.baseline_file, nL_enc, nL_dec)
        print(f"[baseline] saved to {args.baseline_file}")

    if getattr(args, "baseline_use", False) and args.arch == "mt5":
        assert args.baseline_file and os.path.exists(args.baseline_file), "Provide existing --baseline_file to load"
        baseline = TokenAggregators.load_baseline(args.baseline_file)
        print(f"[baseline] loaded from {args.baseline_file}")

    # ---------------- Optional MT5 attention band maps ----------------
    if args.arch == "mt5" and last_enc is not None and last_labels is not None and last_dec_in is not None:
        probe_layers = [9, 10, 11]; probe_kinds = ["self", "cross"]
        names = [_attn_module_path("dec", L, kind) for L in probe_layers for kind in probe_kinds]
        captures = {}
        with torch.no_grad():
            with hook_attn_outputs(model, names, captures):
                _ = model(
                    input_ids=last_enc["input_ids"],
                    attention_mask=last_enc["attention_mask"],
                    decoder_input_ids=last_dec_in,
                    labels=last_labels,
                    use_cache=False,
                    output_attentions=False,
                    output_hidden_states=False,
                )
        single_ids_vis = strings_to_ids_list(tok, args.token_compare_a or "the")
        if len(single_ids_vis) > 0:
            target_tok = single_ids_vis[0]
            for L in probe_layers:
                for kind in probe_kinds:
                    key = _attn_module_path("dec", L, kind)
                    if key not in captures: continue
                    vec = compute_attn_band_for_token(captures[key], last_labels, target_tok)
                    export_attn_band_maps(
                        out_dir=args.out_dir, side="dec", layer_idx=L, kind=kind,
                        vec_raw=vec, tag="single_token_first",
                        enable_local_z=getattr(args,"enable_local_z",False),
                        local_z_window=getattr(args,"local_z_window",31),
                    )

    # ---------------- Export & Numbers ----------------
    print("Exporting maps / numbers ...")

    layer_diffs = []
    total_diff = 0.0

    if args.arch == "mt5" and args.compare_mode == "ffn":
        enc_dir = baseline["enc_dir_down_mean"] if (baseline and "enc_dir_down_mean" in baseline) else None
        enc_row = baseline["enc_rowmass_mean"] if (baseline and "enc_rowmass_mean" in baseline) else None
        dec_dir = baseline["dec_dir_down_mean"] if (baseline and "dec_dir_down_mean" in baseline) else None
        dec_row = baseline["dec_rowmass_mean"] if (baseline and "dec_rowmass_mean" in baseline) else None

        for li, block in enumerate(model.encoder.block if args.capture_encoder else []):
            ffn = block.layer[1].DenseReluDense
            out_dir = os.path.join(args.out_dir, f"enc_L{li:02d}")
            bd, br = (None, None)
            if len(single_ids)>0:
                bd, br = aggregators.word_baseline("enc", li, single_ids)
            elif enc_dir is not None and enc_row is not None:
                bd, br = enc_dir[li], enc_row[li]
            res = aggregators.export_views(out_dir, "enc", li, ffn.wo.weight, bd, br,
                                           enable_local_z=args.enable_local_z, local_z_window=args.local_z_window,
                                           pockets_axis=args.local_z_axis_for_pockets,
                                           token_single=",".join([args.token_compare_a, args.token_compare_b]),
                                           pockets_mode=args.pockets_mode, weight_bands_by_wo=args.weight_bands_by_wo,
                                           compA_ids=setA_ids_for_agg, compB_ids=setB_ids_for_agg)
            if res and res.get("layer_diff_sum") is not None:
                layer_diffs.append({"side":"enc","layer":li,"diff_sum":float(res["layer_diff_sum"]),
                                    "overlap":float(res["layer_overlap"]) if res["layer_overlap"] is not None else None})
                total_diff += float(res["layer_diff_sum"])

        for li, block in enumerate(model.decoder.block if args.capture_decoder else []):
            ffn = block.layer[2].DenseReluDense
            out_dir = os.path.join(args.out_dir, f"dec_L{li:02d}")
            bd, br = (None, None)
            if len(single_ids)>0:
                bd, br = aggregators.word_baseline("dec", li, single_ids)
            elif dec_dir is not None and dec_row is not None:
                bd, br = dec_dir[li], dec_row[li]
            res = aggregators.export_views(out_dir, "dec", li, ffn.wo.weight, bd, br,
                                           enable_local_z=args.enable_local_z, local_z_window=args.local_z_window,
                                           pockets_axis=args.local_z_axis_for_pockets,
                                           token_single=",".join([args.token_compare_a, args.token_compare_b]),
                                           pockets_mode=args.pockets_mode, weight_bands_by_wo=args.weight_bands_by_wo,
                                           compA_ids=setA_ids_for_agg, compB_ids=setB_ids_for_agg)
            if res and res.get("layer_diff_sum") is not None:
                layer_diffs.append({"side":"dec","layer":li,"diff_sum":float(res["layer_diff_sum"]),
                                    "overlap":float(res["layer_overlap"]) if res["layer_overlap"] is not None else None})
                total_diff += float(res["layer_diff_sum"])

    else:
        # Hidden-state numbers (MT5 & RWKV), plus quick diff maps
        for side, nL in (("enc", nL_enc), ("dec", nL_dec)):
            if (side == "enc" and not args.capture_encoder) or (side == "dec" and not args.capture_decoder):
                continue
            for L in range(nL):
                key = (side, L)
                if aggregators.hidA_cnt[key] > 0 and aggregators.hidB_cnt[key] > 0:
                    SA = aggregators.hidA_sum_S[key] / max(1, aggregators.hidA_cnt[key])
                    SB = aggregators.hidB_sum_S[key] / max(1, aggregators.hidB_cnt[key])
                    D = np.abs(SA - SB).astype(np.float32)
                    diff_sum = float(D.sum())
                    overlap = float(D.sum() / (np.abs(SA).sum() + np.abs(SB).sum() + 1e-8))
                    layer_diffs.append({"side":side,"layer":L,"diff_sum":diff_sum,"overlap":overlap})
                    total_diff += diff_sum
                    out_dir = os.path.join(args.out_dir, f"{side}_L{L:02d}", "compare_token_diff_hidden")
                    ensure_dir(out_dir)
                    np.save(os.path.join(out_dir, "diff_map.npy"), D)
                    render_heatmap(D, os.path.join(out_dir, "diff_map.png"),
                                   f"{side} L{L} |A-B| (hidden)", gamma=1.6)

    # Summary
    overlap_values = [x["overlap"] for x in layer_diffs if x.get("overlap") is not None]
    overall_overlap = float(np.mean(overlap_values)) if overlap_values else None

    if layer_diffs:
        report = {"layers": layer_diffs, "total_diff_sum": float(total_diff), "overall_overlap_mean": overall_overlap}
        with open(os.path.join(args.out_dir, "token_compare_summary.json"), "w") as f:
            json.dump(report, f, indent=2)
        print("[compare] layer-wise:")
        for d in layer_diffs:
            print(f"  {d['side']} L{d['layer']:02d}: diff_sum={d['diff_sum']:.3f}, overlap={d['overlap']:.3f}")
        print(f"[compare] TOTAL diff_sum: {total_diff:.3f}")
        if overall_overlap is not None:
            print(f"[compare] OVERALL overlap (mean of layers): {overall_overlap:.3f}")
    else:
        print("[compare] No A/B counts found; check tokens, capture side, or use --inject_compare_sentences.")

    print(f"Done. Outputs in {args.out_dir}")

if __name__ == "__main__":
    main()
