import os
import pickle
import math

import torch
from datasets import load_dataset
from transformers import (
    AutoTokenizer, AutoModelForCausalLM,
)
from peft import LoraConfig, get_peft_model
import evaluate

from tqdm import tqdm
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.optim as optim
import torch.multiprocessing as mp
from torch.utils.data import DataLoader, DistributedSampler

from torch.nn.parallel import DistributedDataParallel as DDP
from transformers import get_linear_schedule_with_warmup
import numpy as np

from opacus.accountants.prv import PRVAccountant
from opacus.accountants.utils import get_noise_multiplier
from opacus import PrivacyEngine
from opacus.distributed import DifferentiallyPrivateDistributedDataParallel as DPDDP

from opacus.utils.batch_memory_manager import BatchMemoryManager

from nltk.translate.nist_score import  corpus_nist
from pycocoevalcap.cider.cider import Cider

from transformers.utils.logging import set_verbosity_error
set_verbosity_error()   # removes ALL warnings from HF

import wandb
wandb.login()




bleu = evaluate.load("bleu")
rouge = evaluate.load("rouge")
meteor = evaluate.load("meteor")
bertscore = evaluate.load("bertscore")
cider_scorer = Cider()

def compute_metrics(val_refs, preds):
    bleu_result = bleu.compute(predictions=preds, references=val_refs)

    rouge_result = rouge.compute(predictions=preds, references=val_refs)

    meteor_result = meteor.compute(predictions=preds, references=val_refs)

    bertscore_result = bertscore.compute(predictions=preds, references=val_refs, lang="en", model_type="roberta-large")
    bertscore_f1_mean = np.mean(bertscore_result['f1'])
    
    cider_score, _ = cider_scorer.compute_score({i: val_refs[i] for i in range(len(val_refs))}, {i: [preds[i]] for i in range(len(preds))})
    
    nist_score = corpus_nist([[r.split() for r in refs] for refs in val_refs], [x.split() for x in preds], n=4)
    
    return {
        "bleu": bleu_result["bleu"],
        "rouge": rouge_result["rougeL"].item(),
        "meteor": meteor_result["meteor"].item(),
        "bertscore_f1_mean": bertscore_f1_mean.item(),
        "cider": cider_score.item(),
        "nist": nist_score
    }

def setup():
    rank = int(os.environ["RANK"])
    world_size = int(os.environ["WORLD_SIZE"])
    local_rank = int(os.environ["LOCAL_RANK"])
    print(f"Initializing rank {rank} (local rank {local_rank}) out of {world_size}...")
    torch.cuda.set_device(local_rank)

    dist.init_process_group(
        backend="nccl",
        init_method="env://",
        world_size=world_size,
        rank=rank
    )
    return local_rank, rank, world_size


def cleanup():
    dist.destroy_process_group()

def gather_ordered_predictions(local_preds, rank=None, world_size=None):
    """
    Gather variable-length lists of predictions from all ranks
    and return them in global order (rank 0 only).
    """
    print(f"Rank {rank} gathering {len(local_preds)} predictions...")
    # print(local_preds[0])  # Print first two predictions for debugging
    # print(local_preds[0][0])
    # Each rank contributes its contiguous shard
    os.makedirs("pickle", exist_ok=True)
    fname = os.path.join("pickle", f"preds_rank_{rank}.pkl")
    with open(fname, "wb") as f:
        pickle.dump(local_preds, f)
    dist.barrier()
    if dist.get_rank() == 0:
        gathered = []
        
        for i in range(world_size):
            fname = os.path.join("pickle", f"preds_rank_{i}.pkl")
            with open(fname, "rb") as f:
                preds_part = pickle.load(f)
            gathered.append(preds_part)
        
        merged = []
        
        for i in range(len(gathered[0])):
            for r in range(world_size):
                merged.extend(gathered[r][i] if i < len(gathered[r]) else [])

        return merged
    else:
        return None


def increment_global_step(step_tensor, step_increment=1):
    local_increment = torch.tensor([step_increment], device=step_tensor.device, dtype=torch.long)
    dist.all_reduce(local_increment, op=dist.ReduceOp.SUM)
    step_tensor += local_increment
    return step_tensor


