#!/usr/bin/env python
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
# Copyright (c) 2020 Tongzhou Wang
import argparse
import builtins
import os
import random
import shutil
import time
import socket
import warnings
import pickle
import copy
import numpy as np

from opacus.accountants.utils import get_noise_multiplier


import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.distributed as dist
import torch.nn.init as init
import torch.optim
import torch.multiprocessing as mp
import torch.utils.data
import torch.utils.data.distributed
import torchvision
import torchvision.transforms as transforms
import torchvision.datasets as datasets

import utils
import moco.loader
from server_model import serverModel
from custom_dataset import MultiViewDataSet

class SplitImageTransform:
    """Take two random crops of one image as the query and key."""

    def __init__(self, transforms):
        self.transforms = transforms

    def __call__(self, x):
        out = []
        for transform in self.transforms:
            out.append(transform(x))
        return out

model_names = sorted(name for name in torchvision.models.__dict__
    if name.islower() and not name.startswith("__")
    and callable(torchvision.models.__dict__[name]))

parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
parser.add_argument('data', metavar='DIR',
                    help='path to dataset')
# parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet50',
#                     choices=model_names,
#                     help='model architecture: ' +
#                         ' | '.join(model_names) +
#                         ' (default: resnet50)')
parser.add_argument('-a', '--arch', metavar='CL_ARCH', default='resnet18',
                    choices=model_names,
                    help='client architecture: ' +
                        ' | '.join(model_names) +
                        ' (default: resnet18)')
parser.add_argument('-sv_a', '--server_arch', metavar='SV_ARCH', default='mlp',
                    choices=['mlp','linear'],
                    help='server architecture: ' +
                        ' | '.join(model_names) +
                        ' (default: linear)')
parser.add_argument('-j', '--workers', default=32, type=int, metavar='N',
                    help='number of data loading workers (default: 32)')
parser.add_argument('--epochs', default=100, type=int, metavar='N',
                    help='number of total epochs to run')
parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
                    help='manual epoch number (useful on restarts)')
parser.add_argument('-b', '--batch-size', default=256, type=int,
                    metavar='N',
                    help='mini-batch size (default: 256), this is the total '
                         'batch size of all GPUs on the current node when '
                         'using Data Parallel or Distributed Data Parallel')
parser.add_argument('--lr', '--learning-rate', default=0.03, type=float,
                    metavar='LR', help='initial learning rate', dest='lr')
parser.add_argument('--server_lr', '--server-learning-rate', default=0.005, type=float,
                    metavar='SVLR', help='learning rate for server', dest='server_lr')
parser.add_argument('--lr_un', '--unsupervised-learning-rate', default=10.0, type=float,
                    metavar='LR', help='initial learning rate for final linear layer', dest='lr_un')
#parser.add_argument('--schedule', default=[60, 80], nargs='*', type=int,
#                    help='learning rate schedule (when to drop lr by a ratio)')
parser.add_argument('--schedule', default=[], nargs='*', type=int,
                    help='learning rate schedule (when to drop lr by a ratio)')
parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
                    help='momentum')
parser.add_argument('--wd', '--weight-decay', default=0., type=float,
                    metavar='W', help='weight decay (default: 0.)',
                    dest='weight_decay')
parser.add_argument('-p', '--print-freq', default=50, type=int,
                    metavar='N', help='print frequency (default: 50)')
parser.add_argument('--resume', default='', type=str, metavar='PATH',
                    help='path to latest checkpoint (default: none)')
parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
                    help='evaluate model on validation set')
parser.add_argument('--world-size', default=-1, type=int,
                    help='number of nodes for distributed training')
parser.add_argument('--rank', default=-1, type=int,
                    help='node rank for distributed training')
parser.add_argument('--dist-url', default='tcp://localhost:10001', type=str,
                    help='url used to set up distributed training')
parser.add_argument('--dist-backend', default='nccl', type=str,
                    help='distributed backend')
parser.add_argument('--seed', default=None, type=int,
                    help='seed for initializing training. ')
parser.add_argument('--gpus', default=None, nargs='+', type=int,
                    help='GPU id(s) to use. Default is all visible GPUs.')
parser.add_argument('--multiprocessing-distributed', action='store_true',
                    help='Use multi-processing distributed training to launch '
                         'N processes per node, which has N GPUs. This is the '
                         'fastest way to use PyTorch for either single node or '
                         'multi node data parallel training')

parser.add_argument('--num_clients', default=12, type=int,
                    help='number of clients')
parser.add_argument('--mode', default="flex", type=str,
        help='VFL algorithm to use: flex, sync, vafl, or pbcd')
parser.add_argument('--labeled_frac', default=1.0, type=float,
                    help='fraction of training data that is labeled')
#parser.add_argument('--local_epochs', default=1, type=float,
#                    help='Number of local iterations')
parser.add_argument('--server_time', default=1.0, type=float,
                    help='How long roundtrip server communication takes')

#Add for DPZV

# parser.add_argument('--no_dp', action='store_true', help='Do not use DP')

parser.add_argument('--warmup_rate', default=0.1, type=float,
                    help='Warmup rate for linear LR warmup')

parser.add_argument('--zo_mu', default=1e-3, type=float,
                    help='the scale for the perturbation of parameters in zero order update')
parser.add_argument('--dp_clip_threshold', default=10., type=float,
                    help='the clipping threshold of the gradient for achieving DP')
parser.add_argument('--dp_epsilon', default=6., type=float,
                    help='DP level parameter epsilon')
parser.add_argument('--dp_delta', default=1e-5, type=float,
                    help='DP level parameter delta')
parser.add_argument('--grad_estimate_method', default='central', type=str,
                    help='The method for estimating zeroth-order gradient: central or forward')
parser.add_argument('--min_lr', default=1e-7, type=float,
                    help='Minimum learning rate')
parser.add_argument('--patience', default=3, type=int,
                    help='Patience for learning rate scheduler')


# Add for ZOFO
parser.add_argument('--num_purt', default=5, type=int,
                    help='number of purturbations for ZOFO')

parser.add_argument('--no_mezo', action='store_true',
                    help='use mezo for ZOFO')

args = parser.parse_args()

def clip_tensor(tensor, max_norm):
    """
    Clips a tensor to have a norm at most `max_norm`.
    
    Args:
        tensor (torch.Tensor): The input tensor.
        max_norm (float): The maximum allowed norm.
    
    Returns:
        torch.Tensor: The clipped tensor.
    """
    norm = torch.norm(tensor, p=2)  # Compute L2 norm
    scale = min(1, max_norm / (norm + 1e-6))  # Compute scaling factor (avoid division by zero)
    return tensor * scale

def embedding_dp(embedding, args):
    embedding = clip_tensor(embedding, args.dp_clip_threshold)
    noise = torch.normal(
        mean=0,
        std=args.dpzero_gaussian_std,
        size=embedding.size(),
        device=embedding.device,
        dtype=embedding.dtype,
    )
    embedding = embedding + noise
    return embedding

