from transformers import AutoModelForCausalLM
import transformers
import torch
import os
import itertools
import pandas as pd 
import ast 
from collections import OrderedDict
from torch.utils.data import Dataset, DataLoader, Subset, BatchSampler
from tqdm import tqdm
import math
from torch.optim import AdamW
from opacus import PrivacyEngine
from opacus.distributed import DifferentiallyPrivateDistributedDataParallel as DPDDP
import numpy as np
import gc
import math
from torch.utils.data.distributed import DistributedSampler  
import copy
from options import args_parser
from prv_accountant.dpsgd import find_noise_multiplier
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
import dp_transformers
from typing import Any, Dict, List
import datasets
from opacus.utils.batch_memory_manager import BatchMemoryManager
from dataclasses import dataclass


from transformers import TrainingArguments

@dataclass
class PrefixMaskingCollatorWrapper:

    tokenizer: PreTrainedTokenizerBase

    def __post_init__(self):
        # Instantiate the original data collator that we want to preserve
        self.base_collator = dp_transformers.DataCollatorForPrivateCausalLanguageModeling(self.tokenizer)

    def __call__(self, examples: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
        # 1. Separate the prefix lengths before they go to the base collator
        prefix_lengths = [e.pop('prefix_lengths') for e in examples]
        
        # 2. Let the original, "tricky" data collator do its work
        batch = self.base_collator(examples)
        
        # 3. Now, apply our masking logic to the `labels` tensor it produced
        # Mask the prefix part
        for i in range(len(batch['labels'])):
            prefix_len = prefix_lengths[i]
            batch['labels'][i, :prefix_len] = -100 # -100 is the ignore_index for loss

        # 4. Also mask any padding tokens in the labels for robustness
        pad_token_mask = (batch['input_ids'] == self.tokenizer.pad_token_id)
        batch['labels'][pad_token_mask] = -100
        
        return batch

class LimitedBatchSampler:

    def __init__(self, batch_sampler: BatchSampler, max_batches: int):
        self.batch_sampler = batch_sampler
        self.max_batches = max_batches

    def __iter__(self):
        return itertools.islice(self.batch_sampler, self.max_batches)

    def __len__(self):
        return min(len(self.batch_sampler), self.max_batches)
    


def set_seed(seed):
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

# Global Optimizer
class GlobalOptimizer():
    def __init__(self, glob_lr, params, optim=None):
        self.params = params
        self.glob_lr = glob_lr
        self.counter = 1 
        if optim == "AdamW":
            self.m = {}
            self.v = {}
            for k in params.keys():
                self.m[k] = torch.zeros_like(params[k])
                self.v[k] = torch.zeros_like(params[k])


    def AdamW_update(self, glob_buffer, beta1=0.9, beta2=0.999, lamb=0.05, eps=1e-6):
        for k in glob_buffer.keys():

            self.params[k].mul_(1 - self.glob_lr * lamb)
            
            # Update momentum and velocity
            self.m[k].mul_(beta1).add_(glob_buffer[k], alpha=1-beta1)
            self.v[k].mul_(beta2).addcmul_(glob_buffer[k], glob_buffer[k], value=1-beta2)
            
            # Apply bias correction
            bias_correction1 = 1 - beta1 ** self.counter
            bias_correction2 = 1 - beta2 ** self.counter
            
            # Compute step size
            step_size = self.glob_lr / bias_correction1
            
            # Compute denom
            denom = torch.sqrt(self.v[k]).div_(math.sqrt(bias_correction2)).add_(eps)
            
            # Update parameters
            self.params[k].addcdiv_(self.m[k], denom, value=-step_size)
        
        self.counter += 1
        gc.collect()
        return self.params
    
def save_checkpoint(model, round, args):

    if args.rank != 0:
        return  # Prevent other processes from writing to the same file
    try:
        os.makedirs(args.output_dir, exist_ok=True)

        
        checkpoint_state = {
            'round': round,
            'model_state_dict': model.state_dict(),
        }
        
        filepath = os.path.join(args.output_dir, f"checkpoint_step_{round+1}.pt")
        torch.save(checkpoint_state, filepath)
        
    except Exception as e:
        print(f"[Rank {args.rank}] Error saving checkpoint: {e}")

def get_local_update(glob_params, local_model):
    local_update = []
    for key, param in local_model.named_parameters():
        if param.requires_grad:
           
            global_param = glob_params[key].to(param.device)
            update = global_param - param.data
            local_update.append([key, update.cpu()])
    return OrderedDict(local_update)


def update_buffer(glob_buffer, local_update, weight):
    for k in local_update.keys():
        glob_buffer[k] += (local_update[k] * weight)
    return glob_buffer

def ddp_compute_loss(model, batch_size, dataset, rank):
    model.eval()
    device = torch.device(f"cuda:{rank}")
    sampler = DistributedSampler(dataset=dataset, num_replicas=8, rank=rank, shuffle=False)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, sampler=sampler)
    with torch.no_grad():
        accumulated_loss = torch.tensor(0, device=device,dtype=torch.float64)
        for batch in tqdm(dataloader):
            inputs = {key: value.to(device) for key, value in batch.items()}
            output = model(**inputs)
            avg_batch_loss = output.loss
            accumulated_loss += avg_batch_loss
        torch.distributed.all_reduce(accumulated_loss)

    avg_loss = accumulated_loss.cpu().item() * batch_size / len(dataset)
    return avg_loss



