import argparse
import pickle
import torch
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.optim as optim
import numpy as np
import os
import glob
import re
import random
from opacus.accountants.utils import get_noise_multiplier
from torch.profiler import profile, record_function, ProfilerActivity
from fvcore.nn import FlopCountAnalysis, parameter_count_table



# Import your custom dataset & models
from models.client_model import Bert
from models.server_model import MLPCombiner
from prepare_data import AmazonPolarityPreprocessor
from torch.utils.data import Dataset
# from transformers import BertTokenizer, BlipImageProcessor

parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
parser.add_argument('data', metavar='DIR',
                    help='path to dataset')
parser.add_argument('-j', '--workers', default=32, type=int, metavar='N',
                    help='number of data loading workers (default: 32)')
parser.add_argument('--epochs', default=100, type=int, metavar='N',
                    help='number of total epochs to run')
# parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
#                     help='manual epoch number (useful on restarts)')
parser.add_argument('-b', '--batch-size', default=256, type=int,
                    metavar='N',
                    help='mini-batch size (default: 256), this is the total '
                         'batch size of all GPUs on the current node when '
                         'using Data Parallel or Distributed Data Parallel')
parser.add_argument('--lr', '--learning-rate', default=0.03, type=float,
                    metavar='LR', help='initial learning rate', dest='lr')
parser.add_argument('--server_lr', '--server-learning-rate', default=0.005, type=float,
                    metavar='SVLR', help='learning rate for server', dest='server_lr')
# parser.add_argument('--lr_un', '--unsupervised-learning-rate', default=10.0, type=float,
#                     metavar='LR', help='initial learning rate for final linear layer', dest='lr_un')
#parser.add_argument('--schedule', default=[60, 80], nargs='*', type=int,
#                    help='learning rate schedule (when to drop lr by a ratio)')
# parser.add_argument('--schedule', default=[], nargs='*', type=int,
#                     help='learning rate schedule (when to drop lr by a ratio)')
parser.add_argument('--momentum', default=0, type=float, metavar='M',
                    help='momentum')
parser.add_argument('--warmup_rate', default=0.1, type=float,
                    help='Warmup rate for linear LR warmup')
parser.add_argument('--wd', '--weight-decay', default=0., type=float,
                    metavar='W', help='weight decay (default: 0.)',
                    dest='weight_decay')
parser.add_argument('-p', '--print-freq', default=50, type=int,
                    metavar='N', help='print frequency (default: 50)')
parser.add_argument('--resume', default='./log', type=str, metavar='PATH',
                    help='path to latest checkpoint (default: none)')
parser.add_argument('--no_resume', action='store_true', help='Do not resume from checkpoint')
parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
                    help='evaluate model on validation set')
# parser.add_argument('--world-size', default=-1, type=int,
#                     help='number of nodes for distributed training')
# parser.add_argument('--rank', default=-1, type=int,
#                     help='node rank for distributed training')
# parser.add_argument('--dist-url', default='tcp://localhost:10001', type=str,
#                     help='url used to set up distributed training')
# parser.add_argument('--dist-backend', default='nccl', type=str,
#                     help='distributed backend')
parser.add_argument('--seed', default=None, type=int,
                    help='seed for initializing training. ')
# parser.add_argument('--gpus', default=None, nargs='+', type=int,
#                     help='GPU id(s) to use. Default is all visible GPUs.')
# parser.add_argument('--multiprocessing-distributed', action='store_true',
#                     help='Use multi-processing distributed training to launch '
#                          'N processes per node, which has N GPUs. This is the '
#                          'fastest way to use PyTorch for either single node or '
#                          'multi node data parallel training')

parser.add_argument('--num_clients', default=4, type=int,
                    help='number of clients')
parser.add_argument('--mode', default="dpzv", type=str,
        help='VFL algorithm to use: vafl, zofo or dpzv')
# parser.add_argument('--labeled_frac', default=1.0, type=float,
#                     help='fraction of training data that is labeled')
#parser.add_argument('--local_epochs', default=1, type=float,
#                    help='Number of local iterations')
# parser.add_argument('--server_time', default=1.0, type=float,
#                     help='How long roundtrip server communication takes')

#Add for DPZV

parser.add_argument('--no_dp', action='store_true', help='Do not use DP')

parser.add_argument('--zo_mu', default=1e-3, type=float,
                    help='the scale for the perturbation of parameters in zero order update')
parser.add_argument('--dp_clip_threshold', default=10., type=float,
                    help='the clipping threshold of the gradient for achieving DP')