def main():

    if args.seed is not None:
        random.seed(args.seed)
        torch.manual_seed(args.seed)

    # save_folder_terms = [
    #     f'MVCNN',
    #     f'b{args.batch_size}',
    #     f'lr{args.lr:g}',
    #     f'mode{args.mode}',
    #     f'st{args.server_time}',
    #     f'seed{args.seed}',
    #     f'e{",".join(map(str, args.schedule))},200',
    # ]
    save_folder_terms = [
            f'ds{args.data}',
            f'mu{args.zo_mu}',
            f'b{args.batch_size}',
            f'cla{args.arch}',
            f'sva{args.server_arch}',
            f'lr{args.lr:g}',
            # f'slr{args.server_lr:g}',
            f'cthr{args.dp_clip_threshold}',
            f'mode{args.mode}',
            # f'st{args.server_time}',
            # f'seed{args.seed}',
            f'dp_eps{args.dp_epsilon}',
            # f'e{",".join(map(str, args.schedule))},200',
            f'mom{args.momentum}',
            # f'no_dp{args.no_dp}',
            f'warmup_rate{args.warmup_rate}'
        ]

    args.save_folder = os.path.join('results', '_'.join(save_folder_terms))
    os.makedirs(args.save_folder, exist_ok=True)
    print(f"save_folder: '{args.save_folder}'")

    if args.dist_url == "env://" and args.world_size == -1:
        args.world_size = int(os.environ["WORLD_SIZE"])

    if args.gpus is None:
        args.gpus = list(range(torch.cuda.device_count()))

    if args.multiprocessing_distributed and len(args.gpus) == 1:
        warnings.warn('You have chosen to use multiprocessing distributed '
                      'training. But only one GPU is available on this node. '
                      'The training will start within the launching process '
                      'instead to minimize process start overhead.')
        args.multiprocessing_distributed = False

    args.distributed = args.world_size > 1 or args.multiprocessing_distributed

    if args.multiprocessing_distributed:
        # Assuming we have len(args.gpus) processes per node, we need to adjust
        # the total world_size accordingly
        args.world_size = len(args.gpus) * args.world_size
        # Use torch.multiprocessing.spawn to launch distributed processes: the
        # main_worker process function
        mp.spawn(main_worker, nprocs=len(args.gpus), args=(args,))
    else:
        # Simply call main_worker function
        main_worker(0, args)


