import numpy as np
import torch
import torch.nn.functional as F
import os
import random
from torch.backends import cudnn
from random import sample
import math
import torch.optim as optim
import torch.nn as nn
from utils import load_data
import copy
from swa_utils import AveragedModel, SWALR
from torch.optim.lr_scheduler import CosineAnnealingLR
from utils import init_model

##############################################################################
# General server function
##############################################################################

def receive_client_models(args, client_nodes, select_list, size_weights):
    client_params = []
    for idx in select_list:
        if 'fedawo' in args.server_method:
            client_params.append(client_nodes[idx].model.get_param(clone = True))
        else:
            client_params.append(copy.deepcopy(client_nodes[idx].model.state_dict()))
    
    agg_weights = [size_weights[idx] for idx in select_list]
    agg_weights = [w/sum(agg_weights) for w in agg_weights]

    return agg_weights, client_params

def get_model_updates(client_params, prev_para):
    prev_param = copy.deepcopy(prev_para)
    client_updates = []
    for param in client_params:
        client_updates.append(param.sub(prev_param))
    return client_updates

def get_client_params_with_serverlr(server_lr, prev_param, client_updates):
    client_params = []
    with torch.no_grad():
        for update in client_updates:
            param = prev_param.add(update*server_lr)
            client_params.append(param)
    return client_params

##############################################################################
# FedAWO function
##############################################################################

def fedawo_generate_global_model(gamma, optmized_weights, client_params, central_node):
    for i in range(len(client_params)):
        if i == 0:
            fedawo_param = gamma*optmized_weights[i]*client_params[i]
        else:
            fedawo_param = fedawo_param.add(gamma*optmized_weights[i]*client_params[i])
    central_node.model.load_param(copy.deepcopy(fedawo_param.detach()))
    
    return central_node



# ## SWA: first gamma then lambda
# def FedAWO_optimization(args, size_weights, parameters, central_node):
#     '''
#     fedawo optimization functions for optimize both gamma and lambdas
#     '''
#     if args.dataset == 'cifar10':
#         server_lr = 0.01
#     else:
#         server_lr = 0.005

#     cohort_size = len(parameters)

#     if args.whether_swa == 'none':
#         # initialize gamma and lambdas
#         # the last element is gamma
#         if args.server_funct == 'exp':
#             optimizees = torch.tensor([torch.log(torch.tensor(j)) for j in size_weights] + [0.0], device='cuda', requires_grad=True)
#         elif args.server_funct == 'quad':
#             optimizees = torch.tensor([math.sqrt(1.0/cohort_size) for j in size_weights]+ [1.0], device='cuda', requires_grad=True)
#         optimizee_list = [optimizees]

#         if args.server_optimizer == 'adam':
#             optimizer = optim.Adam(optimizee_list, lr=server_lr, betas=(0.5, 0.999))
#         elif args.server_optimizer == 'sgd':
#             optimizer = optim.SGD(optimizee_list, lr=server_lr, momentum=0.9)
#         else:
#             raise ValueError('fusion optimizer is not defined!')

#         # set the scheduler
#         scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=20,
#                                         gamma=0.5)

#         # clear grad
#         for i in range(len(optimizee_list)):
#             optimizee_list[i].grad = torch.zeros_like(optimizee_list[i])

#         # Train optimizees
#         softmax = nn.Softmax(dim=0)
#         # set the model as train to update the buffers for normalization layers
#         central_node.model.train()
#         for epoch in range(args.server_epochs): 
#             # the training data is the small dataset on the server
#             train_loader = central_node.validate_set 
#             for itr, (data, target) in enumerate(train_loader):
#                 for i in range(cohort_size):
#                     if i == 0:
#                         if args.server_funct == 'exp':
#                             model_param = torch.exp(optimizees[-1])*softmax(optimizees[:-1])[i]*parameters[i]
#                         elif args.server_funct == 'quad':
#                             model_param = optimizees[-1]*optimizees[-1]*optimizees[i]*optimizees[i]/sum(optimizees[:-1]*optimizees[:-1])*parameters[i]
#                     else:
#                         if args.server_funct == 'exp':
#                             model_param = model_param.add(torch.exp(optimizees[-1])*softmax(optimizees[:-1])[i]*parameters[i])
#                         elif args.server_funct == 'quad':
#                             model_param = model_param.add(optimizees[-1]*optimizees[-1]*optimizees[i]*optimizees[i]/sum(optimizees[:-1]*optimizees[:-1])*parameters[i])

#                 # train model
#                 data, target = load_data(args, central_node.cluster_id, data, target)
#                 data, target = data.cuda(), target.cuda()

#                 # Update optimizees
#                 # zero_grad
#                 optimizer.zero_grad()
#                 # update models according to the lr
#                 output = central_node.model.forward_with_param(data, model_param)
#                 loss =  F.cross_entropy(output, target)
#                 loss.backward()
#                 optimizer.step()
#             # scheduling
#             scheduler.step()
#         # record and print current lam
#         if args.server_funct == 'exp':
#             optmized_weights = [j for j in softmax(optimizees[:-1]).detach().cpu().numpy()]
#             learned_gamma = torch.exp(optimizees[-1])
#         elif args.server_funct == 'quad':
#             optmized_weights = [j*j/sum(optimizees[:-1]*optimizees[:-1]) for j in optimizees[:-1].detach().cpu().numpy()]
#             learned_gamma = optimizees[-1]*optimizees[-1]

