import os
from transformers.activations import ACT2FN

os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

import numpy as np
import random
import torch.nn as nn
import torch.optim as optim
from transformers import MT5ForConditionalGeneration, AutoTokenizer, MT5Tokenizer, AutoModelForSeq2SeqLM, AutoModel
from torch.optim import AdamW
import sacrebleu
from torch.utils.data import DataLoader
from datasets import load_dataset, load_from_disk
from functools import partial
from transformers import DataCollatorWithPadding
from torch.optim.lr_scheduler import StepLR
from torch.cuda.amp import autocast, GradScaler
from tqdm import tqdm
from torch.nn import functional as F
from functools import lru_cache
from dataclasses import dataclass, field
from itertools import groupby  # ← add this
import matplotlib.pyplot as plt
from datasets import load_dataset, interleave_datasets

import torch, gc, random
import math
import re
import torch.fft as fft
from itertools import groupby
from torch.optim.lr_scheduler import LambdaLR


# from mbert_pretraining import BATCH_SIZE

# print("PyTorch version:", torch.__version__)
# print("CUDA available:", torch.cuda.is_available())
# print("CUDA device count:", torch.cuda.device_count())
# print("CUDA version:", torch.version.cuda)


# corruption collator

from project_classes import T5SpanCorruptionCollator, CustomDenseReluDense, ForgetfulT5, DownBandAccumulator, BandSteeringController

from project_classes import load_english, tokenize_fn, lr_lambda, LEAK, build_baseline_and_targets

from collections import Counter

# =========================
# Fast signature matching
# =========================
import torch
from typing import Dict, List, Optional, Tuple

# Optional FAISS acceleration
try:
    import faiss  # type: ignore
    _FAISS_AVAILABLE = True
except Exception:
    _FAISS_AVAILABLE = False


# ---------- Bank builders ----------

def build_layer_bank(
    layer_centers: torch.Tensor,   # [N, d], float32/float16 on CPU or GPU
    layer_powers: torch.Tensor,    # [N]
    use_faiss: bool = False,
    faiss_gpu: bool = False
) -> Dict:
    """
    Precompute norms and (optionally) a FAISS index for a single layer.
    Returns a dict you can cache per layer.
    """
    assert layer_centers.dim() == 2 and layer_powers.dim() == 1
    N, d = layer_centers.shape
    assert layer_powers.shape[0] == N

    # Always keep CPU float32 copies for FAISS (if used), since most FAISS CPU indices want float32 arrays.
    centers_cpu = layer_centers.detach().float().cpu().contiguous()
    powers_cpu  = layer_powers.detach().float().cpu().contiguous()

    bank = {
        "centers": layer_centers,                        # possibly on GPU
        "powers": layer_powers,                          # possibly on GPU
        "centers_norm2": (layer_centers**2).sum(dim=1),  # [N] on same device as centers
        "use_faiss": False,
        "faiss_index": None,
        "faiss_res": None,
        "cpu_copy": {"centers": centers_cpu, "powers": powers_cpu},  # for FAISS
    }

    if use_faiss and _FAISS_AVAILABLE:
        index = faiss.IndexFlatL2(d)
        if faiss_gpu:
            # Move to GPU if requested
            res = faiss.StandardGpuResources()
            index = faiss.index_cpu_to_gpu(res, 0, index)
            bank["faiss_res"] = res
        index.add(centers_cpu.numpy())  # FAISS expects float32 row-major
        bank["faiss_index"] = index
        bank["use_faiss"] = True

    return bank


def build_banks_by_layer(
    bank_by_layer_input: Dict[int, Dict[str, torch.Tensor]],
    use_faiss: bool = False,
    faiss_gpu: bool = False
) -> Dict[int, Dict]:
    """
    bank_by_layer_input[L] = {"centers":[N,d], "powers":[N]}
    Returns banks with norms and optional FAISS per layer.
    """
    out = {}
    for L, pack in bank_by_layer_input.items():
        out[L] = build_layer_bank(
            layer_centers=pack["centers"],
            layer_powers=pack["powers"],
            use_faiss=use_faiss,
            faiss_gpu=faiss_gpu
        )
    return out


# ---------- Power prefilter ----------

def _power_prefilter(
    cand_power: torch.Tensor,      # [] or [B]
    bank_powers: torch.Tensor,     # [N] (same device as centers)
    tol_frac: float,
    ref: str = "center"
) -> torch.Tensor:
    """
    Returns a boolean mask over [N] (or [B, N] if cand_power is [B])
    indicating which bank items pass the relative power diff tolerance.
    """
    eps = 1e-12
    if cand_power.dim() == 0:  # scalar
        if ref == "center":
            denom = bank_powers.abs().clamp_min(eps)             # [N]
            return (bank_powers - cand_power).abs() <= tol_frac * denom
        else:  # candidate
            denom = cand_power.abs().clamp_min(eps)              # []
            return (bank_powers - cand_power).abs() <= tol_frac * denom
    else:
        # batched [B]
        if ref == "center":
            denom = bank_powers.abs().clamp_min(eps)[None, :]    # [1, N]
        else:
            denom = cand_power.abs().clamp_min(eps)[:, None]     # [B, 1]
        diff = (bank_powers[None, :] - cand_power[:, None]).abs()  # [B, N]
        return diff <= tol_frac * denom                           # [B, N]