def main_worker(index, args):
    # We will do a bunch of `setattr`s such that
    #
    # args.rank               the global rank of this process in distributed training
    # args.index              the process index to this node
    # args.gpus               the GPU ids for this node
    # args.gpu                the default GPU id for this node
    # args.batch_size         the batch size for this process
    # args.workers            the data loader workers for this process
    # args.seed               if not None, the seed for this specific process, computed as `args.seed + args.rank`

    args.index = index
    args.gpu = args.gpus[index]
    assert args.gpu is not None
    torch.cuda.set_device(args.gpu)

    # suppress printing for all but one device per node
    if args.multiprocessing_distributed and args.index != 0:
        def print_pass(*args, **kwargs):
            pass
        builtins.print = print_pass

    print(f"Use GPU(s): {args.gpus} for training on '{socket.gethostname()}'")

    # init distributed training if needed
    if args.distributed:
        if args.dist_url == "env://" and args.rank == -1:
            args.rank = int(os.environ["RANK"])
        if args.multiprocessing_distributed:
            ngpus_per_node = len(args.gpus)
            # For multiprocessing distributed training, rank needs to be the
            # global rank among all the processes
            args.rank = args.rank * ngpus_per_node + index
            # When using a single GPU per process and per
            # DistributedDataParallel, we need to divide the batch size and data
            # loader workers based on the total number of GPUs we have.
            assert args.batch_size % ngpus_per_node == 0
            args.batch_size = int(args.batch_size / ngpus_per_node)
            args.workers = int((args.workers + ngpus_per_node - 1) / ngpus_per_node)
        dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
                                world_size=args.world_size, rank=args.rank)
    else:
        args.rank = 0

    if args.seed is not None:
        args.seed = args.seed + args.rank
        random.seed(args.seed)
        torch.manual_seed(args.seed)

    cudnn.deterministic = True
    cudnn.benchmark = True

    # build data loaders before initializing model, since we need num_classes for the latter
    train_loader, val_loader, classes = create_data_loaders(args)

    # Create models
    models = []
    optimizers = []
    for m in range(args.num_clients+1):
        # create model
        if m != args.num_clients:
            print(f"=> creating model '{args.arch}' with {len(classes)} classes")
            model = torchvision.models.__dict__[args.arch](num_classes=128)

        else:
            print(f"=> creating server model")
            model = serverModel(args, num_clients=args.num_clients, num_classes=len(classes), dim=128)
            print("Number of classes:",len(classes))
            # init the fc layer
            model.apply(init_weights)
            # model.fc.weight.data.normal_(mean=0.0, std=0.01)
            # model.fc.bias.data.zero_()

        model.cuda(args.gpu)
        if args.distributed:
            # For multiprocessing distributed, DistributedDataParallel constructor
            # should always set the single device scope, otherwise,
            # DistributedDataParallel will use all available devices.
            if args.multiprocessing_distributed:
                model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
            else:
                model = torch.nn.parallel.DistributedDataParallel(model, device_ids=args.gpus)
        else:
            # DataParallel will divide and allocate batch_size to all available GPUs
            if args.arch.startswith('alexnet') or args.arch.startswith('vgg'):
                model.features = torch.nn.DataParallel(model.features, device_ids=args.gpus)
            else:
                model = torch.nn.DataParallel(model, device_ids=args.gpus)

        # define loss function (criterion) and optimizer
        if args.mode == "dpzv" or args.mode == "zoovfl":
            criterion = nn.CrossEntropyLoss(reduction='none').cuda(args.gpu)
        else:
            criterion = nn.CrossEntropyLoss().cuda(args.gpu)
        if args.mode != "dpzv" and args.mode != "zoovfl":
            if m != args.num_clients:
                optimizer = torch.optim.SGD(model.parameters(), args.lr,
                                        momentum=args.momentum,
                                        weight_decay=args.weight_decay)
            else:
                #define server optimizer
                optimizer = torch.optim.SGD(model.parameters(), args.server_lr,
                                        momentum=args.momentum,
                                        weight_decay=args.weight_decay)
            optimizers.append(optimizer)
        # optimizer = torch.optim.SGD(model.parameters(), args.lr, momentum=args.momentum)

        models.append(model)
        # optimizers.append(optimizer)

    best_acc1 = 0
    args.start_epoch = 0
    train_loss = []
    train_acc1 = []
    train_acc5 = []
    test_loss = []
    test_acc1 = []
    test_acc5 = []

    # optionally resume from a checkpoint
    for client in range(args.num_clients+1):
        save_filename = os.path.join(args.save_folder, f"client{client}.pth.tar")
        if os.path.isfile(save_filename):
            print("=> loading checkpoint '{}'".format(save_filename))
            # Map model to be loaded to specified single gpu.
            checkpoint = torch.load(save_filename, map_location=torch.device('cuda', args.gpu))
            args.start_epoch = checkpoint['epoch']
            best_acc1 = checkpoint['best_acc1']
            if isinstance(best_acc1, torch.Tensor):
                # best_acc1 may be from a checkpoint from a different GPU
                best_acc1 = best_acc1.to(args.gpu)
            models[client].load_state_dict(checkpoint['state_dict'])
            if args.mode != "dpzv" and args.mode != "zoovfl":
                    optimizers[client].load_state_dict(checkpoint['optimizer'])
            # optimizers[client].load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(args.resume, checkpoint['epoch']))
            train_loss = pickle.load(open(os.path.join(args.save_folder,'train_loss.pkl'), 'rb'))
            train_acc1 = pickle.load(open(os.path.join(args.save_folder,'train_acc1.pkl'), 'rb'))
            train_acc5 = pickle.load(open(os.path.join(args.save_folder,'train_acc5.pkl'), 'rb'))
            test_loss = pickle.load(open(os.path.join(args.save_folder,'test_loss.pkl'), 'rb'))
            test_acc1 = pickle.load(open(os.path.join(args.save_folder,'test_acc1.pkl'), 'rb'))
            test_acc5 = pickle.load(open(os.path.join(args.save_folder,'test_acc5.pkl'), 'rb'))

    if args.start_epoch == 0:
        loss, acc1, acc5 = validate(val_loader, models, criterion, args)
        test_loss.append(loss)
        test_acc1.append(acc1)
        test_acc5.append(acc5)
    
    if args.dp_epsilon>0:
        sample_rate = args.batch_size / len(train_loader.dataset)
        try:
            multiplier = get_noise_multiplier(target_epsilon=args.dp_epsilon,
                                                target_delta=args.dp_delta,
                                                epochs=args.epochs,
                                                sample_rate=sample_rate,
                                                accountant='gdp'
                                                )
        except ValueError:
            multiplier = get_noise_multiplier(target_epsilon=args.dp_epsilon,
                                                target_delta=args.dp_delta,
                                                epochs=args.epochs,
                                                sample_rate=sample_rate,
                                                # accountant='gdp'
                                                )
        dpzero_gaussian_std =multiplier * 2 * args.dp_clip_threshold / args.batch_size
        args.dpzero_gaussian_std = dpzero_gaussian_std

    args.total_steps = args.epochs * len(train_loader)
    args.warmup_steps = int(args.warmup_rate * args.total_steps)
    args.global_step = 0
    # Main training loop
    if args.mode == "dpzv":
        dpzv_trainer = DPZV_trainer(train_loader, args, use_mezo=not args.no_mezo)
    elif args.mode == "zoovfl":
        zoovfl_trainer = ZOOVFL_Trainer(train_loader, args)
    for epoch in range(args.start_epoch, args.epochs):
        if args.distributed:
            train_loader.sampler.set_epoch(epoch)
        
        # for m in range(args.num_clients+1):
            # lr = args.lr
            # adjust_learning_rate(optimizers[m], epoch, lr)

        # train for one epoch
        if args.mode == "dpzv":
            dpzv_trainer.train(train_loader, models, criterion, epoch, args)
        elif args.mode == "zoovfl":
            zoovfl_trainer.train(train_loader, models, criterion, epoch, args)
        elif args.mode == "zofo":
            loss, acc1, acc5 = zofo_train(train_loader, models, criterion, optimizers, epoch, args)
        elif args.mode in ["vafl", "vafl2", "vafl3"]:
            # Load embeddings
            embeddings = []
            for i, (images, _) in enumerate(train_loader):
                embeddings.append([])
                for client in range(args.num_clients):
                    images[client] = images[client].cuda(args.gpu, non_blocking=True)

                # compute inital embeddings 
                for client in range(args.num_clients):
                    image_local = images[client]
                    with torch.no_grad():
                        embedding = models[client](image_local)
                        embedding = embedding_dp(embedding, args)
                        embeddings[i].append(embedding)
            loss, acc1, acc5 = train_vafl(train_loader, models, criterion, optimizers, epoch, args, embeddings)
        else:
            loss, acc1, acc5 = train(train_loader, models, criterion, optimizers, epoch, args)
        train_loss.append(loss)
        train_acc1.append(acc1)
        train_acc5.append(acc5)

        # evaluate on validation set
        loss, acc1, acc5 = validate(val_loader, models, criterion, args)
        test_loss.append(loss)
        test_acc1.append(acc1)
        test_acc5.append(acc5)

        # remember best acc@1 and save checkpoint
        is_best = acc1 > best_acc1
        best_acc1 = max(acc1, best_acc1)

        if is_best:
            print(f"New best Acc1 {best_acc1:.4f}")

        if (args.distributed and args.rank == 0) or (args.index == 0):
            pickle.dump(train_loss, open(os.path.join(args.save_folder,'train_loss.pkl'), 'wb'))
            pickle.dump(train_acc1, open(os.path.join(args.save_folder,'train_acc1.pkl'), 'wb'))
            pickle.dump(train_acc5, open(os.path.join(args.save_folder,'train_acc5.pkl'), 'wb'))
            pickle.dump(test_loss, open(os.path.join(args.save_folder,'test_loss.pkl'), 'wb'))
            pickle.dump(test_acc1, open(os.path.join(args.save_folder,'test_acc1.pkl'), 'wb'))
            pickle.dump(test_acc5, open(os.path.join(args.save_folder,'test_acc5.pkl'), 'wb'))

            for client in range(args.num_clients+1):
                # Reset optimizers
                if args.mode != "dpzv" and args.mode != "zoovfl":
                    if m != args.num_clients:
                        optimizers[client] = torch.optim.SGD(models[client].parameters(), args.lr, 
                        momentum=args.momentum)
                    else:
                        #reset server optimizer
                        optimizers[client] = torch.optim.SGD(models[client].parameters(), args.server_lr, momentum=args.momentum)

                # optimizers[client] = torch.optim.SGD(models[client].parameters(), args.lr, momentum=args.momentum)

                # Save client models
                save_filename = os.path.join(args.save_folder, f"client{client}.pth.tar")
                if args.mode != "dpzv" and args.mode != "zoovfl":
                    save_checkpoint({
                        'epoch': epoch + 1,
                        # 'client_arch': args.arch,
                        # 'server_arch': args.server_arch,
                        'state_dict': models[client].state_dict(),
                        'best_acc1': best_acc1,
                        'acc1': acc1,
                        'acc5': acc5,
                        'optimizer' : optimizers[client].state_dict(),
                    }, is_best, save_filename)
                else:
                    save_checkpoint({
                        'epoch': epoch + 1,
                        # 'arch': args.arch,
                        'state_dict': models[client].state_dict(),
                        'best_acc1': best_acc1,
                        'acc1': acc1,
                        'acc5': acc5,
                        # 'optimizer' : optimizers[client].state_dict(),
                    }, is_best, save_filename)
                print(f"saved to '{save_filename}'")
                #if epoch == args.start_epoch:
                #    sanity_check(model.state_dict(), args.pretrained)

def init_weights(m):
    if isinstance(m, nn.Linear):
        init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
        if m.bias is not None:
            init.constant_(m.bias, 0)