#     elif args.whether_swa == 'swa':
#         # Two stage strategy: first, train gamma; second, train lambdas

#         ## Train gamma ##
#         # initialize fusion weights
#         optimizees = []
#         if args.server_funct == 'exp':
#             gamma = torch.tensor(0.0, device='cuda', requires_grad=True)
#         elif args.server_funct == 'quad':
#             gamma = torch.tensor(1.0, device='cuda', requires_grad=True)
#         optimizees.append(gamma)

#         # set the optimizer
#         if args.server_optimizer == 'adam':
#             optimizer = optim.Adam(optimizees, lr=server_lr, betas=(0.5, 0.999))
#         elif args.server_optimizer == 'sgd':
#             optimizer = optim.SGD(optimizees, lr=server_lr, momentum=0.9)
#         else:
#             raise ValueError('fusion optimizer is not defined!')
#         scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=20,
#                                         gamma=0.5)

#         # clear grad
#         for i in range(len(optimizees)):
#             optimizees[i].grad = torch.zeros_like(optimizees[i])

#         # train optimizees
#         central_node.model.train()
#         for epoch in range(args.server_epochs//2): 
#             # the training data is the small dataset on the server
#             train_loader = central_node.validate_set 

#             for _, (data, target) in enumerate(train_loader):

#                 for i in range(cohort_size):
#                     if i == 0:
#                         if args.server_funct == 'exp':
#                             model_param = torch.exp(optimizees[-1])*size_weights[i]*parameters[i]
#                         elif args.server_funct == 'quad':
#                             model_param = optimizees[-1]*optimizees[-1]*size_weights[i]*parameters[i]
#                     else:
#                         if args.server_funct == 'exp':
#                             model_param = model_param.add(torch.exp(optimizees[-1])*size_weights[i]*parameters[i])
#                         elif args.server_funct == 'quad':
#                             model_param = model_param.add(optimizees[-1]*optimizees[-1]*size_weights[i]*parameters[i])

#                 # train model
#                 data, target = load_data(args, central_node.cluster_id, data, target)
#                 data, target = data.cuda(), target.cuda()

#                 # update optimizees
#                 # zero_grad
#                 optimizer.zero_grad()
#                 # update models according to the lr
#                 output = central_node.model.forward_with_param(data, model_param)
#                 loss =  F.cross_entropy(output, target)
#                 loss.backward()
#                 optimizer.step()
#             scheduler.step()
    
#         if args.server_funct == 'exp':
#             learned_gamma = copy.deepcopy(torch.exp(optimizees[-1]).detach())
#         elif args.server_funct == 'quad':
#             learned_gamma = copy.deepcopy(optimizees[-1]*optimizees[-1].detach())

#         ## optimize lambdas ##
#         # initialize fusion weights
#         optimizees = []
#         if args.server_funct == 'exp':
#             lam = torch.tensor([torch.log(torch.tensor(j)) for j in size_weights], device='cuda', requires_grad=True)
#         elif args.server_funct == 'quad':
#             lam = torch.tensor([math.sqrt(1.0/cohort_size) for j in size_weights], device='cuda', requires_grad=True)
#         optimizees.append(lam)

#         # set the optimizer
#         if args.server_optimizer == 'adam':
#             optimizer = optim.Adam(optimizees, lr=server_lr, betas=(0.5, 0.999))
#         elif args.server_optimizer == 'sgd':
#             optimizer = optim.SGD(optimizees, lr=server_lr, momentum=0.9)
#         else:
#             raise ValueError('fusion optimizer is not defined!')

#         # set the scheduler
#         scheduler = CosineAnnealingLR(optimizer, T_max=100)
#         swa_model = AveragedModel(lam)
#         swa_start = 5
#         swa_scheduler = SWALR(optimizer, swa_lr=0.05)

#         # clear grad
#         for i in range(len(optimizees)):
#             optimizees[i].grad = torch.zeros_like(optimizees[i])

#         # train optimizees
#         softmax = nn.Softmax(dim=0)
#         central_node.model.train()
#         for epoch in range(args.server_epochs//2): 
#             # the training data is the small dataset on the server
#             train_loader = central_node.validate_set 
#             for _, (data, target) in enumerate(train_loader):
#                 for i in range(cohort_size):
#                     if i == 0:
#                         if args.server_funct == 'exp':
#                             model_param = learned_gamma*softmax(lam)[i]*parameters[i]
#                         elif args.server_funct == 'quad':
#                             model_param = learned_gamma*lam[i]*lam[i]/sum(lam*lam)*parameters[i]
#                         # print(learned_gamma)
#                     else:
#                         if args.server_funct == 'exp':
#                             model_param = model_param.add(learned_gamma*softmax(lam)[i]*parameters[i])
#                         elif args.server_funct == 'quad':
#                             model_param = model_param.add(learned_gamma*lam[i]*lam[i]/sum(lam*lam)*parameters[i])
#                 # train model
#                 data, target = load_data(args, central_node.cluster_id, data, target)
#                 data, target = data.cuda(), target.cuda()

