import os
import pickle
import math

import torch
from datasets import load_dataset
from transformers import (
    AutoTokenizer, AutoModelForCausalLM,
)
from peft import LoraConfig, get_peft_model
import evaluate

from tqdm import tqdm
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.optim as optim
import torch.multiprocessing as mp
from torch.utils.data import DataLoader, DistributedSampler

from torch.nn.parallel import DistributedDataParallel as DDP
from transformers import get_linear_schedule_with_warmup
import numpy as np

from opacus.accountants.prv import PRVAccountant
from opacus.accountants.utils import get_noise_multiplier
from opacus import PrivacyEngine
from opacus.distributed import DifferentiallyPrivateDistributedDataParallel as DPDDP

from opacus.utils.batch_memory_manager import BatchMemoryManager

from math import sqrt
from torch.nn.functional import cosine_similarity

from nltk.translate.nist_score import  corpus_nist
from pycocoevalcap.cider.cider import Cider

from transformers.utils.logging import set_verbosity_error
set_verbosity_error()   # removes ALL warnings from HF

import wandb
wandb.login()

bleu = evaluate.load("bleu")
rouge = evaluate.load("rouge")
meteor = evaluate.load("meteor")
bertscore = evaluate.load("bertscore")
cider_scorer = Cider()

def compute_metrics(val_refs, preds):
    bleu_result = bleu.compute(predictions=preds, references=val_refs)

    rouge_result = rouge.compute(predictions=preds, references=val_refs)

    meteor_result = meteor.compute(predictions=preds, references=val_refs)

    bertscore_result = bertscore.compute(predictions=preds, references=val_refs, lang="en", model_type="roberta-large")
    bertscore_f1_mean = np.mean(bertscore_result['f1'])
    
    cider_score, _ = cider_scorer.compute_score({i: val_refs[i] for i in range(len(val_refs))}, {i: [preds[i]] for i in range(len(preds))})
    
    nist_score = corpus_nist([[r.split() for r in refs] for refs in val_refs], [x.split() for x in preds], n=4)
    
    return {
        "bleu": bleu_result["bleu"],
        "rouge": rouge_result["rougeL"].item(),
        "meteor": meteor_result["meteor"].item(),
        "bertscore_f1_mean": bertscore_f1_mean.item(),
        "cider": cider_score.item(),
        "nist": nist_score
    }

def setup():
    rank = int(os.environ["RANK"])
    world_size = int(os.environ["WORLD_SIZE"])
    local_rank = int(os.environ["LOCAL_RANK"])
    torch.cuda.set_device(local_rank)

    dist.init_process_group(
        backend="nccl",
        init_method="env://",
        world_size=world_size,
        rank=rank
    )
    return local_rank, rank, world_size

def cleanup():
    dist.destroy_process_group()

def gather_ordered_predictions(local_preds, rank=None, world_size=None):
    """
    Gather variable-length lists of predictions from all ranks
    and return them in global order (rank 0 only).
    """
    print(f"Rank {rank} gathering {len(local_preds)} predictions...")
    # print(local_preds[0])  # Print first two predictions for debugging
    # print(local_preds[0][0])
    # Each rank contributes its contiguous shard
    os.makedirs("pickle", exist_ok=True)
    fname = os.path.join("pickle", f"preds_rank_{rank}.pkl")
    with open(fname, "wb") as f:
        pickle.dump(local_preds, f)
    dist.barrier()
    if rank == 0:
        gathered = []
        
        for i in range(world_size):
            fname = os.path.join("pickle", f"preds_rank_{i}.pkl")
            with open(fname, "rb") as f:
                preds_part = pickle.load(f)
            gathered.append(preds_part)
        
        merged = []
        
        for i in range(len(gathered[0])):
            for r in range(world_size):
                merged.extend(gathered[r][i] if i < len(gathered[r]) else [])

        return merged
    else:
        return None