def evaluate(model, eva_dataset, args):

    tokenizer = transformers.AutoTokenizer.from_pretrained(args.model_checkpoint)
    num_added_toks = tokenizer.add_special_tokens({'pad_token': '[PAD]'})
    data_collator = PrefixMaskingCollatorWrapper(tokenizer=tokenizer)

    model_gpu = copy.deepcopy(model).to(args.local_rank)
    para_model = torch.nn.parallel.DistributedDataParallel(model_gpu, device_ids=[args.local_rank])
    eval_loader = DataLoader(
    eva_dataset,
    batch_size=64, # Use physical batch size for efficiency
    collate_fn=data_collator,       # Use the SAME collator as training
    shuffle=False
    )
    device = f'cuda:{args.local_rank}'

    para_model.eval()  # Set the model to evaluation mode
    total_loss = 0.0
    
    with torch.no_grad():  # No need to calculate gradients during evaluation
        for batch in eval_loader:
            inputs = {key: value.to(device, non_blocking=True) for key, value in batch.items()}
            
            outputs = para_model(**inputs)
            loss = outputs.loss
            total_loss += loss.item()

    local_stats = torch.tensor([total_loss, len(eval_loader)]).to(device)
    
    # Sum the stats across all devices
    torch.distributed.all_reduce(local_stats, op=torch.distributed.ReduceOp.SUM)
    
    # Calculate the global average loss
    global_avg_loss = local_stats[0].item() / local_stats[1].item()
    
    del para_model
    del model_gpu
    torch.cuda.empty_cache()
    return global_avg_loss


def init_ddp():
    world_size = int(os.environ["WORLD_SIZE"])
    rank = int(os.environ["RANK"])
    local_rank = int(os.environ["LOCAL_RANK"])
    torch.distributed.init_process_group(backend='nccl', init_method="env://", rank=rank, world_size=world_size)
    return world_size, rank, local_rank



def data_processing(args):

    def preprocess_function(examples):
    # This function now also calculates the length of the prefix for masking later.
    
        prefixes = []
        full_texts = []
        
        # 1. Prepare the prefix and the full text for every example in the batch
        for i in range(len(examples['text'])):
            prefix_str = "\t".join([examples[name][i] for name in label_column_names]) + "\n\n"
            text_str = examples['text'][i] + tokenizer.eos_token
            
            prefixes.append(prefix_str)
            full_texts.append(prefix_str + text_str)

        # 2. Tokenize the full text to get the model inputs
        tokenized_full = tokenizer(full_texts, padding="max_length", truncation=True, max_length=128)

        # 3. Tokenize just the prefixes to find out their length in tokens.
        # We use add_special_tokens=False to get an accurate token count of only the prefix text.
        tokenized_prefixes = tokenizer(prefixes, add_special_tokens=False)
        
        # 4. Get the length of each tokenized prefix and add it as a new column.
        prefix_lengths = [len(p) for p in tokenized_prefixes['input_ids']]
        tokenized_full['prefix_lengths'] = prefix_lengths
        
        return tokenized_full

    dataset = datasets.load_dataset('csv', data_files={'train': args.training_data, 'validation': args.test_data})
    label_column_names = [name for name in dataset["train"].column_names if "label" in name]
    tokenizer = transformers.AutoTokenizer.from_pretrained(args.model_checkpoint)
    num_added_toks = tokenizer.add_special_tokens({'pad_token': '[PAD]'})


    train_args = TrainingArguments(output_dir="./temp")
    with train_args.main_process_first(desc="tokenizing dataset"):
        dataset = dataset.map(preprocess_function, batched=True, desc="tokenizing dataset", remove_columns=dataset.column_names['train'])




    client_list = [] # Will store the indices of clients
    partition_df = pd.read_csv(args.partition)
    for _, row in partition_df.iterrows():
        indices = ast.literal_eval(row['indices'])
        client_list.append(indices)

    return dataset['train'], dataset['validation'], client_list