#                 # update optimizees
#                 # zero_grad
#                 optimizer.zero_grad()
#                 # update models according to the lr
#                 output = central_node.model.forward_with_param(data, model_param)
#                 loss =  F.cross_entropy(output, target)
#                 loss.backward()
#                 optimizer.step()
#             # scheduling
#             if epoch > swa_start:
#                 swa_model.update_parameters(lam)
#                 swa_scheduler.step()
#             else:
#                 scheduler.step()
#         if args.server_funct == 'exp':
#             optmized_weights = [j for j in softmax(swa_model.module).detach().cpu().numpy()]
#         elif args.server_funct == 'quad':
#             optmized_weights = [j*j/sum(swa_model.module*swa_model.module) for j in swa_model.module.detach().cpu().numpy()]

#     return learned_gamma, optmized_weights



## SWA: first lambdas then gamma
def FedAWO_optimization(args, size_weights, parameters, central_node):
    '''
    fedawo optimization functions for optimize both gamma and lambdas
    '''
    if args.dataset == 'cifar10':
        server_lr = 0.01
    else:
        server_lr = 0.005

    cohort_size = len(parameters)

    if args.whether_swa == 'none':
        # initialize gamma and lambdas
        # the last element is gamma
        if args.server_funct == 'exp':
            optimizees = torch.tensor([torch.log(torch.tensor(j)) for j in size_weights] + [0.0], device='cuda', requires_grad=True)
        elif args.server_funct == 'quad':
            optimizees = torch.tensor([math.sqrt(1.0/cohort_size) for j in size_weights]+ [1.0], device='cuda', requires_grad=True)
        optimizee_list = [optimizees]

        if args.server_optimizer == 'adam':
            optimizer = optim.Adam(optimizee_list, lr=server_lr, betas=(0.5, 0.999))
        elif args.server_optimizer == 'sgd':
            optimizer = optim.SGD(optimizee_list, lr=server_lr, momentum=0.9)
        else:
            raise ValueError('fusion optimizer is not defined!')

        # set the scheduler
        scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=20,
                                        gamma=0.5)

        # clear grad
        for i in range(len(optimizee_list)):
            optimizee_list[i].grad = torch.zeros_like(optimizee_list[i])

        # Train optimizees
        softmax = nn.Softmax(dim=0)
        # set the model as train to update the buffers for normalization layers
        central_node.model.train()
        for epoch in range(args.server_epochs): 
            # the training data is the small dataset on the server
            train_loader = central_node.validate_set 
            for itr, (data, target) in enumerate(train_loader):
                for i in range(cohort_size):
                    if i == 0:
                        if args.server_funct == 'exp':
                            model_param = torch.exp(optimizees[-1])*softmax(optimizees[:-1])[i]*parameters[i]
                        elif args.server_funct == 'quad':
                            model_param = optimizees[-1]*optimizees[-1]*optimizees[i]*optimizees[i]/sum(optimizees[:-1]*optimizees[:-1])*parameters[i]
                    else:
                        if args.server_funct == 'exp':
                            model_param = model_param.add(torch.exp(optimizees[-1])*softmax(optimizees[:-1])[i]*parameters[i])
                        elif args.server_funct == 'quad':
                            model_param = model_param.add(optimizees[-1]*optimizees[-1]*optimizees[i]*optimizees[i]/sum(optimizees[:-1]*optimizees[:-1])*parameters[i])

                # train model
                data, target = load_data(args, central_node.cluster_id, data, target)
                data, target = data.cuda(), target.cuda()

                # Update optimizees
                # zero_grad
                optimizer.zero_grad()
                # update models according to the lr
                output = central_node.model.forward_with_param(data, model_param)
                loss =  F.cross_entropy(output, target)
                loss.backward()
                optimizer.step()
            # scheduling
            scheduler.step()
        # record and print current lam
        if args.server_funct == 'exp':
            optmized_weights = [j for j in softmax(optimizees[:-1]).detach().cpu().numpy()]
            learned_gamma = torch.exp(optimizees[-1])
        elif args.server_funct == 'quad':
            optmized_weights = [j*j/sum(optimizees[:-1]*optimizees[:-1]) for j in optimizees[:-1].detach().cpu().numpy()]
            learned_gamma = optimizees[-1]*optimizees[-1]

    elif args.whether_swa == 'swa':
        # Two stage strategy: first, train gamma; second, train lambdas

        ## optimize lambdas ##
        # initialize fusion weights
        optimizees = []
        if args.server_funct == 'exp':
            lam = torch.tensor([torch.log(torch.tensor(j)) for j in size_weights], device='cuda', requires_grad=True)
        elif args.server_funct == 'quad':
            lam = torch.tensor([math.sqrt(1.0/cohort_size) for j in size_weights], device='cuda', requires_grad=True)
        optimizees.append(lam)

        # set the optimizer
        if args.server_optimizer == 'adam':
            optimizer = optim.Adam(optimizees, lr=server_lr, betas=(0.5, 0.999))
        elif args.server_optimizer == 'sgd':
            optimizer = optim.SGD(optimizees, lr=server_lr, momentum=0.9)
        else:
            raise ValueError('fusion optimizer is not defined!')

        # set the scheduler
        scheduler = CosineAnnealingLR(optimizer, T_max=100)
        swa_model = AveragedModel(lam)
        swa_start = 5
        swa_scheduler = SWALR(optimizer, swa_lr=0.05)

        # clear grad
        for i in range(len(optimizees)):
            optimizees[i].grad = torch.zeros_like(optimizees[i])

        # train optimizees
        softmax = nn.Softmax(dim=0)
        central_node.model.train()
        for epoch in range(args.server_epochs//2): 
            # the training data is the small dataset on the server
            train_loader = central_node.validate_set 
            for _, (data, target) in enumerate(train_loader):
                for i in range(cohort_size):
                    if i == 0:
                        if args.server_funct == 'exp':
                            model_param = softmax(lam)[i]*parameters[i]
                        elif args.server_funct == 'quad':
                            model_param = lam[i]*lam[i]/sum(lam*lam)*parameters[i]
                        # print(learned_gamma)
                    else:
                        if args.server_funct == 'exp':
                            model_param = model_param.add(softmax(lam)[i]*parameters[i])
                        elif args.server_funct == 'quad':
                            model_param = model_param.add(lam[i]*lam[i]/sum(lam*lam)*parameters[i])
                # train model
                data, target = load_data(args, central_node.cluster_id, data, target)
                data, target = data.cuda(), target.cuda()

                # update optimizees
                # zero_grad
                optimizer.zero_grad()
                # update models according to the lr
                output = central_node.model.forward_with_param(data, model_param)
                loss =  F.cross_entropy(output, target)
                loss.backward()
                optimizer.step()
            # scheduling
            if epoch > swa_start:
                swa_model.update_parameters(lam)
                swa_scheduler.step()
            else:
                scheduler.step()
        if args.server_funct == 'exp':
            optmized_weights = [j for j in softmax(swa_model.module).detach().cpu().numpy()]
        elif args.server_funct == 'quad':
            optmized_weights = [j*j/sum(swa_model.module*swa_model.module).detach().cpu().numpy() for j in swa_model.module.detach().cpu().numpy()]

        ## Train gamma ##
        # initialize fusion weights
        optimizees = []
        if args.server_funct == 'exp':
            gamma = torch.tensor(0.0, device='cuda', requires_grad=True)
        elif args.server_funct == 'quad':
            gamma = torch.tensor(1.0, device='cuda', requires_grad=True)
        optimizees.append(gamma)

        # set the optimizer
        if args.server_optimizer == 'adam':
            optimizer = optim.Adam(optimizees, lr=server_lr, betas=(0.5, 0.999))
        elif args.server_optimizer == 'sgd':
            optimizer = optim.SGD(optimizees, lr=server_lr, momentum=0.9)
        else:
            raise ValueError('fusion optimizer is not defined!')
        scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=20,
                                        gamma=0.5)

        # clear grad
        for i in range(len(optimizees)):
            optimizees[i].grad = torch.zeros_like(optimizees[i])

        # train optimizees
        central_node.model.train()
        for epoch in range(args.server_epochs//2): 
            # the training data is the small dataset on the server
            train_loader = central_node.validate_set 

            for _, (data, target) in enumerate(train_loader):

                for i in range(cohort_size):
                    if i == 0:
                        if args.server_funct == 'exp':
                            model_param = torch.exp(optimizees[-1])*optmized_weights[i]*parameters[i]
                        elif args.server_funct == 'quad':
                            model_param = optimizees[-1]*optimizees[-1]*optmized_weights[i]*parameters[i]
                    else:
                        if args.server_funct == 'exp':
                            model_param = model_param.add(torch.exp(optimizees[-1])*optmized_weights[i]*parameters[i])
                        elif args.server_funct == 'quad':
                            model_param = model_param.add(optimizees[-1]*optimizees[-1]*optmized_weights[i]*parameters[i])

                # train model
                data, target = load_data(args, central_node.cluster_id, data, target)
                data, target = data.cuda(), target.cuda()

                # update optimizees
                # zero_grad
                optimizer.zero_grad()
                # update models according to the lr
                output = central_node.model.forward_with_param(data, model_param)
                loss =  F.cross_entropy(output, target)
                loss.backward()
                optimizer.step()
            scheduler.step()
        
        if args.server_funct == 'exp':
            learned_gamma = copy.deepcopy(torch.exp(optimizees[-1]).detach())
        elif args.server_funct == 'quad':
            learned_gamma = copy.deepcopy((optimizees[-1]*optimizees[-1]).detach())

    return learned_gamma, optmized_weights


def optimize_normalized_aw_quad(size_weights, args, parameters, central_node):
    '''
    fedawo optimization functions for optimizing lambdas
    '''

    cohort_size = len(parameters)

    # initialize fusion weights
    optimizees = []
    lam = torch.tensor([math.sqrt(1.0/cohort_size) for j in size_weights], device='cuda', requires_grad=True)
    optimizees.append(lam)

    # set the optimizer
    if args.server_optimizer == 'adam':
        optimizer = optim.Adam(optimizees, lr=0.01, betas=(0.5, 0.999))
    elif args.server_optimizer == 'sgd':
        optimizer = optim.SGD(optimizees, lr=0.01, momentum=0.9, weight_decay=5e-4)
    else:
        raise ValueError('fusion optimizer is not defined!')

    # set the scheduler
    if args.whether_swa == 'none':
        scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=20,
                                        gamma=0.5)
    elif args.whether_swa == 'swa':
        scheduler = CosineAnnealingLR(optimizer, T_max=100)
        swa_model = AveragedModel(lam)
        swa_start = 5
        swa_scheduler = SWALR(optimizer, swa_lr=0.05)

    # clear grad
    for i in range(len(optimizees)):
        optimizees[i].grad = torch.zeros_like(optimizees[i])

    # train optimizees
    softmax = nn.Softmax(dim=0)
    central_node.model.train()
    for epoch in range(args.server_epochs): 
        # the training data is the small dataset on the server
        train_loader = central_node.validate_set 

        for _, (data, target) in enumerate(train_loader):
            for i in range(cohort_size):
                if i == 0:
                    model_param = lam[i]*lam[i]/sum(lam*lam)*parameters[i]
                else:
                    model_param = model_param.add(lam[i]*lam[i]/sum(lam*lam)*parameters[i])

            # train model
            data, target = load_data(args, central_node.cluster_id, data, target)
            data, target = data.cuda(), target.cuda()

            # update optimizees
            # zero_grad
            optimizer.zero_grad()
            # update models according to the lr
            output = central_node.model.forward_with_param(data, model_param)
            loss =  F.cross_entropy(output, target)
            loss.backward()
            optimizer.step()

        # scheduling
        if args.whether_swa == 'none':
            scheduler.step()
        elif args.whether_swa == 'swa':
            if epoch > swa_start:
                swa_model.update_parameters(lam)
                swa_scheduler.step()
            else:
                scheduler.step()

    if args.whether_swa == 'none':
        fus_weights = [j*j/sum(lam*lam) for j in lam.detach().cpu().numpy()]
    elif args.whether_swa == 'swa':
        fus_weights = [j*j/sum(lam*lam) for j in swa_model.module.detach().cpu().numpy()]

    return fus_weights