def create_data_loaders(args):
    # Data loading code
    traindir = args.data
    valdir = args.data
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    train_dataset = MultiViewDataSet(traindir, 'train',
            transform=transforms.Compose([
                    transforms.RandomResizedCrop(224),
                    transforms.RandomHorizontalFlip(),
                    transforms.ToTensor(),
                    normalize,
            ]))

    if args.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset, shuffle=True)
    else:
        train_sampler = None

    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None),
        num_workers=args.workers, pin_memory=True, sampler=train_sampler)

    val_loader = torch.utils.data.DataLoader(
        MultiViewDataSet(traindir, 'test',
                transform=transforms.Compose([
                        transforms.Resize(256),
                        transforms.CenterCrop(224),
                        transforms.ToTensor(),
                        normalize,
            ])),
        batch_size=args.batch_size, shuffle=False,
        num_workers=args.workers, pin_memory=True)

    return train_loader, val_loader, train_dataset.classes 

def train_vafl(train_loader, models, criterion, optimizers, epoch, args, embeddings):
    # Train asynchronous VFL

    batch_time = utils.AverageMeter('Time', '6.3f')
    data_time = utils.AverageMeter('Data', '6.3f')
    losses = utils.AverageMeter('Loss', '.4e')
    top1 = utils.AverageMeter('Acc1', '6.2f')
    top5 = utils.AverageMeter('Acc5', '6.2f')
    progress = utils.ProgressMeter(
        len(train_loader),
        [batch_time, data_time, losses, utils.ProgressMeter.BR, top1, top5],
        prefix="Epoch: [{}]".format(epoch))

    wait = []
    for client in range(args.num_clients+1):
        models[client].train()
        wait.append(0.0)

    end = time.time()
    for i, (images, target) in enumerate(train_loader):
        # measure data loading time
        data_time.update(time.time() - end)

        for client in range(args.num_clients):
            images[client] = images[client].cuda(args.gpu, non_blocking=True)
        target = target.cuda(args.gpu, non_blocking=True)
        # Train clients and server for Q rounds
        for client in range(args.num_clients):
            image_local = images[client]
            # if wait[client] < 1.0:
            optimizers[client] = torch.optim.SGD(models[client].parameters(), args.lr, momentum=args.momentum)
            optimizers[client].zero_grad()
            optimizers[-1].zero_grad()

            embedding_view = [client_view.detach().clone() for client_view in embeddings[i]]
            embedding_view[client] = models[client](image_local)
            embedding_view[client] = embedding_dp(embedding_view[client], args)
            output = models[-1](torch.cat(embedding_view,axis=1))

            # compute gradient and do SGD step
            loss = criterion(output, target)
            loss.backward()

            with torch.no_grad():
                embeddings[i][client] = embedding_dp(models[client](image_local), args)

            optimizers[client].step()
            optimizers[-1].step()

            # measure accuracy and record loss
            acc1, acc5 = accuracy(output, target, topk=(1, 5))
            losses.update(loss, images[0].size(0))
            top1.update(acc1, images[0].size(0))
            top5.update(acc5, images[0].size(0))

            if args.mode == "vafl":
                # if client < 10:
                #     wait[client] += (20/(client*2+1))+args.server_time 
                # else:
                #     wait[client] += 1.0+args.server_time
                pass
            elif args.mode == "vafl2":
                cpus = [0.3638092, 0.17014983, 0.14333789, 0.33265191, 0.17415424, 0.06804619, 0.14446286, 0.29251583, 0.2207424, 0.12800219, 0.27062089, 0.13972783, 0.30291598]
                epochs = (20*(1-np.array(cpus))).astype(int)
                wait[client] += (20/epochs[client]) + args.server_time
            elif args.mode == "vafl3":
                if client < 3:
                    wait[client] += (20/5)+args.server_time 
                elif client < 6:
                    wait[client] += (20/10)+args.server_time 
                elif client < 9:
                    wait[client] += (20/15)+args.server_time 
                else:
                    wait[client] += 1.0+args.server_time
        # else:
        #     wait[client] -= 1.0

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            progress.display(i)

    loss, acc1, acc5 = losses.avg, top1.avg, top5.avg
    print(f'Training * Loss {loss:.5f} Acc1 {acc1:.3f} Acc5 {acc5:.3f}')
    return loss, acc1, acc5