MODEL_ID = "meta-llama/Llama-3.2-3B"   
OUTPUT_DIR = "./llama32_3b_e2e_baseline"

DATASET_ID = "GEM/e2e_nlg"

MAX_INPUT_TOKENS = 124   # MR prompt tokens
MAX_LABEL_TOKENS = 512   # Target text tokens
MAX_SEQ_LEN = MAX_INPUT_TOKENS + MAX_LABEL_TOKENS

# Training hyperparameters
# EPOCHS = 1
STEP_LIMIT = 400

PER_DEVICE_TRAIN_BATCH = 16
GRADIENT_ACCUMULATION_STEPS = 4
LOGICAL_BATCH_SIZE = PER_DEVICE_TRAIN_BATCH * GRADIENT_ACCUMULATION_STEPS
LR = 2e-4
PER_DEVICE_EVAL_BATCH = 16
WEIGHT_DECAY = 0.01
LABEL_SMOOTHING_FACTOR=0.0
LOG_STEPS = 8 * 6 
SAVE_STEPS = 500

# LoRA Config
LORA_R = 32
LORA_ALPHA = LORA_R * 2
LORA_DROPOUT = 0.1
TARGET_MODULES = ["q_proj", "v_proj"]




EPSILON = 5.4
DELTA = 1e-5





import os
PROMPT_PREFIX = "MR: "
TARGET_PREFIX = "\nText: "

def build_prompt_only(example):
    mr = example
    return f"{PROMPT_PREFIX}{mr}{TARGET_PREFIX}"  