def optimize_normalized_aw(size_weights, args, parameters, central_node):
    '''
    fedawo optimization functions for optimizing lambdas
    '''

    cohort_size = len(parameters)

    # initialize fusion weights
    optimizees = []
    lam = torch.tensor([torch.log(torch.tensor(j)) for j in size_weights], device='cuda', requires_grad=True)
    optimizees.append(lam)

    # set the optimizer
    if args.server_optimizer == 'adam':
        optimizer = optim.Adam(optimizees, lr=0.01, betas=(0.5, 0.999))
    elif args.server_optimizer == 'sgd':
        optimizer = optim.SGD(optimizees, lr=0.01, momentum=0.9, weight_decay=5e-4)
    else:
        raise ValueError('fusion optimizer is not defined!')

    # set the scheduler
    if args.whether_swa == 'none':
        scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=20,
                                        gamma=0.5)
    elif args.whether_swa == 'swa':
        scheduler = CosineAnnealingLR(optimizer, T_max=100)
        swa_model = AveragedModel(lam)
        swa_start = 5
        swa_scheduler = SWALR(optimizer, swa_lr=0.05)

    # clear grad
    for i in range(len(optimizees)):
        optimizees[i].grad = torch.zeros_like(optimizees[i])

    # train optimizees
    softmax = nn.Softmax(dim=0)
    central_node.model.train()
    for epoch in range(args.server_epochs): 
        # the training data is the small dataset on the server
        train_loader = central_node.validate_set 

        for _, (data, target) in enumerate(train_loader):
            for i in range(cohort_size):
                if i == 0:
                    model_param = softmax(lam)[i]*parameters[i]
                else:
                    model_param = model_param.add(softmax(lam)[i]*parameters[i])

            # train model
            data, target = load_data(args, central_node.cluster_id, data, target)
            data, target = data.cuda(), target.cuda()

            # update optimizees
            # zero_grad
            optimizer.zero_grad()
            # update models according to the lr
            output = central_node.model.forward_with_param(data, model_param)
            loss =  F.cross_entropy(output, target)
            loss.backward()
            optimizer.step()

        # scheduling
        if args.whether_swa == 'none':
            scheduler.step()
        elif args.whether_swa == 'swa':
            if epoch > swa_start:
                swa_model.update_parameters(lam)
                swa_scheduler.step()
            else:
                scheduler.step()

    if args.whether_swa == 'none':
        fus_weights = [j for j in softmax(lam).detach().cpu().numpy()]
    elif args.whether_swa == 'swa':
        fus_weights = [j for j in softmax(swa_model.module).detach().cpu().numpy()]

    return fus_weights