# ---------- Exact, vectorized distance (fallback / baseline) ----------

def fast_exact_layer_hits(
    cand_center: torch.Tensor,     # [d]
    cand_power:  torch.Tensor,     # [] or [1]
    bank: Dict,
    dist_thresh: float = 5.0,
    power_tol_frac: float = 0.10,
    power_ref: str = "center"
) -> Dict:
    """
    Fully vectorized exact check over the whole bank (no ANN).
    Uses precomputed norms and power prefilter to prune aggressively.
    """
    C = bank["centers"]           # [N, d]
    Cn2 = bank["centers_norm2"]   # [N]
    P = bank["powers"]            # [N]
    device = C.device

    x = cand_center.to(device)                    # [d]
    xp = cand_power.to(P.device)

    # Power prefilter (cheap)
    pw_ok = _power_prefilter(xp, P, power_tol_frac, ref=power_ref)  # [N]
    if not pw_ok.any():
        return {"any_match": False, "match_indices": [], "counts": {"num_matches": 0, "num_candidates": int(C.shape[0])}}

    # Only compute distances for survivors
    idx = torch.nonzero(pw_ok, as_tuple=False).squeeze(1)
    C_sub = C.index_select(dim=0, index=idx)         # [M, d]
    Cn2_sub = Cn2.index_select(dim=0, index=idx)     # [M]

    x2 = (x * x).sum()                                # []
    d2 = Cn2_sub - 2.0 * (C_sub @ x) + x2            # [M]
    d = torch.sqrt(torch.clamp(d2, min=0.0))
    ok = d <= dist_thresh

    if not ok.any():
        return {"any_match": False, "match_indices": [], "counts": {"num_matches": 0, "num_candidates": int(C.shape[0])}}

    valid_idx = idx[ok]
    # Pick closest for convenience
    best_local = torch.argmin(d[ok])
    best_idx = int(valid_idx[best_local].item())

    # Recompute power diff frac for the best (reporting)
    if power_ref == "center":
        denom = P[best_idx].abs().clamp_min(1e-12)
    else:
        denom = xp.abs().clamp_min(1e-12)
    best_power_diff_frac = float(((P[best_idx] - xp).abs() / denom).item())

    return {
        "any_match": True,
        "match_indices": valid_idx.tolist(),
        "best_idx": best_idx,
        "best": {
            "dist": float((d[ok][best_local]).item()),
            "power_diff_frac": best_power_diff_frac,
        },
        "counts": {
            "num_matches": int(ok.sum().item()),
            "num_candidates": int(C.shape[0]),
        }
    }


# ---------- FAISS-accelerated (k-NN or exact radius) ----------

