import torch
import copy

class FedGradSimServer:
    def __init__(self, clients, glob_dict):
        self.client_num = len(clients)
        self.clients = clients
        self.glob_dict = glob_dict
        self.collaboration_graph = []

    def compute_cosine_similarity_matrix(self):
        """
        Uses self.clients[i].grad_direction to compute cosine similarity matrix
        """
        sim_matrix = torch.zeros((self.client_num, self.client_num))
        for i in range(self.client_num):
            g_i = self.clients[i].grad_direction
            for j in range(self.client_num):
                g_j = self.clients[j].grad_direction
                sim_matrix[i, j] = torch.dot(g_i, g_j)
        return sim_matrix

    def calculate_graph_matrix(self, sim_matrix, self_weight=0.3):
        """
        Convert similarity matrix to a normalized graph matrix with self-weight.
        """
        graph_matrix = torch.zeros_like(sim_matrix)
        for i in range(sim_matrix.shape[0]):
            sim_scores = sim_matrix[i].clone()
            sim_scores[i] = 0  # exclude self
            sim_scores = torch.softmax(sim_scores, dim=0)
            sim_scores = sim_scores * (1 - self_weight)
            sim_scores[i] = self_weight
            graph_matrix[i] = sim_scores
        return graph_matrix

    def aggregate(self):
        """
        Personalized model aggregation using gradient direction similarity.
        """
        sim_matrix = self.compute_cosine_similarity_matrix()
        graph_matrix = self.calculate_graph_matrix(sim_matrix)
        self.collaboration_graph.append(graph_matrix)

        # Prepare empty parameter containers
        tmp_client_state_dict = {
            cidx: {
                key: torch.zeros_like(val)
                for key, val in self.clients[0].model.state_dict().items()
            }
            for cidx in range(self.client_num)
        }

        # Aggregate parameters per client using graph weights
        for cidx in range(self.client_num):
            weight_vector = graph_matrix[cidx]
            for j in range(self.client_num):
                model_j = self.clients[j].model.state_dict()
                for key in tmp_client_state_dict[cidx]:
                    if 'num_batches_tracked' not in key:
                        tmp_client_state_dict[cidx][key] += model_j[key] * weight_vector[j]
        

        print(weight_vector)

        # Load new models
        for cidx in range(self.client_num):
            self.clients[cidx].model.load_state_dict(tmp_client_state_dict[cidx])
