import copy
import csv
import logging
import random
import sys
from scipy.stats import entropy
import torch
from torch.distributions import Normal
from collections import defaultdict
from pathlib import Path
from tqdm import trange
import numpy as np
from Models.CNNs import *
from Models.resnets import *
from Models.Googlenet import *
from Models.Mobilenet import *
from node import BaseNodes
from test import eval_acc

def train(data_name: str, data_path: str, classes_per_node: int, num_nodes: int, fraction: float,
          steps: int, epochs: int, optim: str, lr: float, inner_lr: float,
          embed_lr: float, wd: float, inner_wd: float, embed_dim: int, hyper_hid: int,
          n_hidden: int, n_kernels: int, bs: int, device, eval_every: int, save_path: Path,
          seed: int, total_classes: int) -> None:
    ###############################
    # init nodes, hnet, local net #
    ###############################
    nodes = BaseNodes(data_name, data_path, num_nodes, classes_per_node=classes_per_node,
                      batch_size=bs)

    # -------compute aggregation weights-------------#
    train_sample_count = nodes.train_sample_count
    eval_sample_count = nodes.eval_sample_count
    test_sample_count = nodes.test_sample_count

    client_sample_count = [train_sample_count[i] + eval_sample_count[i] + test_sample_count[i] for i in
                           range(len(train_sample_count))]
    # -----------------------------------------------#
    print(data_name)
    if data_name == "cifar10":
        model_1 = CNN_1(n_kernels=n_kernels, out_dim=10)
        model_2 = CNN_2(n_kernels=n_kernels, out_dim=10)
        model_3 = CNN_3(n_kernels=n_kernels, out_dim=10)
        model_4 = CNN_4(n_kernels=n_kernels, out_dim=10)
        model_5 = CNN_5(n_kernels=n_kernels, out_dim=10)
        model_6 = ResNet18(in_channels=3, n_kernels=64, out_dim=10)
        model_7 = ResNet34(in_channels=3, n_kernels=64, out_dim=10)
        model_8 = ResNet50(in_channels=3, n_kernels=64, out_dim=10)
        model_9 = ResNet101(in_channels=3, n_kernels=64, out_dim=10)
        model_10 = ResNet152(in_channels=3, n_kernels=64, out_dim=10)
        model_11 = GoogleNet(in_channels=3, n_kernels=64, out_dim=10)
        model_12 = MobileNet(in_channels=3, n_kernels=64, out_dim=10)

    elif data_name == "cifar100":
        model_1 = CNN_1(n_kernels=n_kernels, out_dim=100)
        model_2 = CNN_2(n_kernels=n_kernels, out_dim=100)
        model_3 = CNN_3(n_kernels=n_kernels, out_dim=100)
        model_4 = CNN_4(n_kernels=n_kernels, out_dim=100)
        model_5 = CNN_5(n_kernels=n_kernels, out_dim=100)
        model_6 = ResNet18(in_channels=3, n_kernels=64, out_dim=100)
        model_7 = ResNet34(in_channels=3, n_kernels=64, out_dim=100)
        model_8 = ResNet50(in_channels=3, n_kernels=64, out_dim=100)
        model_9 = ResNet101(in_channels=3, n_kernels=64, out_dim=100)
        model_10 = ResNet152(in_channels=3, n_kernels=64, out_dim=100)
        model_11 = GoogleNet(in_channels=3, n_kernels=64, out_dim=100)
        model_12 = MobileNet(in_channels=3, n_kernels=64, out_dim=100)

    elif data_name == "mnist":
        net = CNN_1(n_kernels=n_kernels)
    elif data_name == "tinyimagenet":
        model_1 = CNN_1(n_kernels=n_kernels, out_dim=200)
        model_2 = CNN_2(n_kernels=n_kernels, out_dim=200)
        model_3 = CNN_3(n_kernels=n_kernels, out_dim=200)
        model_4 = CNN_4(n_kernels=n_kernels, out_dim=200)
        model_5 = CNN_5(n_kernels=n_kernels, out_dim=200)
        model_6 = ResNet18(in_channels=3, n_kernels=64, out_dim=200)
        model_7 = ResNet34(in_channels=3, n_kernels=64, out_dim=200)
        model_8 = ResNet50(in_channels=3, n_kernels=64, out_dim=200)
        model_9 = ResNet101(in_channels=3, n_kernels=64, out_dim=200)
        model_10 = ResNet152(in_channels=3, n_kernels=64, out_dim=200)
        model_11 = GoogleNet(in_channels=3, n_kernels=64, out_dim=200)
        model_12 = MobileNet(in_channels=3, n_kernels=64, out_dim=200)

    else:
        raise ValueError("choose data_name from ['cifar10', 'cifar100', 'mnist', 'tinyimagenet']")

    net_1 = model_1.to(device)
    net_2 = model_2.to(device)
    net_3 = model_3.to(device)
    net_4 = model_4.to(device)
    net_5 = model_5.to(device)

    net_set = [net_1, net_2, net_3, net_4,net_5] # hete
    # net_set = [net_1, net_1, net_1, net_1, net_1]  # homo
    criteria = torch.nn.CrossEntropyLoss()
    step_iter = trange(steps)

    client_acc = defaultdict()
    client_para = defaultdict()

    # initial every client model
    for i in range(num_nodes):
        client_acc[i] = 0
        client_para[i] = copy.deepcopy(net_set[i % 5].state_dict())

    output_file_path = f'kge_layer_hete_client{num_nodes}_{fraction}_result_{data_name}.out'

    with open(output_file_path, 'w', encoding='utf-8') as f:

        sys.stdout = f
        save_path = Path(save_path)
        save_path.mkdir(parents=True, exist_ok=True)
        with open(str(save_path / f"kge_layer_hete_client{num_nodes}_{fraction}_result_{data_name}.csv"), 'w', newline='') as file:
            mywriter = csv.writer(file, delimiter=',')
            for step in step_iter:  # step is round
                round_id = step
                frac = fraction
                select_nodes = random.sample(range(num_nodes), int(frac * num_nodes))
                all_local_trained_loss = []
                all_local_trained_acc = []
                all_global_loss = []
                all_global_acc = []
                results = []

                logging.info(f'#----Round:{step}----#')

                client_updates = defaultdict(dict)
                for c in select_nodes:
                    node_id = c
                    print(f'client id: {node_id}')

                    net = net_set[node_id % 5]  # client model structure
                    net.load_state_dict(client_para[node_id])  # download state
                    optimizer = torch.optim.SGD(params=net.parameters(), lr=lr, momentum=0.9, weight_decay=wd)
                    pre_training_state = copy.deepcopy(net.state_dict())

                    # local training
                    net.train()
                    for i in range(epochs):
                        for j, batch in enumerate(nodes.train_loaders[node_id], 0):
                            img, label = tuple(t.to(device) for t in batch)
                            optimizer.zero_grad()
                            pred, rep = net(img)
                            loss = criteria(pred, label)
                            loss.backward()
                            torch.nn.utils.clip_grad_norm_(net.parameters(), 50)
                            optimizer.step()

                    post_training_state = copy.deepcopy(net.state_dict())
                    # only save two client's layers with biggest changes e.g. fc3, {'weight':[para], 'bias':[para]}
                    client_updates[node_id] = compute_largest_weight_changes(pre_training_state, post_training_state)
                    # save client's parameters
                    client_para[node_id] = copy.deepcopy(post_training_state)

                    # evaluate trained local model
                    trained_loss, trained_acc = eval_acc(net, nodes.test_loaders[node_id], criteria, device)
                    all_local_trained_loss.append(trained_loss.cpu().item())
                    all_local_trained_acc.append(trained_acc)
                    client_acc[node_id] = trained_acc
                    logging.info(f'Round {step} | client {node_id} acc: {client_acc[node_id]}')
                    del net
                mean_trained_loss = round(np.mean(all_local_trained_loss), 4)
                mean_trained_acc = round(np.mean(all_local_trained_acc), 4)
                mean_global_loss = round(np.mean(all_global_loss), 4)
                mean_global_acc = round(np.mean(all_global_acc), 4)
                results.append(
                    [mean_global_loss, mean_global_acc, mean_trained_loss, mean_trained_acc] + [round(i, 4) for i in
                                                                                                client_acc.values()])
                mywriter.writerows(results)
                file.flush()
                logging.info(
                    f'Round:{step} | mean_global_loss:{mean_global_loss} | mean_global_acc:{mean_global_acc} | mean_trained_loss:{mean_trained_loss} | mean_trained_acc:{mean_trained_acc}')

                # =========== aggregate ============
                global_model = CNN_1(n_kernels=n_kernels, out_dim=total_classes).to(device)

                # revise the structure of client_updates[node_id]
                client_layers = {}
                for client_id, layers in client_updates.items():
                    for layer_name, paras in layers.items():
                        if layer_name not in client_layers:
                            client_layers[layer_name] = []
                        client_layers[layer_name].append((client_id, paras))

                for layer_name, updates in client_layers.items():
                    print(f"Aggregating layer {layer_name}...")
                    # generate input_data dynamicly according to the corresponding layer of global model
                    try:
                        global_layer = getattr(global_model, layer_name)
                        if isinstance(global_layer, torch.nn.Conv2d):
                            input_data = gaussian_noise((1, global_layer.in_channels, 32, 32)).to(device)
                        elif isinstance(global_layer, torch.nn.Linear):
                            input_data = gaussian_noise((1, global_layer.in_features)).to(device)
                        else:
                            input_data = gaussian_noise((1, 3, 32, 32)).to(device)  # default data
                    except Exception as e:
                        print(f"Error in input_data generation: {e}")
                        continue

                    client_outputs = []

                    # Iterate client and weight to update
                    for client_id, client_layer_paras in updates:
                        layer_c = copy.deepcopy(getattr(global_model,
                                          layer_name))  # dimension to change the client layer into the global model layer, so the global model is adopted as the structure of the layer
                        try:
                            client_layer_weight = client_layer_paras['weight']
                            client_layer_bias = client_layer_paras['bias']

                            # get weight and bias of the global model
                            global_layer_weight = global_model.state_dict()[f"{layer_name}.weight"]
                            global_layer_bias = global_model.state_dict().get(f"{layer_name}.bias", None)

                            # adjust client model's weight dimension to global model's
                            if client_layer_weight.shape != global_layer_weight.shape:
                                if len(global_layer_weight.shape) == 4:  # Conv2D weight
                                    # check input is 4 dimensional tensor or not
                                    if len(client_layer_weight.shape) == 3:  # (C, H, W) -> (1, C, H, W)
                                        client_layer_weight = client_layer_weight.unsqueeze(0)
                                    # ensure input is 4 dimensional tensor
                                    if client_layer_weight.ndim != 4:
                                        raise ValueError(
                                            f"Expected 4D tensor for Conv2D, got {client_layer_weight.ndim}D tensor")
                                    # adjust to the dimension of global layer
                                    client_layer_weight_resized = torch.nn.functional.interpolate(
                                        client_layer_weight,  # input tensor
                                        size=global_layer_weight.shape[-2:],  # goal (H, W)
                                        mode='bilinear',
                                        align_corners=False
                                    )

                                    # check and adjust channels (C_in and C_out)
                                    adjusted_weight = torch.zeros_like(global_layer_weight)
                                    min_out_channels = min(client_layer_weight_resized.shape[0],
                                                           global_layer_weight.shape[0])
                                    min_in_channels = min(client_layer_weight_resized.shape[1],
                                                          global_layer_weight.shape[1])
                                    adjusted_weight[:min_out_channels, :min_in_channels, :,
                                    :] = client_layer_weight_resized[
                                         :min_out_channels, :min_in_channels, :, :
                                         ]

                                    client_layer_weight = adjusted_weight

                                elif len(global_layer_weight.shape) == 2:  # Linear weight
                                    # crop or fill to adjust dimension
                                    adjusted_weight = torch.zeros_like(global_layer_weight)
                                    min_rows = min(client_layer_weight.shape[0], global_layer_weight.shape[0])
                                    min_cols = min(client_layer_weight.shape[1], global_layer_weight.shape[1])
                                    adjusted_weight[:min_rows, :min_cols] = client_layer_weight[:min_rows, :min_cols]
                                    client_layer_weight = adjusted_weight
                                else:
                                    raise ValueError(
                                        f"Unsupported weight shape for layer {layer_name}: {global_layer_weight.shape}")

                            # adjust client model's bias dimension to global model's
                            if global_layer_bias is not None and client_layer_bias is not None:
                                # get bias tensor
                                if isinstance(client_layer_bias, tuple):
                                    client_layer_bias = client_layer_bias[1]
                                if isinstance(global_layer_bias, tuple):
                                    global_layer_bias = global_layer_bias[1]
                                # ensure bias is tensor
                                if isinstance(global_layer_bias, torch.Tensor) and isinstance(client_layer_bias,
                                                                                              torch.Tensor):
                                    # if shapes are different, adjust client's bias如
                                    if client_layer_bias.shape != global_layer_bias.shape:
                                        # print(f"Adjusting bias shapes: client {client_layer_bias.shape}, global {global_layer_bias.shape}")
                                        # create zero tensor with global bias shape
                                        adjusted_bias = torch.zeros_like(global_layer_bias, device=client_layer_bias.device)
                                        # fill or crop client's bias
                                        min_len = min(client_layer_bias.numel(), global_layer_bias.numel())
                                        adjusted_bias[:min_len] = client_layer_bias[:min_len]
                                        # update client's bias
                                        client_layer_bias = adjusted_bias
                                else:
                                    print(
                                        f"Bias is not a tensor. Client: {type(client_layer_bias)}, Global: {type(global_layer_bias)}")

                            # client layer's revised weight and bias
                            layer_c.weight.data.copy_(client_layer_weight)
                            if global_layer_bias is not None:
                                layer_c.bias.data.copy_(client_layer_bias)
                            # print(f"Updated client {client_id} layer {layer_name}: Weight shape {client_layer_weight.shape}, Bias shape {client_layer_bias.shape if client_layer_bias is not None else 'None'}")
                        except AttributeError:
                            print(f"Error: Layer {layer_name} not found in the model.")
                        client_layer_output = layer_c(input_data)  # get client layer's output
                        client_outputs.append(client_layer_output)
                        # print(f"Client {client_id} output for {layer_name}: {client_layer_output.shape}")
                        del layer_c
                    # global output
                    try:
                        layer_g = getattr(global_model, layer_name)
                        global_layer_output = layer_g(input_data)  # get global layer's output
                        # global_outputs.append(global_layer_output)
                        # print(f"Global model output for {layer_name}")  #: {global_layer_output.shape}")
                    except AttributeError:
                        print(f"Error: Layer {layer_name} not found in the global model.")
                        continue

                    # calculate difference between client's output and global output
                    loss = torch.stack(
                        [torch.mean((global_layer_output - client_output) ** 2) for client_output in client_outputs]).sum()

                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()
                    del layer_g, global_layer_output, client_outputs

                # ====== Downlink: global weight and bias to client ======
                global_model.eval()

                # update client models
                for node_id, layers in client_updates.items():  # iterate layers uploaded by each client
                    all_matches = []  # store all candidate layers and js divergence

                    for client_layer_name, client_params in layers.items():
                        client_layer_weight = client_params['weight']
                        client_layer_bias = client_params['bias']

                        # iterate global layers
                        for global_layer_name, global_layer_weight in global_model.state_dict().items():
                            # comparison condition: both fc or both conv
                            if not (('fc' in client_layer_name and 'fc' in global_layer_name) or
                                    ('conv' in client_layer_name and 'conv' in global_layer_name)):
                                continue

                            if '.weight' not in global_layer_name:
                                continue  # just compare weight layers

                            global_layer_bias_name = global_layer_name.replace('.weight', '.bias')
                            global_layer_bias = global_model.state_dict().get(global_layer_bias_name, None)

                            # adjust global layer's dimension to be same as client's
                            if global_layer_weight.shape != client_layer_weight.shape:
                                if len(global_layer_weight.shape) == 4 and len(
                                        client_layer_weight.shape) == 4:  # Conv2D -> Conv2D
                                    # print("Adjusting Conv2D (global) to Conv2D (client)...")
                                    # ensure input is 4 dimensional tensor
                                    if global_layer_weight.ndim == 3:
                                        global_layer_weight = global_layer_weight.unsqueeze(0)  # add dimension
                                    # adjust conv weight dimension
                                    global_layer_weight = torch.nn.functional.interpolate(
                                        global_layer_weight,  # ensure input is 4 dimensional tensor (N, C, H, W)
                                        size=client_layer_weight.shape[-2:],  # adjust to client's layer dimension (H, W)
                                        mode='bilinear',
                                        align_corners=False
                                    ).squeeze(0)  # remove the batch dimension to restore the original shape

                                elif len(global_layer_weight.shape) == 2 and len(
                                        client_layer_weight.shape) == 2:  # Linear -> Linear
                                    # print("Adjusting Linear (global) to Linear (client)...")
                                    adjusted_weight = torch.zeros_like(client_layer_weight)
                                    min_rows = min(global_layer_weight.shape[0], client_layer_weight.shape[0])
                                    min_cols = min(global_layer_weight.shape[1], client_layer_weight.shape[1])
                                    adjusted_weight[:min_rows, :min_cols] = global_layer_weight[:min_rows, :min_cols]
                                    global_layer_weight = adjusted_weight
                                    del adjusted_weight

                                elif len(global_layer_weight.shape) == 4 and len(
                                        client_layer_weight.shape) == 2:  # Conv2D -> Linear
                                    # print("Adjusting Conv2D (global) to Linear (client)...")
                                    # flatten conv layer into a fc layer shape
                                    global_layer_weight_flat = global_layer_weight.view(
                                        global_layer_weight.shape[0], -1
                                    )
                                    adjusted_weight = torch.zeros_like(client_layer_weight)
                                    min_rows = min(global_layer_weight_flat.shape[0], client_layer_weight.shape[0])
                                    min_cols = min(global_layer_weight_flat.shape[1], client_layer_weight.shape[1])
                                    adjusted_weight[:min_rows, :min_cols] = global_layer_weight_flat[:min_rows, :min_cols]
                                    global_layer_weight = adjusted_weight
                                    del global_layer_weight_flat, adjusted_weight

                                elif len(global_layer_weight.shape) == 2 and len(
                                        client_layer_weight.shape) == 4:  # Linear -> Conv2D
                                    # print("Adjusting Linear (global) to Conv2D (client)...")
                                    target_out_channels = client_layer_weight.shape[0]
                                    target_in_channels = client_layer_weight.shape[1]
                                    target_kernel_height = client_layer_weight.shape[2]
                                    target_kernel_width = client_layer_weight.shape[3]
                                    # calculate target size
                                    target_size = target_out_channels * target_in_channels * target_kernel_height * target_kernel_width
                                    # crop or fill global_layer_weight
                                    global_flat = global_layer_weight.flatten()
                                    if global_flat.numel() < target_size:  # fill
                                        padded_flat = torch.zeros(target_size, device=global_layer_weight.device)
                                        padded_flat[:global_flat.numel()] = global_flat
                                        global_flat = padded_flat
                                        del padded_flat
                                    elif global_flat.numel() > target_size:  # crop
                                        global_flat = global_flat[:target_size]
                                    # remodel to shape of target conv layer
                                    global_layer_weight = global_flat.view(
                                        target_out_channels, target_in_channels, target_kernel_height,
                                        target_kernel_width
                                    )
                                    del global_flat

                                else:
                                    raise ValueError(
                                        f"Unsupported global layer weight shape: {global_layer_weight.shape}, "
                                        f"client layer weight shape: {client_layer_weight.shape}")

                            if global_layer_bias is not None and client_layer_bias is not None:
                                # if bias is a tuple, extract the actual tensor
                                if isinstance(global_layer_bias, tuple):
                                    global_layer_bias = global_layer_bias[1]
                                if isinstance(client_layer_bias, tuple):
                                    client_layer_bias = client_layer_bias[1]
                                # check shape
                                if global_layer_bias.shape != client_layer_bias.shape:
                                    # print("Adjusting bias...")
                                    adjusted_bias = torch.zeros_like(client_layer_bias)
                                    min_len = min(global_layer_bias.shape[0], client_layer_bias.shape[0])
                                    adjusted_bias[:min_len] = global_layer_bias[:min_len]
                                    global_layer_bias = adjusted_bias

                            # KGE
                            knowledge_gain_entropy = calculate_knowledge_gain_entropy(global_layer_weight,
                                                                                      client_layer_weight)
                            print(
                                f"client layer {client_layer_name} and global layer {global_layer_name} knowledge gain entropy: {knowledge_gain_entropy}")

                            # add
                            if knowledge_gain_entropy > 0:
                                all_matches.append((knowledge_gain_entropy, client_layer_name, global_layer_name,
                                                    global_layer_weight, global_layer_bias))

                    all_matches = sorted(all_matches, key=lambda x: x[0], reverse=True)[:2]  # sort
                    similar_layers = []
                    seen_client_layers = set()  # client_layer_name
                    # if duplicate, just use the first one
                    for match in all_matches:
                        _, client_layer_name, global_layer_name, global_layer_weight, global_layer_bias = match
                        if client_layer_name not in seen_client_layers:
                            similar_layers.append(
                                (client_layer_name, (global_layer_name, global_layer_weight, global_layer_bias))
                            )
                            seen_client_layers.add(client_layer_name)

                    del all_matches

                    # update clients' similar layers
                    for client_layer_name, (global_layer_name, global_layer_weight, global_layer_bias) in similar_layers:
                        print(f"Updating client {node_id} layer {client_layer_name} with global layer {global_layer_name}")

                        client_layer_weight = layers[client_layer_name]['weight']
                        client_layer_bias = layers[client_layer_name]['bias']

                        if isinstance(global_layer_bias, tuple):
                            global_layer_bias = global_layer_bias[1]

                        if isinstance(client_layer_bias, tuple):
                            client_layer_bias = client_layer_bias[1]

                        with torch.no_grad():
                            # adjust weight
                            global_weight_tensor = global_layer_weight.flatten().to(device)
                            client_weight_tensor = client_layer_weight.flatten().to(device)

                            # adjust global weight by client's weight
                            adjusted_weights = torch.zeros_like(client_weight_tensor)
                            min_len = min(global_weight_tensor.numel(), client_weight_tensor.numel())
                            adjusted_weights[:min_len] = global_weight_tensor[:min_len]

                            # adjust global shape
                            adjusted_weights = adjusted_weights.view_as(client_layer_weight)

                            # adjust bias
                            if global_layer_bias is not None:
                                global_bias_tensor = global_layer_bias.flatten().to(device)
                                client_bias_tensor = client_layer_bias.flatten().to(device)

                                adjusted_bias = torch.zeros_like(client_bias_tensor)
                                min_len_bias = min(global_bias_tensor.numel(), client_bias_tensor.numel())
                                adjusted_bias[:min_len_bias] = global_bias_tensor[:min_len_bias]

                                # adjust goal shape
                                adjusted_bias = adjusted_bias.view_as(client_layer_bias)

                            # update client's layer
                            client_para[node_id][client_layer_name + '.weight'].copy_(adjusted_weights)
                            del global_weight_tensor, adjusted_weights

                            if global_layer_bias is not None:
                                client_para[node_id][client_layer_name + '.bias'].copy_(adjusted_bias)
                                del global_bias_tensor, adjusted_bias
                    del similar_layers