def faiss_layer_hits(
    cand_center: torch.Tensor,     # [d]
    cand_power:  torch.Tensor,     # [] or [1]
    bank: Dict,
    dist_thresh: float = 5.0,
    power_tol_frac: float = 0.10,
    power_ref: str = "center",
    k: int = 32,
    use_radius: bool = True
) -> Dict:
    """
    Use FAISS to shortlist candidates, then apply the exact rule.
    If use_radius=True, performs exact radius search (no misses) with radius=dist_thresh.
    Otherwise uses k-NN (fast) and filters with the rule (set k generously).
    """
    assert bank["use_faiss"] and _FAISS_AVAILABLE, "FAISS index not available in this bank."

    index = bank["faiss_index"]
    centers_cpu = bank["cpu_copy"]["centers"]  # [N, d] float32 CPU
    powers_cpu  = bank["cpu_copy"]["powers"]   # [N]

    x = cand_center.detach().float().cpu().unsqueeze(0).contiguous().numpy()  # [1, d]

    if use_radius:
        # FAISS radius is on squared L2; search returns d2
        radius2 = float(dist_thresh**2)
        lims, D2, I = index.range_search(x, radius2)
        # Extract indices for the single query
        start, end = lims[0], lims[1]
        idxs = I[start:end]   # np.array of indices within radius
        if idxs.size == 0:
            return {"any_match": False, "match_indices": [], "counts": {"num_matches": 0, "num_candidates": int(centers_cpu.shape[0])}}
    else:
        # k-NN shortlist
        _, I = index.search(x, k)   # I: [1, k]
        idxs = I[0]

    # Power prefilter on the shortlist
    P_sub = powers_cpu[idxs]                         # [M] CPU
    xp = cand_power.detach().float().cpu()
    eps = 1e-12
    if power_ref == "center":
        denom = torch.from_numpy(P_sub.numpy()).abs().clamp_min(eps)  # convert to torch
    else:
        denom = xp.abs().clamp_min(eps).expand_as(torch.from_numpy(P_sub.numpy()))
    pw_ok = (torch.from_numpy(P_sub.numpy()) - xp).abs() <= power_tol_frac * denom

    if not pw_ok.any():
        return {"any_match": False, "match_indices": [], "counts": {"num_matches": 0, "num_candidates": int(centers_cpu.shape[0])}}

    # For the survivors, compute exact distances (still small M)
    keep = torch.nonzero(pw_ok, as_tuple=False).squeeze(1)
    kept_idxs = torch.tensor(idxs, dtype=torch.long)[keep]     # global indices in bank

    # Use device tensors for precise dists
    C = bank["centers"]                    # [N, d] on device
    Cn2 = bank["centers_norm2"]            # [N]
    P = bank["powers"]                     # [N]
    device = C.device

    x_dev = cand_center.to(device)
    xp_dev = cand_power.to(P.device)

    C_sub = C.index_select(0, kept_idxs.to(device))
    Cn2_sub = Cn2.index_select(0, kept_idxs.to(device))

    x2 = (x_dev * x_dev).sum()
    d2 = Cn2_sub - 2.0 * (C_sub @ x_dev) + x2
    d = torch.sqrt(torch.clamp(d2, min=0.0))
    ok = d <= dist_thresh

    if not ok.any():
        return {"any_match": False, "match_indices": [], "counts": {"num_matches": 0, "num_candidates": int(C.shape[0])}}

    valid_idx = kept_idxs[ok.cpu()]
    best_local = torch.argmin(d[ok])
    best_idx = int(valid_idx[best_local].item())

    if power_ref == "center":
        denom = P[best_idx].abs().clamp_min(1e-12)
    else:
        denom = xp_dev.abs().clamp_min(1e-12)
    best_power_diff_frac = float(((P[best_idx] - xp_dev).abs() / denom).item())

    return {
        "any_match": True,
        "match_indices": valid_idx.tolist(),
        "best_idx": best_idx,
        "best": {
            "dist": float((d[ok][best_local]).item()),
            "power_diff_frac": best_power_diff_frac,
        },
        "counts": {
            "num_matches": int(ok.sum().item()),
            "num_candidates": int(C.shape[0]),
        }
    }


# ---------- Multi-layer wrapper (single candidate) ----------

@torch.no_grad()
def fast_score_across_layers(
    cand_centers_by_layer: Dict[int, torch.Tensor],   # L -> [d]
    cand_powers_by_layer:  Dict[int, torch.Tensor],   # L -> []
    banks_by_layer: Dict[int, Dict],                  # built with build_banks_by_layer
    layers: List[int],
    dist_thresh: float = 5.0,
    power_tol_frac: float = 0.10,
    power_ref: str = "center",
    ann_k: int = 32,
    use_ann_if_available: bool = True,
    use_radius: bool = True
) -> Dict:
    per_layer = {}
    hits = 0

    for L in layers:
        bank = banks_by_layer[L]
        if use_ann_if_available and bank.get("use_faiss", False):
            res = faiss_layer_hits(
                cand_center=cand_centers_by_layer[L],
                cand_power=cand_powers_by_layer[L],
                bank=bank,
                dist_thresh=dist_thresh,
                power_tol_frac=power_tol_frac,
                power_ref=power_ref,
                k=ann_k,
                use_radius=use_radius
            )
        else:
            res = fast_exact_layer_hits(
                cand_center=cand_centers_by_layer[L],
                cand_power=cand_powers_by_layer[L],
                bank=bank,
                dist_thresh=dist_thresh,
                power_tol_frac=power_tol_frac,
                power_ref=power_ref
            )
        per_layer[L] = res
        hits += int(res.get("any_match", False))

    score = hits / max(1, len(layers))
    return {
        "score": float(score),
        "hits": int(hits),
        "total_layers": int(len(layers)),
        "per_layer": per_layer
    }


# ---------- Batch scoring (many candidates, one layer) ----------

@torch.no_grad()
def fast_layer_hits_batch(
    cand_centers: torch.Tensor,    # [B, d]
    cand_powers:  torch.Tensor,    # [B]
    bank: Dict,
    dist_thresh: float = 5.0,
    power_tol_frac: float = 0.10,
    power_ref: str = "center",
    topk_per_query: int = 1
) -> Dict:
    """
    Exact vectorized scan for B candidates against one layer (fast on GPU).
    Returns per-query hit booleans and optional top-1 match indices.
    """
    C = bank["centers"]            # [N, d]
    Cn2 = bank["centers_norm2"]    # [N]
    P = bank["powers"]             # [N]
    device = C.device

    X = cand_centers.to(device)               # [B, d]
    XP = cand_powers.to(P.device)             # [B]

    # Power prefilter: [B, N]
    pw_ok = _power_prefilter(XP, P, power_tol_frac, ref=power_ref).to(device)

    # Compute full distances with precomputed norms:
    # d^2 = ||x||^2 + ||C||^2 - 2 x·C
    X2 = (X * X).sum(dim=1, keepdim=True)                 # [B,1]
    d2 = Cn2[None, :] - 2.0 * (X @ C.t()) + X2            # [B,N]
    d = torch.sqrt(torch.clamp(d2, min=0.0))

    ok = (d <= dist_thresh) & pw_ok                       # [B,N]
    any_hit = ok.any(dim=1)                               # [B]

    out = {
        "any_match": any_hit,
        "num_matches": ok.sum(dim=1),                     # [B]
    }

    if topk_per_query >= 1:
        # Mask non-hits with +inf so they won't be chosen
        masked_d = torch.where(ok, d, torch.full_like(d, float("inf")))
        best_d, best_idx = masked_d.min(dim=1)           # [B]
        out["best_idx"] = best_idx                       # [B] (valid only when any_match is True)
        out["best_dist"] = best_d
    return out