def find_largest_positive_lambda(target, n, m, sigma):
    """
    Finds the largest positive lambda such that:
        sqrt((lambda + (n*sigma^2)/lambda) * (lambda + (m*sigma^2)/lambda)) = target

    Returns 0 if there is no positive solution.
    """
    ratio = target / (sigma * (sqrt(m) + sqrt(n)))
    if ratio <= 1:
        return 0.0, ratio


    a = 1
    b = (n + m) * sigma**2 - target**2 # less than 0
    c = n * m * sigma**4

    # Solve quadratic in x^2: a*z^2 + b*z + c = 0, z = x^2
    discriminant = b**2 - 4*a*c
    if discriminant < 0:
        return 0, ratio

    sqrt_disc = sqrt(discriminant)
    z1 = (-b + sqrt_disc) / (2*a)

    # We want the largest positive lambda, so take the largest positive root for x^2

    return (sqrt(z1) if z1 > 0 else 0), ratio

def component_sim(lambda_hat, n, m, sigma):
    """
    Computes the component similarity for the given lambda_hat, n, m, and sigma.
    """
    return (lambda_hat**4 - n * m * sigma**4) / (lambda_hat**2 * sqrt((lambda_hat**2 + n * sigma**2) * (lambda_hat**2 + m * sigma**2)))


def svd_shrinkage(noisy, sigma):
    """
    Applies SVD shrinkage to the noisy matrix. Returns the denoised matrix, the similarity score, and the rank of the denoised matrix.
    The similarity score will be used to scale the denoised matrix.
    """
    U, S, V = torch.svd(noisy)
    # m, n = noisy.shape
    n, m = noisy.shape
    lambda_inverse_list = []
    similarities = []
    ratios = []
    for s in S:
        new_s, ratio = find_largest_positive_lambda(s, n, m, sigma)
        
        if new_s > 0:
            component_similarity = component_sim(new_s, n, m, sigma)

            lambda_inverse_list.append(new_s)
            similarities.append(component_similarity)
            ratios.append(ratio)
        # If new_s is 0 or negative, we stop adding components
        # because the rest will also be 0.
        else:
            break
    k = len(lambda_inverse_list)
    if k == 0:
        return torch.zeros_like(noisy), k, 1.0
    new_singular_tensor = torch.tensor([ss * sim for ss, sim in zip(lambda_inverse_list, similarities)], device=noisy.device)
    biggest_ratio = max(ratios)
    return U[:, :k] @ torch.diag(new_singular_tensor) @ V[:, :k].t(), k, biggest_ratio

def sim(x, y):
    return cosine_similarity(x.flatten(), y.flatten(), dim=0)

def increment_global_step(step_tensor, step_increment=1):
    local_increment = torch.tensor([step_increment], device=step_tensor.device, dtype=torch.long)
    dist.all_reduce(local_increment, op=dist.ReduceOp.SUM)
    step_tensor += local_increment
    return step_tensor


MODEL_ID = "meta-llama/Llama-3.2-1B"   
OUTPUT_DIR = "./llama32_1b_dart_denoised"

DATASET_ID = "GEM/dart"

MAX_INPUT_TOKENS = 124   # MR prompt tokens
MAX_LABEL_TOKENS = 512   # Target text tokens
MAX_SEQ_LEN = MAX_INPUT_TOKENS + MAX_LABEL_TOKENS

# Training hyperparameters
# EPOCHS = 1
STEP_LIMIT = 400


PER_DEVICE_TRAIN_BATCH = 16
GRADIENT_ACCUMULATION_STEPS = 4
LOGICAL_BATCH_SIZE = PER_DEVICE_TRAIN_BATCH * GRADIENT_ACCUMULATION_STEPS
LR = 2e-4
PER_DEVICE_EVAL_BATCH = 16
WEIGHT_DECAY = 0.01
LABEL_SMOOTHING_FACTOR=0.0
LOG_STEPS = 8 * 6 
SAVE_STEPS = 500

