import torch 
import copy

class FedAvgMServer:
    def __init__(self, clients, glob_dict):
        self.epoch = -1
        self.client_num = len(clients)
        self.clients = clients
        self.glob_dict = glob_dict
        self.graph_matrix = torch.ones(self.client_num, self.client_num) / (self.client_num - 1)  # Collaboration Graph
        self.graph_matrix[range(self.client_num), range(self.client_num)] = 0
        self.dw = []
        self.momentum_weight = None

    def weight_flatten(self, layer):

        return layer.reshape(-1)

    def aggregate(self, train_round, beta=0.9):
        if train_round == 0:
            total_samples = float(sum([client.sample_num for client in self.clients]))
            w_avg = copy.deepcopy(self.clients[0].model.state_dict())
            for key in w_avg.keys():
                for cidx in range(self.client_num):
                    if cidx == 0:
                        continue
                    else:
                        w_avg[key] += self.clients[cidx].model.state_dict()[key]
                w_avg[key] = torch.div(w_avg[key], self.client_num)
            self.momentum_weight = copy.deepcopy(self.glob_dict)
            for key in w_avg.keys():
                self.momentum_weight[key] = w_avg[key] - self.glob_dict[key]
                self.glob_dict[key] = w_avg[key]
            for client in self.clients:
                client.model.load_state_dict(w_avg)
        else:
            w_avg = copy.deepcopy(self.clients[0].model.state_dict())
            for key in w_avg.keys():
                for cidx in range(self.client_num):
                    if cidx == 0:
                        continue
                    else:
                        w_avg[key] += self.clients[cidx].model.state_dict()[key]
                w_avg[key] = torch.div(w_avg[key], self.client_num)
            for key in w_avg.keys():
                self.momentum_weight[key] = (1-beta) * (w_avg[key] - self.glob_dict[key]) + beta * self.momentum_weight[key]
                w_avg[key] = self.glob_dict[key] + self.momentum_weight[key]
                self.glob_dict[key] = w_avg[key]
            for client in self.clients:
                client.model.load_state_dict(w_avg)