class DPZV_trainer:
    def __init__(self, train_loader, args, use_mezo=True):
        self.use_mezo = use_mezo
        self.args = args
        # self.server_optimizer = None
        # self.agg_optimizer = None
        self.optimizer  = None
        sample_rate = args.batch_size / len(train_loader.dataset)
        if args.dp_epsilon>0:
            try:
                multiplier = get_noise_multiplier(target_epsilon=args.dp_epsilon,
                                                target_delta=args.dp_delta,
                                                epochs=args.epochs,
                                                sample_rate=sample_rate,
                                                accountant='gdp'
                                                )
            except ValueError:
                multiplier = get_noise_multiplier(target_epsilon=args.dp_epsilon,
                                                target_delta=args.dp_delta,
                                                epochs=args.epochs,
                                                sample_rate=sample_rate,)
            self.dpzero_gaussian_std = multiplier * 2 * self.args.dp_clip_threshold/args.batch_size
        # print(self.dpzero_gaussian_std)
        self.lr = args.lr
        self.random_seeds = [[] for _ in range(args.num_clients)]
        self.history_diff = [[] for _ in range(args.num_clients)]   
        self.grad = None 
        # self.total_params = [0 for _ in range(args.num_clients)]
        # self.adam_v = 0    
        self.patience = args.patience
        self.threshold = 0.1
        self.min_lr = args.min_lr
        self.best_metric = None
        self.num_bad_epochs = 0

    def _get_learning_rate(self):
        return self.lr
    
    def zo_perturb_parameters(self, model: nn.Module, random_seed: int, scaling_factor=1):
        args = self.args
        torch.manual_seed(random_seed)
        with torch.no_grad():
            for name, param in model.named_parameters():
                z = torch.normal(mean=0, std=1, size=param.data.size(), device=param.data.device, dtype=param.data.dtype)
                param.data = param.data + scaling_factor * z * args.zo_mu
    
    def compute_l2_norm(self, model):
        l2_norm = 0
        for name, param in model.named_parameters():
            l2_norm += torch.norm(param.data)
        return l2_norm

    def dpzero_clip(self, loss_diff, C=1.):
        abs_loss_diff = torch.abs(loss_diff)
        clipped_mask = abs_loss_diff > C
        clipping_rate = clipped_mask.float().mean().item()
        tmp = torch.min(torch.ones_like(loss_diff), torch.div(C * torch.ones_like(loss_diff), abs_loss_diff))
        return torch.mul(tmp, loss_diff).mean(), clipping_rate

    def zo_forward(self, model, inputs):
        """
        Get (no gradient) loss from the model. Dropout is turned off too.
        """
        model.eval()

        with torch.inference_mode():
            output = model(inputs)
        return output
    
    def zo_update(self, model, client):
            """
            Update the parameters with the estimated gradients.
            """
            args = self.args
            seed_list = self.random_seeds[client]
            history_diff = self.history_diff[client]
            total_iter = len(seed_list)
            # adam_v = self.adam_v
            # Reset the random seed for sampling zs
            with torch.no_grad():
                torch.manual_seed(seed_list[-1])
                projected_grad = history_diff[-1]    
                grad = {}
                for name, param in model.named_parameters():
                    # Resample z
                    z = torch.normal(mean=0, std=1, size=param.data.size(), device=param.data.device, dtype=param.data.dtype)
                    if "bias" not in name and "layer_norm" not in name and "layernorm" not in name:
                        grad[name] = projected_grad * z + args.weight_decay * param.data
                    else:
                        grad[name] = projected_grad * z
                    # self.grads[name] = grad.clone()
                for name, param in model.named_parameters():
                    param.data = param.data - self._get_learning_rate() * grad[name]

    def mezo_update(self, model, client):
        """
        Update the parameters with the estimated gradients.
        """
        args = self.args
        seed_list = self.random_seeds[client]
        history_diff = self.history_diff[client]
        total_iter = len(seed_list)
        # adam_v = self.adam_v
        # Reset the random seed for sampling zs
        with torch.no_grad():
            if args.momentum > 0:
                torch.manual_seed(seed_list[-1])
                projected_grad = history_diff[-1] 
                if self.grad is None:
                    self.grad={}
                    for name, param in model.named_parameters():
                        z = torch.normal(mean=0, std=1, size=param.data.size(), device=param.data.device, dtype=param.data.dtype)
                        if "bias" not in name and "layer_norm" not in name and "layernorm" not in name:
                            grad = projected_grad * z + args.weight_decay * param.data
                        else:
                            grad = projected_grad * z
                        self.grad[name] = grad.clone()
                        param.data = param.data - self._get_learning_rate() * grad
                else:
                    for name, param in model.named_parameters():
                        z = torch.normal(mean=0, std=1, size=param.data.size(), device=param.data.device, dtype=param.data.dtype)
                        if "bias" not in name and "layer_norm" not in name and "layernorm" not in name:
                            grad = self.grad[name] * args.momentum+ projected_grad * z * (1-args.momentum) + args.weight_decay * param.data
                        else:
                            grad = self.grad[name] * args.momentum+ projected_grad * z * (1-args.momentum) 
                        self.grad[name] = grad.clone()
                        param.data = param.data - self._get_learning_rate() * grad
                # adam_v = args.beta_2 * adam_v
                # torch.manual_seed(seed_list[-1])
                # projected_grad = history_diff[-1] 
                # for name, param in model.named_parameters():
                #     # Resample z
                #     z = torch.normal(mean=0, std=1, size=param.data.size(), device=param.data.device, dtype=param.data.dtype)
                #     adam_v = adam_v+z
                for name, param in model.named_parameters():
                    # adam_m = 0
                    grad = 0
                #     for iter in range(total_iter):
                #         torch.manual_seed(seed_list[iter])
                #         projected_grad = history_diff[iter] 
                #         # Resample z
                #         z = torch.normal(mean=0, std=1, size=param.data.size(), device=param.data.device, dtype=param.data.dtype)
                #         if iter==0:
                #             grad += args.momentum**(total_iter-iter-1) * projected_grad * z
                #         else:
                #             grad += (1-args.momentum)*args.momentum**(total_iter-iter-1) * projected_grad * z
                #     # adam_grad = adam_m/np.sqrt(adam_v+1e-8) 
                #     if "bias" not in name and "layer_norm" not in name and "layernorm" not in name:
                #         grad = grad + args.weight_decay * param.data
                #     else:
                #         grad = grad
                #     param.data = param.data - self._get_learning_rate() * grad
            else:
                torch.manual_seed(seed_list[-1])
                projected_grad = history_diff[-1]    
                for name, param in model.named_parameters():
                    # Resample z
                    z = torch.normal(mean=0, std=1, size=param.data.size(), device=param.data.device, dtype=param.data.dtype)
                    if "bias" not in name and "layer_norm" not in name and "layernorm" not in name:
                        grad = projected_grad * z + args.weight_decay * param.data
                    else:
                        grad = projected_grad * z
                    # self.grads[name] = grad.clone()
                    param.data = param.data - self._get_learning_rate() * grad
                        
        # self.lr_scheduler.step()
    # def zo_adjust_lr(self, current_metric):
    #     """Decay the learning rate based on schedule"""
    #     if self.best_metric is None or current_metric > self.best_metric + self.threshold:
    #         self.best_metric = current_metric
    #         self.num_bad_epochs = 0
    #     else:
    #         self.num_bad_epochs += 1

    #     if self.num_bad_epochs >= self.patience:
    #         new_lr = max(self.lr * 0.5, self.min_lr)
    #         self.lr= new_lr
    #         self.num_bad_epochs = 0  # Reset counter after reducing learning rate
    #         print(f"Learning rate adjusted to {self.lr}")

    def zo_adjust_lr(self):
        new_lr = adjust_lr(self.args)
        self.lr = new_lr
    
    def train(self, train_loader, models, criterion, epoch, args):
        # Count trainable parameters
        # for client in range(args.num_clients):
        #     total_params = 0
        #     for name, parameter in models[client].named_parameters():
        #         if not parameter.requires_grad:
        #             continue
        #         params = parameter.numel()
        #         total_params += params
        #     self.total_params[client] = total_params
        # args=self.args
        batch_time = utils.AverageMeter('Time', '6.3f')
        data_time = utils.AverageMeter('Data', '6.3f')
        losses = utils.AverageMeter('Loss', '.4e')
        top1 = utils.AverageMeter('Acc1', '6.2f')
        top5 = utils.AverageMeter('Acc5', '6.2f')
        cliprate = utils.AverageMeter('Clip Rate', '.2%')
        progress = utils.ProgressMeter(
            len(train_loader),
            [batch_time, data_time, losses, utils.ProgressMeter.BR, top1, top5, cliprate],
            prefix="Epoch: [{}]".format(epoch))
        for client in range(args.num_clients+1):
            models[client].train()

        end = time.time()
        for i, (images, target) in enumerate(train_loader):
            # measure data loading time
            data_time.update(time.time() - end)
            args.global_step += 1
            self.zo_adjust_lr()

            for client in range(args.num_clients):
                images[client] = images[client].cuda(args.gpu, non_blocking=True)
            target = target.cuda(args.gpu, non_blocking=True)

            # self.random_seed = np.random.randint(1000000000)

            # compute inital embeddings 
            embeddings = []
            for client in range(args.num_clients):
                image_local = images[client]
                with torch.no_grad():
                    embeddings.append(models[client](image_local))

            server_model = models[-1]
            if self.optimizer is None:
                self.optimizer = torch.optim.SGD(server_model.parameters(), args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)
            self.optimizer.zero_grad()
            
            for client in range(args.num_clients):
                

                image_local = images[client]
                embeddings_view_plus = embeddings.copy()
                embeddings_view_minus = embeddings.copy()
                
                seed = np.random.randint(1000000000)
                self.random_seeds[client].append(seed)
                if len(self.random_seeds[client])>100:
                    self.random_seeds[client].pop(0)

                with torch.no_grad():
                    # First function evaluation
                    self.zo_perturb_parameters(model=models[client],random_seed=seed, scaling_factor=1)
                    embeddings_view_plus[client] = models[client](image_local)
                    if args.grad_estimate_method == 'central':
                        # Second function evaluation
                        self.zo_perturb_parameters(model=models[client],random_seed=seed, scaling_factor=-2)
                        embeddings_view_minus[client] = models[client](image_local)
                        # Reset model back to its parameters at start of step
                        self.zo_perturb_parameters(model=models[client],random_seed=seed, scaling_factor=1)
                    elif args.grad_estimate_method == 'forward':
                        # Reset model back to its parameters at start of step
                        self.zo_perturb_parameters(model=models[client],random_seed=seed, scaling_factor=-1)
                        embeddings_view_minus[client] = models[client](image_local)
                    
                with torch.no_grad():
                    embeddings[client] = models[client](image_local)

                    output_plus=models[-1](torch.cat(embeddings_view_plus,axis=1))
                    output_minus=models[-1](torch.cat(embeddings_view_minus,axis=1))
                    # output_plus=models[-1](embeddings_view_plus)
                    # output_minus=models[-1](embeddings_view_minus)

                    # compute gradient
                    loss_1 = criterion(output_plus, target)
                    loss_2 = criterion(output_minus, target)
                    mu_multiplier = 2 if args.grad_estimate_method == 'central' else 1
                    loss_diff = ((loss_1 - loss_2) / (mu_multiplier*args.zo_mu))
                    if args.dp_epsilon>0:
                        projected_grad, clipping_rate= self.dpzero_clip(loss_diff, args.dp_clip_threshold)
                        projected_grad += torch.randn(1).item() * self.dpzero_gaussian_std
                    else:
                        projected_grad, clipping_rate= self.dpzero_clip(loss_diff, args.dp_clip_threshold)
            
                    self.history_diff[client].append(projected_grad)
                    if len(self.history_diff[client])>100:
                        self.history_diff[client].pop(0)
                    if self.use_mezo:
                        self.mezo_update(models[client], client)
                    else:
                        self.zo_update(models[client], client)

                # measure accuracy and record loss
                acc1, acc5 = accuracy(output_plus, target, topk=(1, 5))
                losses.update(loss_1.mean(), images[0].size(0))
                top1.update(acc1, images[0].size(0))
                top5.update(acc5, images[0].size(0))
                cliprate.update(clipping_rate, images[0].size(0))
                

                # Train the server
                
                output = server_model(torch.cat(embeddings,axis=1))
                # output = server_model(embeddings)
                loss = criterion(output, target).mean()
            
                loss.backward()
                self.optimizer.step()

            # measure accuracy and record loss
            acc1, acc5 = accuracy(output, target, topk=(1, 5))
            losses.update(loss, images[0].size(0))
            top1.update(acc1, images[0].size(0))
            top5.update(acc5, images[0].size(0))
            cliprate.update(clipping_rate, images[0].size(0))


            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if i % args.print_freq == 0:
                progress.display(i)

        loss, acc1, acc5, clipping_rate = losses.avg, top1.avg, top5.avg, cliprate.avg
        #adjust learning rate
        print(f'Training * Loss {loss:.5f} Acc1 {acc1:.3f} Acc5 {acc5:.3f} Learning_rate {self.lr} Clipping_rate {clipping_rate:.2%}')
        return loss, acc1, acc5