parser.add_argument('--dp_epsilon', default=6., type=float,
                    help='DP level parameter epsilon')
parser.add_argument('--dp_delta', default=1e-5, type=float,
                    help='DP level parameter delta')
parser.add_argument('--grad_estimate_method', default='central', type=str,
                    help='The method for estimating zeroth-order gradient: central or forward')
parser.add_argument('--min_lr', default=1e-7, type=float,
                    help='Minimum learning rate')
parser.add_argument('--patience', default=3, type=int,
                    help='Patience for learning rate scheduler')


# Add for ZOFO
parser.add_argument('--num_purt', default=5, type=int,
                    help='number of purturbations for ZOFO')
parser.add_argument('--no_mezo', action='store_true',
                    help='use mezo for ZOFO')

args = parser.parse_args()


def clip_tensor(tensor, max_norm):
    """
    Clips a tensor to have a norm at most `max_norm`.
    
    Args:
        tensor (torch.Tensor): The input tensor.
        max_norm (float): The maximum allowed norm.
    
    Returns:
        torch.Tensor: The clipped tensor.
    """
    norm = torch.norm(tensor, p=2)  # Compute L2 norm
    scale = min(1, max_norm / (norm + 1e-6))  # Compute scaling factor (avoid division by zero)
    return tensor * scale



def vafl_measure(loader, models, optimizers, criterion, args):

    device = args.device

    for batch_idx, (texts, labels) in enumerate(loader):
        if batch_idx==5:
            break
        for client in range(args.num_clients + 1):
            adjust_lr(args, optimizer=optimizers[client])
        texts = [t.to(device) for t in texts]
        labels = labels.to(device)
        embeddings = []
        for client in range(args.num_clients):
            text_local = texts[client]
            with torch.no_grad():
                if args.dp_epsilon > 0:
                    embedding = models[client](text_local)
                    embedding = clip_tensor(embedding, args.dp_clip_threshold)
                    noise = torch.normal(
                        mean=0,
                        std=args.dpzero_gaussian_std,
                        size=embedding.size(),
                        device=device,
                        dtype=embedding.dtype,
                    )
                    embedding = embedding + noise
                else:
                    embedding = models[client](text_local)
                embeddings.append(embedding)
        args.global_step += 1
        for client in range(args.num_clients):
            text_local = texts[client]
            optimizers[client].zero_grad()
            optimizers[-1].zero_grad()
            embedding_view = [client_view.detach().clone() for client_view in embeddings]
            embedding_view[client] = models[client](text_local)
            output = models[-1](embeddings)
            loss = criterion(output, labels).mean()
            loss.backward()
            optimizers[client].step()
            optimizers[-1].step()



def perturb_embedding(embedding, random_seed: int, args, scaling_factor=1, ):
    torch.manual_seed(random_seed)
    z = torch.normal(mean=0, std=1, size=embedding.size(), device=embedding.device, dtype=embedding.dtype)
    with torch.no_grad():
        embedding = embedding + scaling_factor * z * args.zo_mu
    return embedding

def project_gradient(loss_diff, size, random_seed: int):
    torch.manual_seed(random_seed)
    z = torch.normal(mean=0, std=1, size=size, device=loss_diff.device, dtype=loss_diff.dtype)
    grad = loss_diff * z
    return grad

def zofo_measure(loader, models, optimizers, criterion, args):
    device = args.device

    (texts, labels) = next(iter(loader))        
    for client in range(args.num_clients + 1):
        adjust_lr(args, optimizer=optimizers[client])
    texts = [t.to(device) for t in texts]
    labels = labels.to(device)
    embeddings = []
    with profile(
        activities=[ProfilerActivity.CUDA],  # Use ProfilerActivity.CUDA for GPU
        record_shapes=True,
        with_flops=True,
        with_stack=True
        ) as prof:
            with record_function("zofo"):
                models[0](texts[0])
    flops = sum(item.flops for item in prof.key_averages() if item.flops is not None)
    print(prof.key_averages().table(
    sort_by="flops",  # Use "cuda_time_total" for GPU
    row_limit=10,
    top_level_events_only=True
))
    print(f"FLOPs: {flops}")
    for client in range(args.num_clients):       
            with torch.no_grad():
                embeddings.append(models[client](texts[client]))
    args.global_step += 1
    for client in range(args.num_clients):
        optimizers[client].zero_grad()
        optimizers[-1].zero_grad()
        text_local = texts[client]
        deltas = []
        mu_multiplier = 2
        embeddings_view_plus = embeddings.copy()
        embeddings_view_minus = embeddings.copy()
        embedding = models[client](text_local)
        for _ in range(args.num_purt):
            embedding_view = embedding.clone()
            random_seed = np.random.randint(1000000000)
            embeddings_view_plus[client] = perturb_embedding(embedding_view, random_seed, args, scaling_factor=1)
            embeddings_view_minus[client] = perturb_embedding(embedding_view, random_seed, args, scaling_factor=-2)
            with torch.no_grad():
                output_plus = models[-1](embeddings_view_plus)
                output_minus = models[-1](embeddings_view_minus)
                loss_1 = criterion(output_plus, labels).mean()
                loss_2 = criterion(output_minus, labels).mean()
                deltas.append((loss_1 - loss_2) / (mu_multiplier * args.zo_mu))
        loss_diff = sum(deltas) / args.num_purt
        partial_grad = project_gradient(loss_diff, embedding.size(), random_seed=random_seed)
        embedding.backward(gradient=partial_grad, inputs=list(models[client].parameters()))
        
        optimizers[client].step()
        output = models[-1](embeddings)
        loss = criterion(output, labels).mean()
        loss.backward()
        optimizers[-1].step()
        


            



