import os
from transformers.activations import ACT2FN

os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

import numpy as np
import random
import torch.nn as nn
import torch.optim as optim
from transformers import MT5ForConditionalGeneration, AutoTokenizer, MT5Tokenizer, AutoModelForSeq2SeqLM, AutoModel
from torch.optim import AdamW
import sacrebleu
from torch.utils.data import DataLoader
from datasets import load_dataset, load_from_disk
from functools import partial
from transformers import DataCollatorWithPadding
from torch.optim.lr_scheduler import StepLR
from torch.cuda.amp import autocast, GradScaler
from tqdm import tqdm
from torch.nn import functional as F
from functools import lru_cache
from dataclasses import dataclass, field
from itertools import groupby  # ← add this
import matplotlib.pyplot as plt
from datasets import load_dataset, interleave_datasets

import torch, gc, random
import math
import re
import torch.fft as fft
from itertools import groupby
from torch.optim.lr_scheduler import LambdaLR


# from mbert_pretraining import BATCH_SIZE

# print("PyTorch version:", torch.__version__)
# print("CUDA available:", torch.cuda.is_available())
# print("CUDA device count:", torch.cuda.device_count())
# print("CUDA version:", torch.version.cuda)


# corruption collator

from project_classes import T5SpanCorruptionCollator, CustomDenseReluDense, ForgetfulT5, DownBandAccumulator, BandSteeringController

from project_classes import load_english, tokenize_fn, lr_lambda, LEAK, build_baseline_and_targets

from project_classes import build_two_letter_code_map_from_model, token_to_two_letter_code

