import os
from transformers.activations import ACT2FN

os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

import numpy as np
import random
import torch.nn as nn
import torch.optim as optim
from transformers import MT5ForConditionalGeneration, AutoTokenizer, MT5Tokenizer, AutoModelForSeq2SeqLM, AutoModel
from torch.optim import AdamW
import sacrebleu
from torch.utils.data import DataLoader
from datasets import load_dataset, load_from_disk
from functools import partial
from transformers import DataCollatorWithPadding
from torch.optim.lr_scheduler import StepLR
from torch.cuda.amp import autocast, GradScaler
from tqdm import tqdm
from torch.nn import functional as F
from functools import lru_cache
from dataclasses import dataclass, field
from itertools import groupby          # ← add this
import matplotlib.pyplot as plt

import torch, gc
import math

#from mbert_pretraining import BATCH_SIZE

#print("PyTorch version:", torch.__version__)
#print("CUDA available:", torch.cuda.is_available())
#print("CUDA device count:", torch.cuda.device_count())
#print("CUDA version:", torch.version.cuda)

import torch
import torch.fft as fft

# corruption collator
from itertools import groupby
import torch, random

@torch.no_grad()
def renorm_toward_rms(param, target_rms, max_change=0.10):
    cur = param.pow(2).mean().sqrt()
    if cur <= 0:
        return
    # multiplicative step toward target, limited to ±10% per call
    step = (target_rms / (cur + 1e-8))
    step = step.clamp(1.0 - max_change, 1.0 + max_change)
    param.mul_(step)

class T5SpanCorruptionCollator:
    """
    Vectorised collator for T5/mT5 span corruption.

    • Works even if <extra_id_*> strings are not in tokenizer.special_tokens
      (sentinel ids are derived numerically: vocab_size‑1‑i).
    • Pads encoder inputs to `input_length`.
    • Pads decoder labels to same length and fills pad positions with ‑100,
      so CrossEntropyLoss(ignore_index=-100) works directly.
    """

    def __init__(self, tokenizer,
                 noise_density=0.15,
                 mean_span_len=3,
                 input_length=64):

        self.tok = tokenizer
        self.noise = noise_density
        self.mean = mean_span_len
        self.L = input_length

        self.pad = tokenizer.pad_token_id
        self.eos = tokenizer.eos_token_id
        # sentinel_id(i) = vocab_size - 1 - i   (T5 design choice)
        self.sentinels = torch.arange(tokenizer.vocab_size - 1,
                                      tokenizer.vocab_size - 101, -1)

    # ------------------------------------------------------------------
    def _random_mask(self, B, L, device):
        """Generate a (B,L) boolean mask with ≈noise_density True values."""
        bern = torch.rand(B, L, device=device) < self.noise

        # ensure every row has at least one True
        row_mask = bern.sum(1) == 0  # rows with 0 masked
        if row_mask.any():
            rand_cols = torch.randint(0, L, (row_mask.sum(),), device=device)
            bern[row_mask, rand_cols] = True
        return bern

    # ------------------------------------------------------------------
    def __call__(self, examples):
        ids = torch.stack([ex["input_ids"][: self.L] for ex in examples])
        B, L = ids.shape
        device = ids.device

        mask = self._random_mask(B, L, device)  # corruption mask

        inputs = ids.clone()
        labels = torch.full_like(ids, self.pad)

        for b in range(B):
            row_mask = mask[b]
            spans = [(flag, len(list(run)))
                     for flag, run in groupby(row_mask.tolist())]

            in_row, lb_row, cur, s_idx = [], [], 0, 0
            for is_mask, span_len in spans:
                if is_mask:
                    sent_id = int(self.sentinels[s_idx])
                    in_row.append(sent_id)  # encoder sentinel
                    lb_row.append(sent_id)  # decoder sentinel
                    lb_row.extend(ids[b, cur:cur + span_len].tolist())
                    s_idx += 1
                else:
                    in_row.extend(ids[b, cur:cur + span_len].tolist())
                cur += span_len
            lb_row.append(self.eos)

            # pad/truncate exactly to L
            if len(in_row) < L:
                in_row.extend([self.pad] * (L - len(in_row)))
            inputs[b] = torch.tensor(in_row[:L], device=device)

            labels[b, : len(lb_row)] = torch.tensor(lb_row[:L], device=device)
            labels[b, len(lb_row):] = -100  # ignore in loss

        batch = {
            "input_ids": inputs,
            "attention_mask": (inputs != self.pad).long(),
            "labels": labels,
        }

        # pass‑through any extra columns untouched (k_alts, k_confs, …)
        for k in examples[0]:
            if k not in batch:
                batch[k] = [ex[k] for ex in examples]

        return batch                                # updated weights


@torch.no_grad()
def leak_once_uniform(wei: torch.Tensor,
                      K7:  torch.Tensor,
                      frac: float = 0.01,
                      max_frac: float = 2.0):

    K7 = K7.to(dtype=wei.dtype, device=wei.device)

    # --- measure pre-leak scale ---
    pre_rms = wei.pow(2).mean().sqrt()

    k_pad = K7.shape[-1]//2

    inp = F.pad(wei.unsqueeze(0).unsqueeze(0),
                pad=(k_pad, k_pad, k_pad, k_pad), mode="circular")

    neigh_avg = F.conv2d(inp, K7).squeeze_(0).squeeze_(0)

    delta = frac * (neigh_avg - wei)
    # clip so no weight changes by more than 5 %
    delta = torch.clamp(delta,
                        -max_frac * wei.abs(),
                         max_frac * wei.abs())

    wei.add_(delta)

    # --- rescale to preserve variance ---
    #if preserve_rms:
    post_rms = wei.pow(2).mean().sqrt()
    scale = (pre_rms / (post_rms + 1e-8)).clamp(0.1, 10.0)
    wei.mul_(scale)

    wei.nan_to_num_(nan=0.0, posinf=1e4, neginf=-1e4)

# ── 1 · global hyper-parameters ────────────────────────────────────────────
@dataclass
class LeakParams:
    alpha_val   :  float = 0.030   # strength for val
    alpha_gate  :  float = 0.030   # strength for gate
    alpha_down  :  float = 0.022
    a_val_dorm  :  float = 0.007   # strength of leak to maintain order
    a_gate_dorm :  float = 0.007
    sigma0      :  float = 2.0
    lam         :  float = 4.0      # controls preference of close vs far: lower prefers close
    tau         :  int   = 50       # batches between leaks
    tau_mod     :  int   = 0
    start_tau   :  int   = 241
    recall_alpha:  list  = field(default_factory=lambda: [0.15, 0.3, 0.3, 0.3]) # [entry, up_gate, up_val, down]
    use_fft     :  bool  = True
    max_variance:  float = 11.0     # base vals: 2.0: 78.1  2.5: 76.5 3.0: 70.5 bleu  5.75: 56.95 bleu  9.0: 14.0 bleu


# Blend enabled?   No          Yes        90%
# Base Model:     78.5 BLEU    25.0 BLEU
# Whole Stream 6: 78.5 bleu    78.7 BLEU  76.1 BLEU

"""
VARIANCE 1.7999999999999998
Batch 1000 of 50000
Loss: 0.279685721218586
Gold loss: 0.4591213328242302
BLEU score: 78.52
Perplexity: 1.0001150089654989
Overall: 78.66
"""

LEAK = LeakParams()


# 2d kernel for weight operations
def make_kernel(size: int, lam: float, centre_zero: bool = True) -> torch.Tensor:
    """
    Isotropic wrap-around kernel.
    size must be an odd number (e.g. 7, 9, 11).
    """
    assert size % 2 == 1, "size must be odd"
    half = size // 2                               # 3 for 7×7, 4 for 9×9 …
    ys = torch.arange(-half, half + 1)
    xs = torch.arange(-half, half + 1)
    yy, xx = torch.meshgrid(ys, xs, indexing='ij') # centred grid
    R = torch.sqrt(xx**2 + yy**2)                  # Euclidean distance
    K = torch.exp(-R / lam)                        # exponential decay
    if centre_zero:
        K[half, half] = 0.0                       # remove self-weight
    K /= K.sum()                                   # normalise
    return K