def optimize_gamma(size_weights, args, parameters, central_node):
    '''
    fedawo optimization functions for optimizing gamma
    '''

    cohort_size = len(parameters)

    # initialize fusion weights
    optimizees = []
    gamma = torch.tensor(0.0, device='cuda', requires_grad=True)
    optimizees.append(gamma)

    # set the optimizer
    if args.server_optimizer == 'adam':
        optimizer = optim.Adam(optimizees, lr=0.01, betas=(0.5, 0.999))
    elif args.server_optimizer == 'sgd':
        optimizer = optim.SGD(optimizees, lr=0.01, momentum=0.9, weight_decay=5e-4)
    else:
        raise ValueError('fusion optimizer is not defined!')
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=20,
                                    gamma=0.5)

    # clear grad
    for i in range(len(optimizees)):
        optimizees[i].grad = torch.zeros_like(optimizees[i])

    # train optimizees
    central_node.model.train()
    for epoch in range(args.server_epochs): 
        # the training data is the small dataset on the server
        train_loader = central_node.validate_set 

        for _, (data, target) in enumerate(train_loader):

            for i in range(cohort_size):
                if i == 0:
                    model_param = torch.exp(optimizees[-1])*size_weights[i]*parameters[i]
                else:
                    model_param = model_param.add(torch.exp(optimizees[-1])*size_weights[i]*parameters[i])

            # train model
            data, target = load_data(args, central_node.cluster_id, data, target)
            data, target = data.cuda(), target.cuda()

            # update optimizees
            # zero_grad
            optimizer.zero_grad()
            # update models according to the lr
            output = central_node.model.forward_with_param(data, model_param)
            loss =  F.cross_entropy(output, target)
            loss.backward()
            optimizer.step()
        scheduler.step()

    return torch.exp(optimizees[-1])