if __name__ == "__main__":
    from transformers import MT5Config, MT5ForConditionalGeneration

    model_save = "mt5_base_forgive_and_forget_ft"
    # model_source = "stage1_base_step120000"
    model_source = "google/mt5-base"

    DEVICE = 'cuda'


    # Different models for different tests
    state_dict_source = "mt5_base_pretuned.pt"  # forgiveness pretune
    #state_dict_source = "mt5_base_forgive_and_forget_whole_stream6.pt"  # end of forgive and forget
    #state_dict_source = "mt5_base_standard_FaF_11x_noise.pt"  # full stack normal model

    state_dict_save = "mt5_base_standard_FaF_11x_noise_two_let_retrain2"  # full stack 2let model

    # load model
    tokenizer = AutoTokenizer.from_pretrained(model_source,
                                              use_fast=False,  # keep full SentencePiece behaviour
                                              legacy=False)

    custom = False
    if custom:
        cfg = MT5Config(
            vocab_size=tokenizer.vocab_size,  # keep full vocab
            d_model=128,  # or 128 / 256
            d_ff=512,  # 4 × d_model is enough
            num_layers=6,  # encoder
            num_decoder_layers=6,  # decoder
            num_heads=4,  # keep d_model % num_heads == 0
            dropout_rate=0.1,

            # ── the three IDs MT5 needs ─────────────────────────────────
            pad_token_id=tokenizer.pad_token_id,
            eos_token_id=tokenizer.eos_token_id,
            decoder_start_token_id=tokenizer.pad_token_id,
        )

        base_model = MT5ForConditionalGeneration(cfg)
    else:
        base_model = MT5ForConditionalGeneration.from_pretrained(model_source)

    device = "cuda"
    base_model = base_model.to(device)

    V = base_model.config.vocab_size  # this is what logits will use

    two_letter_code = build_two_letter_code_map_from_model(tokenizer, V).to(DEVICE)

    print("Checkpoint reloaded!")
    # model = ForgetfulT5(base_model.config)

    # Initialize your custom (ModifiedT5) model.
    # @todo should both models be on gpu or do i need to merge them?
    model = ForgetfulT5(base_model)
    model = model.to(device)

    print(base_model.num_parameters() / 1e6, "M params")  # 30 M for 3+3 @ d=256

    # freeze non relevant layers
    # freeze_for_active_layer(model.base_model, LEAK.active_layer)
    # freeze_for_active_layer(model.base_model, LEAK.active_layer-1)

    # Path for the custom model state dict.
    custom_state_dict_path = state_dict_source

    reload = True

    if reload:
        # If a custom state dict exists, load it (with strict=False so new parameters are left untouched).
        if os.path.exists(custom_state_dict_path):
            print(f"Loading existing custom state dict from {custom_state_dict_path}...")
            state_dict = torch.load(custom_state_dict_path, map_location="cuda")
            missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
            print("Missing keys:", missing_keys)
            print("Unexpected keys:", unexpected_keys)
        else:
            print("No existing custom state dict found; proceeding with freshly initialized custom model.")
    else:
        # randomly overwrites weights
        fully_randomize(model, seed=50, std=0.2)  # start seed 44

    print("Checkpoint reloaded!")

    # visualize weights
    # model.visualize_ffn(23, which="both")
    # model.visualize_ffn(22, which="both")

    # state_dict = torch.load(your_state_dict_path, map_location="cpu")
    # missing_keys, unexpected_keys = base_model.load_state_dict(state_dict, strict=False)

    train_samples = 4000  # 11900000 # 13016
    val_samples = 10000
    start = 6000000

    from datasets import Dataset

    # Load WMT19 training set
    ds = load_english()#Dataset.from_file("wiki_en_topk_10/data.arrow")#
    ds_test = ds.select(range(start + train_samples, val_samples + train_samples + start))
    ds = ds.select(range(start, start + train_samples))

    tokenize_fnn = partial(tokenize_fn, tokenizer=tokenizer, train_english=True)

    # 2) Create a DataLoader for mini-batching
    dataset_en = ds.map(tokenize_fnn, batched=True, num_proc=6)  # , remove_columns=["translation"])
    dataset_en_test = ds_test.map(tokenize_fnn, batched=True, num_proc=6)

    # tokenize_fnnn = partial(tokenize_fn, tokenizer=tokenizer, train_english=False)
    tokenize_fnnn = partial(tokenize_fn, tokenizer=tokenizer, train_english=False)

    # 2) Create a DataLoader for mini-batching
    dataset_zh = ds.map(tokenize_fnnn, batched=True, num_proc=6)
    dataset_zh_test = ds_test.map(tokenize_fnnn, batched=True, num_proc=6)

    # 3) Convert to torch format (and select columns to keep)
    # dataset.set_format("torch", columns=["input_ids", "attention_mask", "labels"])

    print("Language Files loaded")

    dataset_en.set_format(
        type="torch",
        columns=["input_ids", "attention_mask", "labels"]
    )
    dataset_en_test.set_format(
        type="torch",
        columns=["input_ids", "attention_mask", "labels"]
    )

    dataset_zh.set_format(
        type="torch",
        columns=["input_ids", "attention_mask", "labels"]
    )
    dataset_zh_test.set_format(
        type="torch",
        columns=["input_ids", "attention_mask", "labels"]
    )

    # collator = DataCollatorWithPadding(tokenizer=tokenizer, return_tensors="pt")

    # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    # model = base_model.to("cuda")

    # Create an optimizer
    # @todo remove weight decay
    # adds forgetfulness

    # optimizer = AdamW(groups, lr=3e-4)
    optimizer = AdamW(model.parameters(), lr=3e-4)
    scheduler = LambdaLR(optimizer, lr_lambda)

    model.train()  # put in training mode
    epochs = 20
    batches_per_day = 100
    total_batches = 100  # 1200

    check_length = False

    # vars for sampling
    tau_start = 0.95
    tau_end = 0.6

    # vars for prediction forgiveness
    smoothing_start = 0.9
    smoothing_end = 0.15
    top_k = 10
    F_MOD = 1.5

    # handles forgiveness scheduling
    warm_fmod = 50  # 20
    cool_fmod = 5  # 50
    start_forgiveness_ep = 0
    end_forgiveness_ep = 25  # 130

    # toggle for magnetism to allow model to recover at times
    mag_check = False
    forget_now = False

    F_MOD_MULT = 1.0  # slowly scales to 0 instead of modifying f_mod outright
    enable_f_mod = True
    fmod_dec = 1.0 / cool_fmod
    fmod_add = 1.0 / warm_fmod
    f_mod_dec_const = F_MOD / cool_fmod

    if check_length:
        all_lengths = []

        for example in tqdm(dataset_en):
            input_ids = example['input_ids']
            length = sum(1 for token_id in input_ids if token_id != tokenizer.pad_token_id)
            all_lengths.append(length)

        print(f"Average length: {sum(all_lengths) / len(all_lengths):.2f} tokens")
        print(f"90th percentile: {sorted(all_lengths)[int(0.9 * len(all_lengths))]} tokens")
        print(f"Max length: {max(all_lengths)} tokens")

    assert len(dataset_en) == len(dataset_zh)
    assert len(dataset_en_test) == len(dataset_zh_test)


    print(f"Training for {epochs} epochs")
    for cur_test in range(3):
        noise_inc = 0.1
        blur_inc = 0.05
        recall_inc = 0.25

        # reset params
        LEAK.noise_strength = 0.0
        LEAK.max_variance = 0.0
        LEAK.recall_alpha[0] = 0.0
        LEAK.recall_alpha[1] = 0.0
        LEAK.recall_alpha[2] = 0.0
        LEAK.recall_alpha[3] = 0.0

        if cur_test == 2:
            print("\nTest Recall")
            LEAK.apply_blur = False
            LEAK.apply_noise = False
            LEAK.apply_recall = True
        elif cur_test == 0:
            print("\nTest Noise")
            LEAK.apply_blur = False
            LEAK.apply_noise = True
            LEAK.apply_recall = False
        elif cur_test == 1:
            print("\nTest Blur")
            LEAK.apply_blur = True
            LEAK.apply_noise = False
            LEAK.apply_recall = False

        for epoch in range(epochs):
            # Compute smoothing decay
            epoch_frac = epoch / epochs



            print("Reshuffle Data")
            seed = 50 + epoch

            if epoch >= 0:
                print("Scheduler Step")
                # scheduler.step()

            # once fg epoch is hit begin to warm
            if True:
                if start_forgiveness_ep <= epoch < start_forgiveness_ep + warm_fmod:
                    F_MOD_MULT += fmod_add
                elif epoch > end_forgiveness_ep:
                    F_MOD_MULT -= fmod_dec
                elif epoch > end_forgiveness_ep + cool_fmod:
                    F_MOD_MULT = 0.0
            else:
                if epoch < 10:
                    F_MOD_MULT -= fmod_dec
                else:
                    F_MOD_MULT = 0.0

            # clamps f_mod
            if F_MOD_MULT < 0.0:
                F_MOD_MULT = 0.0
            elif F_MOD_MULT > 1.0:
                F_MOD_MULT = 1.0

            # @todo currently shuffling data every epoch, should look into self derived curriculum learning
            # Shuffle the datasets using Hugging Face's shuffle (deterministic with seed)
            perm = np.random.RandomState(seed).permutation(len(dataset_en))
            dataset_en = dataset_en.select(perm)
            dataset_zh = dataset_zh.select(perm)

            collator_en = T5SpanCorruptionCollator(
                tokenizer, noise_density=0.15, mean_span_len=3, input_length=64
            )

            print("Set up Dataloaders")

            batch_size = 40
            train_en_dataloader = DataLoader(dataset_en, batch_size=batch_size, pin_memory=False, shuffle=False,
                                             num_workers=0, collate_fn=collator_en)  #
            train_zh_dataloader = DataLoader(dataset_zh, batch_size=batch_size, pin_memory=False, shuffle=False,
                                             num_workers=0)

            if epoch == 0:
                print("Shuffle test data")
                dataset_en_test = dataset_en_test.shuffle(seed=seed)
                dataset_zh_test = dataset_zh_test.shuffle(seed=seed)
                test_en_dataloader = DataLoader(dataset_en_test, batch_size=batch_size, pin_memory=False, shuffle=False)
                test_zh_dataloader = DataLoader(dataset_zh_test, batch_size=batch_size, pin_memory=False, shuffle=False)

            # find nearest neighbors
            alt_k = 10  # how many neighbours to forgive
            forg_idx = []
            forg_prob = []

            # print(forg_idx[0], forg_val[0])

            total_loss = 0.0
            total_tokens = 0.0
            total_base_loss = 0.0
            num_batches = 0
            overall_bleu = 0.0
            bleu_batches = 0
            batch_count = 0 * batches_per_day
            next_batch = 10000
            acceptable_cut = 74.0

            total_bleu = 0.0
            total_bleu_batches = 0

            if epoch != 0:
                batch_count = 0

            # loops through whole year
            overall_bleu = 0.0
            # total_loss = 0.0
            bleu_batches = 0

            accumulation_steps = 3

            scaler = GradScaler()

            print("Begin Test")
            print("Blur: ", str(LEAK.recall_alpha[0]))
            print("Noise: ", str(LEAK.noise_strength))
            print("Recall: ", str(LEAK.max_variance))


            # iterate all batches
            for i, (batch_en, batch_zh) in enumerate(zip(train_en_dataloader, train_zh_dataloader)):

                # print(batch_en)
                # print(len(batch_zh["k_alt_ids"]))
                # print(len(batch_zh["k_alt_ids"][0]))
                # print(len(batch_zh["k_alt_ids"][0][0]))

                debug_alts = False

                if debug_alts:
                    print(tokenizer.decode(batch_zh["input_ids"][0][2]), '\n',
                          tokenizer.decode(batch_zh["k_alt_ids"][0][2][0]),
                          '\n', tokenizer.decode(batch_zh["k_alt_ids"][0][2][1]),
                          '\n', tokenizer.decode(batch_zh["k_alt_ids"][0][2][2]),
                          '\n', tokenizer.decode(batch_zh["k_alt_ids"][0][2][3]),
                          '\n', tokenizer.decode(batch_zh["k_alt_ids"][0][2][4]),
                          '\n', tokenizer.decode(batch_zh["k_alt_ids"][0][2][5]),
                          '\n', tokenizer.decode(batch_zh["k_alt_ids"][0][2][6]),
                          '\n', tokenizer.decode(batch_zh["k_alt_ids"][0][2][7]),
                          '\n\n')

                # Move inputs/labels to GPU
                input_ids = batch_en["input_ids"].to(device)
                attention_mask = batch_en["attention_mask"].to(device)
                labels = batch_zh["input_ids"].to(device)
                torch.cuda.reset_peak_memory_stats()

                with autocast(enabled=True, dtype=torch.bfloat16):
                    out = model(
                        input_ids=input_ids,
                        attention_mask=attention_mask,
                        labels=labels,  # built-in CE loss
                        use_cache=False
                    )

                # logits -> probs in fp32 for numerical stability
                logits = out.logits.float()  # (B, T, V)
                log_probs = torch.log_softmax(logits, dim=-1)  # (B, T, V)
                probs = log_probs.exp()

                IGNORE = -100
                non_pad = (labels != IGNORE)
                labels_safe = labels.masked_fill(~non_pad, 0)  # for gathers

                # gold
                gold_logp = log_probs.gather(-1, labels_safe.unsqueeze(-1)).squeeze(-1)  # (B, T)
                p_gold = gold_logp.exp()  # (B, T)

                # ---- Top-K tokens & two-letter forgiveness ---------------------------------
                K = 32
                top_probs, top_ids = probs.topk(K, dim=-1)  # (B, T, K)

                gold_code = two_letter_code[labels_safe]  # (B, T)
                top_codes = two_letter_code[top_ids]  # (B, T, K)

                same_code = (top_codes == gold_code.unsqueeze(-1))
                has_code = (top_codes != -1)
                not_gold = (top_ids != labels_safe.unsqueeze(-1))
                forgive_mask = same_code & has_code & not_gold & non_pad.unsqueeze(-1)  # (B, T, K)

                p_sim_sum = (top_probs * forgive_mask.float()).sum(dim=-1)  # (B, T)

                # effective probability = gold + sum of same-code neighbors
                eps = 1e-8
                p_eff = (p_gold + 0.5 * p_sim_sum).clamp_(eps, 1.0 - eps)

                # losses (masked)
                token_losses = -torch.log(p_eff) * non_pad  # (B, T)
                base_token_losses = -gold_logp * non_pad  # (B, T)

                # reduce to per-batch numerators & denom
                denom = non_pad.sum().clamp_min(1)  # scalar
                batch_nll = token_losses.sum()  # scalar
                batch_base_nll = base_token_losses.sum()  # scalar

                # optimization step (keep your grad accumulation)
                (loss := (batch_nll / denom) / accumulation_steps).backward()
                if (num_batches + 1) % accumulation_steps == 50:
                    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                    optimizer.step()
                    scheduler.step()
                    optimizer.zero_grad(set_to_none=True)

                # --- running totals: accumulate sums, divide once when reporting ---
                total_tokens += int(denom.item())
                total_loss += float(batch_nll.item())
                total_base_loss += float(batch_base_nll.item())
                num_batches += 1

                val_ce = total_loss / max(1, total_tokens)
                val_ppl = math.exp(val_ce)

                # print(loss, base_loss)

                # score bleu performance
                with torch.no_grad():
                    outputs = model.generate(input_ids=input_ids,
                                             attention_mask=attention_mask,
                                             max_length=64)

                truncated_outputs = []
                for gen_ids, ref_ids in zip(outputs, batch_zh["input_ids"]):
                    ref_len = len(ref_ids)
                    gen_len = len(gen_ids)
                    clip_len = min(ref_len, gen_len)

                    truncated_gen_ids = gen_ids[:clip_len]
                    truncated_outputs.append(truncated_gen_ids)

                vocab_size = tokenizer.vocab_size
                safe_outputs = []
                for ids in truncated_outputs:
                    safe_ids = [tok.item() for tok in ids if 0 <= tok.item() < vocab_size]
                    safe_outputs.append(safe_ids)

                hypotheses = [
                    tokenizer.decode(ids, skip_special_tokens=True) for ids in safe_outputs
                ]

                references = [
                    tokenizer.decode(ref_ids, skip_special_tokens=True)
                    for ref_ids in batch_zh["input_ids"]
                ]

                bleu_score = sacrebleu.corpus_bleu(hypotheses, [references]).score
                overall_bleu += bleu_score
                total_bleu += bleu_score
                total_bleu_batches += 1
                bleu_batches += 1

                val_ce = total_loss / total_tokens  # cross‑entropy
                val_ppl = math.exp(val_ce)

                if num_batches % 100 == 0:
                    bleu_s = overall_bleu / bleu_batches

                    print(f"Batch {num_batches} of {total_batches}")
                    print(f"Loss: {total_loss / (bleu_batches)}")
                    print(f"Gold loss: {total_base_loss / (bleu_batches)}")
                    print(f"BLEU score: {overall_bleu / bleu_batches:.2f}")
                    print(f"Perplexity: {val_ppl}")
                    print(f"Overall: {total_bleu / total_bleu_batches:.2f}\n")

                    overall_bleu = 0.0
                    bleu_batches = 1.0
                    total_base_loss = 0.0
                    total_loss = 0.0

                    if F_MOD_MULT < 1.0:
                        F_MOD_MULT += fmod_add

                    tau_start = 0  # 10000
                    if num_batches > tau_start:
                        if LEAK.tau_mod > 1:
                            LEAK.tau_mod -= 1
                            print("TAU", str(LEAK.tau_mod))

                    if LEAK.tau_mod < 1.0:
                        LEAK.tau_mod = 1

                # enables forgetfulness if not in recovery phase
                if num_batches % LEAK.tau == 0:
                    bleu_s = overall_bleu / bleu_batches

                    if (epoch > 2 or num_batches >= 1000) and forget_now:  # and bleu_s > (acceptable_cut - 2.0):
                        # cycles forgetfulness
                        if num_batches % 5000 < 1250:
                            for i in range(len(LEAK.safe_to_forget)):
                                if i % 4 == 0:
                                    LEAK.safe_to_forget[i] = True
                        elif num_batches % 5000 < 2500:
                            for i in range(len(LEAK.safe_to_forget)):
                                if i % 4 == 1:
                                    LEAK.safe_to_forget[i] = True
                        elif num_batches % 5000 < 3750:
                            for i in range(len(LEAK.safe_to_forget)):
                                if i % 4 == 2:
                                    LEAK.safe_to_forget[i] = True
                        else:
                            for i in range(len(LEAK.safe_to_forget)):
                                if i % 4 == 3:
                                    LEAK.safe_to_forget[i] = True


            # continue to next day
            batch_count += batches_per_day

            # adjust blur for robustness tests
            if LEAK.apply_blur:
                LEAK.recall_alpha[0] += blur_inc
                LEAK.recall_alpha[1] += blur_inc
                LEAK.recall_alpha[2] += blur_inc
                LEAK.recall_alpha[3] += blur_inc

            if LEAK.apply_noise:
                LEAK.noise_strength += noise_inc

            if LEAK.apply_recall:
                LEAK.max_variance += recall_inc

            # Print average loss over the day
            avg_loss = 100.0 * total_loss / batch_count
            print(
                f"Day {batch_count / batches_per_day}/{total_batches / batches_per_day} - avg train loss: {avg_loss:.4f}\n")

            # Print average loss over the epoch
            avg_loss = total_loss / (batches_per_day * batch_size)
            print(f"Epoch {epoch + 1}/{epochs} - avg train loss: {avg_loss:.4f}\n")