def init_glob(args):
    model = AutoModelForCausalLM.from_pretrained(args.model_checkpoint)
    tokenizer = transformers.AutoTokenizer.from_pretrained(args.model_checkpoint)
    num_added_toks = tokenizer.add_special_tokens({'pad_token': '[PAD]'})
    mean_tok_emb = model.transformer.wte.weight.data.mean(dim=0)
    model.resize_token_embeddings(len(tokenizer))
    for i in range(num_added_toks):
        model.transformer.wte.weight.data[-(i + 1), :] = mean_tok_emb

    glob_params = model.state_dict()
    glob_optimizer = GlobalOptimizer(glob_lr=args.glob_lr, params = glob_params, optim="AdamW")
    weight = 1/args.client_range
    return model, weight, glob_optimizer


def init_buffer(model):
    glob_buffer = OrderedDict()
    for k, param in model.named_parameters():
        if param.requires_grad:
            glob_buffer[k] = torch.zeros_like(param.data, device='cpu')
    return glob_buffer


def DP_processing(args, model, dataset):

    tokenizer = transformers.AutoTokenizer.from_pretrained(args.model_checkpoint)
    num_added_toks = tokenizer.add_special_tokens({'pad_token': '[PAD]'})
    data_collator = PrefixMaskingCollatorWrapper(tokenizer=tokenizer)

    privacy_engine = PrivacyEngine()
    model, optimizer, data_loader = privacy_engine.make_private(
        module=model,
        optimizer=AdamW(model.parameters(), lr=args.local_lr),
        data_loader=DataLoader(dataset, batch_size=args.batch_size, collate_fn=data_collator, shuffle=True), 
        noise_multiplier=args.noise_multiplier,
        max_grad_norm=args.clipping_constant,
    )
    return model, optimizer, data_loader, privacy_engine


def shuffle_subset(subset):
    indices = np.array(subset.indices)
    np.random.shuffle(indices)
    subset.indices = indices.tolist()
    return subset

def local_training(model, local_dataset, args):


    local_model = copy.deepcopy(model)
    local_model.to(args.local_rank)
    local_model = DPDDP(local_model)
    device = f'cuda:{args.local_rank}'

    local_model, optimizer, data_loader, privacy_engine = DP_processing(args, local_model, local_dataset)
    local_model.train()

    limited_batch_sampler = LimitedBatchSampler(
    data_loader.batch_sampler, 
    max_batches=args.local_period
    )

    limited_loader = DataLoader(
        dataset=data_loader.dataset,
        batch_sampler=limited_batch_sampler,
        collate_fn=data_loader.collate_fn,
        num_workers=data_loader.num_workers,
    )

    with BatchMemoryManager(
    data_loader=limited_loader,
    max_physical_batch_size=64,  
    optimizer=optimizer
    ) as memory_safe_data_loader:
        for batch in memory_safe_data_loader:   


            inputs = {key: value.to(device, non_blocking=True) for key, value in batch.items()}
            outputs = local_model(**inputs)
            loss = outputs.loss
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

    
    return local_model



def fl_training(args):

    world_size, rank, local_rank = init_ddp()
    args.rank = rank
    args.local_rank = local_rank
    torch.cuda.set_device(local_rank)

    train_dataset, test_dataset, client_list = data_processing(args)


    model, weight, glob_optimizer = init_glob(args)

    glob_params = model.state_dict()
    dataset_len = len(client_list[0])
    args.noise_multiplier = find_noise_multiplier(
        sampling_probability=args.batch_size/dataset_len, 
        num_steps=int(args.num_rounds * args.local_period), 
        target_epsilon=args.epsilon, 
        target_delta=1/(dataset_len * math.log(dataset_len)))

    best_loss = float('inf')
    best_round = -1

    # Local training
    for i in range(0,args.num_rounds):


        glob_buffer = init_buffer(model)


        for client_id in list(np.arange(args.client_range)):


            local_dataset = Subset(train_dataset, client_list[client_id])
            local_model = local_training(model, local_dataset, args) 

            local_update = get_local_update(glob_params, local_model.module)
            glob_buffer = update_buffer(glob_buffer, local_update, weight=weight)

            del local_model, local_update, local_dataset # Clean up GPU memory
            torch.cuda.empty_cache()



        glob_params = glob_optimizer.AdamW_update(glob_buffer=glob_buffer)
        model.load_state_dict(glob_params)
        del glob_buffer
        gc.collect()




        if (i+1) % 5 == 0:

            eva_loss = evaluate(model, test_dataset, args)

            is_best = eva_loss < best_loss
            if is_best:
                best_loss = eva_loss
                best_round = i
            save_checkpoint(model, i, args, is_best)

            if rank == 0:
                print("Round:", i+1, flush=True)
                print("The evaluation loss is:", eva_loss, flush=True)
                print("The best round is:", best_round)
    
    return model





if __name__ == "__main__": 



    args = args_parser()
    set_seed(args.seed)
    model = fl_training(args)
    


    if torch.distributed.is_initialized():
        torch.distributed.destroy_process_group()


    