# LoRA Config
LORA_R = 32
LORA_ALPHA = LORA_R * 2
LORA_DROPOUT = 0.1
TARGET_MODULES = ["q_proj", "v_proj"]


EPSILON = 5.4
DELTA = 1e-5

KAPPA = 1.02



import os
PROMPT_PREFIX = "MR: "
TARGET_PREFIX = "\nText: "

def build_prompt_only(example):
    mr = example
    return f"{PROMPT_PREFIX}{mr}{TARGET_PREFIX}"  


def launch():
    local_rank, rank, world_size = setup()
    if rank == 0:
        logger = wandb.init(entity="dadsetan", project="svd-dp", tags=["script", 'iclr-abstract'],config={
            "dataset_name": "e2e",
            "step_limit": STEP_LIMIT,
            "model_name": MODEL_ID,
            "denoising": True,
        })
    global_step = torch.tensor([0], dtype=torch.long).to(local_rank)
    print(f"Rank {rank}/{world_size} initialized on local rank {local_rank}.")
    tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_ID,
    )  
    peft_config = LoraConfig(
        r=LORA_R,
        lora_alpha=LORA_ALPHA,
        lora_dropout=LORA_DROPOUT,
        inference_mode=False,
        task_type="CAUSAL_LM",
        target_modules=TARGET_MODULES,
    )
    model = get_peft_model(model, peft_config)
        


    raw = load_dataset(DATASET_ID, trust_remote_code=True)
    SAMPLE_RATE = LOGICAL_BATCH_SIZE / len(raw["train"])



    tokenizer.padding_side = "right"  

    def tokenize_function(examples):
        input_ids = []
        labels = []
        attention_masks = []

        for mr, tgt in zip(examples["tripleset"], examples["target"]):
            tgt_eos = tgt + tokenizer.eos_token
            text = f"{PROMPT_PREFIX}{mr}{TARGET_PREFIX}{tgt_eos}"
            tok = tokenizer(
                text,
                truncation=True,
                max_length=MAX_SEQ_LEN,
                padding=False,
            )
            # Mask the MR part
            split_point = len(tokenizer(f"{PROMPT_PREFIX}{mr}{TARGET_PREFIX}")["input_ids"])
            label_ids = tok["input_ids"].copy()
            label_ids[:split_point] = [-100] * split_point

            input_ids.append(tok["input_ids"])
            attention_masks.append(tok["attention_mask"])
            labels.append(label_ids)

        return {
            "input_ids": input_ids,
            "attention_mask": attention_masks,
            "labels": labels,
        }


    tokenized_train = raw["train"].map(tokenize_function, batched=True, remove_columns=raw["train"].column_names)
    tokenized_train.set_format(type='torch')

    from torch.nn.utils.rnn import pad_sequence

    def clm_collate(batch):
        input_ids = [x["input_ids"] for x in batch]
        labels = [x["labels"] for x in batch]
        attn = [x["attention_mask"] for x in batch]

        input_ids = pad_sequence(input_ids, batch_first=True, padding_value=tokenizer.pad_token_id)
        labels = pad_sequence(labels, batch_first=True, padding_value=-100)
        attn = pad_sequence(attn, batch_first=True, padding_value=0)

        return {"input_ids": input_ids, "labels": labels, "attention_mask": attn}



    train_dl = DataLoader(tokenized_train, batch_size=LOGICAL_BATCH_SIZE, shuffle=False, collate_fn=clm_collate, drop_last=False)
    optimizer = optim.AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)


    model.to(local_rank)
    ddp_model = DPDDP(model)
    ddp_model.to(local_rank)
    ddp_model.train()
    
    loss_fn = nn.CrossEntropyLoss(ignore_index=-100, label_smoothing=LABEL_SMOOTHING_FACTOR, reduction='sum')

    noise_multiplier=get_noise_multiplier(
        target_epsilon = EPSILON,
        target_delta = DELTA,
        sample_rate = SAMPLE_RATE,
        steps = STEP_LIMIT,
        accountant='prv'
    )
    # noise_multiplier = 0.5
    print(f"Noise multiplier: {noise_multiplier}")
    
    privacy_engine = PrivacyEngine()
    ddp_model, optimizer, train_dl = privacy_engine.make_private(
        module=ddp_model,
        optimizer=optimizer,
        data_loader=train_dl,
        criterion=loss_fn,
        noise_multiplier=noise_multiplier,
        max_grad_norm=1.0,
        loss_reduction="sum",
        poisson_sampling=False,
        # grad_sample_mode="ghost_fsdp",
        # grad_sample_mode="ghost",
    )
    
    def eval(model, step):
        model.eval()
        # Evaluation
        GEN_MAX_NEW_TOKENS = 100
        GEN_LENGTH_PENALTY = 0.9   # <1 favors shorter, >1 favors longer
        GEN_NO_REPEAT_NGRAM = 4
        GEN_NUM_BEAMS = 10 # beam search

        def generate_batch(prompts, model, tokenizer):
            device = model.device
            inputs = tokenizer(
                prompts,
                return_tensors="pt",
                padding=True,
                truncation=True,
                max_length=MAX_INPUT_TOKENS,
                padding_side="left",
            ).to(device)
            with torch.no_grad():
                outputs = model.generate(
                    **inputs,
                    max_new_tokens=GEN_MAX_NEW_TOKENS,
                    length_penalty=GEN_LENGTH_PENALTY,
                    no_repeat_ngram_size=GEN_NO_REPEAT_NGRAM,
                    num_beams=GEN_NUM_BEAMS,
                    do_sample=False,
                    pad_token_id=tokenizer.pad_token_id,
                    eos_token_id=tokenizer.eos_token_id,
                    early_stopping=True,
                )
            decoded = tokenizer.batch_decode(outputs, skip_special_tokens=True)
            del outputs
            cleaned = []
            for prompt, full in zip(prompts, decoded):
                if full.startswith(prompt):
                    cont = full[len(prompt):]
                else:
                    idx = full.rfind(prompt)
                    cont = full[idx+len(prompt):] if idx != -1 else full
                idx_end = cont.find(tokenizer.eos_token)
                if idx_end != -1:
                    cont = cont[:idx_end]
                cleaned.append(cont.strip())
            return cleaned

        val_prompts = [build_prompt_only(e) for e in raw["validation"]["tripleset"]]
        val_refs = raw["validation"]["references"]
        # val_prompt_loader = DataLoader(val_prompts, batch_size=16, collate_fn=lambda x: x)
        



        
        BATCH = 16
        preds_batch_list = []
        for i in tqdm(range(0, len(val_prompts), BATCH * world_size), disable=(rank != 0)):
            if i + rank * BATCH >= len(val_prompts):
                preds_batch_list.append([])
                continue
            batch_prompts = val_prompts[i + rank * BATCH : i + (rank + 1) * BATCH]
            preds_batch = generate_batch(batch_prompts, model, tokenizer)
            preds_batch_list.append(preds_batch)
        
        merged_preds = gather_ordered_predictions(preds_batch_list, rank, world_size)
        if rank == 0:
            result = compute_metrics(val_refs, merged_preds)
            logger.log({"step": step, **result})
            print("Validation Metrics:", result)
            def generate_from_mr(mr: str):
                prompt = f"{PROMPT_PREFIX}{mr}{TARGET_PREFIX}"
                return generate_batch([prompt], model, tokenizer)[0]

            example_mr = '[ [ "2013", "REGULAR_SEASON", "4th, Western" ], [ "2013", "LEAGUE", "USL W-League" ], [ "[TABLECONTEXT]", "[TITLE]", "Colorado Rapids Women" ], [ "[TABLECONTEXT]", "YEAR", "2013" ], [ "USL W-League", "DIVISION", "1" ] ]'
            print("MR:", example_mr)
            print("Generated:", generate_from_mr(example_mr))
            
        dist.barrier()
        model.train()
        
    class StepLimitReached(Exception):
        pass
    
    try:
        while True:
            with BatchMemoryManager(
                data_loader=train_dl,
                max_physical_batch_size=PER_DEVICE_TRAIN_BATCH,
                optimizer=optimizer
            ) as memory_safe_data_loader:
                model.train()
                epoch_loss = 0.0
                processed_samples = 0
                progress_bar = tqdm(memory_safe_data_loader, desc=f"", disable=(dist.get_rank() != 0))
                step = 0
                batch_tokens = 0
                for step, batch in enumerate(progress_bar):
                    batch = {k: v.to(local_rank) for k, v in batch.items()}
                    logits = ddp_model(input_ids=batch["input_ids"], attention_mask=batch["attention_mask"]).logits
                    shift_logits = logits[..., :-1, :].contiguous()
                    shift_labels = batch["labels"][..., 1:].contiguous()
                    n_tokens = (shift_labels != -100).sum()
                    batch_tokens += n_tokens.item()
                    loss = loss_fn(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
                    # optimizer.zero_grad()
                    loss.backward()
                    # optimizer.step()
                    if optimizer.pre_step():
                        global_step = increment_global_step(global_step)
                        if global_step.item() >= STEP_LIMIT:
                            raise StepLimitReached()
                        optimizer.reduce_gradients()
                        # Insert post-processing steps here if needed
                        with torch.no_grad():
                            for name, param in model.named_parameters():
                                if param.requires_grad:
                                    denoised, k, biggest_ratio = svd_shrinkage(param.grad, noise_multiplier * optimizer.max_grad_norm)
                                    if biggest_ratio > KAPPA:
                                        param.grad = denoised * torch.norm(param.grad)/torch.norm(denoised)
                        batch_tokens_tensor = torch.tensor([batch_tokens], device=local_rank, dtype=torch.long)
                        torch.distributed.all_reduce(batch_tokens_tensor, op=dist.ReduceOp.SUM)
                        batch_tokens = batch_tokens_tensor.item()
                        for param in ddp_model.parameters():
                            if param.grad is not None:
                                with torch.no_grad():
                                    param.grad /= batch_tokens
                        batch_tokens = 0
                        optimizer.original_optimizer.step()
                        if global_step.item() == STEP_LIMIT - 1:
                            raise StepLimitReached()
                    # scheduler.step()
                        if global_step.item() % LOG_STEPS == 0:
                            dist.barrier()
                            ddp_model.eval()
                            eval(ddp_model.module, step=global_step.item())
                            ddp_model.train()
                            if dist.get_rank() == 0:
                                model.save_pretrained(os.path.join(OUTPUT_DIR, "adapter_step_"+str(global_step.item())))
                                tokenizer.save_pretrained(OUTPUT_DIR)
                                print("Adapter saved to:", os.path.join(OUTPUT_DIR, "adapter_step_"+str(global_step.item())))
                    optimizer.zero_grad()
                    epoch_loss += loss.item()
                    processed_samples += batch["input_ids"].size(0)
                    if rank == 0:
                        progress_bar.set_postfix(loss=loss.item(), step=global_step.item())
                        
                # Evaluate after epoch
                dist.barrier()
                torch.cuda.empty_cache()
                ddp_model.eval()
                eval(ddp_model.module, step=global_step.item())
            # dist.barrier()
    except StepLimitReached:
        print(f"Rank {rank} reached step limit of {STEP_LIMIT}, finishing training.")
        dist.barrier()
        torch.cuda.empty_cache()
        ddp_model.eval()
        eval(ddp_model.module, step=global_step.item())
    if rank == 0:
        model.save_pretrained(os.path.join(OUTPUT_DIR, "adapter"))
        tokenizer.save_pretrained(OUTPUT_DIR)
        print("Adapter saved to:", os.path.join(OUTPUT_DIR, "adapter")) 
    cleanup()


if __name__ == "__main__":
    launch()