#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import argparse, json, os, re
from collections import defaultdict
from typing import List, Tuple, Optional, Dict

import numpy as np
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
from contextlib import contextmanager

from transformers import MT5ForConditionalGeneration, AutoTokenizer

# ---------------------------- Utils ---------------------------- #

def gelu(x): return F.gelu(x)
def silu(x): return F.silu(x)
def act_fn_for_proj(feed_forward_proj: str):
    proj = feed_forward_proj.lower()
    if "gelu" in proj: return gelu
    if "silu" in proj: return silu
    if "relu" in proj: return F.relu
    return gelu

def ensure_dir(path): os.makedirs(path, exist_ok=True)
def to_cpu_f32(x):    return x.detach().to(torch.float32).cpu()

def render_heatmap(mat: np.ndarray, out_png: str, title: str, vclip: Optional[float]=99.5):
    plt.figure(figsize=(10, 4))
    m = mat
    if vclip is not None:
        vmax = np.percentile(np.abs(m), vclip)
        m = np.clip(m, -vmax, vmax)
    plt.imshow(m, aspect='auto', interpolation='nearest')
    plt.title(title); plt.colorbar(); plt.tight_layout()
    plt.savefig(out_png, dpi=200); plt.close()

def render_band_row(vec: np.ndarray, out_png: str, title: str, repeat: int = 6, vclip: Optional[float]=99.5):
    v = vec
    if vclip is not None:
        vmax = np.percentile(np.abs(v), vclip)
        v = np.clip(v, -vmax, vmax)
    mat = np.tile(v[None, :], (repeat, 1))
    render_heatmap(mat, out_png, title, vclip=None)

def tokenize_single_piece(tok, s: str) -> List[int]:
    ids = tok.encode(s, add_special_tokens=False)
    return ids

def read_pairs(path: Optional[str]) -> List[Tuple[str,str]]:
    if path is None or not os.path.exists(path):
        samples = [
            ("translate English to German: A small test.", "Ein kleiner Test."),
            ("Call me at 12:30 tomorrow.", "Ruf mich morgen um 12:30 an."),
            ("He said, \"Hello!\"", "Er sagte: „Hallo!“"),
            ("The price is $19.99.", "Der Preis beträgt 19,99 $."),
            ("Version 2.0 was released in 2024.", "Version 2.0 wurde 2024 veröffentlicht."),
        ]
        return samples
    pairs = []
    with open(path, "r", encoding="utf-8") as f:
        for line in f:
            if "\t" in line:
                src, tgt = line.rstrip("\n").split("\t", 1)
                pairs.append((src, tgt))
    return pairs

def _attn_module_path(side: str, layer_idx: int, kind: str) -> str:
    """
    side: 'dec' or 'enc'
    kind: 'self' or 'cross' (cross only valid for decoder)
    """
    if side == "enc":
        assert kind == "self"
        return f"encoder.block.{layer_idx}.layer.0.SelfAttention"
    else:
        if kind == "self":
            return f"decoder.block.{layer_idx}.layer.0.SelfAttention"
        elif kind == "cross":
            return f"decoder.block.{layer_idx}.layer.1.EncDecAttention"
        else:
            raise ValueError("kind must be 'self' or 'cross'")

from contextlib import contextmanager
import torch

@contextmanager
def hook_attn_outputs(model, names, store_dict):
    """
    Hooks T5Attention modules; saves the *projected output* (hidden states)
    as CPU float32 tensors under store_dict[name].

    Handles cases where module returns:
      - Tensor
      - (Tensor, …)
      - dict-like with 'last_hidden_state'
    """
    handles = []
    named = dict(model.named_modules())

    def _to_cpu_tensor(x):
        # Accept Tensor or tuple/dict-wrappers and get the main hidden states
        if isinstance(x, torch.Tensor):
            return x.detach().float().cpu()
        if isinstance(x, (tuple, list)) and len(x) > 0:
            # first item is the hidden states in HF attention modules
            return _to_cpu_tensor(x[0])
        if isinstance(x, dict):
            # very defensive: try common keys
            for k in ("last_hidden_state", "hidden_states", "output"):
                if k in x:
                    return _to_cpu_tensor(x[k])
        raise TypeError(f"Unexpected attention hook output type: {type(x)}")

    try:
        for n in names:
            mod = named[n]

            def _mk_hook(nm):
                def fn(m, inp, out):
                    store_dict[nm] = _to_cpu_tensor(out)
                return fn

            handles.append(mod.register_forward_hook(_mk_hook(n)))
        yield
    finally:
        for h in handles:
            h.remove()


def compute_attn_band_for_token(attn_tensor: torch.Tensor,
                                labels: torch.Tensor,
                                target_id: int) -> np.ndarray:
    """
    attn_tensor: [B, T_dec, d_model] (whatever device; hooks saved CPU by default)
    labels:      [B, T_dec] (may be on CUDA)
    """
    # Align devices
    dev = attn_tensor.device
    lab = labels.detach().to(dev)

    B, T, D = attn_tensor.shape
    flat = attn_tensor.view(B * T, D)           # same device as attn_tensor
    lab_flat = lab.view(B * T)

    mask = (lab_flat == int(target_id))         # bool mask on same device
    if mask.any():
        v = flat[mask].mean(dim=0)              # [D]
    else:
        v = torch.zeros(D, dtype=attn_tensor.dtype, device=dev)

    return v.cpu().numpy()