def collect_top_tokens_from_dataset(dataset, tokenizer, max_nonpad_tokens=1_000_000, top_k=200):
    """
    Iterates over your *tokenized* HF dataset (e.g., dataset_en) and counts token IDs
    until reaching ~max_nonpad_tokens (excludes pad & all special tokens).
    Returns: (top_ids, id_to_freq) where top_ids is a list of the top k token IDs.
    """
    special_ids = set(tokenizer.all_special_ids or [])
    pad_id = tokenizer.pad_token_id
    if pad_id is not None:
        special_ids.add(pad_id)

    freq = Counter()
    seen = 0

    # NOTE: dataset_en is already set to torch format; iterate example-by-example to avoid span corruption.
    for ex in dataset:
        ids = ex["input_ids"]
        # Some tokenizers may produce 1D tensors; ensure iterable of ints
        for tid in ids:
            tid_int = int(tid)
            if tid_int in special_ids:
                continue
            freq[tid_int] += 1
            seen += 1
            if seen >= max_nonpad_tokens:
                break
        if seen >= max_nonpad_tokens:
            break

    # Top k IDs by frequency
    top_ids = [tid for tid, _ in freq.most_common(top_k)]
    return top_ids, freq



SPACE_PREFIXES = ("▁", "Ġ")  # common SentencePiece/BPE word-boundary markers

def _parse_layer_key(k):
    if isinstance(k, int): return k
    if isinstance(k, str):
        if k.isdigit(): return int(k)
        if k[:1] in ("D", "d") and k[1:].isdigit(): return int(k[1:])
    return None

def _normalize_token_key_exists(signatures, token):
    """Try raw, then without leading space prefix."""
    if token in signatures:
        return token
    for pref in SPACE_PREFIXES:
        if token.startswith(pref):
            cand = token[len(pref):]
            if cand in signatures:
                return cand
    return None

def extract_candidate_from_signatures(signatures, token, layers):
    """Collect per-layer candidate center/power for a specific token (string)."""
    tok_key = _normalize_token_key_exists(signatures, token)
    if tok_key is None:
        raise KeyError(f"Token {token!r} not found in signatures")

    per_layer = signatures[tok_key]
    cand_centers, cand_powers = {}, {}
    layer_set = set(layers)

    # iterate this token's per-layer dict via its actual keys (D0/0/\"0\")
    for k, sigL in per_layer.items():
        L = _parse_layer_key(k)
        if L is None or L not in layer_set:
            continue
        if not (isinstance(sigL, dict) and "center" in sigL and "power" in sigL):
            continue
        c = sigL["center"]; p = sigL["power"]
        cand_centers[L] = c if torch.is_tensor(c) else torch.as_tensor(c)
        cand_powers[L]  = p if torch.is_tensor(p) else torch.tensor(p)

    if not cand_centers:
        raise KeyError(f"Token {token!r} has no matching layers among {layers}")

    return cand_centers, cand_powers

def decode_token_ids(tokenizer, token_ids):
    """
    Decodes IDs to display strings for logging. Uses convert_ids_to_tokens to avoid
    accidental whitespace stripping that decode() can do.
    """
    toks = tokenizer.convert_ids_to_tokens(token_ids)
    # Fallback for any None/empty artifacts
    toks = [t if (t is not None and t != "") else tokenizer.decode([tid], skip_special_tokens=True)
            for t, tid in zip(toks, token_ids)]
    return toks


def make_text_baseline_from_raw(ds, tokenizer, target_token_budget=1_000_000):
    """
    Builds a list of raw text strings from your *raw* dataset ds["text"] until
    tokenized length reaches ~target_token_budget. This will be fed as 'sentences'
    to build_baseline_and_targets (no corruption).
    """
    text_bank = []
    total = 0
    special_ids = set(tokenizer.all_special_ids or [])
    pad_id = tokenizer.pad_token_id
    if pad_id is not None:
        special_ids.add(pad_id)

    for txt in ds["text"]:
        ids = tokenizer(txt, add_special_tokens=False)["input_ids"]
        # Count only non-specials
        nonpad = sum(1 for tid in ids if tid not in special_ids)
        text_bank.append(txt)
        total += nonpad
        if total >= target_token_budget:
            break
    return text_bank