class DPZV_Trainer():

    def __init__(self, args, server_optimizer):
        self.args = args
        self.server_optimizer = server_optimizer
        # self.agg_optimizer = None
        # print(self.dpzero_gaussian_std)
        self.lr = args.lr
        # self.random_seeds = [[] for _ in range(args.num_clients)]
        # self.history_diff = [[] for _ in range(args.num_clients)]   
        # self.grad = None 
        # self.total_params = [0 for _ in range(args.num_clients)]
        # self.adam_v = 0    
        # self.patience = args.patience
        # self.threshold = 0.1
        # self.min_lr = args.min_lr
        # self.best_metric = None
        # self.num_bad_epochs = 0


    def _get_learning_rate(self):
        return self.lr
    
    def zo_perturb_parameters(self, model: nn.Module, random_seed: int, scaling_factor=1):
        args = self.args
        torch.manual_seed(random_seed)
        with torch.no_grad():
            for name, param in model.named_parameters():
                z = torch.normal(mean=0, std=1, size=param.data.size(), device=param.data.device, dtype=param.data.dtype)
                param.data = param.data + scaling_factor * z * args.zo_mu
    
    def dpzero_clip(self, loss_diff, C=1.):
        abs_loss_diff = torch.abs(loss_diff)
        clipped_mask = abs_loss_diff > C
        clipping_rate = clipped_mask.float().mean().item()
        tmp = torch.min(torch.ones_like(loss_diff), torch.div(C * torch.ones_like(loss_diff), abs_loss_diff))
        return torch.mul(tmp, loss_diff).mean(), clipping_rate
    
    def zo_forward(self, model, inputs):
        """
        Get (no gradient) loss from the model. Dropout is turned off too.
        """
        model.eval()

        with torch.inference_mode():
            output = model(inputs)
        return output

    def zo_adjust_lr(self):
        new_lr = adjust_lr(self.args, self.server_optimizer)
        self.lr = new_lr
    
    def zo_update(self, model, seed, projected_grad):
            """
            Update the parameters with the estimated gradients.
            """
            args = self.args
            with torch.no_grad():
                torch.manual_seed(seed)
                grad = {}
                for name, param in model.named_parameters():
                    # Resample z
                    z = torch.normal(mean=0, std=1, size=param.data.size(), device=param.data.device, dtype=param.data.dtype)
                    if "bias" not in name and "layer_norm" not in name and "layernorm" not in name:
                        grad[name] = projected_grad * z + args.weight_decay * param.data
                    else:
                        grad[name] = projected_grad * z
                    # self.grads[name] = grad.clone()
                for name, param in model.named_parameters():
                    param.data = param.data - self._get_learning_rate() * grad[name]

    def mezo_update(self, model, seed, projected_grad ):
        """
        Update the parameters with the estimated gradients.
        """
        args = self.args
        with torch.no_grad():
            torch.manual_seed(seed)
            for name, param in model.named_parameters():
                # Resample z
                z = torch.normal(mean=0, std=1, size=param.data.size(), device=param.data.device, dtype=param.data.dtype)
                if "bias" not in name and "layer_norm" not in name and "layernorm" not in name:
                    grad = projected_grad * z + args.weight_decay * param.data
                else:
                    grad = projected_grad * z
                # self.grads[name] = grad.clone()
                param.data = param.data - self._get_learning_rate() * grad

    def measure(self, loader, models, criterion, args):
        device = args.device
        args = self.args  # using the trainer's own args

        for batch_idx, (texts, labels) in enumerate(loader):
            if batch_idx == 5:
                break

            self.zo_adjust_lr()
            texts = [t.to(device) for t in texts]
            labels = labels.to(device)
            embeddings = []
            for client in range(args.num_clients):
                with torch.no_grad():
                    embeddings.append(models[client](texts[client]))
            args.global_step += 1
            for client in range(args.num_clients):
                text_local = texts[client]
                embeddings_view_plus = embeddings.copy()
                embeddings_view_minus = embeddings.copy()
                seed = np.random.randint(100000)
                with torch.no_grad():
                    self.zo_perturb_parameters(model=models[client], random_seed=seed, scaling_factor=1)
                    embeddings_view_plus[client] = models[client](text_local)
                    self.zo_perturb_parameters(model=models[client], random_seed=seed, scaling_factor=-2)
                    embeddings_view_minus[client] = models[client](text_local)
                    self.zo_perturb_parameters(model=models[client], random_seed=seed, scaling_factor=1)
                    embeddings[client] = models[client](text_local)
                    output_plus = models[-1](embeddings_view_plus)
                    output_minus = models[-1](embeddings_view_minus)
                    loss_1 = criterion(output_plus, labels)
                    loss_2 = criterion(output_minus, labels)
                    mu_multiplier = 2
                    loss_diff = (loss_1 - loss_2) / (mu_multiplier * args.zo_mu)
                if args.dp_epsilon > 0:
                    projected_grad, clipping_rate = self.dpzero_clip(loss_diff, args.dp_clip_threshold)
                    projected_grad += torch.randn(1).item() * args.dpzero_gaussian_std
                else:
                    projected_grad, clipping_rate = self.dpzero_clip(loss_diff, args.dp_clip_threshold)
                if args.no_mezo:
                    self.zo_update(models[client], seed, projected_grad)
                else:
                    self.mezo_update(models[client], seed, projected_grad)
                output = models[-1](embeddings)
                loss = criterion(output, labels).mean()
                loss.backward()
                self.server_optimizer.step()
        