class ZOOVFL_Trainer(DPZV_trainer):
    def __init__(self, train_loader, args):
        super().__init__(train_loader, args)

    def train(self, train_loader, models, criterion, epoch, args):
        # Train synchronous and flexible VFL
        batch_time = utils.AverageMeter('Time', '6.3f')
        data_time = utils.AverageMeter('Data', '6.3f')
        losses = utils.AverageMeter('Loss', '.4e')
        top1 = utils.AverageMeter('Acc1', '6.2f')
        top5 = utils.AverageMeter('Acc5', '6.2f')
        cliprate = utils.AverageMeter('Clip Rate', '.2%')
        progress = utils.ProgressMeter(
            len(train_loader),
            [batch_time, data_time, losses, utils.ProgressMeter.BR, top1, top5, cliprate],
            prefix="Epoch: [{}]".format(epoch))
        for client in range(args.num_clients+1):
            models[client].train()

        end = time.time()
        for i, (images, target) in enumerate(train_loader):
            # measure data loading time
            data_time.update(time.time() - end)
            args.global_step += 1
            self.zo_adjust_lr()

            for client in range(args.num_clients):
                images[client] = images[client].cuda(args.gpu, non_blocking=True)
            target = target.cuda(args.gpu, non_blocking=True)

            # self.random_seed = np.random.randint(1000000000)

            # compute inital embeddings 
            embeddings = []
            for client in range(args.num_clients):
                image_local = images[client]
                with torch.no_grad():
                    embedding = models[client](image_local)
                    if args.dp_epsilon > 0:
                        embedding = embedding_dp(embedding, args)
                    embeddings.append(embedding)

            server_model = models[-1]
            if self.optimizer is None:
                self.optimizer = torch.optim.SGD(server_model.parameters(), args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)
            self.optimizer.zero_grad()
            
            for client in range(args.num_clients):
                

                image_local = images[client]
                embeddings_view_plus = embeddings.copy()
                embeddings_view_minus = embeddings.copy()
                
                seed = np.random.randint(1000000000)
                self.random_seeds[client].append(seed)
                if len(self.random_seeds[client])>100:
                    self.random_seeds[client].pop(0)

                with torch.no_grad():
                    # First function evaluation
                    self.zo_perturb_parameters(model=models[client],random_seed=seed, scaling_factor=1)
                    embedding = models[client](image_local)
                    if args.dp_epsilon > 0:
                        embedding = embedding_dp(embedding, args)
                    embeddings_view_plus[client] = embedding
                    # Second function evaluation
                    self.zo_perturb_parameters(model=models[client],random_seed=seed, scaling_factor=-2)
                    embedding = models[client](image_local)
                    if args.dp_epsilon > 0:
                        embedding = embedding_dp(embedding, args)
                    embeddings_view_minus[client] = embedding
                    # Reset model back to its parameters at start of step
                    self.zo_perturb_parameters(model=models[client],random_seed=seed, scaling_factor=1)

                    
                with torch.no_grad():
                    embedding = models[client](image_local)
                    if args.dp_epsilon > 0:
                        embedding = embedding_dp(embedding, args)
                    embeddings[client] = embedding

                    output_plus=models[-1](torch.cat(embeddings_view_plus,axis=1))
                    output_minus=models[-1](torch.cat(embeddings_view_minus,axis=1))
                    # output_plus=models[-1](embeddings_view_plus)
                    # output_minus=models[-1](embeddings_view_minus)

                    # compute gradient
                    loss_1 = criterion(output_plus, target)
                    loss_2 = criterion(output_minus, target)
                    mu_multiplier = 2 if args.grad_estimate_method == 'central' else 1
                    loss_diff = ((loss_1 - loss_2) / (mu_multiplier*args.zo_mu))
                    projected_grad = loss_diff.mean()
            
                    self.history_diff[client].append(projected_grad)
                    if len(self.history_diff[client])>100:
                        self.history_diff[client].pop(0)
                    self.zo_update(models[client], client)

                # measure accuracy and record loss
                acc1, acc5 = accuracy(output_plus, target, topk=(1, 5))
                losses.update(loss_1.mean(), images[0].size(0))
                top1.update(acc1, images[0].size(0))
                top5.update(acc5, images[0].size(0))
                

                # Train the server
                
                output = server_model(torch.cat(embeddings,axis=1))
                # output = server_model(embeddings)
                loss = criterion(output, target).mean()
            
                loss.backward()
                self.optimizer.step()

            # measure accuracy and record loss
            acc1, acc5 = accuracy(output, target, topk=(1, 5))
            losses.update(loss, images[0].size(0))
            top1.update(acc1, images[0].size(0))
            top5.update(acc5, images[0].size(0))


            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            # if i-1 % args.print_freq == 0:
            #     progress.display(i)

        loss, acc1, acc5 = losses.avg, top1.avg, top5.avg
        #adjust learning rate
        print(f'Training * Loss {loss:.5f} Acc1 {acc1:.3f} Acc5 {acc5:.3f} Learning_rate {self.lr}')
        return loss, acc1, acc5


