import os
from transformers.activations import ACT2FN

os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

import numpy as np
import random
import torch.nn as nn
import torch.optim as optim
from transformers import MT5ForConditionalGeneration, AutoTokenizer, MT5Tokenizer, AutoModelForSeq2SeqLM, AutoModel
from torch.optim import AdamW
import sacrebleu
from torch.utils.data import DataLoader
from datasets import load_dataset, load_from_disk
from functools import partial
from transformers import DataCollatorWithPadding
from torch.optim.lr_scheduler import StepLR
from torch.cuda.amp import autocast, GradScaler
from tqdm import tqdm
from torch.nn import functional as F
from functools import lru_cache
from dataclasses import dataclass, field
from itertools import groupby  # ← add this
import matplotlib.pyplot as plt
from datasets import load_dataset, interleave_datasets
from torch.optim.lr_scheduler import LambdaLR

import torch, gc, random
import math
import re
import torch.fft as fft
from itertools import groupby

# ── 1 · global hyper-parameters ────────────────────────────────────────────
@dataclass
class LeakParams:
    alpha_val: float = 0.030  # strength for val
    alpha_gate: float = 0.030  # strength for gate
    alpha_down: float = 0.022
    a_val_dorm: float = 0.007  # strength of leak to maintain order
    a_gate_dorm: float = 0.007
    sigma0: float = 2.0
    lam: float = 4.0  # controls preference of close vs far: lower prefers close
    tau: int = 50  # batches between leaks
    tau_mod: int = 0
    start_tau: int = 241
    recall_alpha: list = field(default_factory=lambda: [0.00, 0.00, 0.00, 0.00])  # [entry, up_gate, up_val, down]
    use_fft: bool = True
    max_variance: float = 10.0  # was 11.0 at start # base vals: 2.0: 78.1  2.5: 76.5 3.0: 70.5 bleu  5.75: 56.95 bleu  9.0: 14.0 bleu
    num_batches: int = 0
    safe_to_forget: list = field(default_factory=lambda: [False, False, False, False, False, False, False, False, False, False, False, False,
                                                          False, False, False, False, False, False, False, False, False, False, False, False])
    apply_recall: bool = True
    apply_blur: bool = False
    apply_noise: bool = False
    noise_strength: float = 0.0


LEAK = LeakParams()


# Blend enabled?   No          Yes        90%
# Base Model:     78.5 BLEU    25.0 BLEU
# Whole Stream 6: 78.5 bleu    78.7 BLEU  76.1 BLEU

"""
VARIANCE 1.7999999999999998
Batch 1000 of 50000
Loss: 0.279685721218586
Gold loss: 0.4591213328242302
BLEU score: 78.52
Perplexity: 1.0001150089654989
Overall: 78.66
"""



# finds clean first two letter comparison for forgiveness
def token_to_two_letter_code(tok: str) -> int:
    # Remove common leading-space markers, then pull first two ASCII letters only.
    s = tok.replace("Ġ", " ").replace("▁", " ").strip()
    letters = re.findall(r'[A-Za-z]', s)  # guarantees single-char ASCII
    if len(letters) < 2:
        return -1
    a, b = letters[0].lower(), letters[1].lower()
    return (ord(a) - 97) * 26 + (ord(b) - 97)  # 0..675


def build_two_letter_code_map_from_model(tokenizer, model_vocab_size: int) -> torch.Tensor:
    sp_size = getattr(tokenizer, "vocab_size", None)
    if sp_size is None:
        sp_size = len(tokenizer)  # fallback

    table = torch.full((model_vocab_size,), -1, dtype=torch.long)

    for tid in range(sp_size):
        try:
            tok = tokenizer.convert_ids_to_tokens(tid)
        except Exception:
            # Just in case, but we shouldn't hit this now that tid < sp_size
            continue
        if tok is None:
            continue
        code = token_to_two_letter_code(tok)  # your ASCII-only 2-letter mapper
        table[tid] = code

    return table

@torch.no_grad()
def renorm_toward_rms(param, target_rms, max_change=0.10):
    cur = param.pow(2).mean().sqrt()
    if cur <= 0:
        return
    # multiplicative step toward target, limited to ±10% per call
    step = (target_rms / (cur + 1e-8))
    step = step.clamp(1.0 - max_change, 1.0 + max_change)
    param.mul_(step)


@torch.no_grad()
def leak_once_uniform(wei: torch.Tensor,
                      K7: torch.Tensor,
                      frac: float = 0.01,
                      max_frac: float = 2.0):
    K7 = K7.to(dtype=wei.dtype, device=wei.device)

    # --- measure pre-leak scale ---
    pre_rms = wei.pow(2).mean().sqrt()

    k_pad = K7.shape[-1] // 2

    inp = F.pad(wei.unsqueeze(0).unsqueeze(0),
                pad=(k_pad, k_pad, k_pad, k_pad), mode="circular")

    neigh_avg = F.conv2d(inp, K7).squeeze_(0).squeeze_(0)

    delta = frac * (neigh_avg - wei)
    # clip so no weight changes by more than 5 %
    delta = torch.clamp(delta,
                        -max_frac * wei.abs(),
                        max_frac * wei.abs())

    wei.add_(delta)

    # --- rescale to preserve variance ---
    # if preserve_rms:
    post_rms = wei.pow(2).mean().sqrt()
    scale = (pre_rms / (post_rms + 1e-8)).clamp(0.1, 10.0)
    wei.mul_(scale)

    wei.nan_to_num_(nan=0.0, posinf=1e4, neginf=-1e4)



# 2d kernel for weight operations
def make_kernel(size: int, lam: float, centre_zero: bool = True) -> torch.Tensor:
    """
    Isotropic wrap-around kernel.
    size must be an odd number (e.g. 7, 9, 11).
    """
    assert size % 2 == 1, "size must be odd"
    half = size // 2  # 3 for 7×7, 4 for 9×9 …
    ys = torch.arange(-half, half + 1)
    xs = torch.arange(-half, half + 1)
    yy, xx = torch.meshgrid(ys, xs, indexing='ij')  # centred grid
    R = torch.sqrt(xx ** 2 + yy ** 2)  # Euclidean distance
    K = torch.exp(-R / lam)  # exponential decay
    if centre_zero:
        K[half, half] = 0.0  # remove self-weight
    K /= K.sum()  # normalise
    return K