def compute_largest_weight_changes(pre_train_state, post_train_state):
    updates = []

    for name, post_param in post_train_state.items():
        # skip bias
        if 'weight' in name:
            pre_param = pre_train_state.get(name)
            if pre_param is not None and isinstance(post_param, torch.Tensor) and isinstance(pre_param, torch.Tensor):
                # compute L2 norm
                l2_norm = torch.norm(post_param - pre_param, p=2).item()
                # only save layer_name and L2 norm
                updates.append((name, l2_norm))

    # sort by L2 norm and select top two
    largest_updates = sorted(updates, key=lambda x: x[1], reverse=True)[:2]

    result = {}
    for weight_name, l2_norm in largest_updates:
        # get weight and bias
        weight_param = post_train_state[weight_name]
        bias_name = weight_name.replace('weight', 'bias')
        bias_param = post_train_state.get(bias_name, None)

        # delete '.weight' or '.bias'  e.g. fc2 or conv2
        layer_name = weight_name.replace('.weight', '')

        result[layer_name] = {
            'weight': weight_param,
            'bias': (bias_name, bias_param) if bias_param is not None else None
        }

        # release
        del weight_param, bias_param

    return result


def gaussian_noise(size, dtype=torch.float32):
    """generate gaussian noise which conforms to N(0,1)"""
    if isinstance(size, torch.Size):  # ensure size is a tuple which has integer
        size = tuple(size)
    return Normal(torch.zeros(size, dtype=dtype),
                  torch.ones(size, dtype=dtype)).sample()


def calculate_entropy(weight_tensor):
    """Shannon Entropy"""
    weight_flat = weight_tensor.flatten().cpu().numpy()
    weight_hist, _ = np.histogram(weight_flat, bins=50, density=True)
    weight_hist = weight_hist + 1e-8  # not 0
    return entropy(weight_hist)


def calculate_knowledge_gain_entropy(global_weight, client_weight):
    """KGE"""
    global_entropy = calculate_entropy(global_weight)
    client_entropy = calculate_entropy(client_weight)
    knowledge_gain_entropy = global_entropy - client_entropy

    return knowledge_gain_entropy