def signatures_to_bank_input(signatures, layers):
    """
    signatures: dict[token] -> dict per-layer (keys like L, "L", or "D{L}")
                where each per-layer entry is {'center': [d], 'power': float}.
    layers: list[int], e.g. range(12)

    Returns: { L: {"centers":[N,d], "powers":[N]} } (N = #tokens that had that layer)
    """
    bank = {}

    for L in layers:
        centers_list, powers_list = [], []

        for tok, target in signatures.items():
            # resolve per-layer key for this token
            if L in target:
                key = L
            elif str(L) in target:
                key = str(L)
            elif f"D{L}" in target:
                key = f"D{L}"
            else:
                continue  # token has no entry for this layer

            sigL = target[key]
            # expected per-token-per-layer schema
            if isinstance(sigL, dict) and ("center" in sigL and "power" in sigL):
                c = sigL["center"]
                p = sigL["power"]
                centers_list.append(c if torch.is_tensor(c) else torch.as_tensor(c))
                powers_list.append(float(p))
            else:
                # ignore unknown shapes; make this a raise if you prefer strictness
                continue

        if centers_list:
            centers = torch.stack(centers_list, dim=0)                 # [N,d]
            powers  = torch.tensor(powers_list, dtype=torch.float32)   # [N]
            assert centers.dim() == 2, f"centers must be [N,d] at layer {L}"
            assert powers.dim() == 1 and powers.shape[0] == centers.shape[0], \
                f"powers shape mismatch at layer {L}"
            bank[L] = {"centers": centers, "powers": powers}
        # else: no tokens provided this layer; skip creating bank[L]

    return bank

