import torch
import copy
import cvxpy as cp
import numpy as np
from collections import deque
import matplotlib.pyplot as plt
import seaborn as sns
import os

class FedGraphServer:

    def __init__(self, clients, glob_dict):
        self.client_num = len(clients)
        self.clients = clients
        self.glob_dict = glob_dict
        self.graph_matrix = torch.ones(self.client_num, self.client_num) / (self.client_num - 1)  # Collaboration Graph
        self.graph_matrix[range(self.client_num), range(self.client_num)] = 0
        self.dw = deque(maxlen=10)  # Keep only last 10 rounds
        self.collaboration_graph = []


    def plot_collaboration_matrix(self, matrix, round_num=None, save_path="collab_plots"):
        os.makedirs(save_path, exist_ok=True)

        plt.figure(figsize=(8, 6))
        sns.heatmap(matrix.numpy(), annot=True, cmap='viridis', fmt=".2f")
        title = f"Collaboration Matrix"
        if round_num is not None:
            title += f" (Round {round_num})"
        plt.title(title)
        plt.xlabel("Client Index")
        plt.ylabel("Client Index")
        plt.tight_layout()

        filename = f"collab_matrix_round_{round_num}.png" if round_num is not None else "collab_matrix.png"
        filepath = os.path.join(save_path, filename)
        plt.savefig(filepath)
        plt.close()  # Close the plot to avoid memory leak

    def weight_flatten(self, model):
        params = []
        for k in model:
            if 'fc' in k:
                params.append(model[k].reshape(-1))
        params = torch.cat(params)
        return params

    def weight_flatten_all(self, model):
        params = []
        for k in model:
            params.append(model[k].reshape(-1))
        params = torch.cat(params)
        return params

    def cal_model_cosine_difference(self, ckpt, similarity_matric):
        model_similarity_matrix = torch.zeros((self.client_num, self.client_num))

        # Compute current model differences with global model
        current_dw = []
        for cidx in range(self.client_num):
            diff_dict = {}
            model_i = self.clients[cidx].model.state_dict()
            for key in model_i.keys():
                diff_dict[key] = model_i[key] - ckpt[key]
            current_dw.append(diff_dict)

        # Store the latest differences
        self.dw.append(current_dw)

        # Use only the latest one for similarity calculation
        latest_dw = self.dw[-1]

        for i in range(self.client_num):
            for j in range(i, self.client_num):
                if similarity_matric == "all":
                    diff = - torch.nn.functional.cosine_similarity(
                        self.weight_flatten_all(latest_dw[i]).unsqueeze(0),
                        self.weight_flatten_all(latest_dw[j]).unsqueeze(0)
                    )
                    if diff < -0.9:
                        diff = -1.0
                    model_similarity_matrix[i, j] = diff
                    model_similarity_matrix[j, i] = diff
                elif similarity_matric == "fc":
                    diff = - torch.nn.functional.cosine_similarity(
                        self.weight_flatten(latest_dw[i]).unsqueeze(0),
                        self.weight_flatten(latest_dw[j]).unsqueeze(0)
                    )
                    if diff < -0.9:
                        diff = -1.0
                    model_similarity_matrix[i, j] = diff
                    model_similarity_matrix[j, i] = diff

        return model_similarity_matrix

    def update_graph_matrix_neighbor(self, ckpt, similarity_matric, lamba=0.8):
        model_difference_matrix = self.cal_model_cosine_difference(ckpt, similarity_matric)

        total_data_points = sum([self.clients[k].sample_num for k in range(self.client_num)])
        fed_avg_freqs = {k: self.clients[k].sample_num / total_data_points for k in range(self.client_num)}

        n = model_difference_matrix.shape[0]
        p = np.array(list(fed_avg_freqs.values()))
        P = lamba * np.identity(n)
        P = cp.atoms.affine.wraps.psd_wrap(P)
        G = - np.identity(n)
        h = np.zeros(n)
        A = np.ones((1, n))
        b = np.ones(1)

        for i in range(n):
            model_difference_vector = model_difference_matrix[i]
            d = model_difference_vector.numpy()
            q = d - 2 * lamba * p
            x = cp.Variable(n)
            prob = cp.Problem(cp.Minimize(cp.quad_form(x, P) + q.T @ x),
                              [G @ x <= h, A @ x == b])
            prob.solve()
            self.graph_matrix[i, :] = torch.Tensor(x.value)

        return self.graph_matrix

    def aggregate(self, plot=True, round_num=None):
        self.graph_matrix = self.update_graph_matrix_neighbor(self.glob_dict, similarity_matric='all')
        self.collaboration_graph.append(self.graph_matrix)

        if plot:
            self.plot_collaboration_matrix(self.graph_matrix, round_num)


        tmp_client_state_dict = {}
        for cidx in range(self.client_num):
            tmp_client_state_dict[cidx] = copy.deepcopy(self.clients[0].model.state_dict())
            for key in tmp_client_state_dict[cidx]:
                tmp_client_state_dict[cidx][key] = torch.zeros_like(tmp_client_state_dict[cidx][key])

        for cidx in range(self.client_num):
            tmp_client_state = tmp_client_state_dict[cidx]
            aggregation_weight_vector = self.graph_matrix[cidx]
            for cidx1 in range(self.client_num):
                net_para = self.clients[cidx1].model.state_dict()
                for key in tmp_client_state:
                    if 'num_batches_tracked' not in key:
                        tmp_client_state[key] += net_para[key] * aggregation_weight_vector[cidx1]

        for cidx in range(self.client_num):
            self.clients[cidx].model.load_state_dict(tmp_client_state_dict[cidx])