def zofo_train(train_loader, models, criterion, optimizers, epoch, args):
    batch_time = utils.AverageMeter('Time', '6.3f')
    data_time = utils.AverageMeter('Data', '6.3f')
    losses = utils.AverageMeter('Loss', '.4e')
    top1 = utils.AverageMeter('Acc1', '6.2f')
    top5 = utils.AverageMeter('Acc5', '6.2f')
    progress = utils.ProgressMeter(
        len(train_loader),
        [batch_time, data_time, losses, utils.ProgressMeter.BR, top1, top5],
        prefix="Epoch: [{}]".format(epoch))
    for client in range(args.num_clients+1):
        models[client].train()

    end = time.time()
    for i, (images, target) in enumerate(train_loader):
        # measure data loading time
        data_time.update(time.time() - end)

        for client in range(args.num_clients):
            images[client] = images[client].cuda(args.gpu, non_blocking=True)
        target = target.cuda(args.gpu, non_blocking=True)

        # self.random_seed = np.random.randint(1000000000)

        # compute inital embeddings 
        embeddings = []
        for client in range(args.num_clients):
            image_local = images[client]
            with torch.no_grad():
                embedding = models[client](image_local)
                embedding = embedding_dp(embedding, args)
                embeddings.append(embedding)

        

        for client in range(args.num_clients):
            optimizers[client] = torch.optim.SGD(models[client].parameters(), args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)
            image_local = images[client]
            deltas = []
            mu_multiplier = 2 if args.grad_estimate_method == 'central' else 1
            embeddings_view_plus = embeddings.copy()                
            embeddings_view_minus = embeddings.copy()  
            embedding = models[client](image_local)
            embedding = embedding_dp(embedding, args)
            for _ in range(args.num_purt):
                embedding_view = embedding.clone()
                random_seed = np.random.randint(1000000000)
                embeddings_view_plus[client] = perturb_embedding(embedding_view, random_seed, args, scaling_factor=1)
                if args.grad_estimate_method == 'central':
                    embeddings_view_minus[client] = perturb_embedding(embedding_view, random_seed, args, scaling_factor=-2)
                
                with torch.no_grad():
                    output_plus=models[-1](torch.cat(embeddings_view_plus,axis=1))
                    output_minus=models[-1](torch.cat(embeddings_view_minus,axis=1))
                    # output_plus=models[-1](embeddings_view_plus)
                    # output_minus=models[-1](embeddings_view_minus)
                    # compute gradient
                    loss_1 = criterion(output_plus, target)
                    loss_2 = criterion(output_minus, target)
                    deltas.append((loss_1 - loss_2) / (mu_multiplier*args.zo_mu))
            loss_diff = sum(deltas) / args.num_purt
            partial_grad = project_gradient(loss_diff, embedding.size(), random_seed=random_seed)
            
            optimizers[client].zero_grad()
            embedding.backward(gradient=partial_grad, inputs=list(models[client].parameters()))
            optimizers[client].step()


            # measure accuracy and record loss
            acc1, acc5 = accuracy(output_plus, target, topk=(1, 5))
            losses.update(loss_1, images[0].size(0))
            top1.update(acc1, images[0].size(0))
            top5.update(acc5, images[0].size(0))

        # Train the server
        optimizers[-1] = torch.optim.SGD(models[-1].parameters(), args.server_lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)
        output = models[-1](torch.cat(embeddings,axis=1))
        # output = models[-1](embeddings)
        loss = criterion(output, target)
        optimizers[-1].zero_grad()
        loss.backward()
        optimizers[-1].step()

        # measure accuracy and record loss
        acc1, acc5 = accuracy(output, target, topk=(1, 5))
        losses.update(loss, images[0].size(0))
        top1.update(acc1, images[0].size(0))
        top5.update(acc5, images[0].size(0))

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            progress.display(i)

    loss, acc1, acc5 = losses.avg, top1.avg, top5.avg

    print(f'Training * Loss {loss:.5f} Acc1 {acc1:.3f} Acc5 {acc5:.3f} Learning_rate {args.lr}')
    return loss, acc1, acc5

def perturb_embedding(embedding, random_seed: int, args, scaling_factor=1, ):
    torch.manual_seed(random_seed)
    z = torch.normal(mean=0, std=1, size=embedding.size(), device=embedding.device, dtype=embedding.dtype)
    with torch.no_grad():
        embedding = embedding + scaling_factor * z * args.zo_mu
    return embedding

def project_gradient(loss_diff, size, random_seed: int):
    torch.manual_seed(random_seed)
    z = torch.normal(mean=0, std=1, size=size, device=loss_diff.device, dtype=loss_diff.dtype)
    grad = loss_diff * z
    return grad