def launch():
    local_rank, rank, world_size = setup()
    if rank == 0:
        logger = wandb.init(entity="dadsetan", project="svd-dp", tags=["script", 'iclr-abstract'],config={
            "dataset_name": "e2e",
            "step_limit": STEP_LIMIT,
            "model_name": MODEL_ID,
            "denoising": False,
        })
    global_step = torch.tensor([0], dtype=torch.long).to(local_rank)
    print(f"Rank {rank}/{world_size} initialized on local rank {local_rank}.")
    tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_ID,
    )  
    peft_config = LoraConfig(
        r=LORA_R,
        lora_alpha=LORA_ALPHA,
        lora_dropout=LORA_DROPOUT,
        inference_mode=False,
        task_type="CAUSAL_LM",
        target_modules=TARGET_MODULES,
    )
    model = get_peft_model(model, peft_config)
    # for p in model.base_model.model.model.embed_tokens.parameters():
    #     p.requires_grad = False
    # for p in model.base_model.model.lm_head.parameters():
    #     p.requires_grad = True
        


    raw = load_dataset(DATASET_ID, trust_remote_code=True)
    SAMPLE_RATE = LOGICAL_BATCH_SIZE / len(raw["train"])



    tokenizer.padding_side = "right"  

    def tokenize_function(examples):
        input_ids = []
        labels = []
        attention_masks = []

        for mr, tgt in zip(examples["meaning_representation"], examples["target"]):
            tgt_eos = tgt + tokenizer.eos_token
            text = f"{PROMPT_PREFIX}{mr}{TARGET_PREFIX}{tgt_eos}"
            tok = tokenizer(
                text,
                truncation=True,
                max_length=MAX_SEQ_LEN,
                padding=False,
            )
            # Mask the MR part
            split_point = len(tokenizer(f"{PROMPT_PREFIX}{mr}{TARGET_PREFIX}")["input_ids"])
            label_ids = tok["input_ids"].copy()
            label_ids[:split_point] = [-100] * split_point

            input_ids.append(tok["input_ids"])
            attention_masks.append(tok["attention_mask"])
            labels.append(label_ids)

        return {
            "input_ids": input_ids,
            "attention_mask": attention_masks,
            "labels": labels,
        }


    tokenized_train = raw["train"].map(tokenize_function, batched=True, remove_columns=raw["train"].column_names)
    tokenized_train.set_format(type='torch')

    from torch.nn.utils.rnn import pad_sequence

    def clm_collate(batch):
        input_ids = [x["input_ids"] for x in batch]
        labels = [x["labels"] for x in batch]
        attn = [x["attention_mask"] for x in batch]

        input_ids = pad_sequence(input_ids, batch_first=True, padding_value=tokenizer.pad_token_id)
        labels = pad_sequence(labels, batch_first=True, padding_value=-100)
        attn = pad_sequence(attn, batch_first=True, padding_value=0)

        return {"input_ids": input_ids, "labels": labels, "attention_mask": attn}



    train_dl = DataLoader(tokenized_train, batch_size=LOGICAL_BATCH_SIZE, shuffle=False, collate_fn=clm_collate, drop_last=False)
    optimizer = optim.AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)

    model.to(local_rank)
    ddp_model = DPDDP(model)
    ddp_model.to(local_rank)
    ddp_model.train()
    
    loss_fn = nn.CrossEntropyLoss(ignore_index=-100, label_smoothing=LABEL_SMOOTHING_FACTOR, reduction='sum')

    noise_multiplier=get_noise_multiplier(
        target_epsilon = EPSILON,
        target_delta = DELTA,
        sample_rate = SAMPLE_RATE,
        steps = STEP_LIMIT,
        accountant='prv'
    )
    # noise_multiplier = 0.5
    print(f"Noise multiplier: {noise_multiplier}")
    
    privacy_engine = PrivacyEngine()
    ddp_model, optimizer, train_dl = privacy_engine.make_private(
        module=ddp_model,
        optimizer=optimizer,
        data_loader=train_dl,
        criterion=loss_fn,
        noise_multiplier=noise_multiplier,
        max_grad_norm=1.0,
        loss_reduction="sum",
        poisson_sampling=False,
        # grad_sample_mode="ghost_fsdp",
        # grad_sample_mode="ghost",
    )
    
    def eval(model, step):
        model.eval()
        # Evaluation
        GEN_MAX_NEW_TOKENS = 100
        GEN_LENGTH_PENALTY = 0.9   # <1 favors shorter, >1 favors longer
        GEN_NO_REPEAT_NGRAM = 4
        GEN_NUM_BEAMS = 10 # beam search

        def generate_batch(prompts, model, tokenizer):
            device = model.device
            inputs = tokenizer(
                prompts,
                return_tensors="pt",
                padding=True,
                truncation=True,
                max_length=MAX_INPUT_TOKENS,
                padding_side="left",
            ).to(device)
            with torch.no_grad():
                outputs = model.generate(
                    **inputs,
                    max_new_tokens=GEN_MAX_NEW_TOKENS,
                    length_penalty=GEN_LENGTH_PENALTY,
                    no_repeat_ngram_size=GEN_NO_REPEAT_NGRAM,
                    num_beams=GEN_NUM_BEAMS,
                    do_sample=False,
                    pad_token_id=tokenizer.pad_token_id,
                    eos_token_id=tokenizer.eos_token_id,
                    early_stopping=True,
                )
            decoded = tokenizer.batch_decode(outputs, skip_special_tokens=True)
            del outputs
            cleaned = []
            for prompt, full in zip(prompts, decoded):
                if full.startswith(prompt):
                    cont = full[len(prompt):]
                else:
                    idx = full.rfind(prompt)
                    cont = full[idx+len(prompt):] if idx != -1 else full
                idx_end = cont.find(tokenizer.eos_token)
                if idx_end != -1:
                    cont = cont[:idx_end]
                cleaned.append(cont.strip())
            return cleaned

        val_prompts = [build_prompt_only(e) for e in raw["validation"]["meaning_representation"]]
        val_refs = raw["validation"]["references"]
        # val_prompt_loader = DataLoader(val_prompts, batch_size=16, collate_fn=lambda x: x)
        



        
        BATCH = 16
        preds_batch_list = []
        for i in tqdm(range(0, len(val_prompts), BATCH * world_size), disable=(rank != 0)):
            if i + rank * BATCH >= len(val_prompts):
                preds_batch_list.append([])
                continue
            batch_prompts = val_prompts[i + rank * BATCH : i + (rank + 1) * BATCH]
            preds_batch = generate_batch(batch_prompts, model, tokenizer)
            preds_batch_list.append(preds_batch)
        
        merged_preds = gather_ordered_predictions(preds_batch_list, rank, world_size)
        if rank == 0:
            result = compute_metrics(val_refs, merged_preds)
            logger.log({"step": step, **result})
            print("Validation Metrics:", result)
            def generate_from_mr(mr: str):
                prompt = f"{PROMPT_PREFIX}{mr}{TARGET_PREFIX}"
                return generate_batch([prompt], model, tokenizer)[0]

            example_mr = "name[The Wrestlers], eatType[coffee shop], food[French], priceRange[moderate], customer rating[5 out of 5], familyFriendly[yes], near[The Sorrento]"
            print("MR:", example_mr)
            print("Generated:", generate_from_mr(example_mr))
            
        dist.barrier()
        model.train()
    
    class StepLimitReached(Exception):
        pass
    
    try:
        while True:
            with BatchMemoryManager(
                data_loader=train_dl,
                max_physical_batch_size=PER_DEVICE_TRAIN_BATCH,
                optimizer=optimizer
            ) as memory_safe_data_loader:
                model.train()
                epoch_loss = 0.0
                processed_samples = 0
                progress_bar = tqdm(memory_safe_data_loader, desc=f"", disable=(dist.get_rank() != 0))
                batch_tokens = 0
                for _, batch in enumerate(progress_bar):
                    batch = {k: v.to(local_rank) for k, v in batch.items()}
                    logits = ddp_model(input_ids=batch["input_ids"], attention_mask=batch["attention_mask"]).logits
                    shift_logits = logits[..., :-1, :].contiguous()
                    
                    shift_labels = batch["labels"][..., 1:].contiguous()
                    n_tokens = (shift_labels != -100).sum()
                    batch_tokens += n_tokens.item()
                    loss = loss_fn(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
                    # optimizer.zero_grad()
                    loss.backward()
                    if optimizer.pre_step():
                        global_step = increment_global_step(global_step)
                        if global_step.item() >= STEP_LIMIT:
                            raise StepLimitReached()

                            
                        optimizer.reduce_gradients()
                        batch_tokens_tensor = torch.tensor([batch_tokens], device=local_rank, dtype=torch.long)
                        torch.distributed.all_reduce(batch_tokens_tensor, op=dist.ReduceOp.SUM)
                        batch_tokens = batch_tokens_tensor.item()
                        for param in ddp_model.parameters():
                            if param.grad is not None:
                                with torch.no_grad():
                                    param.grad /= batch_tokens
                        batch_tokens = 0
                        optimizer.original_optimizer.step()
                        if global_step.item() == STEP_LIMIT - 1:
                            raise StepLimitReached()
                        if global_step.item() % LOG_STEPS == 0:
                            dist.barrier()
                            ddp_model.eval()
                            eval(ddp_model.module, step=global_step.item())
                            ddp_model.train()
                            if dist.get_rank() == 0:
                                model.save_pretrained(os.path.join(OUTPUT_DIR, "adapter_step_"+str(global_step.item())))
                                tokenizer.save_pretrained(OUTPUT_DIR)
                                print("Adapter saved to:", os.path.join(OUTPUT_DIR, "adapter_step_"+str(global_step.item())))
                    optimizer.zero_grad()
                    epoch_loss += loss.item()
                    processed_samples += batch["input_ids"].size(0)
                    if dist.get_rank() == 0:
                        progress_bar.set_postfix(loss=loss.item(), step=global_step.item())

            
                # Evaluate after epoch
                dist.barrier()
                torch.cuda.empty_cache()
                ddp_model.eval()
                eval(ddp_model.module, step=global_step.item())
                # dist.barrier()
 
    except StepLimitReached:
        if dist.get_rank() == 0:
            print(f"Step limit of {STEP_LIMIT} reached. Ending training.")
        dist.barrier()
        torch.cuda.empty_cache()
        ddp_model.eval()
        eval(ddp_model.module, step=global_step.item())
        if dist.get_rank() == 0:
            model.save_pretrained(os.path.join(OUTPUT_DIR, "adapter"))
            tokenizer.save_pretrained(OUTPUT_DIR)
            print("Adapter saved to:", os.path.join(OUTPUT_DIR, "adapter")) 
    cleanup()


if __name__ == "__main__":
    launch()
