import time
import copy
import torch
import numpy as np
import cvxpy as cp
from collections import deque
import matplotlib.pyplot as plt
import seaborn as sns
import os




class FedAMPServer:
    def __init__(self, clients, glob_dict):
        self.client_num = len(clients)
        self.clients = clients
        self.glob_dict = glob_dict
        self.graph_matrix = torch.ones(self.client_num, self.client_num) / (self.client_num - 1)  # Collaboration Graph
        self.graph_matrix[range(self.client_num), range(self.client_num)] = 0
        self.dw = deque(maxlen=10)  # Keep only last 10 rounds of differences
        self.collaboration_graph = []

    def weight_flatten(self, model):
        params = []
        for k in model:
            if 'fc' in k:
                params.append(model[k].reshape(-1))
        params = torch.cat(params)
        return params

    def weight_flatten_all(self, model):
        params = []
        for k in model:
            params.append(model[k].reshape(-1))
        params = torch.cat(params)
        return params

    def cal_model_cosine_difference(self, ckpt):
        model_similarity_matrix = torch.zeros((self.client_num, self.client_num))

        current_dw = []
        for cidx in range(self.client_num):
            diff_dict = {}
            model_i = self.clients[cidx].model.state_dict()
            for key in model_i.keys():
                diff_dict[key] = model_i[key] - ckpt[key]
            current_dw.append(diff_dict)

        self.dw.append(current_dw)
        latest_dw = self.dw[-1]

        for i in range(self.client_num):
            for j in range(i, self.client_num):
                if i == j:
                    similarity = 0
                else:
                    similarity = torch.norm(
                        self.weight_flatten_all(latest_dw[i]).unsqueeze(0) -
                        self.weight_flatten_all(latest_dw[j]).unsqueeze(0), p=2
                    )
                model_similarity_matrix[i, j] = similarity
                model_similarity_matrix[j, i] = similarity

        return model_similarity_matrix

    def update_graph_matrix_neighbor(self, ckpt):
        model_difference_matrix = self.cal_model_cosine_difference(ckpt)
        graph_matrix = self.calculate_graph_matrix(model_difference_matrix)
        return graph_matrix

    def calculate_graph_matrix(self, model_difference_matrix):
        graph_matrix = torch.zeros((model_difference_matrix.shape[0], model_difference_matrix.shape[0]))
        self_weight = 0.3
        for i in range(model_difference_matrix.shape[0]):
            weight = torch.exp(-model_difference_matrix[i])
            weight[i] = 0
            weight = (1 - self_weight) * weight / weight.sum()
            weight[i] = self_weight
            graph_matrix[i] = weight
        return graph_matrix

    

    def plot_collaboration_matrix(self, matrix, round_num=None, save_path="collab_plots"):
        os.makedirs(save_path, exist_ok=True)

        plt.figure(figsize=(8, 6))
        sns.heatmap(matrix.numpy(), annot=True, cmap='viridis', fmt=".2f")
        title = f"Collaboration Matrix"
        if round_num is not None:
            title += f" (Round {round_num})"
        plt.title(title)
        plt.xlabel("Client Index")
        plt.ylabel("Client Index")
        plt.tight_layout()

        filename = f"collab_matrix_round_{round_num}.png" if round_num is not None else "collab_matrix.png"
        filepath = os.path.join(save_path, filename)
        plt.savefig(filepath)
        plt.close()  # Close the plot to avoid memory leak

    def aggregate(self, plot=True, round_num=None):
        self.graph_matrix = self.update_graph_matrix_neighbor(self.glob_dict)
        self.collaboration_graph.append(self.graph_matrix)

        if plot:
            self.plot_collaboration_matrix(self.graph_matrix, round_num)

        tmp_client_state_dict = {}
        for cidx in range(self.client_num):
            tmp_client_state_dict[cidx] = copy.deepcopy(self.clients[0].model.state_dict())
            for key in tmp_client_state_dict[cidx]:
                tmp_client_state_dict[cidx][key] = torch.zeros_like(tmp_client_state_dict[cidx][key])

        for cidx in range(self.client_num):
            tmp_client_state = tmp_client_state_dict[cidx]
            aggregation_weight_vector = self.graph_matrix[cidx]

            for cidx1 in range(self.client_num):
                net_para = self.clients[cidx1].model.state_dict()
                for key in tmp_client_state:
                    if 'num_batches_tracked' not in key:
                        tmp_client_state[key] += net_para[key] * aggregation_weight_vector[cidx1]

        for cidx in range(self.client_num):
            self.clients[cidx].model.load_state_dict(tmp_client_state_dict[cidx])