@torch.no_grad()
def evaluate(loader, models, criterion, device):
    for model in models:
        model.eval()
    # total_loss = 0.0
    # total_count = 0

    # text_tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
    # blip_processor = BlipImageProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
    # audio_processor = Wav2Vec2FeatureExtractor.from_pretrained("facebook/wav2vec2-base-960h")

    for batch_idx, (texts, labels) in enumerate(loader):
        if texts is None:
            continue # Skip empty batches
        texts = [t.to(device) for t in texts]       
        labels = labels.to(device)
        embeddings = []
        with torch.no_grad():
            for client in range(args.num_clients):
                embedding = models[client](texts[client])
                embeddings.append(embedding)
            output = models[-1](embeddings)
       
        loss = criterion(output, labels).mean()

        acc1 = accuracy(output, labels, topk=(1,))[0]


    return loss.item(), acc1.item()

def accuracy(output, target, topk=(1,)):
    """Computes the accuracy over the k top predictions for the specified values of k"""
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].reshape(-1).float().sum()
            res.append(correct_k.mul_(100.0 / batch_size))
        return res

def adjust_lr(
    args,
    optimizer=None
):
    """
    Mimics the linear warmup and linear decay from HF Transformers.
    
    :param global_step: Current training step (int).
    :param warmup_steps: Number of steps to linearly warm up the LR.
    :param total_steps: Total training steps (e.g., epochs * steps_per_epoch).
    :param init_lr: The maximum (peak) LR reached after warmup.
    """
    global_step = args.global_step
    warmup_steps = args.warmup_steps
    total_steps = args.total_steps
    if global_step < warmup_steps:
        # Warmup phase: LR from 0 -> init_lr
        new_lr = args.lr * float(global_step) / float(warmup_steps)
    else:
        # Decay phase: LR from init_lr -> 0
        # fraction of (remaining steps) completed
        steps_since_warmup = global_step - warmup_steps
        total_decay_steps = total_steps - warmup_steps
        if steps_since_warmup >= total_decay_steps:
            # if we've exceeded total_steps, LR = 0
            new_lr = 0.0
        else:
            remaining_frac = 1.0 - float(steps_since_warmup) / float(total_decay_steps)
            new_lr = args.lr * remaining_frac
    if optimizer is not None:
        for param_group in optimizer.param_groups:
            param_group['lr'] = new_lr
    return new_lr