def optimize_serverlr(size_weights, args, client_updates, prev_para, central_node):
    '''
    optimizing serverlr based on proxy data
    '''
    prev_param = copy.deepcopy(prev_para)
    cohort_size = len(client_updates)

    # initialize fusion weights
    optimizees = []
    gamma = torch.tensor(0.0, device='cuda', requires_grad=True)
    optimizees.append(gamma)

    # set the optimizer
    if args.server_optimizer == 'adam':
        optimizer = optim.Adam(optimizees, lr=0.01, betas=(0.5, 0.999))
    elif args.server_optimizer == 'sgd':
        optimizer = optim.SGD(optimizees, lr=0.01, momentum=0.9, weight_decay=5e-4)
    else:
        raise ValueError('fusion optimizer is not defined!')
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=20,
                                    gamma=0.5)

    # clear grad
    for i in range(len(optimizees)):
        optimizees[i].grad = torch.zeros_like(optimizees[i])

    # train optimizees
    central_node.model.train()
    for epoch in range(args.server_epochs): 
        # the training data is the small dataset on the server
        train_loader = central_node.validate_set 

        for _, (data, target) in enumerate(train_loader):

            for i in range(cohort_size):
                if i == 0:
                    model_param = prev_param.add(torch.exp(optimizees[-1])*size_weights[i]*client_updates[i])
                else:
                    model_param = model_param.add(torch.exp(optimizees[-1])*size_weights[i]*client_updates[i])

            # train model
            data, target = load_data(args, central_node.cluster_id, data, target)
            data, target = data.cuda(), target.cuda()

            # update optimizees
            # zero_grad
            optimizer.zero_grad()
            # update models according to the lr
            output = central_node.model.forward_with_param(data, model_param)
            loss =  F.cross_entropy(output, target)
            loss.backward()
            optimizer.step()
        scheduler.step()

    return torch.exp(optimizees[-1])


##############################################################################
# Baselines function (FedAvg, FedDF, FedBE, Finetune, etc.)
##############################################################################

def Server_update(args, central_node, client_nodes, select_list, size_weights):
    '''
    server update functions for baselines
    '''

    # receive the local models from clients
    agg_weights, client_params = receive_client_models(args, client_nodes, select_list, size_weights)

    # update the global model
    if args.server_method == 'fedavg':
        avg_global_param = fedavg(client_params, agg_weights)
        central_node.model.load_state_dict(avg_global_param)

    elif args.server_method == 'feddf':
        avg_global_param = fedavg(client_params, agg_weights)
        central_node.model.load_state_dict(avg_global_param)
        central_node = feddf(args, central_node, client_nodes, select_list)

    elif args.server_method == 'fedbe':
        prev_global_param = copy.deepcopy(central_node.model.state_dict())
        avg_global_param = fedavg(client_params, agg_weights)
        central_node.model.load_state_dict(avg_global_param)
        central_node = fedbe(args, prev_global_param, central_node, client_nodes, select_list)

    elif args.server_method == 'finetune':
        avg_global_param = fedavg(client_params, agg_weights)
        central_node.model.load_state_dict(avg_global_param)
        central_node = server_finetune(args, central_node)

    elif args.server_method == 'feddyn':
        central_node = feddyn(args, central_node, agg_weights, client_nodes, select_list)
    
    elif args.server_method == 'fedadam':
        avg_global_param = fedavg(client_params, agg_weights)
        central_node = fedadam(args, central_node, avg_global_param)

    else:
        raise ValueError('Undefined server method...')

    return central_node