def train(train_loader, models, criterion, optimizers, epoch, args):
    batch_time = utils.AverageMeter('Time', '6.3f')
    data_time = utils.AverageMeter('Data', '6.3f')
    losses = utils.AverageMeter('Loss', '.4e')
    top1 = utils.AverageMeter('Acc1', '6.2f')
    top5 = utils.AverageMeter('Acc5', '6.2f')
    progress = utils.ProgressMeter(
        len(train_loader),
        [batch_time, data_time, losses, utils.ProgressMeter.BR, top1, top5],
        prefix="Epoch: [{}]".format(epoch))

    """
    Switch to eval mode:
    Under the protocol of linear classification on frozen features/models,
    it is not legitimate to change any part of the pre-trained model.
    BatchNorm in train mode may revise running mean/std (even if it receives
    no gradient), which are part of the model parameters too.
    """
    for client in range(args.num_clients+1):
        models[client].train()

    end = time.time()
    for i, (images, target) in enumerate(train_loader):
        # measure data loading time
        data_time.update(time.time() - end)

        for client in range(args.num_clients):
            images[client] = images[client].cuda(args.gpu, non_blocking=True)
        target = target.cuda(args.gpu, non_blocking=True)

        # compute inital embeddings 
        embeddings = []
        for client in range(args.num_clients):
            image_local = images[client]
            with torch.no_grad():
                embeddings.append(models[client](image_local))

        # Number of local iterations chosen for each party
        # based on the algorithm of choice
        local_epochs = []
        if args.mode == 'sync': 
            for _ in range(args.num_clients+1):
                local_epochs.append(20)
        elif args.mode == 'sync2': 
            #for _ in range(args.num_clients+1):
            #    local_epochs.append(10)
            cpus = [0.3638092, 0.17014983, 0.14333789, 0.33265191, 0.17415424, 0.06804619, 0.14446286, 0.29251583, 0.2207424, 0.12800219, 0.27062089, 0.13972783, 0.30291598]
            epochs = (20*(1-torch.tensor(cpus))).type(torch.int)
            for _ in range(args.num_clients+1):
                local_epochs.append(torch.min(epochs))
        elif args.mode == 'sync3': 
            for _ in range(args.num_clients+1):
                local_epochs.append(5)
        elif args.mode == 'flex': 
            for i in range(10):
                local_epochs.append(i*2+1)
            local_epochs.append(20)
            local_epochs.append(20)
            local_epochs.append(20)
        elif args.mode == 'flex2': 
            #for i in range(6):
            #    local_epochs.append(10)
            #for i in range(6):
            #    local_epochs.append(20)
            #local_epochs.append(20)
            cpus = [0.3638092, 0.17014983, 0.14333789, 0.33265191, 0.17415424, 0.06804619, 0.14446286, 0.29251583, 0.2207424, 0.12800219, 0.27062089, 0.13972783, 0.30291598]
            local_epochs = (20*(1-torch.tensor(cpus))).type(torch.int)
        elif args.mode == 'flex3': 
            for i in range(3):
                local_epochs.append(5)
            for i in range(3):
                local_epochs.append(10)
            for i in range(3):
                local_epochs.append(15)
            for i in range(3):
                local_epochs.append(20)
            local_epochs.append(20)
        elif args.mode == 'pbcd': 
            for _ in range(args.num_clients+1):
                local_epochs.append(1)
        else:
            print("Invalid algorithm chosen:", args.mode)
            return None
        # Train clients and server for Q rounds
        for client in range(args.num_clients+1):
            optimizers[client] = torch.optim.SGD(models[client].parameters(), args.lr, momentum=args.momentum)
            adjust_learning_rate(optimizers[client], epoch, args.lr)
            for q in range(local_epochs[client]):
                if client != args.num_clients:
                    image_local = images[client]
                    embedding_view = embeddings.copy()
                    embedding_view[client] = models[client](image_local)
                else:
                    embedding_view = embeddings
                output = models[-1](torch.cat(embedding_view,axis=1))

                # compute gradient and do SGD step
                loss = criterion(output, target)

                optimizers[client].zero_grad()
                loss.backward()
                optimizers[client].step()

                # measure accuracy and record loss
                acc1, acc5 = accuracy(output, target, topk=(1, 5))
                losses.update(loss, images[0].size(0))
                top1.update(acc1, images[0].size(0))
                top5.update(acc5, images[0].size(0))

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            progress.display(i)

    loss, acc1, acc5 = losses.avg, top1.avg, top5.avg
    print(f'Training * Loss {loss:.5f} Acc1 {acc1:.3f} Acc5 {acc5:.3f}')
    return loss, acc1, acc5


def validate(val_loader, models, criterion, args):
    # Get accuracy of models on data in val_loader

    batch_time = utils.AverageMeter('Time', '6.3f')
    losses = utils.AverageMeter('Loss', '.4e')
    top1 = utils.AverageMeter('Acc1', '6.2f')
    top5 = utils.AverageMeter('Acc5', '6.2f')
    progress = utils.ProgressMeter(
        len(val_loader),
        [batch_time, losses, utils.ProgressMeter.BR, top1, top5],
        prefix='Test: ')

    # switch to evaluate mode
    for i in range(args.num_clients+1):
        models[i].eval()

    with torch.no_grad():
        end = time.time()
        for i, (images, target) in enumerate(val_loader):
            for client in range(args.num_clients):
                images[client] = images[client].cuda(args.gpu, non_blocking=True)
            target = target.cuda(args.gpu, non_blocking=True)

            # compute output
            embeddings = []
            for i in range(args.num_clients):
                image_local = images[i]
                embeddings.append(models[i](image_local))
            output = models[-1](torch.cat(embeddings,axis=1))
            loss = criterion(output, target).mean()

            # measure accuracy and record loss
            acc1, acc5 = accuracy(output, target, topk=(1, 5))
            losses.update(loss, images[0].size(0))
            top1.update(acc1, images[0].size(0))
            top5.update(acc5, images[0].size(0))

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if i % args.print_freq == 0:
                progress.display(i)

    # TODO: this should also be done with the ProgressMeter
    loss, acc1, acc5 = losses.avg, top1.avg, top5.avg
    print(f'Test * Loss {loss:.5f} Acc1 {acc1:.3f} Acc5 {acc5:.3f}')

    return loss, acc1, acc5


def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
    torch.save(state, filename)
    if is_best:
        shutil.copyfile(filename, os.path.join(os.path.split(filename)[0], 'model_best.pth.tar'))


def adjust_learning_rate(optimizer, epoch, lr):
    """Decay the learning rate based on schedule"""
    for milestone in args.schedule:
        lr *= 0.1 if epoch >= milestone else 1.
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

def adjust_lr(
    args,
    optimizer=None
):
    """
    Mimics the linear warmup and linear decay from HF Transformers.
    
    :param global_step: Current training step (int).
    :param warmup_steps: Number of steps to linearly warm up the LR.
    :param total_steps: Total training steps (e.g., epochs * steps_per_epoch).
    :param init_lr: The maximum (peak) LR reached after warmup.
    """
    global_step = args.global_step
    warmup_steps = args.warmup_steps
    total_steps = args.total_steps
    if global_step < warmup_steps:
        # Warmup phase: LR from 0 -> init_lr
        new_lr = args.lr * float(global_step) / float(warmup_steps)
    else:
        # Decay phase: LR from init_lr -> 0
        # fraction of (remaining steps) completed
        steps_since_warmup = global_step - warmup_steps
        total_decay_steps = total_steps - warmup_steps
        if steps_since_warmup >= total_decay_steps:
            # if we've exceeded total_steps, LR = 0
            new_lr = 0.0
        else:
            remaining_frac = 1.0 - float(steps_since_warmup) / float(total_decay_steps)
            new_lr = args.lr * remaining_frac
    if optimizer is not None:
        for param_group in optimizer.param_groups:
            param_group['lr'] = new_lr
    return new_lr

def accuracy(output, target, topk=(1,)):
    """Computes the accuracy over the k top predictions for the specified values of k"""
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].reshape(-1).float().sum()
            res.append(correct_k.mul_(100.0 / batch_size))
        return res


if __name__ == '__main__':
    main()