###############################################
# (E) MAIN
###############################################
def main():
    # Set number of batches to measure
    NUM_BATCHES_TO_MEASURE = 5
    
    if args.seed is not None:
        random.seed(args.seed)
        torch.manual_seed(args.seed)
    save_folder_terms = [
        f'mode{args.mode}',
        # f'mu{args.zo_mu}',
        # f'b{args.batch_size}',
        # f'cla{args.client_arch}',
        # f'sva{args.server_arch}',
        f'lr{args.lr:g}',
        # f'slr{args.server_lr:g}',
        # f'cthr{args.dp_clip_threshold}',
        f'dpeps{args.dp_epsilon}',
        # f'st{args.server_time}',
        f'seed{args.seed}',
        # f'e{",".join(map(str, args.schedule))},200',
        # f'mom{args.momentum}',    
        ]
    
    save_folder = os.path.join(args.resume, '_'.join(save_folder_terms))

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    args.device = device
    print(f"Using device: {device}")

    # 1. Load data
    preprocessor = AmazonPolarityPreprocessor(args)
    # preprocessor.load_dataset()
    preprocessor.preprocess_and_partition()
    train_loader, val_loader = preprocessor.create_dataloaders()

    models = []
    optimizers = []
    for client in range(args.num_clients):
        client_model = Bert().to(device)
        models.append(client_model)
        client_optimizer = optim.Adam(filter(lambda p: p.requires_grad, client_model.parameters()), 
                                         lr=args.server_lr)
        optimizers.append(client_optimizer)
    # # 2. Dataset + DataLoader
    # train_dataset = MOSEIDataset(train_data)
    # dev_dataset = MOSEIDataset(dev_data)
    # test_dataset = MOSEIDataset(test_data)

    # train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True,collate_fn=custom_collate_fn)
    # dev_loader = DataLoader(dev_dataset, batch_size=args.batch_size, shuffle=False,collate_fn=custom_collate_fn)
    # test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False,collate_fn=custom_collate_fn)

    # Initialize parameters for scheduler
    args.total_steps = args.epochs * len(train_loader)
    args.warmup_steps = int(args.warmup_rate * args.total_steps)
    args.global_step = 0

    num_clients = args.num_clients
    num_classes = 2


    # 4. Late Fusion
    server_model = MLPCombiner(
        input_size=num_clients*768, hidden_size=768, num_classes=num_classes
    ).to(device)
    server_optimizer = optim.Adam(filter(lambda p: p.requires_grad, server_model.parameters()), 
                                         lr=args.server_lr)
    models.append(server_model)
    optimizers.append(server_optimizer)


    # 6. Loss Function (Assume MSE for continuous sentiment)
    criterion = nn.CrossEntropyLoss(reduction='none')
    
    # Compute noise multiplier
    if args.dp_epsilon>0:
        sample_rate = args.batch_size / len(train_loader.dataset)
        try:
            multiplier = get_noise_multiplier(target_epsilon=args.dp_epsilon,
                                                target_delta=args.dp_delta,
                                                epochs=args.epochs,
                                                sample_rate=sample_rate,
                                                accountant='gdp'
                                                )
        except ValueError:
            multiplier = get_noise_multiplier(target_epsilon=args.dp_epsilon,
                                                target_delta=args.dp_delta,
                                                epochs=args.epochs,
                                                sample_rate=sample_rate,
                                                # accountant='gdp'
                                                )
        dpzero_gaussian_std =multiplier * 2 * args.dp_clip_threshold / args.batch_size
        args.dpzero_gaussian_std = dpzero_gaussian_std
    # 7. Train/Eval Loop




    # Get total FLOPs
    zofo_measure(train_loader, models, optimizers, criterion, args)
    # with profile(activities=[ProfilerActivity.CUDA, ProfilerActivity.CPU],
    #          record_shapes=False, profile_memory=False, with_flops=True,) as prof:
        # if args.mode == "dpzv":
        #     dpzv_trainer = DPZV_Trainer(args, optimizers[-1])
        

        # elif args.mode == "zofo":
        #     zofo_measure(train_loader, models, optimizers, criterion, args)
            
                
        # elif args.mode == "dpzv":
        #     dpzv_trainer.measure(train_loader, models, criterion, args)
        
    
    # flops = sum([evt.flops for evt in prof.events()])
    # cuda_time = sum([evt.cuda_time for evt in prof.events()])
    # cpu_time = sum([evt.cpu_time for evt in prof.events()])

    # print(f"FLOPs:{flops}, MODE:{args.mode}, CUDA Time:{cuda_time/1e9:4f}, CPU Time:{cpu_time/1e9:4f}")

if __name__ == "__main__":
    main()
# Define a simple model
