import copy

class FedAvgServer:
    """
    Federated averaging
    :param w: list of client model parameters
    :return: updated server model parameters
    """
    def __init__(self, clients):
        self.client_num = len(clients)
        self.clients = clients

    def aggregate(self, weight_avg=None):
        if weight_avg == None:
            weight_avg = [1/self.client_num for i in range(self.client_num)]
        
        w = []
        for idx in range(self.client_num):
            w.append(copy.deepcopy(self.clients[idx].model.state_dict()))
        
        w_avg = copy.deepcopy(self.clients[0].model.state_dict())
        for k in w_avg.keys():
            w_avg[k] = w_avg[k].cuda() * weight_avg[0]
            
        for k in w_avg.keys():
            for i in range(1, self.client_num):
                w_avg[k] = w_avg[k].cuda() + w[i][k].cuda() * weight_avg[i]
            #w_avg[k] = torch.div(w_avg[k].cuda(), len(w)) 
        
        for cidx in range(self.client_num):
            self.clients[cidx].model.load_state_dict(w_avg)
        
        return w_avg