def server_finetune(args, central_node):
    central_node.model.train()
    for epoch in range(args.server_epochs): 
        # the training data is the small dataset on the server
        train_loader = central_node.validate_set 

        for _, (data, target) in enumerate(train_loader):

            central_node.optimizer.zero_grad()
            # train model
            data, target = load_data(args, central_node.cluster_id, data, target)
            data, target = data.cuda(), target.cuda()

            output = central_node.model(data)

            # compute losses according to the weights
            loss =  F.cross_entropy(output, target)
            loss.backward()
            central_node.optimizer.step()

    return central_node

def fedavg(parameters, list_nums_local_data):
    fedavg_global_params = copy.deepcopy(parameters[0])
    for name_param in parameters[0]:
        list_values_param = []
        for dict_local_params, num_local_data in zip(parameters, list_nums_local_data):
            list_values_param.append(dict_local_params[name_param] * num_local_data)
        value_global_param = sum(list_values_param) / sum(list_nums_local_data)
        fedavg_global_params[name_param] = value_global_param
    return fedavg_global_params

# FedDF
def divergence(student_logits, teacher_logits):
    divergence = F.kl_div(
        F.log_softmax(student_logits, dim=1),
        F.softmax(teacher_logits, dim=1),
        reduction="batchmean",
    )  # forward KL
    return divergence

def feddf(args, central_node, client_nodes, select_list):
    # train and update
    central_node.model.cuda().train()
    nets = []
    for client_idx in select_list:
        client_nodes[client_idx].model.cuda().eval()
        nets.append(client_nodes[client_idx].model)

    for _ in range(args.server_epochs):
        train_loader = central_node.validate_set 

        for _, (data, target) in enumerate(train_loader):
            central_node.optimizer.zero_grad()
            # train model
            data, target = load_data(args, central_node.cluster_id, data, target)
            data, target = data.cuda(), target.cuda()

            output = central_node.model(data)
            teacher_logits = sum([net(data).detach() for net in nets]) / len(select_list)

            loss = divergence(output, teacher_logits)
            loss.backward()
            central_node.optimizer.step()

    return central_node

# FedBE
class SWAG_server(torch.nn.Module):
    def __init__(self, base_model, avg_model=None, max_num_models=25, var_clamp=1e-5, concentrate_num=1):
        self.base_model = base_model
        self.max_num_models=max_num_models
        self.var_clamp=var_clamp
        self.concentrate_num = concentrate_num
        self.avg_model = avg_model
         
    def compute_var(self, mean, sq_mean): 
        var_dict = {}
        for k in mean.keys():
          var = torch.clamp(sq_mean[k] - mean[k] ** 2, self.var_clamp) 
          var_dict[k] = var 

        return var_dict

    def compute_mean_sq(self, teachers):
        w_avg = {}
        w_sq_avg = {}
        w_norm ={}
        
        for k in teachers[0].keys():
            if "batches_tracked" in k: continue
            w_avg[k] = torch.zeros(teachers[0][k].size())
            w_sq_avg[k] = torch.zeros(teachers[0][k].size())
            w_norm[k] = 0.0 
          
        for k in w_avg.keys():
            if "batches_tracked" in k: continue
            for i in range(0, len(teachers)):
              grad = teachers[i][k].cpu()- self.base_model[k].cpu()
              norm = torch.norm(grad, p=2)
              
              grad = grad/norm
              sq_grad = grad**2
              
              w_avg[k] += grad
              w_sq_avg[k] += sq_grad
              w_norm[k] += norm
              
            w_avg[k] = torch.div(w_avg[k], len(teachers))
            w_sq_avg[k] = torch.div(w_sq_avg[k], len(teachers))
            w_norm[k] = torch.div(w_norm[k], len(teachers))
            
        return w_avg, w_sq_avg, w_norm
        
    def construct_models(self, teachers, mean=None, mode="dir"):
      if mode=="gaussian":
        w_avg, w_sq_avg, w_norm= self.compute_mean_sq(teachers)
        w_var = self.compute_var(w_avg, w_sq_avg)      
        
        mean_grad = copy.deepcopy(w_avg)
        for i in range(self.concentrate_num):
          for k in w_avg.keys():
            mean = w_avg[k]
            var = torch.clamp(w_var[k], 1e-6)
            
            eps = torch.randn_like(mean)
            sample_grad = mean + torch.sqrt(var) * eps * 0.1
            mean_grad[k] = (i*mean_grad[k] + sample_grad) / (i+1)
        
        for k in w_avg.keys():
          mean_grad[k] = mean_grad[k]*1.0*w_norm[k] + self.base_model[k].cpu()
          
        return mean_grad  
      
      elif mode=="random":
        num_t = 3
        ts = np.random.choice(teachers, num_t, replace=False)
        mean_grad = {}
        for k in ts[0].keys():
          mean_grad[k] = torch.zeros(ts[0][k].size())
          for i, t in enumerate(ts):
            mean_grad[k]+= t[k]
          
        for k in ts[0].keys():
          mean_grad[k]/=num_t  
          
        return mean_grad
      
      elif mode=="dir":
        proportions = np.random.dirichlet(np.repeat(1.0, len(teachers)))
        mean_grad = {}
        for k in teachers[0].keys():
          mean_grad[k] = torch.zeros(teachers[0][k].size())
          for i, t in enumerate(teachers):
            mean_grad[k]+= t[k]*proportions[i]
          
        for k in teachers[0].keys():
          mean_grad[k]/=sum(proportions)  

        return mean_grad   