def export_attn_band_maps(out_dir: str,
                          side: str,
                          layer_idx: int,
                          kind: str,
                          vec_raw: np.ndarray,
                          tag: str,
                          enable_local_z: bool = False,
                          local_z_window: int = 31):
    """
    Saves the raw band vector and a repeated-rows heatmap, mirroring your FFN 'down c' plots.
    If local-z helpers exist, also saves local-z and envelope variants.
    """
    ensure_dir(out_dir)
    sub = os.path.join(out_dir, f"{side}_L{layer_idx}_{kind}_{tag}")
    ensure_dir(sub)
    np.save(os.path.join(sub, "attn_out.npy"), vec_raw)

    # Raw band (repeat rows just for visualization)
    render_band_row(vec_raw, os.path.join(sub, "attn_out_band.png"),
                    title=f"{side} L{layer_idx} {kind} {tag}: band (attn out)")

    # Optional: local-z + envelope, if you registered those helpers
    try:
        if enable_local_z:
            vec_lz = local_zscore_1d(vec_raw, local_z_window, robust=True)
            render_band_row(vec_lz, os.path.join(sub, "attn_out_band_localz.png"),
                            title=f"{side} L{layer_idx} {kind} {tag}: band (attn out) local-z")
            # envelope via light low-pass on |Δ|
            from scipy.ndimage import gaussian_filter1d
            env = gaussian_filter1d(np.abs(vec_raw), sigma=max(3, local_z_window // 10))
            render_band_row(env, os.path.join(sub, "attn_out_envelope.png"),
                            title=f"{side} L{layer_idx} {kind} {tag}: envelope |attn out| (LP)")
    except NameError:
        pass  # helpers not present; skip

def _slugify_token_label(s: str) -> str:
    # keep it filesystem-friendly
    return (
               s.replace(" ", "_")
               .replace("/", "_")
               .replace("\\", "_")
               .replace("\t", "_")
               .replace("\n", "_")
               .replace("<", "")
               .replace(">", "")
               .replace(":", "")
               .replace("|", "")
               .replace("*", "")
               .replace("?", "")
               .replace('"', "")
           )[:64] or "token"

# ---------------------------- Aggregators ---------------------------- #

class TokenAggregators:
    """
    Maintains per-layer/side aggregates for:
      - First occurrence of target single token
      - Mean(target single)
      - Mean(generic)
      - Mean(SetA) and Mean(SetB)
      - Baseline (mean of per-token normalized down vector ĉ; mean row-mass p)
      - Target single (mean ĉ and mean p) for baseline contrasts

    Keys are (side, layer) where side in {"enc","dec"}.
    """

    def __init__(self, n_layers_enc: int, n_layers_dec: int, d_ff: int, d_model: int,
                 target_single_ids: List[int], set_a_ids: List[int], set_b_ids: List[int],
                 special_ids: List[int],
                 want_baseline: bool, baseline_exclude_ids: set):
        self.n_layers_enc = n_layers_enc
        self.n_layers_dec = n_layers_dec
        self.d_ff = d_ff; self.d_model = d_model

        self.target_single_ids = set(target_single_ids)
        self.set_a_ids = set(set_a_ids)
        self.set_b_ids = set(set_b_ids)
        self.special_ids = set(special_ids)

        self.want_baseline = want_baseline
        self.baseline_exclude_ids = set(baseline_exclude_ids)

        # Legacy sums (raw)
        self.first_single_eff = defaultdict(lambda: None)  # np[d_ff]
        self.first_single_down = defaultdict(lambda: None) # np[d_model]
        self.sum_single_eff = defaultdict(lambda: np.zeros(d_ff, dtype=np.float32))
        self.sum_single_down = defaultdict(lambda: np.zeros(d_model, dtype=np.float32))
        self.cnt_single = defaultdict(int)
        self.sum_generic_eff = defaultdict(lambda: np.zeros(d_ff, dtype=np.float32))
        self.sum_generic_down = defaultdict(lambda: np.zeros(d_model, dtype=np.float32))
        self.cnt_generic = defaultdict(int)
        self.sum_setA_eff = defaultdict(lambda: np.zeros(d_ff, dtype=np.float32))
        self.sum_setA_down = defaultdict(lambda: np.zeros(d_model, dtype=np.float32))
        self.cnt_setA = defaultdict(int)
        self.sum_setB_eff = defaultdict(lambda: np.zeros(d_ff, dtype=np.float32))
        self.sum_setB_down = defaultdict(lambda: np.zeros(d_model, dtype=np.float32))
        self.cnt_setB = defaultdict(int)

        # NEW: Baseline (direction & rowmass) and target-single (direction & rowmass)
        self.base_sum_dir_down = defaultdict(lambda: np.zeros(d_model, dtype=np.float32))
        self.base_cnt_dir_down = defaultdict(int)
        self.base_sum_rowmass = defaultdict(lambda: np.zeros(d_ff, dtype=np.float32))
        self.base_cnt_rowmass = defaultdict(int)
        self.baseline_total_tokens = 0

        self.single_sum_dir_down = defaultdict(lambda: np.zeros(d_model, dtype=np.float32))
        self.single_cnt_dir_down = defaultdict(int)
        self.single_sum_rowmass = defaultdict(lambda: np.zeros(d_ff, dtype=np.float32))
        self.single_cnt_rowmass = defaultdict(int)




    def _update_positions(self, side: str, layer: int,
                          A_btF: torch.Tensor, c_btD: torch.Tensor,
                          token_ids_bt: torch.Tensor,
                          count_for_baseline: bool):
        A = A_btF.numpy()      # [B,T,d_ff]
        C = c_btD.numpy()      # [B,T,d_model]
        TIDs = token_ids_bt.numpy()  # [B,T]
        B, T, _ = A.shape
        key = (side, layer)

        for b in range(B):
            for t in range(T):
                tid = int(TIDs[b, t])
                if tid in self.special_ids:
                    continue

                a = A[b, t]
                c = C[b, t]

                # Legacy pools (raw means)
                if tid in self.target_single_ids:
                    if self.first_single_eff[key] is None:  self.first_single_eff[key]  = a.copy()
                    if self.first_single_down[key] is None: self.first_single_down[key] = c.copy()
                    self.sum_single_eff[key]  += a; self.sum_single_down[key] += c; self.cnt_single[key] += 1
                elif tid in self.set_a_ids:
                    self.sum_setA_eff[key]  += a; self.sum_setA_down[key] += c; self.cnt_setA[key] += 1
                elif tid in self.set_b_ids:
                    self.sum_setB_eff[key]  += a; self.sum_setB_down[key] += c; self.cnt_setB[key] += 1
                else:
                    self.sum_generic_eff[key]  += a; self.sum_generic_down[key] += c; self.cnt_generic[key] += 1

                # NEW: per-token normalized views
                # Down (direction)
                cnorm = np.linalg.norm(c) + 1e-8
                c_hat = c / cnorm
                # Up (row-mass)
                denom = np.abs(a).sum() + 1e-8
                p = np.abs(a) / denom

                if tid in self.target_single_ids:
                    self.single_sum_dir_down[key] += c_hat; self.single_cnt_dir_down[key] += 1
                    self.single_sum_rowmass[key] += p;      self.single_cnt_rowmass[key] += 1

                if count_for_baseline and (tid not in self.baseline_exclude_ids):
                    self.base_sum_dir_down[key] += c_hat; self.base_cnt_dir_down[key] += 1
                    self.base_sum_rowmass[key] += p;      self.base_cnt_rowmass[key] += 1
                    self.baseline_total_tokens += 1

    # ---------- Exporters ---------- #

    def export_views(self, out_dir: str, side: str, layer: int, Wo_dF: torch.Tensor,
                     baseline_dir_down: Optional[np.ndarray]=None,
                     baseline_rowmass: Optional[np.ndarray]=None,
                     enable_local_z: bool=True,
                     local_z_window: int=35,
                     pockets_axis: str="columns", token_single=""):
        ensure_dir(out_dir)
        key = (side, layer)
        Wo = Wo_dF.detach().cpu().numpy()  # [d_model, d_ff]
        WoT = Wo.T

        def save_pack(tag: str, A_vec: Optional[np.ndarray], c_vec: Optional[np.ndarray]):
            if A_vec is None or c_vec is None: return
            pack_dir = os.path.join(out_dir, tag);
            ensure_dir(pack_dir)

            # Save originals
            np.save(os.path.join(pack_dir, "eff_A.npy"), A_vec)
            np.save(os.path.join(pack_dir, "down_c.npy"), c_vec)
            #S = WoT * A_vec[:, None]  # [d_ff, d_model]
            WoT_np = Wo_dF.detach().float().cpu().numpy().T
            S = compute_pockets(A_vec, c_vec, WoT_np)

            np.save(os.path.join(pack_dir, "pockets_S.npy"), S)

            # Render originals
            render_heatmap(S, os.path.join(pack_dir, "pockets_S.png"),
                           f"{side} L{layer} {tag}: pockets S (d_ff x d_model)")
            render_band_row(c_vec, os.path.join(pack_dir, "down_c.png"),
                            f"{side} L{layer} {tag}: band (down c)")
            render_band_row(A_vec, os.path.join(pack_dir, "eff_A.png"),
                            f"{side} L{layer} {tag}: eff A")


            # --- NEW: local z-scored variants (with token suffix & pockets as outer product) ---
            if enable_local_z:
                tok_slug = _slugify_token_label(token_single)

                # 1-D local z for bands (down c) and rows (eff A)
                c_lz = local_zscore_1d(c_vec, local_z_window)  # [d_model]
                A_lz = local_zscore_1d(A_vec, local_z_window)  # [d_ff]

                np.save(os.path.join(pack_dir, f"down_c_localz__{tok_slug}.npy"), c_lz)
                np.save(os.path.join(pack_dir, f"eff_A_localz__{tok_slug}.npy"), A_lz)

                render_band_row(
                    c_lz,
                    os.path.join(pack_dir, f"down_c_localz__{tok_slug}.png"),
                    f"{side} L{layer} {tag}: band (down c) local-z [{tok_slug}]",
                    repeat=6
                )
                render_band_row(
                    A_lz,
                    os.path.join(pack_dir, f"eff_A_localz__{tok_slug}.png"),
                    f"{side} L{layer} {tag}: eff A local-z [{tok_slug}]",
                    repeat=6
                )

                # 2-D pockets from elementwise outer product of 1-D local-z maps
                # shape: [d_ff, d_model] = [len(A_lz), len(c_lz)]
                S_lz = (A_lz[:, None] * c_lz[None, :]).astype(np.float32)

                # If you prefer magnitude-only visualization, uncomment:
                # S_lz = np.abs(S_lz)

                np.save(os.path.join(pack_dir, f"pockets_S_localz__{tok_slug}.npy"), S_lz)
                render_heatmap(
                    S_lz,
                    os.path.join(pack_dir, f"pockets_S_localz__{tok_slug}.png"),
                    f"{side} L{layer} {tag}: pockets S (A_lz ⊗ c_lz) [{tok_slug}]"
                )

        # Legacy packs
        if self.first_single_eff[key] is not None and self.first_single_down[key] is not None:
            save_pack("single_token_first", self.first_single_eff[key], self.first_single_down[key])
        if self.cnt_single[key] > 0:
            A = self.sum_single_eff[key] / max(1, self.cnt_single[key])
            c = self.sum_single_down[key] / max(1, self.cnt_single[key])
            save_pack("mean_single_token", A, c)
        if self.cnt_generic[key] > 0:
            A_g = self.sum_generic_eff[key]  / max(1, self.cnt_generic[key])
            c_g = self.sum_generic_down[key] / max(1, self.cnt_generic[key])
            if self.cnt_single[key] > 0:
                A_s = self.sum_single_eff[key]  / max(1, self.cnt_single[key])
                c_s = self.sum_single_down[key] / max(1, self.cnt_single[key])
                save_pack("mean_single_minus_generic", A_s - A_g, c_s - c_g)
            if self.cnt_setA[key] > 0 and self.cnt_setB[key] > 0:
                A_a = self.sum_setA_eff[key]  / max(1, self.cnt_setA[key])
                c_a = self.sum_setA_down[key] / max(1, self.cnt_setA[key])
                A_b = self.sum_setB_eff[key]  / max(1, self.cnt_setB[key])
                c_b = self.sum_setB_down[key] / max(1, self.cnt_setB[key])
                save_pack("mean_setA_minus_setB", A_a - A_b, c_a - c_b)



        # NEW: Baseline contrasts (direction + row-mass)
        if baseline_dir_down is not None and baseline_rowmass is not None and self.single_cnt_dir_down[key] > 0:
            # Target means (direction/rowmass)
            mu_dir = self.single_sum_dir_down[key] / max(1, self.single_cnt_dir_down[key])       # [d_model]
            mu_p   = self.single_sum_rowmass[key] / max(1, self.single_cnt_rowmass[key])         # [d_ff]
            # Contrasts vs baseline
            delta_dir = mu_dir - baseline_dir_down                                               # [d_model]
            delta_p   = mu_p   - baseline_rowmass                                                # [d_ff]
            pack_dir = os.path.join(out_dir, "single_vs_baseline_dir"); ensure_dir(pack_dir)
            np.save(os.path.join(pack_dir, "delta_dir_down.npy"), delta_dir)
            np.save(os.path.join(pack_dir, "delta_rowmass_eff.npy"), delta_p)
            # Directional pockets: rows weighted by delta_p (cleaner band visualization)
            S_dir = compute_pockets(delta_p, delta_dir, WoT_np)
            np.save(os.path.join(pack_dir, "pockets_S_dir.npy"), S_dir)
            render_heatmap(S_dir, os.path.join(pack_dir, "pockets_S_dir.png"),
                           f"{side} L{layer} single vs baseline: directional pockets")
            render_band_row(delta_dir, os.path.join(pack_dir, "delta_dir_down.png"),
                            f"{side} L{layer} single vs baseline: band direction Δĉ", repeat=6)

            if enable_local_z:
                # 1-D local z for down (c) and eff (A)
                c_lz = local_zscore_1d(c_vec, local_z_window)  # [d_model]
                A_lz = local_zscore_1d(A_vec, local_z_window)  # [d_ff]
                np.save(os.path.join(pack_dir, "down_c_localz.npy"), c_lz)
                np.save(os.path.join(pack_dir, "eff_A_localz.npy"), A_lz)

                render_band_row(c_lz, os.path.join(pack_dir, "down_c_localz.png"),
                                f"{side} L{layer} {tag}: band (down c) local-z", repeat=6)
                render_band_row(A_lz, os.path.join(pack_dir, "eff_A_localz.png"),
                                f"{side} L{layer} {tag}: eff A local-z", repeat=6)

                # --- NEW: pockets from elementwise product of 1-D z maps (outer product) ---
                # shape: [d_ff, d_model] = [len(A_lz), len(c_lz)]
                S_lz = (A_lz[:, None] * c_lz[None, :]).astype(np.float32)

                # If you prefer a magnitude-only view, uncomment the next line:
                # S_lz = np.abs(S_lz)

                np.save(os.path.join(pack_dir, "pockets_S_localz.npy"), S_lz)
                render_heatmap(S_lz, os.path.join(pack_dir, "pockets_S_localz.png"),
                               f"{side} L{layer} {tag}: pockets S (A_lz ⊗ c_lz)")

    # ---------- Baseline I/O ---------- #

    def baseline_means(self, side: str, n_layers: int) -> Tuple[np.ndarray, np.ndarray]:
        """
        Returns (mean_dir_down[L,d_model], mean_rowmass[L,d_ff]) for the given side.
        Missing layers default to zeros (if no tokens seen).
        """
        mean_dir = []
        mean_row = []
        for L in range(n_layers):
            k = (side, L)
            if self.base_cnt_dir_down[k] > 0:
                mean_dir.append(self.base_sum_dir_down[k] / self.base_cnt_dir_down[k])
            else:
                mean_dir.append(np.zeros(self.d_model, dtype=np.float32))
            if self.base_cnt_rowmass[k] > 0:
                mean_row.append(self.base_sum_rowmass[k] / self.base_cnt_rowmass[k])
            else:
                mean_row.append(np.zeros(self.d_ff, dtype=np.float32))
        return np.stack(mean_dir, 0), np.stack(mean_row, 0)

    def save_baseline(self, path: str, nL_enc: int, nL_dec: int):
        enc_dir, enc_row = self.baseline_means("enc", nL_enc)
        dec_dir, dec_row = self.baseline_means("dec", nL_dec)
        np.savez(path,
                 enc_dir_down_mean=enc_dir, enc_rowmass_mean=enc_row,
                 dec_dir_down_mean=dec_dir, dec_rowmass_mean=dec_row)

    @staticmethod
    def load_baseline(path: str) -> Dict[str, np.ndarray]:
        arrs = np.load(path)
        return {k: arrs[k] for k in arrs.files}


# ---------------------------- Hooking ---------------------------- #

class FFNCapture:
    """
    Hooks all FFNs (encoder layer[1], decoder layer[2]) and streams:
      A = act(wi_0(x))*wi_1(x)   (or act(wi(x)))
      c = A @ W_o^T
    """

    def __init__(self, model, tokenizer, aggregators: TokenAggregators,
                 capture_encoder: bool, capture_decoder: bool):
        self.model = model; self.tok = tokenizer; self.agg = aggregators
        self.capture_encoder = capture_encoder; self.capture_decoder = capture_decoder
        self.act = act_fn_for_proj(model.config.feed_forward_proj)
        self.handles = []
        self._enc_ids_bt = None; self._dec_ids_bt = None

        if capture_encoder:
            for li, block in enumerate(model.encoder.block):
                self._register_ffn_hook(block.layer[1].DenseReluDense, "enc", li)
        if capture_decoder:
            for li, block in enumerate(model.decoder.block):
                self._register_ffn_hook(block.layer[2].DenseReluDense, "dec", li)

    def _register_ffn_hook(self, ffn, side: str, layer: int):
        has_gated = hasattr(ffn, "wi_0") and hasattr(ffn, "wi_1")
        def hook_fn(module, inputs, output):
            x = inputs[0]
            with torch.no_grad():
                if has_gated:
                    u = ffn.wi_0(x); v = ffn.wi_1(x); A = self.act(u) * v
                else:
                    u = ffn.wi(x); A = self.act(u)
                Wo = ffn.wo.weight
                c = torch.einsum("btd,md->btm", A, Wo)  # [B,T,d_model]
                ids = self._enc_ids_bt if side=="enc" else self._dec_ids_bt
                if ids is None: return
                self.agg._update_positions(side, layer, to_cpu_f32(A), to_cpu_f32(c), ids,
                                           count_for_baseline=self.agg.want_baseline)
        self.handles.append(ffn.register_forward_hook(hook_fn))

    def set_batch_token_ids(self, enc_ids_bt: Optional[torch.Tensor], dec_ids_bt: Optional[torch.Tensor]):
        self._enc_ids_bt = enc_ids_bt; self._dec_ids_bt = dec_ids_bt
    def close(self):
        for h in self.handles: h.remove()
        self.handles = []

def _moving_stats_1d(x: np.ndarray, window: int) -> tuple[np.ndarray, np.ndarray]:
    """Reflect-padded moving mean/std for 1-D arrays."""
    w = int(max(1, window))
    pad = w // 2
    k = np.ones(w, dtype=np.float32)
    xp = np.pad(x.astype(np.float32), (pad, pad), mode="reflect")
    mx = np.convolve(xp, k, mode="valid")
    mx2 = np.convolve(xp * xp, k, mode="valid")
    mean = mx / w
    var = np.maximum(mx2 / w - mean * mean, 0.0)
    std = np.sqrt(var + 1e-8)
    return mean, std

def local_zscore_1d(x: np.ndarray, window: int, robust: bool=True) -> np.ndarray:
    """
    Now returns the uniform moving average over `window` (no z-scoring).
    `robust` is ignored to keep the signature stable.
    """
    w = int(max(3, window))
    if w % 2 == 0:
        w += 1  # enforce odd
    pad = w // 2

    x32 = x.astype(np.float32, copy=False)
    # reflect padding so edges behave nicely
    xp = np.pad(x32, (pad, pad), mode="reflect")

    # box filter (uniform average)
    k = np.ones(w, dtype=np.float32) / w
    local_mean = np.convolve(xp, k, mode="valid")  # length == len(x)

    return local_mean


def _rolling_median_mad_2d(X: np.ndarray, w: int, axis: int, pad_mode: str = "reflect"):
    """
    Robust rolling stats along one axis of a 2D array.
    axis=1 => columns; axis=0 => rows.
    Returns (median, scale) where scale ≈ robust std via MAD.
    """
    assert X.ndim == 2
    w = int(w)
    if w < 3:
        w = 3
    if w % 2 == 0:
        w += 1  # enforce odd
    pad = w // 2

    if axis == 1:  # columns
        xp = np.pad(X.astype(np.float32), ((0, 0), (pad, pad)), mode=pad_mode)
        # windows shape: [rows, cols, w]
        shape = (X.shape[0], X.shape[1], w)
        strides = (xp.strides[0], xp.strides[1], xp.strides[1])
        windows = np.lib.stride_tricks.as_strided(xp, shape=shape, strides=strides, writeable=False)
    elif axis == 0:  # rows
        xp = np.pad(X.astype(np.float32), ((pad, pad), (0, 0)), mode=pad_mode)
        shape = (X.shape[0], w, X.shape[1])       # [rows, w, cols]
        strides = (xp.strides[0], xp.strides[0], xp.strides[1])
        windows = np.lib.stride_tricks.as_strided(xp, shape=shape, strides=strides, writeable=False)
        windows = np.swapaxes(windows, 1, 2)      # -> [rows, cols, w]
    else:
        raise ValueError("axis must be 0 (rows) or 1 (columns)")

    med = np.median(windows, axis=2)
    mad = np.median(np.abs(windows - med[..., None]), axis=2)
    scale = np.maximum(1e-6, 1.4826 * mad)  # MAD -> robust std
    return med, scale

def local_zscore_2d(mat: np.ndarray, window: int, axis: str = "columns", robust: bool = True) -> np.ndarray:
    """
    Local z-score along a chosen axis for 2-D arrays.
    axis ∈ {"columns","rows","both"}.
    Robust = median/MAD; if False, falls back to mean/std (use only if you want spikes to dominate).
    """
    X = mat.astype(np.float32, copy=False)

    def _z_along_axis(X, ax):
        if robust:
            med, scale = _rolling_median_mad_2d(X, window, axis=ax)
            return (X - med) / scale
        else:
            # mean/std version
            w = window if window % 2 == 1 else window + 1
            pad = w // 2
            if ax == 1:
                xp = np.pad(X, ((0, 0), (pad, pad)), mode="reflect")
                k = np.ones(w, dtype=np.float32)
                m = np.apply_along_axis(lambda r: np.convolve(r, k, mode="valid"), 1, xp) / w
                m2 = np.apply_along_axis(lambda r: np.convolve(r**2, k, mode="valid"), 1, xp) / w
                s = np.sqrt(np.maximum(m2 - m*m, 0.0) + 1e-8)
                return (X - m) / s
            elif ax == 0:
                return _z_along_axis(X.T, 1).T
            else:
                raise ValueError

    if axis == "columns":
        return _z_along_axis(X, ax=1)
    elif axis == "rows":
        return _z_along_axis(X, ax=0)
    elif axis == "both":
        Zc = _z_along_axis(X, ax=1)
        return _z_along_axis(Zc, ax=0)
    else:
        raise ValueError(f"axis must be 'columns','rows','both', got {axis}")

def read_text_as_pairs(path: str, max_examples: int):
    pairs = []
    with open(path, "r", encoding="utf-8") as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
            pairs.append((line, line))  # src = tgt = the same sentence
            if len(pairs) >= max_examples:
                break
    return pairs

def compute_pockets(A_vec: np.ndarray,
                    c_vec: np.ndarray,
                    WoT: np.ndarray | None,
                    mode: str = "wo_times_eff") -> np.ndarray:
    """
    Return pockets map of shape [d_ff, d_model].

    - outer_eff_down: S = A[:,None] * c[None,:]   (eff ⊗ down)
    - wo_times_eff:  S = Wo.T * A[:,None]         (original: rows of Wo.T scaled by eff)
    """
    if mode == "outer_eff_down":
        # A: [d_ff], c: [d_model]  ->  S: [d_ff, d_model]
        return (A_vec[:, None] * c_vec[None, :]).astype(np.float32)
    elif mode == "wo_times_eff":
        if WoT is None:
            raise ValueError("compute_pockets(mode='wo_times_eff') requires WoT (W_o.T)")
        return (WoT * A_vec[:, None]).astype(np.float32)
    else:
        raise ValueError(f"Unknown pockets mode: {mode}")


# ---------------------------- Main ---------------------------- #

def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--ckpt", type=str, default="google/mt5-base",
                    help="Base HF checkpoint to load before applying --state_dict_path")
    ap.add_argument("--state_dict_path", type=str,
                    #default="mt5_base_pretuned.pt",
                    #default="mt5_base_forgive_and_forget_whole_stream6.pt",
                    default="mt5_base_standard_FaF_11x_noise.pt",
                    #default="mt5_base_twolet_FaF_12x_noise0_batch_5000_o.pt",
                    help="Path to .pt/.bin state_dict to load on top of --ckpt")


    # Token sets (Modify for different maps
    ap.add_argument("--token_single", type=str, default="the", help="single token string (ideally 1 id)")
    ap.add_argument("--out_dir", type=str, default="Interpretation_Records/map_comparison/the_full")


    # z score for visibility
    ap.add_argument("--enable_local_z", default=True, action="store_true",
                    help="Also save local z-scored variants of maps (bands, rows, pockets)")
    ap.add_argument("--local_z_window", type=int, default=19,
                    help="Odd window size for local z-score (columns/rows); try 7–13")
    ap.add_argument("--local_z_axis_for_pockets", type=str, default="columns",
                    choices=["columns", "rows", "both"],
                    help="Axis along which to compute local z for pockets S")

    # Data inputs (choose one)
    ap.add_argument("--pairs_file", type=str, default=None,
                    help="TSV: src<TAB>tgt per line (overrides --text_file if given)")
    ap.add_argument("--text_file", type=str, default="Interpretation_Records/interpret_sub_texts.txt",
                    help="Plaintext: one example per line; span corruption will generate (src,tgt) pairs")

    ap.add_argument("--max_examples", type=int, default=10000)
    ap.add_argument("--batch_size", type=int, default=16)
    ap.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu")




    ap.add_argument("--set_a", type=str, default="0,1,2,3,4,5,6,7,8,9")
    ap.add_argument("--set_b", type=str, default="0,1,2,3,4,5,6,7,8,9")
    ap.add_argument("--capture_encoder", action="store_true")
    ap.add_argument("--capture_decoder", default=True, action="store_true")

    # Baseline modes
    ap.add_argument("--baseline_build", action="store_true", help="stream and save baseline means (ĉ, p)")
    ap.add_argument("--baseline_use", action="store_true", help="load baseline and emit single-vs-baseline outputs")
    ap.add_argument("--baseline_file", type=str, default=None, help="path to .npz for saving or loading baseline")
    ap.add_argument("--baseline_tokens_target", type=int, default=100000, help="approx tokens to include in baseline")
    ap.add_argument("--baseline_exclude_targets", action="store_true",
                    help="exclude target sets (single, setA, setB) from baseline")
    # Span corruption knobs (optional; defaults match T5-ish)
    ap.add_argument("--noise_density", type=float, default=0.15, help="Fraction of tokens to corrupt")
    ap.add_argument("--mean_span_length", type=float, default=3.0, help="Poisson mean for span lengths")

    ap.add_argument(
        "--pockets_mode",
        type=str,
        default="outer_eff_down",  # was "wo_times_eff" before
        choices=["outer_eff_down", "wo_times_eff"],
        help="How to form pockets: outer product of eff×down, or Wo.T scaled by eff"
    )

    args = ap.parse_args()
    ensure_dir(args.out_dir)
    with open(os.path.join(args.out_dir, "meta.json"), "w") as f:
        json.dump({"args": vars(args)}, f, indent=2)

    tok = AutoTokenizer.from_pretrained(args.ckpt)
    model = MT5ForConditionalGeneration.from_pretrained(args.ckpt)
    # Load your fine-tuned weights on top (strict=False is safer across versions)
    if args.state_dict_path is not None and os.path.exists(args.state_dict_path):
        sd = torch.load(args.state_dict_path, map_location="cpu")
        if isinstance(sd, dict) and "state_dict" in sd:
            sd = sd["state_dict"]

        # 1) Strip the extra prefix (base_model.) and common wrappers
        def strip_prefix(k: str) -> str:
            return re.sub(r'^(?:model\.|module\.|base_model\.)', '', k, count=1)

        sd = {strip_prefix(k): v for k, v in sd.items()}

        # 2) (Optional) resize embeddings if tokenizer size differs
        # model.resize_token_embeddings(len(tok))

        # 3) Keep only keys the HF model actually has (drop custom extras)
        model_keys = set(model.state_dict().keys())
        sd = {k: v for k, v in sd.items() if k in model_keys}

        # 4) *** LOAD INTO MODEL ***
        missing, unexpected = model.load_state_dict(sd, strict=False)
        print(f"[state_dict] loaded: kept={len(sd)}  missing={len(missing)}  unexpected={len(unexpected)}")

    else:
        print("[state_dict] no state_dict_path provided or file not found; using base weights only.")

    model.eval().to(args.device)

    d_model = model.config.d_model
    probe_ffn = model.decoder.block[0].layer[2].DenseReluDense
    d_ff = (probe_ffn.wi_0.out_features if hasattr(probe_ffn, "wi_0") else probe_ffn.wi.out_features)
    nL_enc = len(model.encoder.block); nL_dec = len(model.decoder.block)

    def strings_to_ids_list(s: str) -> List[int]:
        ids = []
        for it in [x.strip() for x in s.split(",") if x.strip() != ""]:
            id_list = tokenize_single_piece(tok, it)
            if len(id_list) == 0:
                print(f"[warn] '{it}' tokenized to nothing; skipping.")
                continue
            if len(id_list) > 1:
                print(f"[warn] '{it}' tokenized to {id_list}; using second id {id_list[1]}.")
                ids.append(id_list[1])
            else:
                ids.append(id_list[0])
        return ids

    single_ids = strings_to_ids_list(args.token_single)
    setA_ids  = strings_to_ids_list(args.set_a)
    setB_ids  = strings_to_ids_list(args.set_b)
    specials = {tok.pad_token_id, tok.eos_token_id, tok.unk_token_id}; specials = {i for i in specials if i is not None}

    baseline_exclude = set()
    if args.baseline_exclude_targets:
        baseline_exclude |= set(single_ids) | set(setA_ids) | set(setB_ids)

    want_baseline = args.baseline_build or args.baseline_use
    aggregators = TokenAggregators(
        n_layers_enc=nL_enc, n_layers_dec=nL_dec,
        d_ff=d_ff, d_model=d_model,
        target_single_ids=single_ids, set_a_ids=setA_ids, set_b_ids=setB_ids,
        special_ids=list(specials),
        want_baseline=want_baseline, baseline_exclude_ids=baseline_exclude
    )
    capturer = FFNCapture(model, tok, aggregators,
                          capture_encoder=args.capture_encoder,
                          capture_decoder=args.capture_decoder)

    # Decide where to get (src, tgt) pairs
    if args.pairs_file:
        pairs = read_pairs(args.pairs_file)[:args.max_examples]
        print(f"[data] using TSV pairs: {len(pairs)}")
    elif args.text_file and os.path.exists(args.text_file):
        pairs = read_text_as_pairs(args.text_file, args.max_examples)
        print(f"[data] using text file (src=tgt): {len(pairs)} from {args.text_file}")
    else:
        pairs = read_pairs(None)[:args.max_examples]
        print(f"[data] using built-in sample: {len(pairs)}")

    print(f"Scanning {len(pairs)} pairs ...")

    # -------------------------
    # Pass through data (build FFN aggregates / optional baseline)
    # -------------------------
    last_enc = None
    last_labels = None
    last_dec_in = None

    for i in range(0, len(pairs), args.batch_size):
        batch = pairs[i:i + args.batch_size]
        src_texts = [s for s, t in batch]
        tgt_texts = [t for s, t in batch]

        enc = tok(src_texts, return_tensors="pt", padding=True, truncation=True).to(args.device)
        with tok.as_target_tokenizer():
            tgt = tok(tgt_texts, return_tensors="pt", padding=True, truncation=True).to(args.device)

        labels = tgt["input_ids"].clone()
        labels[labels == tok.pad_token_id] = -100  # standard HF loss ignore index
        decoder_input_ids = model.prepare_decoder_input_ids_from_labels(labels=labels)

        # Stash the last batch so we can reuse it for attention visualization
        last_enc = enc
        last_labels = labels
        last_dec_in = decoder_input_ids

        capturer.set_batch_token_ids(
            enc_ids_bt=enc["input_ids"].detach().cpu() if args.capture_encoder else None,
            dec_ids_bt=decoder_input_ids.detach().cpu() if args.capture_decoder else None,
        )

        with torch.no_grad():
            _ = model(
                input_ids=enc["input_ids"],
                attention_mask=enc["attention_mask"],
                decoder_input_ids=decoder_input_ids,
                use_cache=False,
            )

        # Early stop for baseline build once target token budget reached (approximate)
        if args.baseline_build and aggregators.baseline_total_tokens >= args.baseline_tokens_target:
            print(f"[baseline] token budget reached: {aggregators.baseline_total_tokens}")
            break

    # -------------------------
    # Save or load baseline for FFN views
    # -------------------------
    baseline = None
    if args.baseline_build:
        assert args.baseline_file is not None, "Provide --baseline_file to save baseline .npz"
        aggregators.save_baseline(args.baseline_file, nL_enc, nL_dec)
        print(f"[baseline] saved to {args.baseline_file}")

    if args.baseline_use:
        assert args.baseline_file is not None and os.path.exists(args.baseline_file), \
            "Provide existing --baseline_file to load"
        baseline = TokenAggregators.load_baseline(args.baseline_file)
        print(f"[baseline] loaded from {args.baseline_file}")

        # -------------------------
        # Attention-output band maps (decoder self/cross attn)
        # -------------------------
        # Choose layers/kinds to probe (late decoder is most interesting)
    probe_layers = [9, 10, 11]
    probe_kinds = ["self", "cross"]  # set to ["self"] if you only want self-attn

    names = [_attn_module_path("dec", L, kind) for L in probe_layers for kind in probe_kinds]

    if last_enc is None or last_labels is None or last_dec_in is None:
        print("[warn] No batches processed; skipping attention visualization.")
    else:
        captures = {}
        with torch.no_grad():
            with hook_attn_outputs(model, names, captures):
                _ = model(
                    input_ids=last_enc["input_ids"],
                    attention_mask=last_enc["attention_mask"],
                    decoder_input_ids=last_dec_in,
                    labels=last_labels,  # aligns positions with decoder time steps
                    use_cache=False,
                    output_attentions=False,
                    output_hidden_states=False,
                )

        # Resolve target token id robustly (reuse parsing used elsewhere)
        single_ids = strings_to_ids_list(args.token_single)
        assert len(single_ids) >= 1, "token_single yielded no ids"
        target_tok = single_ids[0]

        for L in probe_layers:
            for kind in probe_kinds:
                key = _attn_module_path("dec", L, kind)
                if key not in captures:
                    continue
                vec = compute_attn_band_for_token(captures[key], last_labels, target_tok)
                export_attn_band_maps(
                    out_dir=args.out_dir,
                    side="dec",
                    layer_idx=L,
                    kind=kind,
                    vec_raw=vec,
                    tag="single_token_first",  # mirror FFN naming
                    enable_local_z=getattr(args, "enable_local_z", False),
                    local_z_window=getattr(args, "local_z_window", 31),
                )

    # -------------------------
    # Export FFN per-layer packs
    # -------------------------
    print("Exporting FFN maps and figures ...")
    if args.capture_encoder:
        enc_dir = baseline["enc_dir_down_mean"] if (baseline and "enc_dir_down_mean" in baseline) else None
        enc_row = baseline["enc_rowmass_mean"] if (baseline and "enc_rowmass_mean" in baseline) else None
        for li, block in enumerate(model.encoder.block):
            ffn = block.layer[1].DenseReluDense
            out_dir = os.path.join(args.out_dir, f"enc_L{li:02d}")
            bd = enc_dir[li] if enc_dir is not None else None
            br = enc_row[li] if enc_row is not None else None
            aggregators.export_views(out_dir, "enc", li, ffn.wo.weight, bd, br)

    if args.capture_decoder:
        dec_dir = baseline["dec_dir_down_mean"] if (baseline and "dec_dir_down_mean" in baseline) else None
        dec_row = baseline["dec_rowmass_mean"] if (baseline and "dec_rowmass_mean" in baseline) else None
        for li, block in enumerate(model.decoder.block):
            ffn = block.layer[2].DenseReluDense
            out_dir = os.path.join(args.out_dir, f"dec_L{li:02d}")
            bd = dec_dir[li] if dec_dir is not None else None
            br = dec_row[li] if dec_row is not None else None
            aggregators.export_views(out_dir, "dec", li, ffn.wo.weight, bd, br, token_single=args.token_single)



    # (No attention baseline subtraction here; add later if you build attn_baseline)
    capturer.close()
    print(f"Done. Outputs in {args.out_dir}")


if __name__ == "__main__":
    main()