if __name__ == "__main__":
    train_english = True

    from transformers import MT5Config, MT5ForConditionalGeneration

    if train_english:
        model_save = "mt5_base_forgive_and_forget_ft"
        # model_source = "stage1_base_step120000"
        model_source = "google/mt5-base"

        # full model
        state_dict_source = "mt5_base_forgive_and_forget_whole_stream00_batch_10000.pt"

        # state_dict_source = "mt5_small_as_to_en_ref190.pt"
        # state_dict_source = "mt5_small_en_to_as_forgive10.pt"
        state_dict_save = "mt5_base_forgive_and_forget_whole_stream"

        # load model
        tokenizer = AutoTokenizer.from_pretrained(model_source,
                                                  use_fast=False,  # keep full SentencePiece behaviour
                                                  legacy=False)


    device = "cuda"
    base_model = MT5ForConditionalGeneration.from_pretrained(model_source)
    base_model = base_model.to(device)

    print("Checkpoint reloaded!")
    # model = ForgetfulT5(base_model.config)

    # Initialize your custom (ModifiedT5) model.
    # @todo should both models be on gpu or do i need to merge them?
    model = ForgetfulT5(base_model)
    model = model.to(device)

    print(base_model.num_parameters() / 1e6, "M params")  # 30 M for 3+3 @ d=256

    # Path for the custom model state dict.
    custom_state_dict_path = state_dict_source

    reload = True

    if reload:
        # If a custom state dict exists, load it (with strict=False so new parameters are left untouched).
        if os.path.exists(custom_state_dict_path):
            print(f"Loading existing custom state dict from {custom_state_dict_path}...")
            state_dict = torch.load(custom_state_dict_path, map_location="cuda")
            missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
            print("Missing keys:", missing_keys)
            print("Unexpected keys:", unexpected_keys)
        else:
            print("No existing custom state dict found; proceeding with freshly initialized custom model.")
    else:
        if custom_state_dict_path is not None and os.path.exists(custom_state_dict_path):
            sd = torch.load(custom_state_dict_path, map_location="cpu")
            if isinstance(sd, dict) and "state_dict" in sd:
                sd = sd["state_dict"]


            # 1) Strip the extra prefix (base_model.) and common wrappers
            def strip_prefix(k: str) -> str:
                return re.sub(r'^(?:model\.|module\.|base_model\.)', '', k, count=1)


            sd = {strip_prefix(k): v for k, v in sd.items()}

            # 2) (Optional) resize embeddings if tokenizer size differs
            # model.resize_token_embeddings(len(tok))

            # 3) Keep only keys the HF model actually has (drop custom extras)
            model_keys = set(model.state_dict().keys())
            sd = {k: v for k, v in sd.items() if k in model_keys}

            # 4) *** LOAD INTO MODEL ***
            missing, unexpected = model.load_state_dict(sd, strict=False)
            print(f"[state_dict] loaded: kept={len(sd)}  missing={len(missing)}  unexpected={len(unexpected)}")

        else:
            print("[state_dict] no state_dict_path provided or file not found; using base weights only.")

    print("Checkpoint reloaded!")

    # visualize weights
    # model.visualize_ffn(23, which="both")
    # model.visualize_ffn(22, which="both")

    # state_dict = torch.load(your_state_dict_path, map_location="cpu")
    # missing_keys, unexpected_keys = base_model.load_state_dict(state_dict, strict=False)

    train_samples = 2000000  # 11900000 # 13016
    val_samples = 10000
    start = 0

    from datasets import Dataset

    # Load WMT19 training set
    if train_english:
        # dataset = load_dataset("wmt19", "zh-en", split="train")
        # dataset = load_dataset("wmt19", "gu-en", split="train")  # Gujarati–English
        # dataset = load_dataset("ai4bharat/samanantar", "as", split="train")
        # ds = load_from_disk("wiki_en_topk_10")
        ds = load_english()

        # dataset_test = load_dataset("wmt19", "zh-en", split="validation")

        # shortens dataset to speed training

        # dataset_test = dataset.select(range(start + train_samples, val_samples + train_samples + start))
        # dataset = dataset.select(range(start, train_samples))

        ds_test = ds.select(range(start + train_samples, val_samples + train_samples + start))
        ds = ds.select(range(start, start + train_samples))

    tokenize_fnn = partial(tokenize_fn, tokenizer=tokenizer, train_english=True)  # True when english->assamese

    # 2) Create a DataLoader for mini-batching
    dataset_en = ds.map(tokenize_fnn, batched=True, num_proc=6)  # , remove_columns=["translation"])
    dataset_en_test = ds_test.map(tokenize_fnn, batched=True, num_proc=6)

    # tokenize_fnnn = partial(tokenize_fn, tokenizer=tokenizer, train_english=False)
    tokenize_fnnn = partial(tokenize_fn, tokenizer=tokenizer, train_english=False)

    # 2) Create a DataLoader for mini-batching
    dataset_zh = ds.map(tokenize_fnnn, batched=True, num_proc=6)
    dataset_zh_test = ds_test.map(tokenize_fnnn, batched=True, num_proc=6)

    # 3) Convert to torch format (and select columns to keep)
    # dataset.set_format("torch", columns=["input_ids", "attention_mask", "labels"])

    print("Language Files loaded")

    dataset_en.set_format(
        type="torch",
        columns=["input_ids", "attention_mask", "labels"]
    )
    dataset_en_test.set_format(
        type="torch",
        columns=["input_ids", "attention_mask", "labels"]
    )

    dataset_zh.set_format(
        type="torch",
        columns=["input_ids", "attention_mask", "labels"]
    )
    dataset_zh_test.set_format(
        type="torch",
        columns=["input_ids", "attention_mask", "labels"]
    )

    # collator = DataCollatorWithPadding(tokenizer=tokenizer, return_tensors="pt")

    # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    # model = base_model.to("cuda")

    # Create an optimizer
    # @todo remove weight decay
    # adds forgetfulness

    # optimizer = AdamW(groups, lr=3e-4)
    optimizer = AdamW(model.parameters(), lr=3e-4)
    scheduler = LambdaLR(optimizer, lr_lambda)

    # 9.65 bleu
    model.train()  # put in training mode
    epochs = 10
    batches_per_day = 50000
    total_batches = 50000  # 1200
    wake_cycles = 1

    check_length = False

    # vars for sampling
    tau_start = 0.95
    tau_end = 0.6

    # vars for prediction forgiveness
    smoothing_start = 0.9
    smoothing_end = 0.15
    top_k = 10
    F_MOD = 1.5


    forget_now = True

    # handles forgiveness scheduling
    warm_fmod = 50  # 20
    cool_fmod = 5  # 50
    start_forgiveness_ep = 0
    end_forgiveness_ep = 25  # 130

    # toggle for magnetism to allow model to recover at times
    mag_check = True

    F_MOD_MULT = 1.0  # slowly scales to 0 instead of modifying f_mod outright
    enable_f_mod = True
    fmod_dec = 1.0 / cool_fmod
    fmod_add = 1.0 / warm_fmod
    f_mod_dec_const = F_MOD / cool_fmod

    if check_length:
        all_lengths = []

        for example in tqdm(dataset_en):
            input_ids = example['input_ids']
            length = sum(1 for token_id in input_ids if token_id != tokenizer.pad_token_id)
            all_lengths.append(length)

        print(f"Average length: {sum(all_lengths) / len(all_lengths):.2f} tokens")
        print(f"90th percentile: {sorted(all_lengths)[int(0.9 * len(all_lengths))]} tokens")
        print(f"Max length: {max(all_lengths)} tokens")

    assert len(dataset_en) == len(dataset_zh)
    assert len(dataset_en_test) == len(dataset_zh_test)

    model.eval()

    print("Reshuffle Data")
    seed = 50


    # clamps f_mod
    if F_MOD_MULT < 0.0:
        F_MOD_MULT = 0.0
    elif F_MOD_MULT > 1.0:
        F_MOD_MULT = 1.0

    # @todo currently shuffling data every epoch, should look into self derived curriculum learning
    # Shuffle the datasets using Hugging Face's shuffle (deterministic with seed)
    perm = np.random.RandomState(seed).permutation(len(dataset_en))
    dataset_en = dataset_en.select(perm)
    dataset_zh = dataset_zh.select(perm)

    collator_en = T5SpanCorruptionCollator(
        tokenizer, noise_density=0.15, mean_span_len=3, input_length=64
    )

    print("Set up Dataloaders")

    batch_size = 38
    train_en_dataloader = DataLoader(dataset_en, batch_size=batch_size, pin_memory=False, shuffle=False,
                                     num_workers=0, collate_fn=collator_en)  #
    train_zh_dataloader = DataLoader(dataset_zh, batch_size=batch_size, pin_memory=False, shuffle=False,
                                     num_workers=0)

    # find nearest neighbors
    alt_k = 10  # how many neighbours to forgive
    forg_idx = []
    forg_prob = []

    # print(forg_idx[0], forg_val[0])

    total_loss = 0.0
    total_tokens = 0.0
    total_base_loss = 0.0
    LEAK.num_batches = 0
    overall_bleu = 0.0
    bleu_batches = 0
    batch_count = 0 * batches_per_day
    next_batch = 10000
    acceptable_cut = 74.0

    total_bleu = 0.0
    total_bleu_batches = 0

    # ======================
    # Config
    # ======================
    TOP_K_TOKENS = 200  # change to 200 if you prefer your earlier test size
    BASELINE_BUDGET = 1_000_000  # non-pad tokens
    LAYERS_TO_TRACK = list(range(12))
    PER_TOKEN_SAMPLES = 8
    ANN_K = 32  # shortlist size when using k-NN (unused if use_radius=True)

    # ======================
    # 1) Baseline + top tokens from your datasets
    # ======================
    print("---- Build ~1M-token baseline from raw ds ----")
    sent_bank = make_text_baseline_from_raw(
        ds=ds,
        tokenizer=tokenizer,
        target_token_budget=BASELINE_BUDGET
    )
    print(f"Baseline sentences collected: {len(sent_bank)} (~{BASELINE_BUDGET:,} non-pad tokens)")

    print(f"---- Count top-{TOP_K_TOKENS} tokens from tokenized dataset_en (no corruption) ----")
    top_ids, id_freq = collect_top_tokens_from_dataset(
        dataset=dataset_en, tokenizer=tokenizer, max_nonpad_tokens=BASELINE_BUDGET
    )
    # keep only K
    top_ids = top_ids[:TOP_K_TOKENS]
    top_tokens = decode_token_ids(tokenizer, top_ids)

    print("Top 20 tokens (id -> token -> freq):")
    for tid in top_ids[:20]:
        tok = tokenizer.convert_ids_to_tokens([tid])[0]
        print(f"{tid:>8} -> {tok!r} -> {id_freq[tid]}")

    import json, os

    os.makedirs("artifacts", exist_ok=True)
    with open(f"artifacts/top_{TOP_K_TOKENS}_tokens.json", "w", encoding="utf-8") as f:
        json.dump({
            "token_ids": top_ids,
            "tokens": top_tokens,
            "frequencies": [id_freq[tid] for tid in top_ids]
        }, f, ensure_ascii=False, indent=2)
    print(f"Saved artifacts/top_{TOP_K_TOKENS}_tokens.json")

    target_token_strs = top_tokens
    layers_to_track = LAYERS_TO_TRACK

    # ======================
    # 2) Signatures
    # ======================
    print("---- Build baseline + token means + signatures (per_token_samples=8) ----")
    baseline, token_means, signatures = build_baseline_and_targets(
        model=model,
        tokenizer=tokenizer,
        sentences=sent_bank,  # baseline text bank
        target_token_strs=target_token_strs,  # top-K frequent tokens
        layers_to_track=layers_to_track,
        per_token_samples=PER_TOKEN_SAMPLES,  # << per your request
        max_baseline_positions=BASELINE_BUDGET,
        batch_size=40,
        device=device
    )


    # ======================
    # 3) Adapter: signatures -> bank_by_layer_input
    #    We support two common layouts:
    #    (A) signatures[L] = {"centers": [N,d], "powers": [N], "tokens": List[str]/List[int]}
    #    (B) signatures[L][token] = {"center": [d], "power": float} (dict keyed by token)
    # ======================



    bank_by_layer_input = signatures_to_bank_input(signatures, layers_to_track)

    # ======================
    # 4) Build fast banks (precompute norms + optional FAISS)
    # ======================
    banks = build_banks_by_layer(bank_by_layer_input, use_faiss=True, faiss_gpu=False)


    # pick a demo token (falls back to first in list)
    demo_token = target_token_strs[0]
    try:
        cand_centers_by_layer, cand_powers_by_layer = extract_candidate_from_signatures(
            signatures, demo_token, layers_to_track
        )
    except Exception as e:
        print(f"[warn] Could not extract demo token '{demo_token}': {e}")
        # As a fallback, synthesize a random normalized candidate from layer 0's dim:
        d = bank_by_layer_input[layers_to_track[f"D{0}"]]["centers"].shape[1]
        rnd = torch.randn(d);
        rnd = rnd / (rnd.norm() + 1e-12)
        cand_centers_by_layer = {L: rnd.clone() for L in layers_to_track}
        cand_powers_by_layer = {L: torch.tensor(1.0) for L in layers_to_track}

    # Score across layers (exact radius via FAISS → same results as full scan)
    print("---- Cross-layer match score for demo candidate ----")
    result = fast_score_across_layers(
        cand_centers_by_layer,
        cand_powers_by_layer,
        banks,
        layers=layers_to_track,
        dist_thresh=5.0,
        power_tol_frac=0.10,
        power_ref="center",
        ann_k=ANN_K,
        use_ann_if_available=True,
        use_radius=True
    )
    print(result["score"], f"{result['hits']}/{result['total_layers']} layers matched")

    # ======================
    # 6) GPU-fast batch example for one layer (optional)
    # ======================
    layer = 7
    print(f"---- Batched single-layer check (layer {layer}) ----")
    # Make a small batch by repeating the candidate 64 times
    cand_centers_batch = torch.stack([cand_centers_by_layer[layer] for _ in range(64)], dim=0)  # [B,d]
    cand_powers_batch = torch.stack([cand_powers_by_layer[layer] for _ in range(64)], dim=0)  # [B]
    B_out = fast_layer_hits_batch(
        cand_centers=cand_centers_batch,
        cand_powers=cand_powers_batch,
        bank=banks[layer],
        dist_thresh=5.0,
        power_tol_frac=0.10,
        power_ref="center",
    )
    print(f"Batch any_match: {B_out['any_match'].sum().item()}/{cand_centers_batch.shape[0]}")

    # Pretty print a small slice of the signature dict (human-friendly)
    #for tok, perL in list(signatures.items()):
    #    print(tok, perL)

    # ---- Quick interactive steering on a few hand-picked sentences ----
    demo_sents = [
        "3 + 1 = <extra_id_0>",
        "They saw <extra_id_0> birds near the lake.",
        "He stacked items in the receptacle and closed the lid.",
        "A small brown dog runs across the yard.",
        "The cat sleeps on the warm window sill.",
        "They saw 12 birds near the lake."
    ]



    model.eval()

    print(demo_sents[0])
    #tok_sent = tokenizer.encode(demo_sents[0])
    #print(tok_sent)
    #sent_ids = torch.tensor([tok_sent], dtype=torch.long, device=device)  # shape [1, T]

    inputs = tokenizer(demo_sents[0], return_tensors="pt").to(device)
    input_ids = inputs.input_ids.to(device)
    attention_mask = inputs.attention_mask.to(device)

    print(input_ids)
    print(attention_mask, "\n")

    out_ids = model.generate(input_ids=input_ids, attention_mask=attention_mask, output_scores=True)
    print(out_ids)
    decoded = []
    for o in out_ids:
        decoded.append(tokenizer.decode(o))

    print("\n--- Baseline (no steering) ---")

    print(decoded)
    #input()

    LEAK.num_batches += 1

    # print(loss, base_loss)

    # score performance
    if False and j == 0:
        with torch.no_grad():
            outputs = model.generate(input_ids=input_ids,
                                     attention_mask=attention_mask,
                                     max_length=64)

        truncated_outputs = []
        for gen_ids, ref_ids in zip(outputs, batch_zh["input_ids"]):
            ref_len = len(ref_ids)
            gen_len = len(gen_ids)
            clip_len = min(ref_len, gen_len)

            truncated_gen_ids = gen_ids[:clip_len]
            truncated_outputs.append(truncated_gen_ids)

        vocab_size = tokenizer.vocab_size
        safe_outputs = []
        for ids in truncated_outputs:
            safe_ids = [tok.item() for tok in ids if 0 <= tok.item() < vocab_size]
            safe_outputs.append(safe_ids)

        hypotheses = [
            tokenizer.decode(ids, skip_special_tokens=True) for ids in safe_outputs
        ]

        references = [
            tokenizer.decode(ref_ids, skip_special_tokens=True)
            for ref_ids in batch_zh["input_ids"]
        ]

        bleu_score = sacrebleu.corpus_bleu(hypotheses, [references]).score
        overall_bleu += bleu_score
        total_bleu += bleu_score
        total_bleu_batches += 1
        bleu_batches += 1

        val_ce = total_loss / total_tokens  # cross‑entropy
        val_ppl = math.exp(val_ce)

        if False and LEAK.num_batches % 250 == 0:
            bleu_s = overall_bleu / bleu_batches


            print(f"Batch {LEAK.num_batches} of {total_batches}")
            print(f"Loss: {total_loss / (LEAK.num_batches)}")
            print(f"Gold loss: {total_base_loss / (LEAK.num_batches)}")
            print(f"BLEU score: {overall_bleu / bleu_batches:.2f}")
            print(f"Perplexity: {val_ppl}")
            print(f"Overall: {total_bleu / total_bleu_batches:.2f}\n")

            overall_bleu = 0.0
            bleu_batches = 1.0

        # continue to next day
        batch_count += batches_per_day

        # Print average loss over the day
        avg_loss = 100.0 * total_loss / batch_count
        print(
            f"Day {batch_count / batches_per_day}/{total_batches / batches_per_day} - avg train loss: {avg_loss:.4f}\n")