def fedbe(args, prev_global_param, central_node, client_nodes, select_list):
    # generate teachers
    nets = []
    base_teachers = []

    fedavg_model = init_model(args.local_model, args).cuda()
    swag_model = init_model(args.local_model, args).cuda()
    fedavg_model.load_state_dict(copy.deepcopy(central_node.model.state_dict()))
    nets.append(copy.deepcopy(fedavg_model))

    for client_idx in select_list:
        client_nodes[client_idx].model.cuda().eval()
        nets.append(copy.deepcopy(client_nodes[client_idx].model))
        base_teachers.append(copy.deepcopy(client_nodes[client_idx].model.state_dict()))

    # generate swag model
    swag_server = SWAG_server(prev_global_param, avg_model=copy.deepcopy(central_node.model.state_dict()), concentrate_num=1)
    w_swag = swag_server.construct_models(base_teachers, mode='gaussian') 
    swag_model.load_state_dict(w_swag)
    nets.append(copy.deepcopy(swag_model))  

    # train and update
    central_node.model.cuda().train()
    for _ in range(args.server_epochs):
        train_loader = central_node.validate_set 

        for _, (data, target) in enumerate(train_loader):
            central_node.optimizer.zero_grad()
            # train model
            data, target = load_data(args, central_node.cluster_id, data, target)
            data, target = data.cuda(), target.cuda()

            output = central_node.model(data)
            teacher_logits = sum([net(data).detach() for net in nets]) / len(nets)

            loss = divergence(output, teacher_logits)
            loss.backward()
            central_node.optimizer.step()

    return central_node

def feddyn(args, central_node, agg_weights, client_nodes, select_list):
    '''
    server function for feddyn
    '''

    # update server's state
    uploaded_models = []
    for i in select_list:
        uploaded_models.append(copy.deepcopy(client_nodes[i].model))

    model_delta = copy.deepcopy(uploaded_models[0])
    for param in model_delta.parameters():
        param.data = torch.zeros_like(param.data)

    for idx, client_model in enumerate(uploaded_models):
        for server_param, client_param, delta_param in zip(central_node.model.parameters(), client_model.parameters(), model_delta.parameters()):
            delta_param.data += (client_param - server_param) * agg_weights[idx]

    for state_param, delta_param in zip(central_node.server_state.parameters(), model_delta.parameters()):
        state_param.data -= args.mu * delta_param

    # aggregation
    central_node.model = copy.deepcopy(uploaded_models[0])
    for param in central_node.model.parameters():
        param.data = torch.zeros_like(param.data)
        
    for idx, client_model in enumerate(uploaded_models):
        for server_param, client_param in zip(central_node.model.parameters(), client_model.parameters()):
            server_param.data += client_param.data.clone() * agg_weights[idx]

    for server_param, state_param in zip(central_node.model.parameters(), central_node.server_state.parameters()):
        server_param.data -= (1/args.mu) * state_param

    return central_node

def fedadam(args, central_node, avg_global_param):
    # hyperparam for fedadam, suggested in their paper, cifar10
    # lr_g = 0.01
    lr_g = float(args.server_lr)
    beta1 = 0.9
    beta2 = 0.99
    w = copy.deepcopy(central_node.model)
    w.load_state_dict(avg_global_param)
    w_t = copy.deepcopy(central_node.model)

    # compute delta_w_t
    delta_w_t = copy.deepcopy(w_t)
    for delta_w_t_param, w_t_param, w_param in zip(delta_w_t.parameters(), w_t.parameters(), w.parameters()):
        delta_w_t_param.data = w_param.data - w_t_param.data

    # compute param
    for delta_w_t_param, m_param, v_param, w_t_param, w_param in zip(delta_w_t.parameters(), central_node.m.parameters(), central_node.v.parameters(), w_t.parameters(), w.parameters()):
        m_param.data = beta1*m_param.data+(1-beta1)*delta_w_t_param.data
        v_param.data = beta2*v_param.data+(1-beta2)*delta_w_t_param.data.pow(2)
        w_param.data = w_t_param.data + lr_g*m_param.data/(torch.sqrt(v_param.data)+1e-5)

    central_node.model = copy.deepcopy(w)
    return central_node
    # for name_param in w_t.keys():
    #     delta_w_t[name_param] = avg_global_param[name_param] - w_t[name_param]

    # for server_param, client_param in zip(central_node.model.parameters(), client_model.parameters()):
    #     server_param.data += client_param.data.clone() * agg_weights[idx]

    # # compute param
    # for key in w_t.keys():
    #     central_node.m[key] = beta1*central_node.m[key]+(1-beta1)*delta_w_t[key]
    #     central_node.v[key] = beta2*central_node.v[key]+(1-beta2)*delta_w_t[key].pow(2)
    #     w[key] = w_t[key]+lr_g*central_node.m[key]/(torch.sqrt(central_node.v[key])+1e-5)

    # central_node.model.load_state_dict(w)
    # return central_node