import time
import copy
import torch
import numpy as np
import os
import seaborn as sns
import matplotlib.pyplot as plt


def extract_bn_stats(model):
    stats = {}
    for name, module in model.named_modules():
        if hasattr(module, 'global_mean') and hasattr(module, 'global_var'):
            stats[name] = {
                'mean': module.global_mean.clone().detach().flatten(),
                'var': module.global_var.clone().detach().flatten()
            }
    return stats


class FedBNStatServer:
    def __init__(self, clients, glob_dict):
        self.clients = clients
        self.client_num = len(clients)
        self.graph_matrix = torch.ones(self.client_num, self.client_num) / (self.client_num - 1)
        self.graph_matrix[range(self.client_num), range(self.client_num)] = 0
        self.collaboration_graph = []

    def extract_bn_stats_all(self):
        stats_list = []
        for client in self.clients:
            stats = extract_bn_stats(client.model)
            stats_list.append(stats)
        return stats_list

    def compute_bn_distance_matrix(self, client_stats_list):
        dist_matrix = torch.zeros((self.client_num, self.client_num))
        for i in range(self.client_num):
            for j in range(self.client_num):
                if i == j:
                    dist_matrix[i, j] = 0.0
                else:
                    dist = 0.0
                    layer_count = 0
                    for layer in client_stats_list[i]:
                        if layer in client_stats_list[j]:
                            mean_i = client_stats_list[i][layer]['mean']
                            var_i = client_stats_list[i][layer]['var']
                            mean_j = client_stats_list[j][layer]['mean']
                            var_j = client_stats_list[j][layer]['var']

                            mean_dist = torch.norm(mean_i - mean_j, p=2)
                            var_dist = torch.norm(var_i - var_j, p=2)
                            dist += (mean_dist + var_dist) / 2
                            layer_count += 1
                    dist_matrix[i, j] = dist / max(layer_count, 1)
        return dist_matrix

    def compute_collaboration_matrix(self, dist_matrix, temperature=0.3, eps=1e-8):
        # Normalize distances to [0, 1]
        dist_min = dist_matrix.min()
        dist_max = dist_matrix.max()
        normalized_dist = (dist_matrix - dist_min) / (dist_max - dist_min + eps)

        graph_matrix = torch.zeros_like(dist_matrix)

        for i in range(self.client_num):
            similarity = torch.exp(-normalized_dist[i] / temperature)
            similarity[i] = 0  # Exclude self-similarity

            total_similarity = similarity.sum()
            self_weight = 1 / (1 + total_similarity + eps)

            similarity = (1 - self_weight) * similarity / (total_similarity + eps)
            similarity[i] = self_weight

            graph_matrix[i] = similarity

        return graph_matrix



    def plot_collaboration_matrix(self, matrix, round_num=None, save_path="collab_plots"):
        os.makedirs(save_path, exist_ok=True)
        plt.figure(figsize=(8, 6))
        sns.heatmap(matrix.numpy(), annot=True, cmap='viridis', fmt=".2f")
        title = f"BN-Based Collaboration Matrix"
        if round_num is not None:
            title += f" (Round {round_num})"
        plt.title(title)
        plt.xlabel("Client Index")
        plt.ylabel("Client Index")
        plt.tight_layout()
        filename = f"collab_matrix_round_{round_num}.png" if round_num is not None else "collab_matrix.png"
        filepath = os.path.join(save_path, filename)
        plt.savefig(filepath)
        plt.close()

    def aggregate(self, round_num=None, plot=True):
        client_bn_stats = self.extract_bn_stats_all()
        dist_matrix = self.compute_bn_distance_matrix(client_bn_stats)
        self.graph_matrix = self.compute_collaboration_matrix(dist_matrix)
        self.collaboration_graph.append(self.graph_matrix)

        if plot:
            print(dist_matrix)
            self.plot_collaboration_matrix(self.graph_matrix, round_num)

        tmp_client_state_dict = {}
        for cidx in range(self.client_num):
            tmp_client_state_dict[cidx] = copy.deepcopy(self.clients[0].model.state_dict())
            for key in tmp_client_state_dict[cidx]:
                tmp_client_state_dict[cidx][key] = torch.zeros_like(tmp_client_state_dict[cidx][key])

        for cidx in range(self.client_num):
            tmp_client_state = tmp_client_state_dict[cidx]
            aggregation_weight_vector = self.graph_matrix[cidx]
            for cidx1 in range(self.client_num):
                net_para = self.clients[cidx1].model.state_dict()
                for key in tmp_client_state:
                    if 'num_batches_tracked' not in key:
                        tmp_client_state[key] += net_para[key] * aggregation_weight_vector[cidx1]

        for cidx in range(self.client_num):
            self.clients[cidx].model.load_state_dict(tmp_client_state_dict[cidx])