# makes a 1d neighborhood kernel
def make_blend_kernel(K: int = 11, sigma: float = 2.0, dtype=None, device=None) -> torch.Tensor:
    c = K // 2

    assert K % 2 == 1 and K >= 3, "K must be odd and >=3"
    x = torch.arange(-(K // 2), K // 2 + 1, dtype=dtype, device=device)
    w = torch.exp(-0.5 * (x / float(sigma)) ** 2)
    w = w / (w.sum() + 1e-12)

    if w[c].abs() > 0:
        w = w.clone()
        w[c] = 0
        s = w.sum()
        if float(s) <= 1e-12:
            raise ValueError("neighbor_only produced zero-sum kernel; choose larger K or different sigma.")
        w = w / s

    return w


def get_recall_mask(hidden_states: torch.Tensor) -> torch.Tensor:
    """
    Decide which token positions to blend.
    Priority:
      1) global LEAK.recall_mask (bool [B,T]) if present
      2) LEAK.decoder_out_mask if present
      3) fallback: last timestep per sequence.
    """
    B, T = hidden_states.shape[:2]
    # 1) external mask
    m = getattr(LEAK, "recall_mask", None)
    if isinstance(m, torch.Tensor):
        return m.to(hidden_states.device).bool()
    m = getattr(LEAK, "decoder_out_mask", None)
    if isinstance(m, torch.Tensor):
        return m.to(hidden_states.device).bool()
    # 3) fallback: last position
    mask = torch.zeros((B, T), dtype=torch.bool, device=hidden_states.device)
    mask[:, -1] = True
    return mask




def _alpha_eff(a):
    mod_t = (float(LEAK.start_tau - LEAK.tau_mod) / float(LEAK.start_tau))
    return (float(a) * mod_t)


def make_slanted_kernels(K: int = 9, decay: float = 0.6, *, dtype=None, device=None):
    """
    Two one-sided kernels (center weight = 0, sums = 1):
      - k_right pulls mass from left neighbors (move right)
      - k_left  pulls mass from right neighbors (move left)
    """
    assert K % 2 == 1 and K >= 3
    half = K // 2
    w = (decay ** torch.arange(1, half + 1, dtype=dtype, device=device))  # near neighbor largest
    w = w / (w.sum() + 1e-12)

    k_right = torch.zeros(K, dtype=dtype, device=device)  # use left half (…,-2,-1)
    k_right[:half] = w.flip(0)                            # nearest (-1) gets biggest weight
    # center = 0
    k_left  = torch.zeros(K, dtype=dtype, device=device)  # use right half (+1,+2,…)
    k_left[half+1:] = w                                   # nearest (+1) gets biggest weight
    return k_left, k_right


def _circular_conv1d(x, kernel):
    """x: [N, 1, D], kernel: [1, 1, K], circular padding."""
    K = kernel.shape[-1]
    r = K // 2
    xpad = F.pad(x, (r, r), mode='circular')
    return F.conv1d(xpad, kernel)


def _circular_box_smooth(abs_z, win: int):
    """
    abs_z: [N, 1, D] → smoothed |z| with length-'win' box kernel (circular).
    """
    box = torch.ones(1, 1, win, device=abs_z.device, dtype=abs_z.dtype) / float(win)
    return _circular_conv1d(abs_z, box)


# ---- ring distance cache (per D × device) ----
_RING_DIFF_CACHE = {}  # (D, dev_type, dev_index) -> [D,D] int16

# ---- global caches ----
_RING_DIFF_CACHE = {}
_AMP_LUT_CACHE = {}   # key: (D, dev.type, dev.index, dtype, p_dist, dead_zone)

def get_ring_signed_diff(D: int, device):
    """
    Returns T[c, j] = signed shortest distance from j to center c on a ring of length D,
    in [-D//2, D//2]. Stored as int16 for bandwidth; cast at use.
    """
    dev = torch.device(device)
    key = (D, dev.type, dev.index if dev.type == "cuda" else -1)
    tbl = _RING_DIFF_CACHE.get(key)
    if tbl is not None and tbl.device == dev:
        return tbl

    pos = torch.arange(D, device=dev, dtype=torch.int32)      # [D]
    centers = pos.view(D, 1)                                   # [D,1]
    diff = ((centers - pos + D // 2) % D) - D // 2             # [D, D] in [-D//2, D//2]
    tbl = diff.to(torch.int16).contiguous()
    _RING_DIFF_CACHE[key] = tbl
    return tbl

def get_amp_luts(D: int, device, dtype, p_dist: float):
    """
    Precompute distance->amplitude for SAME and OPPOSITE sign cases.
    Matches:
        nd = clamp(|d| / (D/6), 0, 1)
        amp_same = nd**p_dist
        amp_opp  = (1 - nd)**p_dist
    Returns two 1D tensors (length D//2 + 1) in 'dtype' on 'device'.
    """
    dev = torch.device(device)
    key = (D, dev.type, dev.index if dev.type == "cuda" else -1, str(dtype), float(p_dist))
    cached = _AMP_LUT_CACHE.get(key)
    if cached is not None and all(t.device == dev for t in cached):
        return cached

    maxd = D // 2
    dist = torch.arange(0, maxd + 1, device=dev, dtype=torch.float32)
    nd = (dist / (D / 6.0)).clamp_(0.0, 1.0)
    amp_same = nd.pow(p_dist)
    amp_opp  = (1.0 - nd).pow(p_dist)

    same = amp_same.to(dtype).contiguous()
    opp  = amp_opp.to(dtype).contiguous()
    _AMP_LUT_CACHE[key] = (same, opp)
    return same, opp


device = "cuda"

#get_ring_signed_diff(768, device)
#get_ring_signed_diff(2048, device)

# At the top of your file, outside any class:
_GLOBAL_HUB_KERNELS = {}
_GLOBAL_DIST_CACHE = {}

def _ema_decay_from_halflife(halflife_steps: float) -> float:
    # decay = exp(-ln(2)/halflife)
    return math.exp(-math.log(2.0) / max(1e-6, float(halflife_steps)))

class CustomDenseReluDense(nn.Module):
    def __init__(self, orig_module, layer_num, variance=0.0010, long_var=0.0010, default_std=2.5,
                 forget_factor=0.998, long_forget_factor=0.9999, eps=1e-5, ema_halflife=600.0):
        """
        Wraps the original DenseReluDense module and overrides the inner feed-forward computation.

        Args:
            orig_module: The original DenseReluDense module from T5/MT5.
            eps: Small epsilon to avoid division by zero.
        """
        super().__init__()
        self.eps = eps
        # For MT5, the feed-forward module uses gated linear projections:
        self.wi_0 = orig_module.wi_0  # Linear: (d_model -> d_ff)
        self.wi_1 = orig_module.wi_1  # Linear: (d_model -> d_ff)
        self.wo = orig_module.wo  # Linear: (d_ff -> d_model)
        self.forward_passes = 0.0
        self.dropout = orig_module.dropout
        self.activation = ACT2FN["gelu_new"]  # Activation function (e.g. GeLU)

        with torch.no_grad():
            self._rms_target = {
                "wi0": self.wi_0.weight.pow(2).mean().sqrt().item(),
                "wi1": self.wi_1.weight.pow(2).mean().sqrt().item(),
                "wo": self.wo.weight.pow(2).mean().sqrt().item(),
            }

        print(self._rms_target)

        # Register buffer for custom normalization; shape = (d_ff , d_model)
        #self.register_buffer("short_fg_value", torch.zeros(self.wi_0.weight.shape), persistent=True)
        #self.register_buffer("short_fg_gate", torch.zeros(self.wi_1.weight.shape), persistent=True)

        # Register long term forgetfulness
        #self.register_buffer("long_fg_value", torch.zeros(self.wi_0.weight.shape), persistent=True)
        #self.register_buffer("long_fg_gate", torch.zeros(self.wi_1.weight.shape), persistent=True)

        self.layer_num = int(layer_num)

        #print(self.short_fg_value.shape)

        # multiplier for how intense range is from given forgetfulness value
        self.variance = variance
        self.long_var = long_var
        self.forget_factor = forget_factor
        self.long_forget_factor = long_forget_factor
        self.default_std = default_std
        self.base_intensity = 2.5  # @todo try or theorize much higher intensities
        self.memory_boost = 200000.0
        self.LN2 = math.log(2.0)

        self._kernel_cache = {}

        self.register_buffer(
            "kernel11",
            make_kernel(11, LEAK.lam, centre_zero=True).unsqueeze(0).unsqueeze(0),
            persistent=False
        )

        self.register_buffer(
            "kernel7",
            make_kernel(7, (LEAK.lam/8.0), centre_zero=True).unsqueeze(0).unsqueeze(0),
            persistent=False
        )

        self.register_buffer(
            "recall_kernel11",
            make_blend_kernel(K=9, sigma=2.5, dtype=torch.float32)  # stays on CPU; moved at forward
        )

        # slanted kernels for magnetism
        #kR, kL = make_slanted_kernels(K=11, decay=0.9, dtype=torch.float32)
        #self.register_buffer("mag_k_left", kL, persistent=False)
        #self.register_buffer("mag_k_right", kR, persistent=False)

        #self.hub_kernel = {}  # Cache kernels by window size
        #self._hub_kernels = {}  # Cache kernels by window size
        #self._circ_dist_cache = None  # For scramble function

        self.d_ff = 2048
        self.d_model = 768

        # bucket size is fixed at 4
        self.bucket_size = 4
        self.buckets_model = self.d_model // self.bucket_size
        self.buckets_ffn = self.d_ff // self.bucket_size

        # EMA decay
        self.ema_decay = _ema_decay_from_halflife(ema_halflife)

        # Per-layer EMAs (pos/neg) for model- and FFN-width hub densities (bucketed)
        self.register_buffer("ema_model_pos", torch.zeros(self.buckets_model))
        self.register_buffer("ema_model_neg", torch.zeros(self.buckets_model))
        self.register_buffer("ema_ffn_pos", torch.zeros(self.buckets_ffn))
        self.register_buffer("ema_ffn_neg", torch.zeros(self.buckets_ffn))


        if layer_num == 23:
            print(layer_num)
            print(self.kernel7)
            print("\nWide Kernel")
            print(self.kernel11)
            print("\nRecall Kernel")
            print(self.recall_kernel11)
            #print("\nLeft Kernel")
            #print(self.mag_k_left)
            #print("\nRight Kernel")
            #print(self.mag_k_right)
        #print(layer_num)
        #print(self.kernel7)

    @torch.no_grad()
    def find_hubs_per_token(self, z: torch.Tensor, num_hubs: int = 2,
                            min_sep: int = 0, win: int = 9):
        assert win % 2 == 1, "win must be odd"
        B, T, D = z.shape
        N = B * T
        x = z.view(N, 1, D)

        # cache a 1×1×win box kernel on the right device/dtype
        key = (win, z.dtype, z.device.index if z.device.type == 'cuda' else -1)
        if not hasattr(self, '_box_k'): self._box_k = {}
        if key not in self._box_k:
            self._box_k[key] = torch.ones(1, 1, win, device=z.device, dtype=z.dtype) / float(win)
        k = self._box_k[key]
        r = win // 2

        # smooth positives and negatives separately (no cancellation)
        pos = F.relu(x)
        neg = F.relu(-x)

        sm_pos = F.conv1d(F.pad(pos, (r, r), mode='circular'), k, stride=4).squeeze(1)  # [N, D]
        sm_neg = F.conv1d(F.pad(neg, (r, r), mode='circular'), k, stride=4).squeeze(1)  # [N, D]

        pos_val, pos_idx = sm_pos.max(dim=-1)
        neg_val, neg_idx = sm_neg.max(dim=-1)

        centers = torch.stack([pos_idx, neg_idx], dim=-1).view(B, T, 2).long()
        strengths = torch.stack([pos_val, neg_val], dim=-1).view(B, T, 2)

        return centers, strengths

    @torch.no_grad()
    def update_bucketed_ema_and_scores(
            self,
            centers: torch.Tensor,  # [B, T, 2]  int indices (pos, neg)
            dim: int,  # width that produced centers; e.g., 2048 (FFN) or 768 (model)
            ten_x_cap: float = 7.0,  # "1.0 score == 10× average"
            min_cutoff: float = 0.20
    ) -> torch.Tensor:
        """
        Steps:
          1) Bucket centers with size=4 and AoE weights {+3, +2, +1} out to ±2 buckets (circular).
          2) Build two histograms length = dim//4 (pos and neg).
          3) Normalize so mean=1 by dividing by 36.
          4) Update the relevant EMA (model vs FFN) based on 'dim'.
          5) For each center, read back its bucket EMA value and map linearly to [0,1]
             with ramp start at (min_cutoff * ten_x_cap)× average and 1.0 at ten_x_cap×.
          6) Return [B, T, 2] scores.
        """
        assert centers.dim() == 3 and centers.size(-1) == 2, "centers must be [B,T,2]"
        B, T, _ = centers.shape
        device = centers.device

        # Select target EMA tensors by 'dim'
        if dim == self.d_ff:
            ema_pos = self.ema_ffn_pos
            ema_neg = self.ema_ffn_neg
        elif dim == self.d_model:
            ema_pos = self.ema_model_pos
            ema_neg = self.ema_model_neg
        else:
            raise ValueError(f"Unrecognized dim={dim}; expected one of {{d_ff={self.d_ff}, d_model={self.d_model}}}")

        bucket_count = dim // self.bucket_size
        if bucket_count <= 0:
            raise ValueError(f"dim={dim} too small for bucket_size={self.bucket_size}")

        # Flatten centers -> [N], clamp into range [0, dim-1]
        centers_flat = centers.view(-1, 2).to(torch.int64)
        pos_idx = centers_flat[:, 0].clamp_(0, dim - 1)
        neg_idx = centers_flat[:, 1].clamp_(0, dim - 1)

        # Convert to bucket indices (size=4 buckets)
        pos_b0 = torch.div(pos_idx, self.bucket_size, rounding_mode='floor')  # [N]
        neg_b0 = torch.div(neg_idx, self.bucket_size, rounding_mode='floor')  # [N]

        # AoE neighborhood (circular): offsets and weights
        offsets = torch.tensor([-2, -1, 0, 1, 2], device=device, dtype=torch.int64)
        weights = torch.tensor([1., 2., 3., 2., 1.], device=device, dtype=torch.float32)  # sums to 9

        # --- Build bucket histograms via weighted bincount (faster than scatter_add_) ---
        def build_hist(b0: torch.Tensor) -> torch.Tensor:
            if b0.numel() == 0:
                return torch.zeros(bucket_count, device=device, dtype=torch.float32)
            neigh = (b0.unsqueeze(1) + offsets) % bucket_count  # [N,5]
            w = weights.expand(b0.size(0), -1)  # [N,5]
            # Flatten, then weighted bincount
            return torch.bincount(
                neigh.reshape(-1),
                weights=w.reshape(-1),
                minlength=bucket_count
            ).to(torch.float32)

        hist_pos = build_hist(pos_b0)
        hist_neg = build_hist(neg_b0)

        # Normalize to mean=1 by dividing by 4 (bucket size)
        hist_pos = hist_pos / 4.0
        hist_neg = hist_neg / 4.0

        # EMA update
        decay = float(self.ema_decay)
        ema_pos.lerp_(hist_pos, 1.0 - decay)  # ema = ema*decay + hist*(1-decay)
        ema_neg.lerp_(hist_neg, 1.0 - decay)

        # Per-center scores in [0,1] (thresholded ramp to ten_x_cap× average)
        pos_vals = ema_pos.gather(0, pos_b0)  # values are in ×-average units (1.0 == average)
        neg_vals = ema_neg.gather(0, neg_b0)

        min_thresh = min_cutoff * ten_x_cap  # e.g., 0.15 * 7 = 1.05×
        scale = max(ten_x_cap - min_thresh, 1e-6)

        pos_scores = torch.clamp((pos_vals - min_thresh) / scale, min=0.0, max=1.0)
        neg_scores = torch.clamp((neg_vals - min_thresh) / scale, min=0.0, max=1.0)

        return torch.stack([pos_scores, neg_scores], dim=-1).view(B, T, 2)

    @torch.no_grad()
    def scramble_magnetism_noise(self, z: torch.Tensor, leak_scale: float,
                                 beta_max: float = 0.20, radius: int = 12,
                                 win: int = 9):
        if leak_scale <= 0.0 or beta_max <= 0.0:
            return torch.zeros_like(z)

        B, T, D = z.shape
        device, dtype = z.device, z.dtype

        # Get hubs and per-center trust in [0,1] (pos, neg)
        centers, _ = self.find_hubs_per_token(z, num_hubs=2, win=win)  # [B,T,2]
        center_trust = self.update_bucketed_ema_and_scores(centers, D)  # [B,T,2]

        #if self.layer_num == 22:
        #    print(center_trust.mean())

        z_flat = z.view(B * T, D)
        centers_flat = centers.view(B * T, 2)
        beta = beta_max * leak_scale

        # Distances to pos/neg centers (circular)
        idx = torch.arange(D, device=device).unsqueeze(0)  # [1,D]
        pos_c = centers_flat[:, 0] % D
        neg_c = centers_flat[:, 1] % D
        diff_pos = (idx - pos_c.unsqueeze(1)).abs()
        diff_neg = (idx - neg_c.unsqueeze(1)).abs()
        pos_dist = torch.minimum(diff_pos, D - diff_pos)
        neg_dist = torch.minimum(diff_neg, D - diff_neg)

        # Sign masks
        is_pos = (z_flat > 0)
        is_neg = (z_flat < 0)

        # Inside/outside the sign-matched hub
        pos_inside = is_pos & (pos_dist <= radius)
        neg_inside = is_neg & (neg_dist <= radius)
        outside = (is_pos & (pos_dist > radius)) | (is_neg & (neg_dist > radius))

        # Build per-position scale:
        # - outside hub: scale = 1
        # - inside hub:  scale = trust (pos or neg), in [0,1]
        trust_flat = center_trust.view(B * T, 2).to(dtype)
        pos_scale_row = trust_flat[:, 0].unsqueeze(1)  # [N,1]
        neg_scale_row = trust_flat[:, 1].unsqueeze(1)  # [N,1]

        scale = torch.zeros_like(z_flat, dtype=dtype)
        scale = torch.where(outside, torch.ones_like(scale), scale)
        scale = torch.where(pos_inside, pos_scale_row.expand_as(z_flat), scale)
        scale = torch.where(neg_inside, neg_scale_row.expand_as(z_flat), scale)

        # Generate multiplicative noise only where scale>0
        if not torch.any(scale > 0):
            return torch.zeros_like(z)

        noise = torch.empty_like(z_flat).uniform_(-beta, beta)
        noise = noise * z_flat * scale  # multiplicative; scale∈[0,1], outside=1, inside=trust

        # Clamp to avoid sign flips
        noise = torch.where(
            z_flat > 0,
            noise.clamp(min=-z_flat.abs()),
            noise.clamp(max=z_flat.abs())
        )

        return noise.view(B, T, D)

    def forgetful_activation(self, x):
        x_clamp = torch.clamp(x, min=1e-6)
        return torch.where(x_clamp >= 1, torch.log2(x_clamp) + 1,
                           torch.exp2(x_clamp - 1))

    def old_forgetful_activation(self, x, beta: float = 12.0, eps:  float = 1e-6):

        # “selector” ∈ (0,1) —  ~0   when x≪1 ,  ~1 when x≫1
        s = torch.sigmoid(beta * (x - 1.0))

        # right-hand branch (x ≥ 1) — use log1p for better accuracy near 0
        rhs = torch.log1p(torch.clamp_min(x, eps) - 1.0) / self.LN2 + 1.0     # log2(x)+1

        # left-hand branch (x < 1)
        lhs = torch.exp((x - 1.0) * self.LN2)                                 # 2**(x-1)

        # sharp but smooth blend
        return s * rhs + (1.0 - s) * lhs


    # calculates the relationships between weights at each layer
    def pairwise_activation(self, hidden_states, gate, w_val):
        """
        hidden_states : (B, T, 768)    fp16 or bf16
        gate          : (B, T, 2048)   fp16
        w_val         : (2048, 768)    fp16
        returns       : (B, T, 2048, 768)
        """
        # 1) h_j * W_ij  →  (B,T,2048,768)
        contrib = torch.einsum('btm,im->btim', hidden_states, w_val)
        # 2) apply σ(gate_i)
        contrib.mul_(gate.unsqueeze(-1))
        return contrib

    @torch.no_grad()
    def _recall_core(self, z: torch.Tensor, k_conv: torch.Tensor,
                     alpha: float, center: int) -> torch.Tensor:
        """Core recall logic, separated for better compilation."""
        if alpha <= 0.0:
            return torch.zeros_like(z)

        # Direct computation, no intermediate variables
        return alpha * (
                F.conv1d(
                    F.pad(z.reshape(-1, 1, z.shape[-1]), (center, center), 'circular'),
                    k_conv
                ).view_as(z) - z
        )

    def apply_recall_1d(self, z: torch.Tensor, kernel: torch.Tensor,
                        alpha: float) -> torch.Tensor:
        """Wrapper that handles kernel caching."""
        K = kernel.numel()
        key = (K, z.dtype, z.device.type, z.device.index)

        if key not in self._kernel_cache:
            self._kernel_cache[key] = kernel.to(
                z.device, z.dtype
            ).reshape(1, 1, -1).contiguous()

        return self._recall_core(z, self._kernel_cache[key], alpha, K // 2)

    start_tau = LEAK.tau_mod

    def forward(self, hidden_states):
        # collect memories
        self.forward_passes += 1.0


        # lowers memory variance by 20% every five epochs
        # @todo should work more cleanly once nrem is set up to reinforce memory
        if (self.forward_passes + 1.0) % self.memory_boost == 1.0:
            self.forward_passes = 2.0
            self.variance = self.variance * 0.9
            self.long_var = self.long_var * 0.95
            self.default_std = self.default_std * 0.95
            #self.memory_boost = float(int(1.3 * self.memory_boost))
            #print("Current Variance: ", self.variance)

        # apply forgetfulness once every ten batches
        global safe_to_forget
        global num_batches

        # prints emas
        if num_batches == 4 and self.layer_num == 24:
            for i in self.ema_model_pos:
                print(i)
            print("end\n")

        # Leak once every 50 batches to blend weights
        if safe_to_forget[self.layer_num]:
            leak_check = int(LEAK.tau)
            with torch.no_grad():
                #if ((num_batches + 1) % leak_check) == 0:
                #mod_t = 7.0 / float(6.0 + float(LEAK.tau_mod))
                mod_t = 1.0#(float(LEAK.start_tau - LEAK.tau_mod)**1.2 / float(LEAK.start_tau)**1.2) #@ todo, more stable stall while recall still ramps
                leak_once_uniform(self.wi_0.weight, self.kernel7, LEAK.alpha_val * mod_t)
                leak_once_uniform(self.wi_1.weight, self.kernel7, LEAK.alpha_gate * mod_t)
                leak_once_uniform(self.wo.weight,   self.kernel7, LEAK.alpha_down * mod_t)


                safe_to_forget[self.layer_num] = False


        apply_recall = True
        #if apply_recall: print(self.layer_num)
        recall_alpha = LEAK.recall_alpha  # [entry, up_gate, up_val, down]
        recall_mask = None #get_recall_mask(hidden_states) if apply_recall else None
        k11 = self.recall_kernel11  # moved to device/dtype inside helper

        # Add recall forgetfulness to input
        if apply_recall and recall_alpha[0] > 0.0:
            with torch.no_grad():
                pass
                # sanity check
                #print("\n",hidden_states[10][0][15:30])
                #hs_blend = self.apply_recall_1d(hidden_states, k11, _alpha_eff(recall_alpha[0]))

                # magnetic pull, stronger at further distances
                #hs_mag = scramble_magnetism_noise(hidden_states, **MAG)

            #print("hs_blend", abs(hs_blend).sum())

            #hidden_states = hidden_states + hs_blend.detach()# + hs_mag.detach()
            #print(hidden_states[10][0][15:30])

        # Perform up projection
        value = self.wi_0(hidden_states)
        gate = self.wi_1(hidden_states)


        # recall forgetfulness for up projection
        # these functions are for gate and value separately
        """if apply_recall and recall_alpha[1] > 0.0:
            with torch.no_grad():
                g_blend = apply_recall_1d(gate, k11, _alpha_eff(recall_alpha[1]), mask=recall_mask)
                gate_mag = self.magnetize_1d(gate, self.mag_k_left, self.mag_k_right, **self.mag_cfg)
                

            gate = gate + g_blend.detach()# + gate_mag.detach()

        if apply_recall and recall_alpha[2] > 0.0:
            with torch.no_grad():
                v_blend = apply_recall_1d(value, k11, _alpha_eff(recall_alpha[2]), mask=recall_mask)
                #value_mag = self.magnetize_1d(value, self.mag_k_left, self.mag_k_right, **self.mag_cfg)

            value = value + v_blend.detach()# + value_mag.detach()"""

        # elementwise multiplication
        x = value * self.activation(gate)

        mag_scale = _alpha_eff(1.0)

        # applies fuzzy recall after merging to make cancelations less likely (cancelations risk ignoring ffn
        if ((num_batches + self.layer_num) % 2 == 0) and apply_recall and recall_alpha[2] > 0.0:
            with torch.no_grad():
                #x_blend = self.apply_recall_1d(x, k11, _alpha_eff(recall_alpha[2]))
                x_mag = self.scramble_magnetism_noise(x,leak_scale=mag_scale, beta_max=LEAK.max_variance, radius=40, win=15)

            x.add_(x_mag.detach())

        # dropout and downprojection
        x = self.dropout(x)
        x = self.wo(x)

        # recall forgetfulness for layer output
        if ((num_batches + self.layer_num) % 2 == 1) and apply_recall and recall_alpha[3] > 0.0:
            with torch.no_grad():
                #y_blend = self.apply_recall_1d(x, k11, _alpha_eff(recall_alpha[3]))
                y_mag = self.scramble_magnetism_noise(x, leak_scale=mag_scale, beta_max=LEAK.max_variance, radius=16, win=9)
                # with recall: 15 sec
                # with recall + magnetism:

            x.add_(y_mag.detach())

        #print("LAYER: ", self.layer_num)
        #print(x[12])
        #print(self._circ_dist_cache, self._hub_kernels)

        return x


# forgetful t5
class ForgetfulT5(nn.Module):
    def __init__(self, base_model, eps=1e-5):
        """
                Wraps the base MT5 model and replaces its feed-forward modules with a custom MT5 dense gated
                variant that incorporates your custom normalization.

                Args:
                    base_model: A pretrained MT5 model (e.g., loaded via AutoModelForConditionalGeneration).
                    eps: Epsilon for numerical stability.
                """
        super().__init__()
        self.base_model = base_model
        self.eps = eps

        layer_num = 0

        #print(layer_num)
        # Modify encoder blocks (typically, feed-forward is located in layer[1]).
        for block in self.base_model.encoder.block:
            orig_ff = block.layer[1].DenseReluDense
            block.layer[1].DenseReluDense = CustomDenseReluDense(orig_ff, layer_num, eps=self.eps)
            layer_num += 1

        # @todo implement for cross attention as well
        # @todo is the memory loss enough for the number of layers?
        # Modify decoder blocks (for MT5 with cross-attention, feed-forward is usually in layer[2]).
        for block in self.base_model.decoder.block:
            if len(block.layer) >= 3:
                orig_ff = block.layer[2].DenseReluDense
                block.layer[2].DenseReluDense = CustomDenseReluDense(orig_ff, layer_num, eps=self.eps)
                #print("Layer ", layer_num)
            else:
                # Fallback (shouldn't be needed for MT5)
                orig_ff = block.layer[1].DenseReluDense
                block.layer[1].DenseReluDense = CustomDenseReluDense(orig_ff, layer_num, eps=self.eps)

            layer_num += 1

        print("Custom feed-forward modules have been applied to the MT5 model.")
        print("Total Layers: ", layer_num)

    def forward(self, input_ids=None, attention_mask=None, labels=None, **kwargs):
        # Delegate the forward pass to the underlying MT5 model.
        return self.base_model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=labels,
            **kwargs
        )

    def generate(self, *args, **kwargs):
        return self.base_model.generate(*args, **kwargs)

    def _get_ffn(self, layer_num: int):
        """Return the FFN module at global layer_num without storing an array."""
        enc_blocks = self.base_model.encoder.block
        dec_blocks = self.base_model.decoder.block
        n_enc = len(enc_blocks)

        if not (0 <= layer_num < n_enc + len(dec_blocks)):
            raise IndexError("layer_num out of range")

        if layer_num < n_enc:
            # T5 encoder: FFN is block.layer[1].DenseReluDense
            return enc_blocks[layer_num].layer[1].DenseReluDense
        else:
            i = layer_num - n_enc
            blk = dec_blocks[i]
            # T5 decoder: FFN usually at layer[2], fallback to [1] if needed
            ffn_idx = 2 if len(blk.layer) >= 3 else 1
            return blk.layer[ffn_idx].DenseReluDense

    def visualize_ffn(self, layer_num: int, which: str = "both",
                      show: bool = True, savepath: str = None):
        """
        Heatmaps for FFN weights at `layer_num`.
        which ∈ {"both","val","gate","all"}; "all" also shows wo.
        Color: blue (neg) → white (0) → red (pos), zero-centered.
        """
        ffn = self._get_ffn(layer_num)

        mats, titles = [], []
        if which in ("both", "val", "all"):
            Wv = ffn.wi_0.weight.detach().to("cpu").float().T  # (d_model x d_ff) ~ 768x2048
            mats.append(Wv);
            titles.append("wi_0 (value up)")
        if which in ("both", "gate", "all"):
            Wg = ffn.wi_1.weight.detach().to("cpu").float().T
            mats.append(Wg);
            titles.append("wi_1 (gate up)")
        if which == "all" and hasattr(ffn, "wo"):
            Wo = ffn.wo.weight.detach().to("cpu").float()  # (d_model x d_ff)
            mats.append(Wo.T);
            titles.append("wo (down)")

        vmax = max(m.abs().max().item() for m in mats)
        vmin = -vmax

        n = len(mats)
        fig, axes = plt.subplots(1, n, figsize=(6 * n, 6), constrained_layout=True)
        axes = [axes] if n == 1 else axes

        for ax, M, title in zip(axes, mats, titles):
            im = ax.imshow(M, cmap="seismic", vmin=vmin, vmax=vmax,
                           aspect="auto", interpolation="nearest")
            ax.set_title(f"Forgetful Model\nLayer {layer_num}: {title}\n(768 × 2048 view)")
            ax.set_xlabel("d_ff (neurons)")
            ax.set_ylabel("d_model (features)")
            fig.colorbar(im, ax=ax, fraction=0.046, pad=0.04)

        if savepath:
            fig.savefig(savepath, dpi=150)
        if show:
            plt.show()
        return fig

def generate_output(model_wrapper, tokenizer, text):
    """
    Given a text string, encode it, pass it through the model's encoder (optionally with noise),
    and then let T5 generate the output (auto-regressive, no teacher forcing).
    """
    # Tokenize input (default on CPU)
    enc_in = tokenizer(text, return_tensors="pt")
    # Move the inputs to the same device as the model
    device = next(model_wrapper.model.parameters()).device
    input_ids = enc_in["input_ids"].to(device)
    attention_mask = enc_in["attention_mask"].to(device)

    input_length = input_ids.shape[1]

    with torch.no_grad():
        # 1) Encode
        encoder_outputs = model_wrapper.model.encoder(
            input_ids=input_ids,
            attention_mask=attention_mask,
            return_dict=True
        )
        # (Optional) If you want noise at inference time:
        # latent = encoder_outputs.last_hidden_state
        # noise = torch.randn_like(latent) * model_wrapper.noise_std
        # latent = latent + noise

        # 2) Generate from the encoder outputs
        generated_ids = model_wrapper.model.generate(
            encoder_outputs=encoder_outputs,
            input_ids=None,  # not needed since we passed encoder_outputs
            num_beams=1,     # greedy decoding
            max_length=input_length,
            length_penalty=0.7,
            repetition_penalty=1.2,
            early_stopping=True
        )

    # Decode the generated tokens (decoding can stay on CPU)
    return tokenizer.decode(generated_ids[0], skip_special_tokens=True)


def evaluate_bleu(model_wrapper, tokenizer, texts):
    """
    Compute BLEU score for an autoencoder: how well does the model reconstruct
    the exact input text? We'll treat each input as a reference, and the model's
    reconstruction as a hypothesis.
    """
    # We want one single reference list (same length as texts),
    # and one hypothesis list. Then we pass [references] to sacrebleu.
    references = []
    hypotheses = []

    for text in texts:
        hyp = generate_output(model_wrapper, tokenizer, text)
        references.append(text)  # add the "clean" input text
        hypotheses.append(hyp)

    # sacrebleu expects a list of hypothesis strings, and
    # a list of *reference sets*, each reference set being a list of strings.
    # For one reference set, we wrap references in an extra list:
    bleu_score = sacrebleu.corpus_bleu(hypotheses, [references]).score
    return bleu_score, hypotheses


def tokenize_fn(batch, tokenizer, train_english):
    # extract english or chinese text depending on task
    #if train_english:
    #    en_list = [ex["en"] for ex in batch["translation"]]
    #else:
    #    en_list = [ex["as"] for ex in batch["translation"]]
    #en_list = en_list[0:347822]

    # Extract source or target text directly
    if train_english:
        text_list = batch["src"]  # Directly use the list of source sentences (English)
    else:
        text_list = batch["tgt"]  # Directly use the list of target sentences (Assamese)

    # Tokenize
    # @todo adaptive batch sizes
    tokenized = tokenizer(text_list, padding="max_length", truncation=True, max_length=64)

    # Create labels, replacing pad tokens with -100 for loss masking
    labels = [
        [-100 if tid == tokenizer.pad_token_id else tid for tid in seq]
        for seq in tokenized["input_ids"]
    ]
    tokenized["labels"] = labels

    return tokenized

def tokenize_forgiveness(batch, tokenizer, train_english):
    max_len = 64
    # ------------------------------------------------------------------ 1. pick the sentences
    #text_list = batch["en"] if train_english else batch["as"] # for use in en to assamese
    text_list = batch["en"]

    #print(text_list[0:5])

    # ------------------------------------------------------------------ 2. normal tokenisation
    tokenised = tokenizer(
        text_list,
        padding="max_length",
        truncation=True,
        max_length=max_len,
    )

    # ------------------------------------------------------------------ 3. label tensor (-100 on pads)
    pad_id  = tokenizer.pad_token_id
    tokenised["labels"] = [
        [-100 if tid == pad_id else tid for tid in seq]
        for seq in tokenised["input_ids"]
    ]

    # ------------------------------------------------------------------ 4. pass through neighbour info
    if (not train_english) and ("k_alts" in batch):

        sent_level_alt_ids = []  # will become [B][T][k]

        for sent_alts in batch["k_alts"]:  # iterate sentences
            tok_level_alt_ids = []
            for tok_alts in sent_alts:  # iterate tokens
                # `tok_alts` is a list[str] length k
                ids = [tokenizer.convert_tokens_to_ids(tok) for tok in tok_alts]
                tok_level_alt_ids.append(ids)
            sent_level_alt_ids.append(tok_level_alt_ids)

        # keep everything – some downstream code may still need the strings
        tokenised["k_alts"] = batch["k_alts"]
        tokenised["k_confs"] = batch["k_confs"]
        tokenised["k_alt_ids"] = sent_level_alt_ids

    #print(tokenised.keys())

    if "k_alt_ids" in tokenised:  # we created these three lists above
        k_val = len(tokenised["k_alt_ids"][0][0])  # number of alts (= k)

        pad_tok_vec = ['<pad>'] * k_val  # dummy string
        pad_ids_vec = [pad_id] * k_val  # pad-ids
        pad_conf_vec = [0.0] * k_val  # zero confidence

        def pad(seq, pad_elem):  # helper
            return seq + [pad_elem] * (max_len - len(seq))

        # pad every sentence-level list to length 64
        tokenised["k_alts"] = [pad(s, pad_tok_vec) for s in tokenised["k_alts"]]
        tokenised["k_confs"] = [pad(s, pad_conf_vec) for s in tokenised["k_confs"]]
        tokenised["k_alt_ids"] = [pad(s, pad_ids_vec) for s in tokenised["k_alt_ids"]]

        # convert the numeric ones to tensors so the default collate can stack
        tokenised["k_alt_ids"] = torch.tensor(tokenised["k_alt_ids"], dtype=torch.long)
        tokenised["k_confs"] = torch.tensor(tokenised["k_confs"], dtype=torch.float)

    for key in ("input_ids", "attention_mask", "labels"):
        tokenised[key] = torch.tensor(tokenised[key], dtype=torch.long)

    if "k_alt_ids" in tokenised:  # (B, T, k)
        tokenised["k_alt_ids"] = torch.tensor(tokenised["k_alt_ids"],
                                              dtype=torch.long)

    return tokenised

# ---------------------------------------------------------------------
# 3. Scoring head (MLP) ------------------------------------------------
class ScoringHead(nn.Module):
    def __init__(self, dim, hidden=512):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(dim, hidden),
            nn.GELU(),
            nn.Linear(hidden, dim)
        )
    def forward(self, h):
        return self.mlp(h)


@torch.no_grad()
def compute_similarity_matrix(model, tokenizer, top_k=10, batch_size=16):
    """
    Compute top-k most similar tokens restricted to Assamese/Bengali block.
    Returns:
        topk_vals: [vocab_size, top_k]
        topk_idx: [vocab_size, top_k]
    """
    device = model.base_model.device
    embeddings = model.base_model.shared.weight.detach()  # [vocab_size, d_model]
    embeddings = torch.nn.functional.normalize(embeddings, dim=1)

    vocab_size = embeddings.size(0)

    # --- Step 1: Precompute Assamese valid tokens ---
    bengali_tokens = []
    for idx in range(vocab_size):
        decoded = tokenizer.decode([idx])
        if any('\u0980' <= ch <= '\u09FF' for ch in decoded):
            bengali_tokens.append(idx)
    bengali_tokens = torch.tensor(bengali_tokens, device=device)

    assamese_embeddings = embeddings[bengali_tokens]  # [N_assamese, d_model]

    # --- Step 2: Top-k search only within Assamese embeddings ---
    topk_vals_list = []
    topk_idx_list = []

    for start in range(0, vocab_size, batch_size):
        end = min(start + batch_size, vocab_size)
        batch = embeddings[start:end]        # [batch_size, d_model]

        sims = batch @ assamese_embeddings.T  # [batch_size, N_assamese]

        topk_vals, topk_relative_idx = sims.topk(top_k, dim=-1)  # relative indices into assamese
        topk_idx = bengali_tokens[topk_relative_idx]             # map back to full vocab ids

        topk_vals_list.append(topk_vals)
        topk_idx_list.append(topk_idx)

    # Concatenate
    topk_vals = torch.cat(topk_vals_list, dim=0)  # [vocab_size, top_k]
    topk_idx = torch.cat(topk_idx_list, dim=0)    # [vocab_size, top_k]

    return topk_vals, topk_idx


@torch.no_grad()
def build_soft_targets_table(topk_vals, topk_idx, smoothing_strength):
    vocab_size, _ = topk_idx.size()

    # Quadratic distance decay
    distance = 1.0 - topk_vals
    bonus = (distance ** 2) * 0.5

    # Smoothing reshaping
    bonus = (1 - bonus ** 2)
    bonus = (bonus - bonus.min(dim=-1, keepdim=True).values + smoothing_strength / 15) ** 0.5
    bonus = (bonus - bonus.min(dim=-1, keepdim=True).values + smoothing_strength / 20) ** 0.5
    bonus = (bonus - bonus.min(dim=-1, keepdim=True).values + smoothing_strength / 30) ** 0.5

    bonus *= smoothing_strength

    bonus = bonus.clamp(min=0.0, max=0.8)

    # Ensure the first entry (assumed GT) gets 1.0
    bonus[:, 0] = 1.0

    return bonus, topk_idx


def build_soft_targets(labels, topk_vals, topk_idx, smoothing_strength, top_k=10):
    B, T = labels.size()
    vocab_size, _ = topk_idx.size()
    device = labels.device

    soft_targets = torch.zeros(B, T, vocab_size, device=device)

    # Lookup top-k neighbors and similarities
    label_topk_vals = topk_vals[labels]    # [B, T, top_k]
    label_topk_idx = topk_idx[labels]      # [B, T, top_k]

    # Quadratic distance decay
    distance = 1.0 - label_topk_vals
    bonus = (distance ** 2) * 0.5

    # Mask out bonuses for ground-truth token
    correct_mask = (label_topk_idx == labels.unsqueeze(-1))  # [B, T, top_k]
    bonus = bonus.masked_fill(correct_mask, 0.0)

    # Normalize remaining bonus mass
    bonus = (1-bonus ** 2)
    bonus = (bonus - bonus.min(dim=-1, keepdim=True).values + smoothing_strength/15)**0.5
    bonus = (bonus - bonus.min(dim=-1, keepdim=True).values + smoothing_strength/20)**0.5
    bonus = (bonus - bonus.min(dim=-1, keepdim=True).values + smoothing_strength/30)**0.5

    bonus *= smoothing_strength

    # Mask out bonuses for ground-truth token
    correct_mask = (label_topk_idx == labels.unsqueeze(-1))  # [B, T, top_k]
    bonus = bonus.masked_fill(correct_mask, 0.0)

    # Scatter correct token full 1.0
    soft_targets.scatter_(2, labels.unsqueeze(-1), 1.0)

    # Add bonus scores to close non-correct neighbors
    soft_targets.scatter_add_(2, label_topk_idx, bonus)

    return soft_targets


from torch.optim.lr_scheduler import LambdaLR

def lr_lambda(current_step):
    if current_step < warmup_steps:
        return float(current_step) / float(max(1, warmup_steps))  # linear warmup
    else:
        steps_since_warmup = current_step - warmup_steps
        step_factor = steps_since_warmup // step_size
        return gamma# ** step_factor

warmup_steps = 10     # adjust as needed
step_size = 1           # same as your StepLR
gamma = 0.99           # decay factor


@torch.inference_mode()
def batch_topk_substitutions(sentences, k=15, max_len=64):
    """
    sentences : list[str]  (or a single str)
    returns   : list[list[dict]]  – results per sentence
    """

    global bert_model, bert_scorer, bert_tokenizer, TEMPERATURE, special_ids, id2tok, embed, F_MOD, F_MAX, tokenizer

    if isinstance(sentences, str):
        sentences = [sentences]

    inputs = bert_tokenizer(
        sentences,
        padding=True,
        truncation=True,
        max_length=max_len,
        return_tensors="pt"
    ).to(device)

    # 1) contextual vectors
    h = bert_model(**inputs).last_hidden_state           # (B, T, D)
    h = bert_scorer(h)                                   # (B, T, D)

    # 2) batched scores   –   one matmul for the whole batch
    scores = torch.matmul(h, embed.T) / TEMPERATURE   # (B, T, V)
    probs  = F.softmax(scores, dim=-1)                # (B, T, V)

    probs = probs.mul(F_MOD).clamp(min=0.03, max=F_MAX)

    # -- 1. get the raw top-k for every (b,t) in one shot --------------------
    alt_probs, alt_ids = probs.topk(k, dim=-1)  # both (B, T, k)

    # -- 2. zero-out positions that correspond to special tokens ------------
    input_ids = inputs["input_ids"]  # (B, T)
    mask = torch.zeros_like(input_ids, dtype=torch.bool)
    for sid in special_ids:
        mask |= (input_ids == sid)  # True where special

    # broadcast the mask to (B, T, k)
    mask = mask.unsqueeze(-1).expand_as(alt_ids)

    alt_ids[mask] = tokenizer.pad_token_id  # or any sentinel
    alt_probs[mask] = 0.0

    # ------------------------------------------------------------------
    # NEW MAPPING BLOCK  (BERT-ids  →  target-tokenizer ids)
    # ------------------------------------------------------------------
    B, T, K = alt_ids.shape
    flat_ids = alt_ids.view(-1).tolist()  # (B·T·K,)

    mapped = []
    unk_id = tokenizer.unk_token_id
    pad_id = tokenizer.pad_token_id

    for bid in flat_ids:
        # keep padding sentinel as-is
        if bid == pad_id:
            mapped.append(pad_id)
            continue

        # 1) text surface from BERT id (no clean-up)
        piece = bert_tokenizer.decode([bid], clean_up_tokenization_spaces=False)

        # 2) how does *our* tokenizer encode that text?
        tgt_ids = tokenizer.encode(piece, add_special_tokens=False)

        if len(tgt_ids) == 1:
            mapped.append(tgt_ids[0])  # perfect 1-token match
        else:
            mapped.append(unk_id)  # fallback → <unk>

    mapped_alt_ids = torch.tensor(
        mapped, dtype=alt_ids.dtype, device=alt_ids.device
    ).view(B, T, K)

    print("Token Swap", alt_ids[5][2], "New", mapped_alt_ids[5][2])
    return mapped_alt_ids, alt_probs


    #return alt_ids, alt_probs

def pad_to_len(x: torch.Tensor, target_len: int, pad_val):
    """
    x : (..., T, k)
    returns a tensor with shape (..., target_len, k)
    """
    pad = target_len - x.size(-2)          # amount to pad on the T axis
    if pad == 0:
        return x
    # F.pad wants the paddings **from the last dim backwards**:
    # (..., T, k)  -> pad (0,0) on k-dim  and (0, pad) on T-dim
    return F.pad(x, (0, 0, 0, pad), value=pad_val)

def fully_randomize(model, *, seed=50, std=0.02):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

    def _reset(m):
        # linear / embedding → N(0, std)
        if isinstance(m, (nn.Linear, nn.Embedding)):
            nn.init.normal_(m.weight, mean=0.0, std=std)
            if getattr(m, "bias", None) is not None:
                nn.init.zeros_(m.bias)
        # layernorm → ones / zeros
        elif isinstance(m, nn.LayerNorm):
            nn.init.ones_(m.weight)
            nn.init.zeros_(m.bias)
        # conv → kaiming_uniform
        elif isinstance(m, nn.Conv2d):
            nn.init.kaiming_uniform_(m.weight, a=math.sqrt(5))
            if getattr(m, "bias", None) is not None:
                fan_in, _ = nn.init._calculate_fan_in_and_fan_out(m.weight)
                bound = 1 / math.sqrt(fan_in)
                nn.init.uniform_(m.bias, -bound, bound)

    model.apply(_reset)     # walks every sub-module
    if hasattr(model, "tie_weights"):
        model.tie_weights() # for language models

import torch
from transformers import MT5ForConditionalGeneration

def freeze_for_active_layer(model: MT5ForConditionalGeneration, active_layer: int):
    """
    active_layer ∈ [0, 23]
      0-11  → encoder block 0-11
      12-23 → decoder block 0-11 (active_layer-12)
    """
    def set_grad(module, flag: bool):
        for p in module.parameters():
            p.requires_grad_(flag)

    # 1. freeze everything
    set_grad(model, False)

    # 2. always-trainable parts
    set_grad(model.shared, True)          # tied token embeddings
    set_grad(model.lm_head, True)         # keeps output layer aligned
    for m in model.modules():             # all LayerNorm scales / biases
        if isinstance(m, torch.nn.LayerNorm):
            set_grad(m, True)

    # 3. identify encoder/decoder side
    enc_blocks = model.encoder.block
    dec_blocks = model.decoder.block
    n_enc = len(enc_blocks)               # 12 for mt5-base

    if active_layer < n_enc:              # ─── encoder side ───
        idx = active_layer
        for k in (idx-1, idx, idx+1):
            if 0 <= k < n_enc:
                blk = enc_blocks[k]
                set_grad(blk.layer[0], True)          # self-attention
                if k == idx:                          # centre block’s FFN
                    set_grad(blk.layer[1], True)

    else:                                 # ─── decoder side ───
        idx = active_layer - n_enc
        for k in (idx-1, idx, idx+1):
            if 0 <= k < len(dec_blocks):
                blk = dec_blocks[k]                   # sublayer order: 0 self-attn, 1 cross-attn, 2 FFN
                set_grad(blk.layer[0], True)          # self-attention
                set_grad(blk.layer[1], True)          # cross-attention
                if k == idx:                          # centre block’s FFN
                    set_grad(blk.layer[2], True)

    # (optional) sanity print
    trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
    total     = sum(p.numel() for p in model.parameters())
    print(f"Layer {active_layer}: {trainable/total:.2%} of parameters are trainable")

# Example usage
# model = MT5ForConditionalGeneration.from_pretrained("google/mt5-base")
# freeze_for_active_layer(model, active_layer=6)   # repairs encoder block 6



def groups_ffn_decay_except_wi0(model, wd_wi1=0.06, wd_wo=0.015, wd_down=0.05):
    """
    Apply weight decay only to FFN gate (wi_1.weight) and down-proj (wo.weight).
    No decay for wi_0, LayerNorm, biases, embeddings, attention, etc.
    Only includes params with requires_grad=True.
    """
    wi1, wo, no_decay = [], [], []
    for name, p in model.named_parameters():
        if not p.requires_grad:
            continue
        lname = name.lower()
        if lname.endswith("wi_1.weight"):
            wi1.append(p)
        elif lname.endswith("wo.weight"):
            wo.append(p)
        elif lname.endswith("bias") or "layernorm" in lname or "layer_norm" in lname:
            no_decay.append(p)
        else:
            # everything else (wi_0, attention, embeddings, lm_head, etc.)
            no_decay.append(p)

    # (Optional) quick sanity print
    # print(len(wi1), "wi_1 tensors,", len(wo), "wo tensors,", len(no_decay), "no_decay tensors")

    return [
        {"params": wi1, "weight_decay": wd_wi1, "name": "wi1"},
        {"params": wo,  "weight_decay": wd_wo,  "name": "wo"},
        {"params": no_decay, "weight_decay": 0.0, "name": "no_decay"},
    ]


if __name__ == "__main__":
    train_english = True

    from transformers import MT5Config, MT5ForConditionalGeneration


    if train_english:
        model_save = "mt5_base_forgive_and_forget_ft"
        #model_source = "stage1_base_step120000"
        model_source = "google/mt5-base"

        # layer 23, exploding fg
        #state_dict_source = "mt5_base_pretuned.pt"
        #state_dict_source = "mt5_base_forgive_and_forget_fine_tune0_batch_20000.pt" # cut strength from .015 to .008 here
        #state_dict_source = "mt5_base_forgive_and_forget_fine_tune20_batch_20000.pt" # after grouping for one full epoch
        #state_dict_source = "mt5_base_forgive_and_forget_fine_tune30_batch_20000.pt" # after 20000 batches w/o fg or weight decay. Now start both again on cycle

        # layer 22
        #state_dict_source = "mt5_base_forgive_and_forget_fine_tune_lay_220_batch_10000.pt"
        #state_dict_source = "mt5_base_forgive_and_forget_fine_tune_lay_220.pt"

        # full model
        #state_dict_source = "mt5_base_pretuned.pt"
        state_dict_source = "FandF_StateDicts/mt5_base_forgive_and_forget_whole_stream6.pt"
        #state_dict_source = "mt5_base_forgive_and_forget_whole_stream_spread_0.pt"
        #state_dict_source = "FandF_StateDicts/mt5_base_forgive_and_forget_whole_stream00_batch_10000.pt"



        #state_dict_source = "mt5_small_as_to_en_ref190.pt"
        #state_dict_source = "mt5_small_en_to_as_forgive10.pt"
        state_dict_save = "mt5_base_forgive_and_forget_whole_stream"

        # load model
        tokenizer = AutoTokenizer.from_pretrained(model_source,
                                                  use_fast=False,  # keep full SentencePiece behaviour
                                                  legacy=False)

        custom = False
        if custom:
            cfg = MT5Config(
                vocab_size=tokenizer.vocab_size,  # keep full vocab
                d_model=128,  # or 128 / 256
                d_ff=512,  # 4 × d_model is enough
                num_layers=6,  # encoder
                num_decoder_layers=6,  # decoder
                num_heads=4,  # keep d_model % num_heads == 0
                dropout_rate=0.1,

                # ── the three IDs MT5 needs ─────────────────────────────────
                pad_token_id         = tokenizer.pad_token_id,
                eos_token_id         = tokenizer.eos_token_id,
                decoder_start_token_id = tokenizer.pad_token_id,
            )

            base_model = MT5ForConditionalGeneration(cfg)
        else:
            base_model = MT5ForConditionalGeneration.from_pretrained(model_source)



    device = "cuda"
    base_model = base_model.to(device)



    print("Checkpoint reloaded!")
    # model = ForgetfulT5(base_model.config)

    # Initialize your custom (ModifiedT5) model.
    # @todo should both models be on gpu or do i need to merge them?
    model = ForgetfulT5(base_model)
    model = model.to(device)


    print(base_model.num_parameters() / 1e6, "M params")  # 30 M for 3+3 @ d=256

    # freeze non relevant layers
    #freeze_for_active_layer(model.base_model, LEAK.active_layer)
    #freeze_for_active_layer(model.base_model, LEAK.active_layer-1)



    # Path for the custom model state dict.
    custom_state_dict_path = state_dict_source

    reload = True

    if reload:
        # If a custom state dict exists, load it (with strict=False so new parameters are left untouched).
        if os.path.exists(custom_state_dict_path):
            print(f"Loading existing custom state dict from {custom_state_dict_path}...")
            state_dict = torch.load(custom_state_dict_path, map_location="cuda")
            missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
            print("Missing keys:", missing_keys)
            print("Unexpected keys:", unexpected_keys)
        else:
            print("No existing custom state dict found; proceeding with freshly initialized custom model.")
    else:
        # randomly overwrites weights
        fully_randomize(model, seed=50, std=0.2) # start seed 44

    print("Checkpoint reloaded!")


    # visualize weights
    model.visualize_ffn(23, which="all")
    model.visualize_ffn(22, which="all")

    #state_dict = torch.load(your_state_dict_path, map_location="cpu")
    #missing_keys, unexpected_keys = base_model.load_state_dict(state_dict, strict=False)

    train_samples = 2000000  #11900000 # 13016
    val_samples = 10000
    start = 5000000

    from datasets import Dataset

    # Load WMT19 training set
    if train_english:
        #dataset = load_dataset("wmt19", "zh-en", split="train")
        #dataset = load_dataset("wmt19", "gu-en", split="train")  # Gujarati–English
        #dataset = load_dataset("ai4bharat/samanantar", "as", split="train")
        #ds = load_from_disk("wiki_en_topk_10")
        ds = Dataset.from_file("wiki_en_topk_10/data.arrow")

        # dataset_test = load_dataset("wmt19", "zh-en", split="validation")

        # shortens dataset to speed training

        #dataset_test = dataset.select(range(start + train_samples, val_samples + train_samples + start))
        #dataset = dataset.select(range(start, train_samples))

        ds_test = ds.select(range(start + train_samples, val_samples + train_samples + start))
        ds = ds.select(range(start, start + train_samples))

    tokenize_fnn = partial(tokenize_forgiveness, tokenizer=tokenizer, train_english=True)  # True when english->assamese

    # 2) Create a DataLoader for mini-batching
    dataset_en = ds.map(tokenize_fnn, batched=True, num_proc=6)  # , remove_columns=["translation"])
    dataset_en_test = ds_test.map(tokenize_fnn, batched=True, num_proc=6)

    # tokenize_fnnn = partial(tokenize_fn, tokenizer=tokenizer, train_english=False)
    tokenize_fnnn = partial(tokenize_forgiveness, tokenizer=tokenizer, train_english=False)

    # 2) Create a DataLoader for mini-batching
    dataset_zh = ds.map(tokenize_fnnn, batched=True, num_proc=6)
    dataset_zh_test = ds_test.map(tokenize_fnnn, batched=True, num_proc=6)

    # 3) Convert to torch format (and select columns to keep)
    #dataset.set_format("torch", columns=["input_ids", "attention_mask", "labels"])

    print("Language Files loaded")

    dataset_en.set_format(
        type="torch",
        columns=["input_ids", "attention_mask", "labels"]
    )
    dataset_en_test.set_format(
        type="torch",
        columns=["input_ids", "attention_mask", "labels"]
    )

    dataset_zh.set_format(
        type="torch",
        columns=["input_ids", "attention_mask", "labels", "k_alts", "k_confs", "k_alt_ids"]
    )
    dataset_zh_test.set_format(
        type="torch",
        columns=["input_ids", "attention_mask", "labels", "k_alts", "k_confs", "k_alt_ids"]
    )

    #collator = DataCollatorWithPadding(tokenizer=tokenizer, return_tensors="pt")

    #device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    #model = base_model.to("cuda")

    # Create an optimizer
    # @todo remove weight decay
    # adds forgetfulness
    groups = groups_ffn_decay_except_wi0(model, wd_wi1=0.006, wd_wo=0.003)

    #optimizer = AdamW(groups, lr=3e-4)
    optimizer = AdamW(model.parameters(), lr=3e-4)
    scheduler = LambdaLR(optimizer, lr_lambda)

    #9.65 bleu
    model.train()  # put in training mode
    epochs = 10
    batches_per_day = 50000
    total_batches = 50000#1200
    wake_cycles = 1

    check_length = False

    # vars for sampling
    tau_start = 0.95
    tau_end = 0.6

    # vars for prediction forgiveness
    smoothing_start = 0.9
    smoothing_end = 0.15
    top_k = 10
    F_MOD = 1.5

    safe_to_forget = [False, False, False, False, False, False, False, False, False, False, False, False,
                      False, False, False, False, False, False, False, False, False, False, False, False]
    forget_now = True

    # handles forgiveness scheduling
    warm_fmod = 50 #20
    cool_fmod = 5 #50
    start_forgiveness_ep = 0
    end_forgiveness_ep = 25#130

    # toggle for magnetism to allow model to recover at times
    mag_check = True

    F_MOD_MULT = 1.0  # slowly scales to 0 instead of modifying f_mod outright
    enable_f_mod = True
    fmod_dec = 1.0 / cool_fmod
    fmod_add = 1.0 / warm_fmod
    f_mod_dec_const = F_MOD / cool_fmod


    if check_length:
        all_lengths = []

        for example in tqdm(dataset_en):
            input_ids = example['input_ids']
            length = sum(1 for token_id in input_ids if token_id != tokenizer.pad_token_id)
            all_lengths.append(length)

        print(f"Average length: {sum(all_lengths) / len(all_lengths):.2f} tokens")
        print(f"90th percentile: {sorted(all_lengths)[int(0.9 * len(all_lengths))]} tokens")
        print(f"Max length: {max(all_lengths)} tokens")

    assert len(dataset_en) == len(dataset_zh)
    assert len(dataset_en_test) == len(dataset_zh_test)

    # @todo precompute similar words for efficiency
    # freeze encoder during first epochs, slowly unthaw from top to bottom
    # Figure out why gradients are nan for autocast (need to use big range low precision)
    # switch back to mt5-Large and pretraining
    # ensure that forgetfulness is being applied properly
    # get vram usage way down
    # figure out why custom modules make model so much slower than standard t5
    print(f"Training for {epochs} epochs")
    for epoch in range(epochs):
        # Compute smoothing decay
        epoch_frac = epoch / epochs

        print("Reshuffle Data")
        seed = 50 + epoch


        if epoch >= 0:
            print("Scheduler Step")
            #scheduler.step()

        # once fg epoch is hit begin to warm
        if True:
            if start_forgiveness_ep <= epoch < start_forgiveness_ep + warm_fmod:
                F_MOD_MULT += fmod_add
            elif epoch > end_forgiveness_ep:
                F_MOD_MULT -= fmod_dec
            elif epoch > end_forgiveness_ep + cool_fmod:
                F_MOD_MULT = 0.0
        else:
            if epoch < 10:
                F_MOD_MULT -= fmod_dec
            else:
                F_MOD_MULT = 0.0

        #clamps f_mod
        if F_MOD_MULT < 0.0:
            F_MOD_MULT = 0.0
        elif F_MOD_MULT > 1.0:
            F_MOD_MULT = 1.0


        # @todo currently shuffling data every epoch, should look into self derived curriculum learning
        # Shuffle the datasets using Hugging Face's shuffle (deterministic with seed)
        perm = np.random.RandomState(seed).permutation(len(dataset_en))
        dataset_en = dataset_en.select(perm)
        dataset_zh = dataset_zh.select(perm)

        collator_en = T5SpanCorruptionCollator(
            tokenizer, noise_density=0.15, mean_span_len=3, input_length=64
        )

        print("Set up Dataloaders")

        batch_size = 38
        train_en_dataloader = DataLoader(dataset_en, batch_size=batch_size, pin_memory=False, shuffle=False, num_workers=0, collate_fn=collator_en)  #
        train_zh_dataloader = DataLoader(dataset_zh, batch_size=batch_size, pin_memory=False, shuffle=False, num_workers=0)

        if epoch == 0:
            print("Shuffle test data")
            dataset_en_test = dataset_en_test.shuffle(seed=seed)
            dataset_zh_test = dataset_zh_test.shuffle(seed=seed)
            test_en_dataloader = DataLoader(dataset_en_test, batch_size=batch_size, pin_memory=False, shuffle=False)
            test_zh_dataloader = DataLoader(dataset_zh_test, batch_size=batch_size, pin_memory=False, shuffle=False)

        # find nearest neighbors
        alt_k = 10  # how many neighbours to forgive
        forg_idx = []
        forg_prob = []

        #print(forg_idx[0], forg_val[0])

        total_loss = 0.0
        total_tokens = 0.0
        total_base_loss = 0.0
        num_batches = 0
        overall_bleu = 0.0
        bleu_batches = 0
        batch_count = 0 * batches_per_day
        next_batch = 10000
        acceptable_cut = 74.0

        total_bleu = 0.0
        total_bleu_batches = 0

        if epoch != 0:
            batch_count = 0

        # loops through whole year
        while batch_count < total_batches:
            overall_bleu = 0.0
            #total_loss = 0.0
            bleu_batches = 0

            accumulation_steps = 3

            scaler = GradScaler()

            #@todo make sure topk is being loaded correctly and is actually representative
            print("Go through day")
            for j in range(wake_cycles):
                for i, (batch_en, batch_zh) in enumerate(zip(train_en_dataloader, train_zh_dataloader)):
                    if i < batch_count:
                        continue
                    if i >= batch_count + batches_per_day:
                        break

                    #print(batch_en)
                    #print(len(batch_zh["k_alt_ids"]))
                    #print(len(batch_zh["k_alt_ids"][0]))
                    #print(len(batch_zh["k_alt_ids"][0][0]))

                    debug_alts = False

                    if debug_alts:
                        print(tokenizer.decode(batch_zh["input_ids"][0][2]), '\n', tokenizer.decode(batch_zh["k_alt_ids"][0][2][0]),
                                            '\n', tokenizer.decode(batch_zh["k_alt_ids"][0][2][1]),
                                            '\n', tokenizer.decode(batch_zh["k_alt_ids"][0][2][2]),
                                            '\n', tokenizer.decode(batch_zh["k_alt_ids"][0][2][3]),
                                            '\n', tokenizer.decode(batch_zh["k_alt_ids"][0][2][4]),
                                            '\n', tokenizer.decode(batch_zh["k_alt_ids"][0][2][5]),
                                            '\n', tokenizer.decode(batch_zh["k_alt_ids"][0][2][6]),
                                            '\n', tokenizer.decode(batch_zh["k_alt_ids"][0][2][7]),
                                            '\n\n')

                    # Move inputs/labels to GPU
                    input_ids = batch_en["input_ids"].to(device)
                    attention_mask = batch_en["attention_mask"].to(device)
                    labels = batch_zh["input_ids"].to(device)
                    zh_alts = batch_zh["k_alt_ids"].to(device)
                    alt_probs = batch_zh["k_confs"].to(device)

                    torch.cuda.reset_peak_memory_stats()

                    with autocast(True, dtype=torch.bfloat16):
                        # Forward pass
                        outputs = model(
                            input_ids=input_ids,
                            attention_mask=attention_mask,
                            labels=labels,               # built-in CE loss
                            use_cache=False
                        )
                        #logits = outputs.logits

                    log_probs = torch.nn.functional.log_softmax(outputs.logits, dim=-1)  # (B, T, V)

                    # ── gather the log-probs (→ probs) of the k “forgivable” tokens ───────────
                    neighbor_log_probs = -log_probs.gather(-1, zh_alts)
                    neighbor_probs = (-neighbor_log_probs).exp()

                    #print(alt_probs[0])

                    a_probs = alt_probs * F_MOD

                    a_probs = a_probs.clamp(min=0.0, max=0.85)


                    # ── which *single* token did the model actually choose? ───────────────────
                    #pred_ids = logits.argmax(dim=-1, keepdim=True)  # (B, T, 1)

                    #print(a_probs[0])
                    # mask[k] = 1 iff the model’s prediction equals that k-th alternative
                    #active_mask = (zh_alts == logits)  # (B, T, k)

                    # ── proportional forgiveness credit ───────────────────────────────────────
                    # credit =  Σ_k   𝟙[pred==alt_k] · P_model(alt_k) · weight_k
                    forgiveness = (neighbor_probs * a_probs).sum(dim=-1)

                    forgiveness = forgiveness.clamp(min=0.0, max=0.85)

                    target_log_probs = log_probs.gather(  # (B, T, 1)
                        -1,  # vocab dim
                        labels.unsqueeze(-1)  # gold ids
                    ).squeeze(-1)  # (B, T)

                    base_token_losses = -target_log_probs

                    #print(token_losses.sum())
                    #@todo make sure gold token isn't forgiven as that would dilute loss

                    target_log_probs = log_probs.gather(  # (B, T, 1)
                        -1,  # vocab dim
                        labels.unsqueeze(-1)  # gold ids
                    ).squeeze(-1)  # (B, T)

                    #print(target_log_probs[0][0])
                    p_gold  = (target_log_probs).exp()

                    p_eff = torch.clamp(p_gold + forgiveness, min=1e-8, max=1.0 - 1e-8)  # (B,T)

                    #print(p_gold[0][0], forgiveness[0][0], p_eff[0][0])

                    # Final loss:  −log p_eff
                    token_losses = -torch.log(p_eff)

                    # clamp at 0
                    token_losses = token_losses.clamp(min=0.)

                    # 2) mask and average
                    ignore_index = 0
                    non_pad_mask = (labels != ignore_index).float()

                    mixed_token_losses = F_MOD_MULT * token_losses + (1.0 - F_MOD_MULT) * base_token_losses

                    # find base loss for reference
                    loss = (mixed_token_losses * non_pad_mask).sum() / non_pad_mask.sum()
                    ntok = (labels != -100).sum().item()

                    with torch.no_grad():
                        base_loss = (base_token_losses * non_pad_mask).sum() / non_pad_mask.sum()


                    #print(loss.item(), base_loss.item(), blend_loss.item())
                    # 3) scale for gradient accumulation
                    loss = loss / accumulation_steps

                    #print(loss, base_loss)

                    # Backprop
                    scaler.scale(loss).backward()

                    #print("REAL peak:", torch.cuda.max_memory_allocated() / 1e6, "MB")

                    # perform the opt-step only after accum_steps mini-batches
                    if (num_batches + 1) % accumulation_steps == 0:
                        scaler.unscale_(optimizer)  # (opt) for clipping
                        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                        scaler.step(optimizer)
                        scaler.update()
                        scheduler.step()
                        optimizer.zero_grad()

                    # Periodic memory cleanup (every 1000 batches)
                    #if num_batches % 100 == 0:
                    #    torch.cuda.empty_cache()
                    #    torch.cuda.synchronize()

                    total_loss += loss.item() * accumulation_steps  # restore true loss
                    total_tokens += ntok

                    total_base_loss += base_loss.item()
                    num_batches += 1

                    #print(loss, base_loss)

                    # score performance
                    if j == 0:
                        with torch.no_grad():
                            outputs = model.generate(input_ids=input_ids,
                                                     attention_mask=attention_mask,
                                                     max_length=64)

                        truncated_outputs = []
                        for gen_ids, ref_ids in zip(outputs, batch_zh["input_ids"]):
                            ref_len = len(ref_ids)
                            gen_len = len(gen_ids)
                            clip_len = min(ref_len, gen_len)

                            truncated_gen_ids = gen_ids[:clip_len]
                            truncated_outputs.append(truncated_gen_ids)

                        vocab_size = tokenizer.vocab_size
                        safe_outputs = []
                        for ids in truncated_outputs:
                            safe_ids = [tok.item() for tok in ids if 0 <= tok.item() < vocab_size]
                            safe_outputs.append(safe_ids)

                        hypotheses = [
                            tokenizer.decode(ids, skip_special_tokens=True) for ids in safe_outputs
                        ]

                        references = [
                            tokenizer.decode(ref_ids, skip_special_tokens=True)
                            for ref_ids in batch_zh["input_ids"]
                        ]

                        bleu_score = sacrebleu.corpus_bleu(hypotheses, [references]).score
                        overall_bleu += bleu_score
                        total_bleu += bleu_score
                        total_bleu_batches += 1
                        bleu_batches += 1

                        val_ce = total_loss / total_tokens  # cross‑entropy
                        val_ppl = math.exp(val_ce)

                        if num_batches % 250 == 0:
                            bleu_s = overall_bleu/bleu_batches

                            if mag_check and bleu_s < 77.7:
                                mag_check = False
                            elif not mag_check and bleu_s > 78.2:
                                mag_check = True

                            if mag_check:
                                if LEAK.max_variance < 3.0:
                                    LEAK.max_variance += 0.075
                                elif LEAK.max_variance < 11.0:
                                    LEAK.max_variance += 0.1
                                elif LEAK.max_variance < 15.0:
                                    LEAK.max_variance += 0.05
                                #elif LEAK.max_variance < 20.0:
                                #    LEAK.max_variance += 0.3
                                print("VARIANCE", str(LEAK.max_variance))
                            else:
                                LEAK.max_variance -= 0.05

                            print(f"Batch {num_batches} of {total_batches}")
                            print(f"Loss: {total_loss / (num_batches)}")
                            print(f"Gold loss: {total_base_loss / (num_batches)}")
                            print(f"BLEU score: {overall_bleu/bleu_batches:.2f}")
                            print(f"Perplexity: {val_ppl}")
                            print(f"Overall: {total_bleu/total_bleu_batches:.2f}\n")

                            overall_bleu = 0.0
                            bleu_batches = 1.0

                            if F_MOD_MULT < 1.0:
                                F_MOD_MULT += fmod_add

                            tau_start = 0#10000
                            if num_batches > tau_start:
                                if LEAK.tau_mod > 1:
                                    LEAK.tau_mod -= 1
                                    print("TAU", str(LEAK.tau_mod))

                            if LEAK.tau_mod < 1.0:
                                LEAK.tau_mod = 1

                            if F_MOD_MULT < 0.0:
                                F_MOD_MULT = 0.0
                            elif F_MOD_MULT > 1.0:
                                F_MOD_MULT = 1.0

                            if not enable_f_mod:
                                F_MOD_MULT = 0.0



                        # enables forgetfulness if not in recovery phase
                        if num_batches % LEAK.tau == 0:
                            bleu_s = overall_bleu/bleu_batches

                            #if bleu_s <= (acceptable_cut - 14.0):
                            #    forget_now = False
                            #elif bleu_s >= (acceptable_cut - 10.0):
                            #    forget_now = True

                            if (epoch > 2 or num_batches >= 1000) and forget_now:# and bleu_s > (acceptable_cut - 2.0):
                                # cycles forgetfulness
                                if num_batches % 5000 < 1250:
                                    for i in range(len(safe_to_forget)):
                                        if i % 4 == 0:
                                            safe_to_forget[i] = True
                                elif num_batches % 5000 < 2500:
                                    for i in range(len(safe_to_forget)):
                                        if i % 4 == 1:
                                            safe_to_forget[i] = True
                                elif num_batches % 5000 < 3750:
                                    for i in range(len(safe_to_forget)):
                                        if i % 4 == 2:
                                            safe_to_forget[i] = True
                                else:
                                    for i in range(len(safe_to_forget)):
                                        if i % 4 == 3:
                                            safe_to_forget[i] = True

                        """if safe_to_forget:
                            print(bleu_s)
                            if bleu_s > 65.0:
                                LEAK.alpha_val += 0.001
                                LEAK.alpha_gate += 0.001
                            else:
                                safe_to_forget = False
                                LEAK.alpha_val -= 0.001
                                LEAK.alpha_gate -= 0.001
                                if LEAK.alpha_val < 0.0:
                                    LEAK.alpha_val = 0.0
                                elif LEAK.alpha_gate < 0.0:
                                    LEAK.alpha_gate = 0.0"""

                        if num_batches % 10000 == 0:
                            base_model.save_pretrained(model_save, max_shard_size="1GB")
                            tokenizer.save_pretrained(model_save)
                            torch.save(model.state_dict(), str(state_dict_save + str(epoch) + str('_batch_') + str(num_batches) + '.pt'))

                            print("Checkpoint saved!")
                            print("Location: ", str(state_dict_save + str(epoch) + '.pt'))




            # continue to next day
            batch_count += batches_per_day

            # Print average loss over the day
            avg_loss = 100.0 * total_loss / batch_count
            print(f"Day {batch_count/batches_per_day}/{total_batches/batches_per_day} - avg train loss: {avg_loss:.4f}\n")



        # Print average loss over the epoch
        avg_loss = total_loss / (batches_per_day * batch_size)
        print(f"Epoch {epoch+1}/{epochs} - avg train loss: {avg_loss:.4f}\n")

        # Put model in eval mode for inference
        model.eval()
        overall_bleu = 0.0
        current_batch = 0

        for batch_en, batch_zh in zip(test_en_dataloader, test_zh_dataloader):
            # 1. Move English input to GPU
            input_ids = batch_en["input_ids"].to(device)
            attention_mask = batch_en["attention_mask"].to(device)

            # 2. Build your references from the Chinese batch (already tokenized)
            references = [
                tokenizer.decode(ref_ids, skip_special_tokens=True)
                for ref_ids in batch_zh["input_ids"]
            ]

            # 3. Generate predictions (Chinese) from the English input
            with torch.no_grad():
                outputs = model.generate(input_ids=input_ids,
                                         attention_mask=attention_mask,
                                         max_length=64)
            truncated_outputs = []
            for gen_ids, ref_ids in zip(outputs, batch_zh["input_ids"]):
                ref_len = len(ref_ids)  # Number of tokens in the reference
                truncated_gen_ids = gen_ids[:ref_len]
                truncated_outputs.append(truncated_gen_ids)

            # Decode model outputs into text
            hypotheses = [
                tokenizer.decode(ids, skip_special_tokens=True) for ids in truncated_outputs
            ]

            # 4. Evaluate BLEU (or any metric) comparing `hypotheses` to `references`
            #    Here we assume you have a helper function like `compute_bleu_score`.
            #    Alternatively, you could call `evaluate_bleu(model, ...)` if it internally
            #    handles the generation, but the key is to pass EN as input and compare
            #    with the ZH references.
            bleu_score = sacrebleu.corpus_bleu(hypotheses, [references]).score
            overall_bleu += bleu_score

            current_batch += 1
            if current_batch > 100:
                # Stop early if desired
                break

        overall_bleu /= current_batch

        model.train()

        # print(f"\nTest BLEU score: {bleu_score:.2f}")
        print(f"\nTest BLEU score: {overall_bleu:.2f}")

        if epoch % 1 == 0:
            base_model.save_pretrained(model_save)
            tokenizer.save_pretrained(model_save)
            torch.save(model.state_dict(), str(state_dict_save + str(epoch) + '.pt'))

            print("Checkpoint saved!")
            print("Location: ", str(state_dict_save + str(epoch) + '.pt'))

    # save checkpoint
    """if epochs != 0 and epochs % 5 == 0:
        base_model.save_pretrained(model_save)
        tokenizer.save_pretrained(model_save)
        torch.save(model.state_dict(), str(custom_state_dict_path + str(epoch)))

        print("Checkpoint saved!")"""

    # check score

    # Put model in eval mode for inference
    model.eval()
    overall_bleu = 0.0
    current_batch = 0

    # it should be high early on when forgetfulness is also high. Then oscillate up or down depending on the day/content
    for batch_en, batch_zh in zip(test_en_dataloader, test_zh_dataloader):
        # 1. Move English input to GPU
        input_ids = batch_en["input_ids"].to(device)
        attention_mask = batch_en["attention_mask"].to(device)

        # 2. Build your references from the Chinese batch (already tokenized)
        references = [
            tokenizer.decode(ref_ids, skip_special_tokens=True)
            for ref_ids in batch_zh["input_ids"]
        ]

        # 3. Generate predictions (Chinese) from the English input
        with torch.no_grad():
            outputs = model.generate(input_ids=input_ids,
                                     attention_mask=attention_mask,
                                     max_length=64)
        truncated_outputs = []
        for gen_ids, ref_ids in zip(outputs, batch_zh["input_ids"]):
            ref_len = len(ref_ids)  # Number of tokens in the reference
            truncated_gen_ids = gen_ids[:ref_len]
            truncated_outputs.append(truncated_gen_ids)

        # Decode model outputs into text
        hypotheses = [
            tokenizer.decode(ids, skip_special_tokens=True) for ids in truncated_outputs
        ]

        # 4. Evaluate BLEU (or any metric) comparing `hypotheses` to `references`
        #    Here we assume you have a helper function like `compute_bleu_score`.
        #    Alternatively, you could call `evaluate_bleu(model, ...)` if it internally
        #    handles the generation, but the key is to pass EN as input and compare
        #    with the ZH references.
        bleu_score = sacrebleu.corpus_bleu(hypotheses, [references]).score
        overall_bleu += bleu_score

        # Optionally print a few predictions
        print("Sample predictions:")
        for ref, hyp in zip(references, hypotheses):
            print(f"Reference : {ref}")
            print(f"Predicted : {hyp}")
            print("-----")

        current_batch += 1
        if current_batch > 100:
            # Stop early if desired
            break

    overall_bleu /= current_batch

    #print(f"\nTest BLEU score: {bleu_score:.2f}")
    print(f"\nTest BLEU score: {overall_bleu:.2f}")
    #for t, h in zip(test_texts, hyps):
    #    print(f"Input:  {t}")
    #    print(f"Output: {h}")
    #    print("---")