# makes a 1d neighborhood kernel
def make_blend_kernel(K: int = 11, sigma: float = 2.0, dtype=None, device=None) -> torch.Tensor:
    c = K // 2

    assert K % 2 == 1 and K >= 3, "K must be odd and >=3"
    x = torch.arange(-(K // 2), K // 2 + 1, dtype=dtype, device=device)
    w = torch.exp(-0.5 * (x / float(sigma)) ** 2)
    w = w / (w.sum() + 1e-12)

    if w[c].abs() > 0:
        w = w.clone()
        w[c] = 0
        s = w.sum()
        if float(s) <= 1e-12:
            raise ValueError("neighbor_only produced zero-sum kernel; choose larger K or different sigma.")
        w = w / s

    return w


def _alpha_eff(a):
    mod_t = (float(LEAK.start_tau - LEAK.tau_mod) / float(LEAK.start_tau))
    return (float(a) * mod_t)


# old kernel creation for neighborhood vs worm hole magnetism
def make_slanted_kernels(K: int = 9, decay: float = 0.6, *, dtype=None, device=None):
    """
    Two one-sided kernels (center weight = 0, sums = 1):
      - k_right pulls mass from left neighbors (move right)
      - k_left  pulls mass from right neighbors (move left)
    """
    assert K % 2 == 1 and K >= 3
    half = K // 2
    w = (decay ** torch.arange(1, half + 1, dtype=dtype, device=device))  # near neighbor largest
    w = w / (w.sum() + 1e-12)

    k_right = torch.zeros(K, dtype=dtype, device=device)  # use left half (…,-2,-1)
    k_right[:half] = w.flip(0)  # nearest (-1) gets biggest weight
    # center = 0
    k_left = torch.zeros(K, dtype=dtype, device=device)  # use right half (+1,+2,…)
    k_left[half + 1:] = w  # nearest (+1) gets biggest weight
    return k_left, k_right


device = "cuda"


# At the top of your file, outside any class:
_GLOBAL_HUB_KERNELS = {}
_GLOBAL_DIST_CACHE = {}


def _ema_decay_from_halflife(halflife_steps: float) -> float:
    # decay = exp(-ln(2)/halflife)
    return math.exp(-math.log(2.0) / max(1e-6, float(halflife_steps)))



def tokenize_fn(batch, tokenizer, train_english):
    # extract english or chinese text depending on task
    # if train_english:
    #    en_list = [ex["en"] for ex in batch["translation"]]
    # else:
    #    en_list = [ex["as"] for ex in batch["translation"]]
    # en_list = en_list[0:347822]

    if "text" in batch:
        text_list = batch["text"]
    else:
        text_list = batch["en"]

    # Tokenize
    # @todo adaptive batch sizes
    tokenized = tokenizer(text_list, padding="max_length", truncation=True, max_length=64)

    # Create labels, replacing pad tokens with -100 for loss masking
    labels = [
        [-100 if tid == tokenizer.pad_token_id else tid for tid in seq]
        for seq in tokenized["input_ids"]
    ]
    tokenized["labels"] = labels

    return tokenized


def tokenize_forgiveness(batch, tokenizer, train_english):
    max_len = 64
    # ------------------------------------------------------------------ 1. pick the sentences
    # text_list = batch["en"] if train_english else batch["as"] # for use in en to assamese
    text_list = batch["text"]

    # print(text_list[0:5])

    # ------------------------------------------------------------------ 2. normal tokenisation
    tokenised = tokenizer(
        text_list,
        padding="max_length",
        truncation=True,
        max_length=max_len,
    )

    # ------------------------------------------------------------------ 3. label tensor (-100 on pads)
    pad_id = tokenizer.pad_token_id
    tokenised["labels"] = [
        [-100 if tid == pad_id else tid for tid in seq]
        for seq in tokenised["input_ids"]
    ]

    # ------------------------------------------------------------------ 4. pass through neighbour info
    if (not train_english) and ("k_alts" in batch):

        sent_level_alt_ids = []  # will become [B][T][k]

        for sent_alts in batch["k_alts"]:  # iterate sentences
            tok_level_alt_ids = []
            for tok_alts in sent_alts:  # iterate tokens
                # `tok_alts` is a list[str] length k
                ids = [tokenizer.convert_tokens_to_ids(tok) for tok in tok_alts]
                tok_level_alt_ids.append(ids)
            sent_level_alt_ids.append(tok_level_alt_ids)

        # keep everything – some downstream code may still need the strings
        tokenised["k_alts"] = batch["k_alts"]
        tokenised["k_confs"] = batch["k_confs"]
        tokenised["k_alt_ids"] = sent_level_alt_ids

    # print(tokenised.keys())

    if "k_alt_ids" in tokenised:  # we created these three lists above
        k_val = len(tokenised["k_alt_ids"][0][0])  # number of alts (= k)

        pad_tok_vec = ['<pad>'] * k_val  # dummy string
        pad_ids_vec = [pad_id] * k_val  # pad-ids
        pad_conf_vec = [0.0] * k_val  # zero confidence

        def pad(seq, pad_elem):  # helper
            return seq + [pad_elem] * (max_len - len(seq))

        # pad every sentence-level list to length 64
        tokenised["k_alts"] = [pad(s, pad_tok_vec) for s in tokenised["k_alts"]]
        tokenised["k_confs"] = [pad(s, pad_conf_vec) for s in tokenised["k_confs"]]
        tokenised["k_alt_ids"] = [pad(s, pad_ids_vec) for s in tokenised["k_alt_ids"]]

        # convert the numeric ones to tensors so the default collate can stack
        tokenised["k_alt_ids"] = torch.tensor(tokenised["k_alt_ids"], dtype=torch.long)
        tokenised["k_confs"] = torch.tensor(tokenised["k_confs"], dtype=torch.float)

    for key in ("input_ids", "attention_mask", "labels"):
        tokenised[key] = torch.tensor(tokenised[key], dtype=torch.long)

    if "k_alt_ids" in tokenised:  # (B, T, k)
        tokenised["k_alt_ids"] = torch.tensor(tokenised["k_alt_ids"],
                                              dtype=torch.long)

    return tokenised





def lr_lambda(current_step):
    if current_step < warmup_steps:
        return float(current_step) / float(max(1, warmup_steps))  # linear warmup
    else:
        steps_since_warmup = current_step - warmup_steps
        step_factor = steps_since_warmup // step_size
        return gamma  # ** step_factor


warmup_steps = 10  # adjust as needed
step_size = 1  # same as your StepLR
gamma = 0.99  # decay factor


def pad_to_len(x: torch.Tensor, target_len: int, pad_val):
    """
    x : (..., T, k)
    returns a tensor with shape (..., target_len, k)
    """
    pad = target_len - x.size(-2)  # amount to pad on the T axis
    if pad == 0:
        return x
    # F.pad wants the paddings **from the last dim backwards**:
    # (..., T, k)  -> pad (0,0) on k-dim  and (0, pad) on T-dim
    return F.pad(x, (0, 0, 0, pad), value=pad_val)


def fully_randomize(model, *, seed=50, std=0.02):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

    def _reset(m):
        # linear / embedding → N(0, std)
        if isinstance(m, (nn.Linear, nn.Embedding)):
            nn.init.normal_(m.weight, mean=0.0, std=std)
            if getattr(m, "bias", None) is not None:
                nn.init.zeros_(m.bias)
        # layernorm → ones / zeros
        elif isinstance(m, nn.LayerNorm):
            nn.init.ones_(m.weight)
            nn.init.zeros_(m.bias)
        # conv → kaiming_uniform
        elif isinstance(m, nn.Conv2d):
            nn.init.kaiming_uniform_(m.weight, a=math.sqrt(5))
            if getattr(m, "bias", None) is not None:
                fan_in, _ = nn.init._calculate_fan_in_and_fan_out(m.weight)
                bound = 1 / math.sqrt(fan_in)
                nn.init.uniform_(m.bias, -bound, bound)

    model.apply(_reset)  # walks every sub-module
    if hasattr(model, "tie_weights"):
        model.tie_weights()  # for language models


import math, numpy as np, torch, torch.nn.functional as F
from dataclasses import dataclass
from typing import Dict, List, Tuple, Optional

def tokenize_single_piece(tok, s: str) -> list[int]:
    ids = tok.encode(s, add_special_tokens=False)
    if len(ids) == 0:
        return []
    if len(ids) == 1:
        return [ids[0]]
    # pick the SECOND piece per your convention
    return [ids[1]]

def strings_to_ids_list(tok, csv_or_list) -> list[int]:
    items = csv_or_list if isinstance(csv_or_list, list) else [x.strip() for x in str(csv_or_list).split(",") if x.strip()]
    out = []
    for it in items:
        ids = tok.encode(it, add_special_tokens=False)
        if not ids:
            print(f"[warn] '{it}' tokenized to nothing; skipping.")
            continue
        if len(ids) > 1:
            print(f"[warn] '{it}' tokenized to {ids}; using second id {ids[1]}.")
            out.append(ids[1])
        else:
            out.append(ids[0])
    return out



_GLOBAL_HUB_KERNELS = {}

@torch.no_grad()
def find_hubs_for_signatures(delta_1d: torch.Tensor, win: int = 9) -> tuple[int, float, int, float]:
    """
    delta_1d: [D] (Δ = token_avg - baseline) for a layer
    returns: (pos_center, pos_mag>0, neg_center, neg_mag>0)
    """
    assert delta_1d.dim() == 1
    D = delta_1d.numel()
    z = delta_1d.view(1, 1, D)
    key = (win, delta_1d.dtype, delta_1d.device.index if delta_1d.device.type == 'cuda' else -1)
    if key not in _GLOBAL_HUB_KERNELS:
        _GLOBAL_HUB_KERNELS[key] = torch.ones(1, 1, win, device=delta_1d.device, dtype=delta_1d.dtype) / float(win)
    k = _GLOBAL_HUB_KERNELS[key]; r = win // 2

    pos = F.relu(z)
    neg = F.relu(-z)

    sm_pos = F.conv1d(F.pad(pos, (r, r), mode='circular'), k).squeeze(0).squeeze(0)  # [D]
    sm_neg = F.conv1d(F.pad(neg, (r, r), mode='circular'), k).squeeze(0).squeeze(0)  # [D]

    pv, pi = sm_pos.max(dim=-1)                # prominence in Δ+
    nv, ni = sm_neg.max(dim=-1)                # prominence in Δ−

    pos_center = int(pi.item())
    neg_center = int(ni.item())
    pos_mag    = float(delta_1d[pos_center].abs().item())   # positive magnitude
    neg_mag    = float(delta_1d[neg_center].abs().item())   # positive magnitude

    return pos_center, pos_mag, neg_center, neg_mag


@torch.no_grad()
def find_hubs_for_signatures_multi(
    delta_1d: torch.Tensor,
    win: int = 9,
    max_hubs: int = 3,
    frac: float = 0.85,
) -> tuple[list[int], list[float], list[int], list[float]]:
    """
    delta_1d: [D]  (Δ = token_mean - baseline)
    Returns:
      P_idx:   list[int]   (up to max_hubs)
      P_mag:   list[float] (smoothed +prominence at those indices)
      N_idx:   list[int]
      N_mag:   list[float] (smoothed -prominence at those indices)

    Notes:
      - Magnitudes are taken from the SAME smoothed metric used for selection
        (conv1d over ReLU(Δ) with circular padding), so the 0.8 cutoff matches
        the reported magnitudes: min(P_mag)/max(P_mag) >= frac (sim. for N).
      - min separation = win between hubs of the same sign (NMS).
    """
    assert delta_1d.dim() == 1
    D = delta_1d.numel()
    z = delta_1d.view(1, 1, D)

    # cache kernel per (win, dtype, device)
    key = (win, delta_1d.dtype, delta_1d.device.index if delta_1d.is_cuda else -1)
    if key not in _GLOBAL_HUB_KERNELS:
        _GLOBAL_HUB_KERNELS[key] = torch.ones(1, 1, win, device=delta_1d.device, dtype=delta_1d.dtype) / float(win)
    k = _GLOBAL_HUB_KERNELS[key]
    r = win // 2

    pos = F.relu(z)       # positive part
    neg = F.relu(-z)      # negative part

    sm_pos = F.conv1d(F.pad(pos, (r, r), mode='circular'), k).squeeze(0).squeeze(0)  # [D]
    sm_neg = F.conv1d(F.pad(neg, (r, r), mode='circular'), k).squeeze(0).squeeze(0)  # [D]

    def _pick_hubs(sm: torch.Tensor) -> tuple[list[int], list[float]]:
        # sort by smoothed prominence
        vals, idxs = torch.sort(sm, descending=True)
        if vals.numel() == 0 or vals[0].item() <= 0:
            return [], []

        thr = vals[0].item() * float(frac)
        chosen_idx: list[int] = []
        chosen_val: list[float] = []

        for v, i in zip(vals.tolist(), idxs.tolist()):
            if v < thr:
                break
            # NMS: require separation >= win from all previously chosen (circular distance)
            ok = True
            for j in chosen_idx:
                circ_dist = min(abs(i - j), D - abs(i - j))
                if circ_dist < win:
                    ok = False
                    break
            if not ok:
                continue
            chosen_idx.append(int(i))
            chosen_val.append(float(v))
            if len(chosen_idx) >= max_hubs:
                break
        return chosen_idx, chosen_val

    P_idx, P_mag = _pick_hubs(sm_pos)
    N_idx, N_mag = _pick_hubs(sm_neg)

    return P_idx, P_mag, N_idx, N_mag


def _normalized_hub_power(delta: torch.Tensor, center: int, win: int) -> float:
    # mean(|Δ| in window) / mean(|Δ| whole layer)
    D = delta.numel()
    r = max(1, win // 2)
    lo, hi = max(0, center - r), min(D, center + r + 1)
    local_mean = delta[lo:hi].abs().mean()
    global_mean = delta.abs().mean().clamp(min=1e-8)
    return float((local_mean / global_mean).item())


def build_signatures_from_means(
    baseline: dict[int, np.ndarray],                 # baseline[L] -> [D]
    token_means: dict[int, dict[int, np.ndarray]],   # token_means[token_id][L] -> [D]
    id2str_fn,                                       # e.g., lambda ids: tokenizer.decode(ids, skip_special_tokens=True)
    smooth_win: int = 9,
    inverse = True
) -> dict[str, dict[str, dict[str, float | int]]]:
    """
    Returns:
      {
        "<token_str>": {
          "D11": {"P": <pos_center>, "N": <neg_center>, "P_mag": <|Δ| at P>, "N_mag": <|Δ| at N>},
          "D10": {...}, ...
        }, ...
      }
    Notes:
      - Magnitudes are stored as positive values; P vs N tells you the sign.
      - Δ = (token_mean - baseline) is computed per layer.
    """
    signatures: dict[str, dict[str, dict[str, float | int]]] = {}

    for token_id, per_layer_mean in token_means.items():
        token_str = id2str_fn([token_id]).strip()
        if not token_str:
            token_str = str(token_id)

        entry_per_layer: dict[str, dict[str, float | int]] = {}
        for L, token_vec in per_layer_mean.items():
            base_vec = baseline.get(L, None)
            if base_vec is None:
                continue

            m_tok = float(np.mean(np.abs(token_vec)) + 1e-8)
            m_base = float(np.mean(np.abs(base_vec)) + 1e-8)

            token_norm = token_vec / m_tok
            base_norm = base_vec / m_base

            delta = torch.from_numpy(token_norm - base_norm).float()  # [D]

            # Hubs as before on delta (pos/neg smoothed separately, no cancellation)
            P_idx, P_mag, N_idx, N_mag = find_hubs_for_signatures_multi(delta, win=smooth_win)

            """entry_per_layer[f"D{L}"] = {
                "P": P_idx,  # list[int] centers
                "P_mag": P_mag,  # list[float] window mean |δ| / global mean |δ|  (ratios)
                "N": N_idx,
                "N_mag": N_mag,
            }"""
            # flips centers to discourage token
            if inverse:
                entry_per_layer[f"D{L}"] = {
                    "P": [int(c) for c in N_idx],
                    "P_mag": [_normalized_hub_power(delta, int(c), smooth_win) for c in N_idx],
                    "N": [int(c) for c in P_idx],
                    "N_mag": [_normalized_hub_power(delta, int(c), smooth_win) for c in P_idx],
                }
            else:
                entry_per_layer[f"D{L}"] = {
                    "P": [int(c) for c in P_idx],
                    "P_mag": [_normalized_hub_power(delta, int(c), smooth_win) for c in P_idx],
                    "N": [int(c) for c in N_idx],
                    "N_mag": [_normalized_hub_power(delta, int(c), smooth_win) for c in N_idx],
                }

        if entry_per_layer:
            signatures[token_str] = entry_per_layer

    return signatures




@torch.no_grad()
def build_baseline_and_targets(
    model, tokenizer,
    sentences: list[str],
    target_token_strs: list[str],
    layers_to_track: list[int] = [9,10,11],
    per_token_samples:int=64,
    max_baseline_positions:int=100_000,
    batch_size:int=32,
    device:str="cuda"
):
    model.eval()
    tgt_ids = strings_to_ids_list(tokenizer, target_token_strs)
    acc = DownBandAccumulator(model, tokenizer, layers=layers_to_track,
                              target_ids=tgt_ids, per_token_samples=per_token_samples)
    acc.register()

    seen_valid = 0
    for i in range(0, len(sentences), batch_size):
        batch_txt = sentences[i:i+batch_size]

        enc = tokenizer(batch_txt, return_tensors="pt", padding=True, truncation=True, max_length=128).to(device)
        with tokenizer.as_target_tokenizer():
            tgt = tokenizer(batch_txt, return_tensors="pt", padding=True, truncation=True, max_length=128).to(device)

        labels = tgt["input_ids"].clone()              # [B, T_lab]
        labels[labels == tokenizer.pad_token_id] = -100

        # Shift-right to get decoder inputs; [B, T_dec]
        dec_inp = model.base_model.prepare_decoder_input_ids_from_labels(labels=labels)

        # Build valid mask aligned to decoder positions
        B, T_dec = dec_inp.shape
        valid = torch.zeros_like(dec_inp, dtype=torch.bool)  # [B, T_dec]
        if T_dec > 1 and labels.size(1) >= (T_dec - 1):
            valid[:, 1:] = (labels[:, :T_dec-1] != -100)

        # Align labels to decoder positions: slot 0 invalid; slot t holds labels[:, t-1]
        labels_dec = torch.full_like(dec_inp, fill_value=-100)
        if T_dec > 1 and labels.size(1) >= (T_dec - 1):
            labels_dec[:, 1:] = labels[:, :T_dec-1]

        # Hand off ids + masks to accumulator (uses labels_dec for token capture)
        acc.set_decoder_ids_and_mask(dec_inp.detach(), valid.detach(), labels_dec.detach())

        _ = model(input_ids=enc["input_ids"],
                  attention_mask=enc["attention_mask"],
                  decoder_input_ids=dec_inp,
                  use_cache=False)

        seen_valid += int(valid.sum().item())
        if seen_valid >= max_baseline_positions:
            break

    acc.unregister()
    baseline, token_means = acc.build_baseline_and_token_means()

    id2str = lambda ids: tokenizer.decode(ids, skip_special_tokens=True)
    signatures = build_signatures_from_means(baseline, token_means, id2str_fn=id2str)

    return baseline, token_means, signatures



def _gaussian_1d(D:int, centers:list[int], sigma:float)->torch.Tensor:
    x = torch.arange(D, dtype=torch.float32)
    g = torch.zeros(D, dtype=torch.float32)
    for c in centers:
        if c is None: continue
        g += torch.exp(-0.5 * ((x - float(c))/float(sigma))**2)
    return (g / g.max().clamp(min=1e-6)).clamp(0,1)

def _window_unit_direction(delta: torch.Tensor, centers:list[int], signs:list[int], radius:int)->torch.Tensor:
    D = delta.numel()
    u = torch.zeros(D, dtype=torch.float32)
    for c, s in zip(centers, signs):
        if c is None: continue
        lo, hi = max(0, c - radius), min(D, c + radius + 1)
        w = delta[lo:hi].clone()
        w = w - w.mean()  # local detrend
        u[lo:hi] += s * w
    n = u.norm(p=2).clamp(min=1e-6)
    return u / n


def load_english():
    srcs = [
        #load_dataset("allenai/c4", "en", split="train[:7000]"),
        load_dataset("wikipedia", "20220301.en", split="train[:6000000]"),
        load_dataset("bookcorpus", "plain_text", split="train[:8000000]"),
    ]
    return interleave_datasets(srcs)

def _hub_dir_and_ratio(delta: torch.Tensor, center: int, radius: int):
    """Return (u, R) where u is mean-free, unit-L2 window direction and
       R = mean(|Δ_window|) / mean(|Δ_layer|)."""
    D = delta.numel()
    lo, hi = max(0, center - radius), min(D, center + radius + 1)
    win = delta[lo:hi]
    u = win - win.mean()
    u = u / (u.norm(p=2) + 1e-8)
    R = win.abs().mean() / (delta.abs().mean() + 1e-8)
    return u.cpu().numpy(), float(R)

class T5SpanCorruptionCollator:
    """
    Vectorised collator for T5/mT5 span corruption.

    • Works even if <extra_id_*> strings are not in tokenizer.special_tokens
      (sentinel ids are derived numerically: vocab_size‑1‑i).
    • Pads encoder inputs to `input_length`.
    • Pads decoder labels to same length and fills pad positions with ‑100,
      so CrossEntropyLoss(ignore_index=-100) works directly.
    """

    def __init__(self, tokenizer,
                 noise_density=0.15,
                 mean_span_len=3,
                 input_length=64):

        self.tok = tokenizer
        self.noise = noise_density
        self.mean = mean_span_len
        self.L = input_length

        self.pad = tokenizer.pad_token_id
        self.eos = tokenizer.eos_token_id
        # sentinel_id(i) = vocab_size - 1 - i   (T5 design choice)
        self.sentinels = torch.arange(tokenizer.vocab_size - 1,
                                      tokenizer.vocab_size - 101, -1)

    # ------------------------------------------------------------------
    def _random_mask(self, B, L, device):
        """Generate a (B,L) boolean mask with ≈noise_density True values."""
        bern = torch.rand(B, L, device=device) < self.noise

        # ensure every row has at least one True
        row_mask = bern.sum(1) == 0  # rows with 0 masked
        if row_mask.any():
            rand_cols = torch.randint(0, L, (row_mask.sum(),), device=device)
            bern[row_mask, rand_cols] = True
        return bern

    # ------------------------------------------------------------------
    def __call__(self, examples):
        ids = torch.stack([ex["input_ids"][: self.L] for ex in examples])
        B, L = ids.shape
        device = ids.device

        mask = self._random_mask(B, L, device)  # corruption mask

        inputs = ids.clone()
        labels = torch.full_like(ids, self.pad)

        for b in range(B):
            row_mask = mask[b]
            spans = [(flag, len(list(run)))
                     for flag, run in groupby(row_mask.tolist())]

            in_row, lb_row, cur, s_idx = [], [], 0, 0
            for is_mask, span_len in spans:
                if is_mask:
                    sent_id = int(self.sentinels[s_idx])
                    in_row.append(sent_id)  # encoder sentinel
                    lb_row.append(sent_id)  # decoder sentinel
                    lb_row.extend(ids[b, cur:cur + span_len].tolist())
                    s_idx += 1
                else:
                    in_row.extend(ids[b, cur:cur + span_len].tolist())
                cur += span_len
            lb_row.append(self.eos)

            # pad/truncate exactly to L
            if len(in_row) < L:
                in_row.extend([self.pad] * (L - len(in_row)))
            inputs[b] = torch.tensor(in_row[:L], device=device)

            labels[b, : len(lb_row)] = torch.tensor(lb_row[:L], device=device)
            labels[b, len(lb_row):] = -100  # ignore in loss

        batch = {
            "input_ids": inputs,
            "attention_mask": (inputs != self.pad).long(),
            "labels": labels,
        }

        # pass‑through any extra columns untouched (k_alts, k_confs, …)
        for k in examples[0]:
            if k not in batch:
                batch[k] = [ex[k] for ex in examples]

        return batch  # updated weights

class CustomDenseReluDense(nn.Module):
    def __init__(self, orig_module, layer_num, variance=0.0010, long_var=0.0010, default_std=2.5,
                 forget_factor=0.998, long_forget_factor=0.9999, eps=1e-5, ema_halflife=600.0):
        """
        Wraps the original DenseReluDense module and overrides the inner feed-forward computation.

        Args:
            orig_module: The original DenseReluDense module from T5/MT5.
            eps: Small epsilon to avoid division by zero.
        """
        super().__init__()
        self.eps = eps
        # For MT5, the feed-forward module uses gated linear projections:
        self.wi_0 = orig_module.wi_0  # Linear: (d_model -> d_ff)
        self.wi_1 = orig_module.wi_1  # Linear: (d_model -> d_ff)
        self.wo = orig_module.wo  # Linear: (d_ff -> d_model)
        self.forward_passes = 0.0
        self.dropout = orig_module.dropout
        self.activation = ACT2FN["gelu_new"]  # Activation function (e.g. GeLU)

        with torch.no_grad():
            self._rms_target = {
                "wi0": self.wi_0.weight.pow(2).mean().sqrt().item(),
                "wi1": self.wi_1.weight.pow(2).mean().sqrt().item(),
                "wo": self.wo.weight.pow(2).mean().sqrt().item(),
            }

        print(self._rms_target)

        # Register buffer for custom normalization; shape = (d_ff , d_model)
        # self.register_buffer("short_fg_value", torch.zeros(self.wi_0.weight.shape), persistent=True)
        # self.register_buffer("short_fg_gate", torch.zeros(self.wi_1.weight.shape), persistent=True)

        # Register long term forgetfulness
        # self.register_buffer("long_fg_value", torch.zeros(self.wi_0.weight.shape), persistent=True)
        # self.register_buffer("long_fg_gate", torch.zeros(self.wi_1.weight.shape), persistent=True)

        self.layer_num = int(layer_num)

        # print(self.short_fg_value.shape)

        # multiplier for how intense range is from given forgetfulness value
        self.variance = variance
        self.long_var = long_var
        self.forget_factor = forget_factor
        self.long_forget_factor = long_forget_factor
        self.default_std = default_std
        self.base_intensity = 2.5  # @todo try or theorize much higher intensities
        self.memory_boost = 200000.0
        self.LN2 = math.log(2.0)

        self._kernel_cache = {}

        self.register_buffer(
            "kernel11",
            make_kernel(11, LEAK.lam, centre_zero=True).unsqueeze(0).unsqueeze(0),
            persistent=False
        )

        self.register_buffer(
            "kernel7",
            make_kernel(7, (LEAK.lam / 8.0), centre_zero=True).unsqueeze(0).unsqueeze(0),
            persistent=False
        )

        self.register_buffer(
            "recall_kernel11",
            make_blend_kernel(K=9, sigma=2.5, dtype=torch.float32)  # stays on CPU; moved at forward
        )

        # slanted kernels for magnetism
        # self.register_buffer("mag_k_left", kL, persistent=False)
        # self.register_buffer("mag_k_right", kR, persistent=False)

        # self.hub_kernel = {}  # Cache kernels by window size
        # self._hub_kernels = {}  # Cache kernels by window size
        # self._circ_dist_cache = None  # For scramble function

        self.d_ff = 2048
        self.d_model = 768

        # bucket size is fixed at 4
        self.bucket_size = 4
        self.buckets_model = self.d_model // self.bucket_size
        self.buckets_ffn = self.d_ff // self.bucket_size

        # EMA decay
        self.ema_decay = _ema_decay_from_halflife(ema_halflife)

        # Per-layer EMAs (pos/neg) for model- and FFN-width hub densities (bucketed)
        self.register_buffer("ema_model_pos", torch.zeros(self.buckets_model))
        self.register_buffer("ema_model_neg", torch.zeros(self.buckets_model))
        self.register_buffer("ema_ffn_pos", torch.zeros(self.buckets_ffn))
        self.register_buffer("ema_ffn_neg", torch.zeros(self.buckets_ffn))

        self.cos_sim1d = 0.0
        #self.cos_sim2d = str(self.cosine_agreement_2d_weights(self.wo))
        self.wa = 0.0
        self.hub_energy = 0.0
        self.hub_e_counter = self.eps
        self.has_printed = False

        if layer_num == 23:
            print(layer_num)
            print(self.kernel7)
            print("\nWide Kernel")
            print(self.kernel11)
            print("\nRecall Kernel")
            print(self.recall_kernel11)
            # print("\nLeft Kernel")
            # print(self.mag_k_left)
            # print("\nRight Kernel")
            # print(self.mag_k_right)
        # print(layer_num)
        # print(self.kernel7)

    # Functions for Metrics

    def print_metrics(self):
        print("\nLayer ", str(self.layer_num))
        print("Morans 2d: ", str(self.morans_2d(self.wo.weight)))
        print("Morans 1d:", str(self.cos_sim1d / self.hub_e_counter))
        print("Hub Energy: ", str(self.hub_energy / self.hub_e_counter))
        print("WA-Alignment ", str(self.wa / self.hub_e_counter))

    # -----------------------------
    # 1) W–A alignment under downprojection (per layer)
    # -----------------------------
    def wa_alignment_downproj(
            self,
            X: torch.Tensor,  # z_flat: (N, D)
            W: torch.Tensor,  # (T, D)
            *,
            center: bool = True,
            normalize_rows: bool = True,
            projector: str = "orth",  # "orth" or "none"
            reduce: str = "mean"  # "mean", "median", or "none"
    ):
        assert hasattr(self, "token_ids"), "self.token_ids required"
        tok = self.token_ids.to(X.device)
        device = X.device
        T, D = W.shape
        eps = 1e-8

        Wn = F.normalize(W, p=2, dim=1) if normalize_rows else W
        if projector == "orth":
            WT = W.transpose(0, 1)  # (D, T)
            Pinv = torch.linalg.pinv(W @ WT)  # (T, T)
            P = WT @ (Pinv @ W)  # (D, D)
        elif projector == "none":
            P = None
        else:
            raise ValueError("projector must be 'orth' or 'none'")

        Xc = X - X.mean(dim=0, keepdim=True) if center else X
        per_token_mu = torch.full((T, D), float('nan'), device=device)
        for t_ in torch.unique(tok):
            t = int(t_.item());
            Xt = Xc[tok == t]
            if Xt.numel() == 0: continue
            mu = Xt.mean(dim=0)
            mu = mu @ P if P is not None else mu
            per_token_mu[t] = F.normalize(mu, p=2, dim=0) if normalize_rows else mu

        align = torch.full((T,), float('nan'), device=device)
        valid = ~torch.isnan(per_token_mu).any(dim=1)
        if valid.any():
            W_use = Wn if normalize_rows else W
            align[valid] = (W_use[valid] * per_token_mu[valid]).sum(dim=1)

        if reduce == "mean":
            summary = torch.nanmean(align)
        elif reduce == "median":
            summary = torch.nanmedian(align).values
        else:
            summary = None
        return {"per_token_alignment": align, "summary": summary}

    def morans_1d(self, x: torch.Tensor, k: int = 1) -> torch.Tensor:
        """
        1D Moran's I (spatial autocorrelation) along index order.
        x: (L,) or (N, L). If (N, L), returns the mean Moran's I over N rows.
        k: neighborhood radius (# of adjacent positions on each side).
        """
        if x.ndim == 1:
            x = x.unsqueeze(0)  # (1, L)
        x = x - x.mean(dim=1, keepdim=True)  # center per row
        # 1D neighbor sum via conv
        kernel = torch.ones(1, 1, 2 * k + 1, device=x.device, dtype=x.dtype)
        kernel[0, 0, k] = 0  # remove center
        y = F.conv1d(x.unsqueeze(1), kernel, padding=k).squeeze(1)  # neighbor sum
        num = (x * y).sum(dim=1)  # per-row numerator
        den = (x * x).sum(dim=1).clamp_min(1e-12)  # per-row denominator
        L = x.size(1)
        S0 = (2 * k) * L  # total neighbor weights per row
        I = (L / S0) * (num / den)
        return I.mean()

    def morans_2d(self, W2d: torch.Tensor, eight_neigh: bool = True) -> torch.Tensor:
        # W2d: (H, W) weight matrix viewed as an image
        x = W2d - W2d.mean()
        k = torch.tensor([[1, 1, 1], [1, 0, 1], [1, 1, 1]] if eight_neigh else [[0, 1, 0], [1, 0, 1], [0, 1, 0]],
                         dtype=x.dtype, device=x.device).unsqueeze(0).unsqueeze(0)
        y = F.conv2d(x.unsqueeze(0).unsqueeze(0), k, padding=1).squeeze()  # sum of neighbors
        num = (x * y).sum()
        den = (x * x).sum().clamp_min(1e-12)
        N = x.numel()
        S0 = k.sum() * (W2d.shape[0] * W2d.shape[1])  # total neighbor weights
        return (N / S0) * (num / den)

    def hub_concentration_masked(
            self,
            X: torch.Tensor,  # (N, D)
            pos_inside: torch.Tensor,  # (N, D) bool
            neg_inside: torch.Tensor,  # (N, D) bool
            outside: torch.Tensor | None = None,
            *, eps: float = 1e-12
    ) -> torch.Tensor:
        assert X.shape == pos_inside.shape == neg_inside.shape
        inmask = pos_inside | neg_inside
        if outside is None:
            outside = ~inmask
        else:
            assert outside.shape == X.shape

        # Net mass per hub region (signs cancel *inside* each hub), then take magnitudes
        pos_net = (X * pos_inside).sum(dim=1).abs()  # |sum of values in positive hubs|
        neg_net = (X * neg_inside).sum(dim=1).abs()  # |sum of values in negative hubs|
        Ein = pos_net + neg_net  # directed in-hub energy (net)

        # Outside uses absolute values directly
        Eout = (X.abs() * outside).sum(dim=1)

        prop = Ein / (Ein + Eout + eps)  # per-sample proportion
        return prop.mean().clamp(0.0, 1.0)

    #Functions for Computation

    @torch.no_grad()
    def find_hubs_per_token(self, z: torch.Tensor, num_hubs: int = 2,
                            min_sep: int = 0, win: int = 9):
        assert win % 2 == 1, "win must be odd"
        B, T, D = z.shape
        N = B * T
        x = z.view(N, 1, D)

        # cache a 1×1×win box kernel on the right device/dtype
        key = (win, z.dtype, z.device.index if z.device.type == 'cuda' else -1)
        if not hasattr(self, '_box_k'): self._box_k = {}
        if key not in self._box_k:
            self._box_k[key] = torch.ones(1, 1, win, device=z.device, dtype=z.dtype) / float(win)
        k = self._box_k[key]
        r = win // 2

        # smooth positives and negatives separately (no cancellation)
        pos = F.relu(x)
        neg = F.relu(-x)

        sm_pos = F.conv1d(F.pad(pos, (r, r), mode='circular'), k, stride=4).squeeze(1)  # [N, D]
        sm_neg = F.conv1d(F.pad(neg, (r, r), mode='circular'), k, stride=4).squeeze(1)  # [N, D]

        pos_val, pos_idx = sm_pos.max(dim=-1)
        neg_val, neg_idx = sm_neg.max(dim=-1)

        centers = torch.stack([pos_idx, neg_idx], dim=-1).view(B, T, 2).long()
        strengths = torch.stack([pos_val, neg_val], dim=-1).view(B, T, 2)

        return centers, strengths

    @torch.no_grad()
    def update_bucketed_ema_and_scores(
            self,
            centers: torch.Tensor,  # [B, T, 2]  int indices (pos, neg)
            dim: int,  # width that produced centers; e.g., 2048 (FFN) or 768 (model)
            ten_x_cap: float = 7.0,  # "1.0 score == 10× average"
            min_cutoff: float = 0.20
    ) -> torch.Tensor:
        """
        Steps:
          1) Bucket centers with size=4 and AoE weights {+3, +2, +1} out to ±2 buckets (circular).
          2) Build two histograms length = dim//4 (pos and neg).
          3) Normalize so mean=1 by dividing by 36.
          4) Update the relevant EMA (model vs FFN) based on 'dim'.
          5) For each center, read back its bucket EMA value and map linearly to [0,1]
             with ramp start at (min_cutoff * ten_x_cap)× average and 1.0 at ten_x_cap×.
          6) Return [B, T, 2] scores.
        """
        assert centers.dim() == 3 and centers.size(-1) == 2, "centers must be [B,T,2]"
        B, T, _ = centers.shape
        device = centers.device

        # Select target EMA tensors by 'dim'
        if dim == self.d_ff:
            ema_pos = self.ema_ffn_pos
            ema_neg = self.ema_ffn_neg
        elif dim == self.d_model:
            ema_pos = self.ema_model_pos
            ema_neg = self.ema_model_neg
        else:
            raise ValueError(f"Unrecognized dim={dim}; expected one of {{d_ff={self.d_ff}, d_model={self.d_model}}}")

        bucket_count = dim // self.bucket_size
        if bucket_count <= 0:
            raise ValueError(f"dim={dim} too small for bucket_size={self.bucket_size}")

        # Flatten centers -> [N], clamp into range [0, dim-1]
        centers_flat = centers.view(-1, 2).to(torch.int64)
        pos_idx = centers_flat[:, 0].clamp_(0, dim - 1)
        neg_idx = centers_flat[:, 1].clamp_(0, dim - 1)

        # Convert to bucket indices (size=4 buckets)
        pos_b0 = torch.div(pos_idx, self.bucket_size, rounding_mode='floor')  # [N]
        neg_b0 = torch.div(neg_idx, self.bucket_size, rounding_mode='floor')  # [N]

        # AoE neighborhood (circular): offsets and weights
        offsets = torch.tensor([-2, -1, 0, 1, 2], device=device, dtype=torch.int64)
        weights = torch.tensor([1., 2., 3., 2., 1.], device=device, dtype=torch.float32)  # sums to 9

        # --- Build bucket histograms via weighted bincount (faster than scatter_add_) ---
        def build_hist(b0: torch.Tensor) -> torch.Tensor:
            if b0.numel() == 0:
                return torch.zeros(bucket_count, device=device, dtype=torch.float32)
            neigh = (b0.unsqueeze(1) + offsets) % bucket_count  # [N,5]
            w = weights.expand(b0.size(0), -1)  # [N,5]
            # Flatten, then weighted bincount
            return torch.bincount(
                neigh.reshape(-1),
                weights=w.reshape(-1),
                minlength=bucket_count
            ).to(torch.float32)

        hist_pos = build_hist(pos_b0)
        hist_neg = build_hist(neg_b0)

        # Normalize to mean=1 by dividing by 4 (bucket size)
        hist_pos = hist_pos / 4.0
        hist_neg = hist_neg / 4.0

        # EMA update
        decay = float(self.ema_decay)
        ema_pos.lerp_(hist_pos, 1.0 - decay)  # ema = ema*decay + hist*(1-decay)
        ema_neg.lerp_(hist_neg, 1.0 - decay)

        # Per-center scores in [0,1] (thresholded ramp to ten_x_cap× average)
        pos_vals = ema_pos.gather(0, pos_b0)  # values are in ×-average units (1.0 == average)
        neg_vals = ema_neg.gather(0, neg_b0)

        min_thresh = min_cutoff * ten_x_cap  # e.g., 0.15 * 7 = 1.05×
        scale = max(ten_x_cap - min_thresh, 1e-6)

        pos_scores = torch.clamp((pos_vals - min_thresh) / scale, min=0.0, max=1.0)
        neg_scores = torch.clamp((neg_vals - min_thresh) / scale, min=0.0, max=1.0)

        return torch.stack([pos_scores, neg_scores], dim=-1).view(B, T, 2)

    # helper for averaging tokens
    def average_per_token(self, values, counts=None, eps: float = 1e-12):
        if torch.is_tensor(values) and values.ndim == 0 and counts is None:
            return values.clamp(0.0, 1.0)  # scalar passthrough
        values = torch.as_tensor(values, device=counts.device if torch.is_tensor(counts) else None, dtype=torch.float32)
        mask = ~torch.isnan(values);
        vals = torch.nan_to_num(values, nan=0.0)
        if counts is not None:
            w = counts.to(vals.dtype) * mask;
            denom = w.sum().clamp_min(eps);
            return (vals * w).sum() / denom
        denom = mask.sum().clamp_min(1);
        return (vals.sum() / denom).clamp(0.0, 1.0)

    @torch.no_grad()
    def scramble_magnetism_noise(self, z: torch.Tensor, leak_scale: float,
                                 beta_max: float = 0.20, radius: int = 12,
                                 win: int = 9, track_energy: bool = False):
        if leak_scale <= 0.0 or beta_max <= 0.0:
            return torch.zeros_like(z)

        B, T, D = z.shape
        device, dtype = z.device, z.dtype

        # Get hubs and per-center trust in [0,1] (pos, neg)
        centers, _ = self.find_hubs_per_token(z, num_hubs=2, win=win)  # [B,T,2]
        center_trust = self.update_bucketed_ema_and_scores(centers, D)  # [B,T,2]

        # if self.layer_num == 22:
        #    print(center_trust.mean())

        z_flat = z.view(B * T, D)
        centers_flat = centers.view(B * T, 2)
        beta = beta_max * leak_scale

        # Distances to pos/neg centers (circular)
        idx = torch.arange(D, device=device).unsqueeze(0)  # [1,D]
        pos_c = centers_flat[:, 0] % D
        neg_c = centers_flat[:, 1] % D
        diff_pos = (idx - pos_c.unsqueeze(1)).abs()
        diff_neg = (idx - neg_c.unsqueeze(1)).abs()
        pos_dist = torch.minimum(diff_pos, D - diff_pos)
        neg_dist = torch.minimum(diff_neg, D - diff_neg)

        # Sign masks
        is_pos = (z_flat > 0)
        is_neg = (z_flat < 0)

        # Inside/outside the sign-matched hub
        pos_inside = is_pos & (pos_dist <= radius)
        neg_inside = is_neg & (neg_dist <= radius)
        outside = (is_pos & (pos_dist > radius)) | (is_neg & (neg_dist > radius))

        if track_energy and D < 1500 and not self.has_printed:
            #print(z_flat.shape)
            portions = self.hub_concentration_masked(z_flat, pos_inside, neg_inside, outside)
            avg_portion = self.average_per_token(portions)
            self.hub_energy += avg_portion

            #wa = self.wa_alignment_downproj(z_flat, self.wo, center=True, projector="orth", reduce="mean")["summary"]

            #self.wa_align += wa

            cos1d_avg = self.morans_1d(z_flat)
            self.cos_sim1d += cos1d_avg

            self.hub_e_counter += 1.0

            if not self.has_printed:
                print("Printing Now")
                self.print_metrics()
                self.has_printed = True


        # Build per-position scale:
        # - outside hub: scale = 1
        # - inside hub:  scale = trust (pos or neg), in [0,1]
        trust_flat = center_trust.view(B * T, 2).to(dtype)
        pos_scale_row = trust_flat[:, 0].unsqueeze(1)  # [N,1]
        neg_scale_row = trust_flat[:, 1].unsqueeze(1)  # [N,1]

        scale = torch.zeros_like(z_flat, dtype=dtype)
        scale = torch.where(outside, torch.ones_like(scale), scale)
        scale = torch.where(pos_inside, pos_scale_row.expand_as(z_flat), scale)
        scale = torch.where(neg_inside, neg_scale_row.expand_as(z_flat), scale)

        # Generate multiplicative noise only where scale>0
        if not torch.any(scale > 0):
            return torch.zeros_like(z)

        noise = torch.empty_like(z_flat).uniform_(-beta, beta)
        noise = noise * z_flat * scale  # multiplicative; scale∈[0,1], outside=1, inside=trust

        # Clamp to avoid sign flips
        noise = torch.where(
            z_flat > 0,
            noise.clamp(min=-z_flat.abs()),
            noise.clamp(max=z_flat.abs())
        )

        return noise.view(B, T, D)

    def forgetful_activation(self, x):
        x_clamp = torch.clamp(x, min=1e-6)
        return torch.where(x_clamp >= 1, torch.log2(x_clamp) + 1,
                           torch.exp2(x_clamp - 1))

    def old_forgetful_activation(self, x, beta: float = 12.0, eps: float = 1e-6):

        # “selector” ∈ (0,1) —  ~0   when x≪1 ,  ~1 when x≫1
        s = torch.sigmoid(beta * (x - 1.0))

        # right-hand branch (x ≥ 1) — use log1p for better accuracy near 0
        rhs = torch.log1p(torch.clamp_min(x, eps) - 1.0) / self.LN2 + 1.0  # log2(x)+1

        # left-hand branch (x < 1)
        lhs = torch.exp((x - 1.0) * self.LN2)  # 2**(x-1)

        # sharp but smooth blend
        return s * rhs + (1.0 - s) * lhs

    # calculates the relationships between weights at each layer
    def pairwise_activation(self, hidden_states, gate, w_val):
        """
        hidden_states : (B, T, 768)    fp16 or bf16
        gate          : (B, T, 2048)   fp16
        w_val         : (2048, 768)    fp16
        returns       : (B, T, 2048, 768)
        """
        # 1) h_j * W_ij  →  (B,T,2048,768)
        contrib = torch.einsum('btm,im->btim', hidden_states, w_val)
        # 2) apply σ(gate_i)
        contrib.mul_(gate.unsqueeze(-1))
        return contrib

    @torch.no_grad()
    def _recall_core(self, z: torch.Tensor, k_conv: torch.Tensor,
                     alpha: float, center: int) -> torch.Tensor:
        """Core recall logic, separated for better compilation."""
        if alpha <= 0.0:
            return torch.zeros_like(z)

        # Direct computation, no intermediate variables
        return alpha * (
                F.conv1d(
                    F.pad(z.reshape(-1, 1, z.shape[-1]), (center, center), 'circular'),
                    k_conv
                ).view_as(z) - z
        )

    def apply_recall_1d(self, z: torch.Tensor, kernel: torch.Tensor,
                        alpha: float) -> torch.Tensor:
        """Wrapper that handles kernel caching."""
        K = kernel.numel()
        key = (K, z.dtype, z.device.type, z.device.index)

        if key not in self._kernel_cache:
            self._kernel_cache[key] = kernel.to(
                z.device, z.dtype
            ).reshape(1, 1, -1).contiguous()

        return self._recall_core(z, self._kernel_cache[key], alpha, K // 2)

    def scramble_noise(self, t: torch.Tensor, magnitude: float, *, generator: torch.Generator | None = None) -> torch.Tensor:
        """
        Apply element-wise multiplicative noise to a tensor of shape [B, S, T, D].
        magnitude=m -> factor ~ Uniform(max(0, 1-m), 1+m).  (Clamp to 0 to avoid sign flips.)
        Preserves device/dtype. Gradients flow through the multiplication.

        Args:
            t: Input tensor [B, S, T, D] (any shape works; this is element-wise).
            magnitude: Non-negative scalar. e.g., 0.4 -> factors in [0.6, 1.4].
            generator: Optional torch.Generator for reproducibility.

        Returns:
            Tensor of same shape as t with multiplicative noise applied.
        """
        if magnitude < 0:
            raise ValueError("magnitude must be >= 0")

        low = max(0.0, 1.0 - magnitude)  # clamp lower bound at 0.0
        high = 1.0 + magnitude

        # Sample factors ~ U(low, high), element-wise
        noise = torch.empty_like(t, dtype=t.dtype, device=t.device).uniform_(low, high, generator=generator)
        # (low is already clamped, but keep this for numerical safety if low>high edge cases ever arise)
        noise = noise.clamp_min(0.0)

        return t * noise

    start_tau = LEAK.tau_mod

    def forward(self, hidden_states):
        # collect memories
        self.forward_passes += 1.0

        # lowers memory variance by 20% every five epochs
        # @todo should work more cleanly once nrem is set up to reinforce memory
        if (self.forward_passes + 1.0) % self.memory_boost == 1.0:
            self.forward_passes = 2.0
            self.variance = self.variance * 0.9
            self.long_var = self.long_var * 0.95
            self.default_std = self.default_std * 0.95
            # self.memory_boost = float(int(1.3 * self.memory_boost))
            # print("Current Variance: ", self.variance)

        # apply forgetfulness once every ten batches


        # prints emas
        if LEAK.num_batches == 4 and self.layer_num == 24:
            for i in self.ema_model_pos:
                print(i)
            print("end\n")

        # Leak once every 50 batches to blend weights
        if LEAK.safe_to_forget[self.layer_num]:
            leak_check = int(LEAK.tau)
            with torch.no_grad():
                # if ((num_batches + 1) % leak_check) == 0:
                # mod_t = 7.0 / float(6.0 + float(LEAK.tau_mod))
                mod_t = 1.0  # (float(LEAK.start_tau - LEAK.tau_mod)**1.2 / float(LEAK.start_tau)**1.2) #@ todo, more stable stall while recall still ramps
                leak_once_uniform(self.wi_0.weight, self.kernel7, LEAK.alpha_val * mod_t)
                leak_once_uniform(self.wi_1.weight, self.kernel7, LEAK.alpha_gate * mod_t)
                leak_once_uniform(self.wo.weight, self.kernel7, LEAK.alpha_down * mod_t)

                safe_to_forget[self.layer_num] = False

        # if apply_recall: print(self.layer_num)
        recall_alpha = LEAK.recall_alpha  # [entry, up_gate, up_val, down]
        k11 = self.recall_kernel11  # moved to device/dtype inside helper

        # Add recall forgetfulness to input
        if LEAK.apply_blur and recall_alpha[0] > 0.0:
            with torch.no_grad():
                # sanity check
                # print("\n",hidden_states[10][0][15:30])
                hs_blend = self.apply_recall_1d(hidden_states, k11, _alpha_eff(recall_alpha[0]))

            hidden_states = hidden_states + hs_blend.detach()# + hs_mag.detach()
            # print(hidden_states[10][0][15:30])

        # Perform up projection
        value = self.wi_0(hidden_states)
        gate = self.wi_1(hidden_states)

        # recall forgetfulness for up projection
        # these functions are for gate and value separately
        if LEAK.apply_blur and recall_alpha[1] > 0.0:
            with torch.no_grad():
                g_blend = self.apply_recall_1d(gate, k11, _alpha_eff(recall_alpha[1]))


            gate = gate + g_blend.detach()# + gate_mag.detach()

        if LEAK.apply_blur and recall_alpha[2] > 0.0:
            with torch.no_grad():
                v_blend = self.apply_recall_1d(value, k11, _alpha_eff(recall_alpha[2]))

            value = value + v_blend.detach()# + value_mag.detach()

        # elementwise multiplication
        x = value * self.activation(gate)
        #x = self.activation(value) * gate


        mag_scale = _alpha_eff(1.0)

        # applies fuzzy recall after merging to make cancelations less likely (cancelations risk ignoring ffn)
        if LEAK.apply_recall and ((LEAK.num_batches + self.layer_num) % 2 == 0) and LEAK.max_variance > 0.0:
            with torch.no_grad():
                # x_blend = self.apply_recall_1d(x, k11, _alpha_eff(recall_alpha[2]))
                x_mag = self.scramble_magnetism_noise(x, leak_scale=mag_scale, beta_max=LEAK.max_variance, radius=40,
                                                      win=15)
            x = x + x_mag.detach()

        # noise on all areas evenly for robustness testing
        if LEAK.apply_noise:
            with torch.no_grad():
                y_noise = self.scramble_noise(x, magnitude=LEAK.noise_strength)
            x = x + (y_noise - x).detach()

        # dropout and downprojection
        x = self.dropout(x)
        x = self.wo(x)

        # recall forgetfulness for layer output
        if LEAK.apply_recall and ((LEAK.num_batches + self.layer_num) % 2 == 1) and LEAK.max_variance > 0.0:
            with torch.no_grad():
                y_mag = self.scramble_magnetism_noise(x, leak_scale=mag_scale, beta_max=LEAK.max_variance, radius=16,
                                                      win=9)
                # with recall: 15 sec
                # with recall + magnetism:
            x = x + y_mag.detach()


        # strong noise outside hubs, weak noise inside
        if LEAK.apply_blur and recall_alpha[3] > 0.0:
            with torch.no_grad():
                x_blend = self.apply_recall_1d(x, k11, _alpha_eff(recall_alpha[3]))

            x = x + x_blend.detach()


        if LEAK.apply_noise:
            with torch.no_grad():
                x_noise = self.scramble_noise(x, magnitude=LEAK.noise_strength)
            x = x + (x_noise - x).detach()

        return x


# forgetful t5
class ForgetfulT5(nn.Module):
    def __init__(self, base_model, eps=1e-5):
        """
                Wraps the base MT5 model and replaces its feed-forward modules with a custom MT5 dense gated
                variant that incorporates your custom normalization.

                Args:
                    base_model: A pretrained MT5 model (e.g., loaded via AutoModelForConditionalGeneration).
                    eps: Epsilon for numerical stability.
                """
        super().__init__()
        self.base_model = base_model
        self.eps = eps

        layer_num = 0

        # print(layer_num)
        # Modify encoder blocks (typically, feed-forward is located in layer[1]).
        for block in self.base_model.encoder.block:
            orig_ff = block.layer[1].DenseReluDense
            block.layer[1].DenseReluDense = CustomDenseReluDense(orig_ff, layer_num, eps=self.eps)
            layer_num += 1

        # @todo implement for cross attention as well
        # @todo is the memory loss enough for the number of layers?
        # Modify decoder blocks (for MT5 with cross-attention, feed-forward is usually in layer[2]).
        for block in self.base_model.decoder.block:
            if len(block.layer) >= 3:
                orig_ff = block.layer[2].DenseReluDense
                block.layer[2].DenseReluDense = CustomDenseReluDense(orig_ff, layer_num, eps=self.eps)
                # print("Layer ", layer_num)
            else:
                # Fallback (shouldn't be needed for MT5)
                orig_ff = block.layer[1].DenseReluDense
                block.layer[1].DenseReluDense = CustomDenseReluDense(orig_ff, layer_num, eps=self.eps)

            layer_num += 1

        print("Custom feed-forward modules have been applied to the MT5 model.")
        print("Total Layers: ", layer_num)

    def forward(self, input_ids=None, attention_mask=None, labels=None, **kwargs):
        # Delegate the forward pass to the underlying MT5 model.
        return self.base_model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=labels,
            **kwargs
        )

    def generate(self, *args, **kwargs):
        return self.base_model.generate(*args, **kwargs)

    def _get_ffn(self, layer_num: int):
        """Return the FFN module at global layer_num without storing an array."""
        enc_blocks = self.base_model.encoder.block
        dec_blocks = self.base_model.decoder.block
        n_enc = len(enc_blocks)

        if not (0 <= layer_num < n_enc + len(dec_blocks)):
            raise IndexError("layer_num out of range")

        if layer_num < n_enc:
            # T5 encoder: FFN is block.layer[1].DenseReluDense
            return enc_blocks[layer_num].layer[1].DenseReluDense
        else:
            i = layer_num - n_enc
            blk = dec_blocks[i]
            # T5 decoder: FFN usually at layer[2], fallback to [1] if needed
            ffn_idx = 2 if len(blk.layer) >= 3 else 1
            return blk.layer[ffn_idx].DenseReluDense

    def visualize_ffn(self, layer_num: int, which: str = "both",
                      show: bool = True, savepath: str = None):
        """
        Heatmaps for FFN weights at `layer_num`.
        which ∈ {"both","val","gate","all"}; "all" also shows wo.
        Color: blue (neg) → white (0) → red (pos), zero-centered.
        """
        ffn = self._get_ffn(layer_num)

        mats, titles = [], []
        if which in ("both", "val", "all"):
            Wv = ffn.wi_0.weight.detach().to("cpu").float().T  # (d_model x d_ff) ~ 768x2048
            mats.append(Wv);
            titles.append("wi_0 (value up)")
        if which in ("both", "gate", "all"):
            Wg = ffn.wi_1.weight.detach().to("cpu").float().T
            mats.append(Wg);
            titles.append("wi_1 (gate up)")
        if which == "all" and hasattr(ffn, "wo"):
            Wo = ffn.wo.weight.detach().to("cpu").float()  # (d_model x d_ff)
            mats.append(Wo.T);
            titles.append("wo (down)")

        vmax = max(m.abs().max().item() for m in mats)
        vmin = -vmax

        n = len(mats)
        fig, axes = plt.subplots(1, n, figsize=(6 * n, 6), constrained_layout=True)
        axes = [axes] if n == 1 else axes

        for ax, M, title in zip(axes, mats, titles):
            im = ax.imshow(M, cmap="seismic", vmin=vmin, vmax=vmax,
                           aspect="auto", interpolation="nearest")
            ax.set_title(f"Layer {layer_num}: {title}\n(768 × 2048 view)")
            ax.set_xlabel("d_ff (neurons)")
            ax.set_ylabel("d_model (features)")
            fig.colorbar(im, ax=ax, fraction=0.046, pad=0.04)

        if savepath:
            fig.savefig(savepath, dpi=150)
        if show:
            plt.show()
        return fig

class DownBandAccumulator:
    """
    Accumulates mean FFN-down (post-wo) vectors per decoder layer over **valid decode positions**.
    Also collects per-target token vectors for the first N occurrences (valid positions only).
    """
    def __init__(self, model, tokenizer, layers: list[int], target_ids: list[int], per_token_samples:int=8):
        self.model = model
        self.tok = tokenizer
        self.layers = sorted(layers)
        self.target_ids = set(target_ids)
        self.per_token_samples = per_token_samples
        self._last_labels_dec = None   # <-- add

        D = model.base_model.config.d_model
        self.sum_by_L = {L: torch.zeros(D, dtype=torch.float32) for L in self.layers}
        self.cnt_by_L = {L: 0 for L in self.layers}  # **counts positions**, not batches

        self.collected = {tid: {L: [] for L in self.layers} for tid in self.target_ids}

        self._last_dec_ids: Optional[torch.Tensor] = None     # [B,T]
        self._last_valid_mask: Optional[torch.Tensor] = None  # [B,T] bool
        self._handles = []

    def set_decoder_ids_and_mask(self, dec_ids_bt: torch.Tensor, valid_mask_bt: torch.Tensor,
                                 labels_dec_bt: torch.Tensor):
        """Store decoder ids and a bool mask of valid decode positions, both [B,T]."""
        self._last_dec_ids = dec_ids_bt
        self._last_valid_mask = valid_mask_bt
        self._last_labels_dec = labels_dec_bt   # <-- add

    def _mk_hook(self, L: int):
        def fn(module, inp, out):
            if not isinstance(out, torch.Tensor):
                return out
            y = out.detach().float()  # [B,T,D]

            # 1) Baseline accumulation over valid positions only (normalized per position)
            if self._last_valid_mask is not None:
                vm = self._last_valid_mask.to(y.device)  # [B,T] bool
                if vm.any():
                    vflat = y[vm]  # [num,D]
                    m = vflat.abs().mean(dim=1, keepdim=True) + 1e-8  # [num,1]
                    vflat_norm = vflat / m  # per-position avg|.| = 1
                    self.sum_by_L[L] += vflat_norm.sum(dim=0).cpu()
                    self.cnt_by_L[L] += vflat_norm.size(0)
            else:
                # fallback (shouldn’t trigger normally)
                vflat = y.reshape(-1, y.size(-1))
                m = vflat.abs().mean(dim=1, keepdim=True) + 1e-8
                vflat_norm = vflat / m
                self.sum_by_L[L] += vflat_norm.sum(dim=0).cpu()
                self.cnt_by_L[L] += vflat_norm.size(0)

            # 2) Per-token collection (leave raw here; we normalize later per sample)
            if (self._last_labels_dec is not None) and (self._last_valid_mask is not None):
                lab = self._last_labels_dec.to(out.device)
                vm = self._last_valid_mask.to(out.device)
                for tid in list(self.target_ids):
                    if all(len(self.collected[tid][Li]) >= self.per_token_samples for Li in self.layers):
                        continue
                    mask_tid = (lab == tid) & vm
                    if mask_tid.any():
                        vec = out.detach().float()[mask_tid].mean(dim=0).cpu()
                        self.collected[tid][L].append(vec)
            return out

        return fn

    def register(self):
        named = dict(self.model.base_model.named_modules())
        for L in self.layers:
            path = f"decoder.block.{L}.layer.2.DenseReluDense"
            if path not in named:
                raise KeyError(f"Cannot find module {path}")
            self._handles.append(named[path].register_forward_hook(self._mk_hook(L)))

    def unregister(self):
        for h in self._handles:
            try: h.remove()
            except: pass
        self._handles = []

    def build_baseline_and_token_means(self) -> tuple[dict[int, np.ndarray], dict[int, dict[int, np.ndarray]]]:
        # Baseline: average of per-sample normalized vectors (already accumulated as normalized)
        baseline = {}
        for L in self.layers:
            c = max(1, self.cnt_by_L[L])
            # sum_by_L[L] already contains sum of (vec / mean(|vec|)):
            baseline[L] = (self.sum_by_L[L] / c).detach().cpu().numpy()

        # Token means: normalize each sample before averaging
        token_means = {}
        for tid in self.target_ids:
            perL = {}
            for L in self.layers:
                lst = self.collected[tid][L]  # list[Tensor [D]]
                if len(lst) == 0:
                    continue
                # Normalize each sample to avg|.|=1, then mean
                normed = []
                for t in lst:
                    m = t.abs().mean() + 1e-8
                    normed.append(t / m)
                perL[L] = torch.stack(normed, dim=0).mean(dim=0).detach().cpu().numpy()
            if perL:
                token_means[tid] = perL

        return baseline, token_means

class BandSteeringController:
    """
    Ratio-space hub pulsing with correct proportional scaling.

    - baseline_norm[L]: avg(|.|) ~= 1 per layer (normalized corpus baseline)
    - signatures_one_token: per-layer centers with ratio magnitudes r>0 for P/N
      (computed from normalized token_mean - normalized baseline)

    Runtime on a targeted cache step:
      g_abs   = mean_D(|y|)
      base_den = baseline_norm * g_abs

      g_delta = mean_D(|y - base_den|)              # <-- deviation scale
      pulse_norm = sum_over_centers(sign * r * α on each apply window)
      pulse_den  = pulse_norm * g_delta             # <-- denorm with deviation scale

      y_out = base_den + pulse_den
    """

    def __init__(self,
                 model,
                 baseline_norm: dict[int, np.ndarray],
                 signatures_one_token: dict[str, dict[str, list]],
                 layers: list[int],
                 win: int = 9,
                 radius: int | None = None,
                 alphas: float | dict[int, float] = 0.3,
                 step_indices: set[int] | None = None,
                 verbose: bool = False):
        import torch

        self.model = model
        self.layers = sorted(int(L) for L in layers)
        self.win = int(win)
        self.radius = int(radius if radius is not None else max(1, self.win // 2))
        self.verbose = verbose

        # Per-layer α
        if isinstance(alphas, dict):
            self.alphas = {int(k): float(v) for k, v in alphas.items()}
        else:
            self.alphas = {L: float(alphas) for L in self.layers}

        self.step_indices = set(step_indices or set())

        # Plan per layer: list of (center, ratio, sign)
        self.plan_by_L: dict[int, list[tuple[int, float, int]]] = {}
        for tagL, entry in (signatures_one_token or {}).items():
            if not (isinstance(tagL, str) and tagL.startswith("D")):
                continue
            L = int(tagL[1:])
            if L not in self.layers:
                continue
            P, Pm = entry.get("P", []) or [], entry.get("P_mag", []) or []
            N, Nm = entry.get("N", []) or [], entry.get("N_mag", []) or []
            plan = []
            for c, r in zip(P, Pm): plan.append((int(c), float(r), +1))
            for c, r in zip(N, Nm): plan.append((int(c), float(r), -1))
            if plan:
                self.plan_by_L[L] = plan

        # Baselines on CPU
        self.baseline_by_L = {
            L: torch.from_numpy(baseline_norm[L]).float()
            for L in self.layers if L in baseline_norm
        }

        self._patches: dict[int, tuple[object, callable]] = {}
        self._curr_step = -1

    def reset(self):
        self._curr_step = -1

    def register(self):
        import torch

        if self._patches:
            if self.verbose:
                print("[controller] Existing patches found; unregistering.")
            self.unregister()

        named = dict(self.model.base_model.named_modules())
        self._curr_step = -1

        # Earliest actually patched layer for step counting
        patched_layers = []
        for L in self.layers:
            path = f"decoder.block.{L}.layer.2.DenseReluDense"
            if (L in self.plan_by_L) and (self.alphas.get(L, 0.0) != 0.0) and (path in named) and (L in self.baseline_by_L):
                patched_layers.append(L)
        first_patched = min(patched_layers) if patched_layers else None

        for L in self.layers:
            plan = self.plan_by_L.get(L)
            alpha = float(self.alphas.get(L, 0.0))
            path = f"decoder.block.{L}.layer.2.DenseReluDense"
            mod = named.get(path)
            b_cpu = self.baseline_by_L.get(L)

            if not plan or alpha == 0.0 or (mod is None) or (b_cpu is None):
                if self.verbose:
                    why = ("no plan" if not plan else
                           ("alpha=0" if alpha == 0.0 else
                            ("missing module" if mod is None else "no baseline")))
                    print(f"[controller] skip L{L}: {why}")
                continue

            # Bind per-layer locals to avoid closure bleed-through
            _plan     = [(int(c), float(r), int(s)) for (c, r, s) in plan]
            _alpha    = alpha
            _rad      = self.radius
            _steps    = set(self.step_indices)
            _L        = L
            _b_cpu    = b_cpu
            _orig_fwd = mod.forward
            _first    = first_patched
            _verbose  = self.verbose

            def wrapped_forward(x, *args,
                                _plan=_plan, _alpha=_alpha, _rad=_rad,
                                _steps=_steps, _L=_L, _b_cpu=_b_cpu,
                                _orig_fwd=_orig_fwd, _first=_first,
                                _verbose=_verbose, **kwargs):
                import torch

                y = _orig_fwd(x, *args, **kwargs)  # [B,T,D]

                if not (isinstance(y, torch.Tensor) and y.dim() == 3):
                    return y
                B, T, D = y.shape
                if T != 1:
                    if _verbose:
                        print(f"[controller] L{_L} SKIP (non-cache pass T={T})")
                    return y

                # Step counting only at earliest patched layer
                if _L == _first:
                    self._curr_step += 1

                if _steps and (self._curr_step not in _steps):
                    if _verbose:
                        print(f"[controller] L{_L} SKIP (cache step {self._curr_step} not targeted)")
                    return y

                # Denormalize baseline with absolute scale (kept for diagnostics; not used to recompose)
                sel = 0
                # ---- Scale reference (diagnostics only; not used to recompose) ----
                b = _b_cpu.to(y.device, y.dtype).view(1, 1, -1)  # normalized baseline (avg|.|≈1)
                g_abs = y.abs().mean(dim=-1, keepdim=True)  # [B,1,1]
                base_den = b * g_abs  # [B,1,D]

                # ---- Uniform pulse in y-space over apply windows ----
                _, _, D = y.shape
                pulse = torch.zeros_like(y)  # [B,1,D]

                for center, ratio, sign in _plan:
                    w = int(_rad)
                    lo = max(0, center - w)
                    hi = min(D, center + w + 1)
                    if lo >= hi:
                        continue

                    # Uniform mask over [lo:hi)
                    u = torch.zeros(1, 1, D, device=y.device, dtype=y.dtype)
                    u[:, :, lo:hi] = 1.0

                    # Amount to add uniformly in that region (no zero-mean)
                    # If you want energy-consistent pulses, optionally divide by sqrt(hi - lo).
                    # scale = _alpha * float(sign) * (float(ratio) - 1.0) * g_abs / math.sqrt(hi - lo)
                    scale = _alpha * float(sign) * (float(ratio) - 1.0) * g_abs  # [B,1,1]

                    pulse = pulse + scale * u  # broadcast across batch

                # Apply the pulse directly to the FFN output
                y_out = y + pulse

                if _verbose:
                    centers = [c for (c, _, _) in _plan]
                    mean_pulse = pulse.abs().mean().item()
                    print(f"[controller] L{_L} APPLIED (step {self._curr_step}) "
                          f"centers={centers} | g_abs={g_abs.mean().item():.4f} | mean|pulse|={mean_pulse:.4f}")

                return y_out

            self._patches[L] = (mod, _orig_fwd)
            mod.forward = wrapped_forward

            if self.verbose:
                centers_str = ", ".join(str(c) for (c, _, _) in _plan)
                print(f"[controller] Patched {path} | α={alpha} | rad={self.radius} | centers=[{centers_str}]")

    def unregister(self):
        for _, (mod, orig) in self._patches.items():
            mod.forward = orig
        self._patches.clear()
        if self.verbose:
            print("[controller] Unpatched all layers.")