#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Python version: 3.6

try:
    import matplotlib
    matplotlib.use('Agg')
    import matplotlib.pyplot as plt
except Exception:  # pragma: no cover
    matplotlib = None
    plt = None

try:
    from torch import nn
except Exception:  # pragma: no cover
    nn = None
import copy
import numpy as np
from torchvision import datasets, transforms
import torch
from models.vggmodule import vgg_unlrean
from models.Resnet import ResNet18_Pruning, ResNet50_Pruning
from models.Nets import DigitModel_Pruning
import torch.nn.functional as F
import gc

def print_model_weights(model):
    for name, param in model.named_parameters():
        if param.requires_grad:
            print(f'Layer: {name} | Size: {param.size()}  \n')


def test_un(test_loader,net_glob):
    net_glob.eval()
    test_loss = 0
    correct = 0
    total = 0
    targets = []
    loss_fun = nn.CrossEntropyLoss()

    if len(test_loader) == 0:
        return 0, 0

    for data, target in test_loader:
        data = data.to(args.device).float()
        target = target.to(args.device).long()

        for i in range(len(data)):
            sample_target = target[i].item()

            # if sample_target in delete_class:

            targets.append(sample_target)
            output = net_glob(data[i].unsqueeze(0))

            test_loss += loss_fun(output, target[i].unsqueeze(0)).item()

            pred = output.data.max(1)[1]

            if pred.item() == sample_target:
                correct += 1

            total += 1

    acc_test = correct / total
    return test_loss / total, acc_test


#
def channel_flatten_avg(label_to_features_convN):
    Ave_convN = {}

    for i in label_to_features_convN:
        Ave = []
        for j in range(len(label_to_features_convN[i])):
            label_to_features_convN[i][j] = label_to_features_convN[i][j].flatten()
            Ave.append(label_to_features_convN[i][j])

        if len(Ave) == 0:
            Ave_convN[i] = None
            continue

        Ave_res = list(np.mean(Ave, axis=0))
        Ave_convN[i] = Ave_res

    return Ave_convN


def glob_channel_weight_average(data):
    # Initialize an empty dictionary to store the sum of the channels and the count
    channel_sum = {}
    channel_count = {}

    # Iterate over each client
    for client in data.values():
        # Iterate over each kernel
        for i, kernel in enumerate(client):
            # Iterate over each class and its channels in the kernel
            for class_id, channels in kernel.items():
                # If this class does not exist in the channel_sum and channel_count, add it
                if (i, class_id) not in channel_sum:
                    channel_sum[(i, class_id)] = [0] * len(channels)
                    channel_count[(i, class_id)] = 0

                # Add the channels to the channel_sum and increment the count
                channel_sum[(i, class_id)] = [x + y for x, y in zip(channel_sum[(i, class_id)], channels)]
                channel_count[(i, class_id)] += 1
                
    # Compute the maximum number of kernels across all clients to ensure
    # consistent dimensions for the aggregated structures
    max_kernels = max(len(c) for c in data.values())

    # Compute the average for each class in each kernel
    avg_data = [{} for _ in range(max_kernels)]
    for (i, class_id), total in channel_sum.items():
        avg_data[i][class_id] = [x / channel_count[(i, class_id)] for x in total]
        # print((i, class_id),len(avg_data[i][class_id]) ,avg_data[i][class_id][5])

    # Prepare weight_data with the same maximum kernel count for consistency
    weight_data = [[{} for _ in range(max_kernels)] for _ in range(len(data))]
    for client_idx, (client_id, client) in enumerate(data.items()):
        for i, kernel in enumerate(client):
            for class_id, channels in kernel.items():
                total_sum = sum(channel_sum[(i, class_id)])
                weight_data[client_idx][i][class_id] = [x / total_sum * 100 for x in channels]

    # print(weight_data[1][1][1][1])
    # exit()
    return avg_data, weight_data


def calculate_TF_values(class_labels, Avg_conv_global):
    """Compute term frequency values for each class across feature groups.

    Parameters
    ----------
    class_labels : Iterable[int]
        Identifiers for the classes to evaluate.
    Avg_conv_global : Dict[str, Dict[int, List[float]]]
        Mapping of feature group name to class-wise channel averages.

    Returns
    -------
    Dict[str, List[float]]
        Dictionary keyed by ``TF_<class>_<feature>`` containing the
        normalised term frequencies for that class and feature group.
    """

    TF_all = {}
    for label in class_labels:
        for name, class_dict in Avg_conv_global.items():
            weights = class_dict[label]
            total = sum(weights)
            key = f"TF_{label}_{name}"
            TF_all[key] = [w / total for w in weights]

    return TF_all


def calculate_IDF_values(input_dict):
    num_docs = len(input_dict)
    num_values = max([len(v) for v in input_dict.values()])

    # Transform dictionary values into a numpy array
    frequencies = np.zeros((num_docs, num_values))
    for doc_id, freqs in input_dict.items():
        frequencies[int(doc_id), :len(freqs)] = freqs

    # Calculate the mean frequency for each document
    doc_means = np.mean(frequencies, axis=1)

    # Count the number of words in each document that have a frequency greater than the mean
    greater_than_mean = frequencies > doc_means[:, np.newaxis]
    n_values = np.sum(greater_than_mean, axis=0)

    # Calculate the IDF values
    IDF_values = np.log((1 + num_docs) / (1 + n_values))

    return IDF_values.tolist()


def prune_weight_vgg(R, net, weight, TF_IDF):
    prune_all = {}
    for key, value in TF_IDF.items():

        vgg_conv_layer = ['0', '3', '7', '10', '14', '17', '20', '24', '27', '30', '34', '37', '40']
        key_parts = key.split('_')

        layer = 'feature.' + vgg_conv_layer[int(key_parts[2][4:]) - 1] + '.weight'
        num_channels = len(value)
        # Calculate the number of channels to prune based on the percentage R
        num_channels_to_prune = int(num_channels * R / 100)

        # Sort the channels by their values in descending order
        sorted_channels = sorted(enumerate(value), key=lambda x: x[1], reverse=True)

        # Get the indices of channels to prune
        channels_to_prune = [channel_index for channel_index, _ in sorted_channels[:num_channels_to_prune]]

        if layer not in weight:
            continue

        if layer in prune_all:
            prune_all[layer].extend(channels_to_prune)
            prune_all[layer] = list(set(prune_all[layer]))
        else:
            prune_all[layer] = channels_to_prune
        for channel in channels_to_prune:
            layer_weights = weight[layer].clone()
            mask = torch.ones_like(layer_weights)
            mask[channel] = 0
            layer_weights *= mask
            weight[layer] = layer_weights

    for key, value in prune_all.items():
        if key in weight:
            print(key, len(value), weight[key].shape, value)

    return weight

def get_resnet_conv_layers(net):
    """Return convolutional layer names for a ResNet model.

    The layers are derived from the model's ``state_dict`` to reflect the
    actual architecture. Shortcut/downsample convolution layers are ignored,
    and the resulting list is ordered to match the forward pass.
    """

    return [
        name
        for name in net.state_dict().keys()
        if name.endswith("weight") and "conv" in name and "shortcut" not in name
    ]


def prune_weight_resnet(R, net, weight, TF_IDF):
    prune_all = {}
    resnet_layer = get_resnet_conv_layers(net)
    for key, value in TF_IDF.items():
        key_parts = key.split('_')

        conv_idx = int(key_parts[2][4:]) - 1
        if conv_idx >= len(resnet_layer):
            continue
        layer = resnet_layer[conv_idx]
        num_channels = len(value)
        # Calculate the number of channels to prune based on the percentage R
        num_channels_to_prune = int(num_channels * R / 100)

        # Sort the channels by their values in descending order
        sorted_channels = sorted(enumerate(value), key=lambda x: x[1], reverse=True)

        # Get the indices of channels to prune
        channels_to_prune = [channel_index for channel_index, _ in sorted_channels[:num_channels_to_prune]]

        if layer not in weight:
            continue

        if layer in prune_all:
            prune_all[layer].extend(channels_to_prune)
            prune_all[layer] = list(set(prune_all[layer]))
        else:
            prune_all[layer] = channels_to_prune
        for channel in channels_to_prune:
            layer_weights = weight[layer].clone()
            mask = torch.ones_like(layer_weights)
            mask[channel] = 0
            layer_weights *= mask
            weight[layer] = layer_weights

    for key, value in prune_all.items():
        if key in weight:
            print(key, len(value), weight[key].shape, value)

    return weight

def prune_weight_digital(R, net, weight, TF_IDF):
    prune_all = {}
    for key, value in TF_IDF.items():
        digital_layer = ['conv1.weight', 'conv2.weight', 'conv3.weight']
        key_parts = key.split('_')

        layer = digital_layer[int(key_parts[2][4:]) - 1]
        num_channels = len(value)
        # Calculate the number of channels to prune based on the percentage R
        num_channels_to_prune = int(num_channels * R / 100)

        # Sort the channels by their values in descending order
        sorted_channels = sorted(enumerate(value), key=lambda x: x[1], reverse=True)

        # Get the indices of channels to prune
        channels_to_prune = [channel_index for channel_index, _ in sorted_channels[:num_channels_to_prune]]

        if layer not in weight:
            continue

        if layer in prune_all:
            prune_all[layer].extend(channels_to_prune)
            prune_all[layer] = list(set(prune_all[layer]))
        else:
            prune_all[layer] = channels_to_prune
        for channel in channels_to_prune:
            layer_weights = weight[layer].clone()
            mask = torch.ones_like(layer_weights)
            mask[channel] = 0
            layer_weights *= mask
            weight[layer] = layer_weights

    for key, value in prune_all.items():
        if key in weight:
            print(key, len(value), weight[key].shape, value)

    return weight


def prune_weight_vit(R, net, weight, TF_IDF):
    prune_all = {}
    vit_layers = ['vit.embeddings.patch_embeddings.projection.weight']
    num_layers = getattr(getattr(net, 'config', object()), 'num_hidden_layers', 0)
    for i in range(num_layers):
        vit_layers.append(f'vit.encoder.layer.{i}.attention.attention.query.weight')

    layer_map = {
        **{f'layer_{i}': layer for i, layer in enumerate(vit_layers)},
        **{f'conv{i + 1}': layer for i, layer in enumerate(vit_layers)},
    }

    for key, value in TF_IDF.items():
        key_parts = key.split('_')
        layer_key = key_parts[2]
        if layer_key == 'layer' and len(key_parts) > 3:
            layer_key = f"{key_parts[2]}_{key_parts[3]}"
        layer = layer_map.get(layer_key)
        if layer is None:
            continue
        num_channels = len(value)
        num_channels_to_prune = int(num_channels * R / 100)

        sorted_channels = sorted(enumerate(value), key=lambda x: x[1], reverse=True)
        channels_to_prune = [idx for idx, _ in sorted_channels[:num_channels_to_prune]]

        if layer not in weight:
            continue

        if layer in prune_all:
            prune_all[layer].extend(channels_to_prune)
            prune_all[layer] = list(set(prune_all[layer]))
        else:
            prune_all[layer] = channels_to_prune

        for channel in channels_to_prune:
            layer_weights = weight[layer].clone()
            mask = torch.ones_like(layer_weights)
            mask[channel] = 0
            layer_weights *= mask
            weight[layer] = layer_weights

    for key, value in prune_all.items():
        if key in weight:
            print(key, len(value), weight[key].shape, value)

    return weight

def unlearn_prune(args, delete_usr, w_locals, w_glob, net_glob, train_loaders):
    print('-' * 50, 'UNLEARNING', '-' * 50)
    print('UNLEARNING')
    Avg_conv_clients = {}
    print("UNLEARNING OF CLIENTS ")
    
    if args.model == 'vgg16':
        unlearning_model = vgg_unlrean(dataset=args.dataset, depth=16, init_weights=True, cfg=None).to(args.device)
    elif args.model == 'resnet18':
        unlearning_model = ResNet18_Pruning(dataset=args.dataset).to(args.device)
    elif args.model == 'resnet50':
        unlearning_model = ResNet50_Pruning(dataset=args.dataset).to(args.device)
    elif args.model == 'vit':
        from transformers import ViTForImageClassification
        model_name = 'google/vit-base-patch16-224'
        unlearning_model = ViTForImageClassification.from_pretrained(
            model_name,
            num_labels=args.num_classes,
            output_hidden_states=True,
            ignore_mismatched_sizes=True,
        ).to(args.device)
    else:
        unlearning_model = DigitModel_Pruning().to(args.device)

    for client_idx in range(args.num_users):
        if client_idx in delete_usr:
            continue
        print("client ", client_idx)
        
        ##################################################
        ##################################################
        client_weights = w_locals[client_idx]
        
        filtered_weights = {
            k: v for k, v in client_weights.items() 
            if not k.startswith('classifier.')
        }
        
        unlearning_model.load_state_dict(filtered_weights, strict=False)
        
        ##################################################
        ##################################################
        unlearning_model.eval()
        
        if args.model == 'vgg16':
            label_to_features_convs = [{label: [] for label in range(args.num_classes)} for _ in range(13)]
        elif args.model == 'resnet18':
            label_to_features_convs = [{label: [] for label in range(args.num_classes)} for _ in range(17)]
        elif args.model == 'resnet50':
            label_to_features_convs = [{label: [] for label in range(args.num_classes)} for _ in range(49)]
        elif args.model == 'vit':
            num_layers = unlearning_model.config.num_hidden_layers + 1
            label_to_features_convs = [{label: [] for label in range(args.num_classes)} for _ in range(num_layers)]
        else:
            label_to_features_convs = [{label: [] for label in range(args.num_classes)} for _ in range(3)]

        ls_Avg = Avg_conv_clients[f"Avg_conv_client_{client_idx}"] = []

        for i, (data, labels, _) in enumerate(train_loaders[client_idx]):
            unlearning_model.eval()
            data, labels = data.to(args.device), labels.to(args.device)
            if args.model == 'vit':
                outputs = unlearning_model(
                    pixel_values=data, output_hidden_states=True, return_dict=True
                )
                conv_outs = [hs[:, 0, :] for hs in outputs.hidden_states]
            else:
                _, conv_outs = unlearning_model(data)

            for i in range(len(labels)):
                label = labels[i].item()
                for n, conv_out in enumerate(conv_outs):
                    conv_out_i = conv_out[i].detach().cpu().numpy()
                    label_to_features_convs[n][label].append(conv_out_i)

        for label_to_features_conv in label_to_features_convs:
            ls_Avg.append(channel_flatten_avg(label_to_features_conv))

    print("DONE")
    print("UNLEARNING OF SERVER ")

    glob_channel, clients_weight = glob_channel_weight_average(Avg_conv_clients)
    del Avg_conv_clients
    gc.collect()
    print(" 清理CPU内存 ")
    
    if args.model == 'vit':
        Avg_conv_global = {f'layer_{i}': v for i, v in enumerate(glob_channel)}
    else:
        Avg_conv_global = {f'conv{i + 1}': v for i, v in enumerate(glob_channel)}
    layer_names = list(Avg_conv_global.keys())
    classes = list(next(iter(Avg_conv_global.values())).keys())

    print("计算TF")
    TF = calculate_TF_values(classes, Avg_conv_global)
    print("计算IDF")
    IDF = {}
    for name, glob in Avg_conv_global.items():
        IDF[f'IDF_{name}'] = calculate_IDF_values(glob)

    TF_IDF = {}
    for key1, value1 in TF.items():
        for key2, value2 in IDF.items():
            if key1.endswith(key2.split('_')[1]):
                result_value = [a * b for a, b in zip(value1, value2)]
                TF_IDF[key1 + '_' + key2] = result_value

    R = 45
    print("权重修剪")

    for usrs in delete_usr:
        outputdic = {}
        dic_usr = clients_weight[usrs]
        for layer_idx, class_dict in enumerate(dic_usr):
            layer_name = layer_names[layer_idx]
            for tfidf_key, channel_ti in TF_IDF.items():
                parts = tfidf_key.split('_')
                class_id = int(parts[1])
                layer_key = parts[2]
                if layer_key == layer_name and class_id in class_dict:
                    usr_weight = class_dict[class_id]
                    outputdic[tfidf_key] = [
                        usr_weight[i] * channel_ti[i] for i in range(len(usr_weight))
                    ]

        if args.model == 'vgg16':
            weight_pruned = prune_weight_vgg(R, net_glob, w_glob, outputdic)
        elif args.model in ('resnet18', 'resnet50'):
            weight_pruned = prune_weight_resnet(R, net_glob, w_glob, outputdic)
        elif args.model == 'vit':
            weight_pruned = prune_weight_vit(R, net_glob, w_glob, outputdic)
        else:
            weight_pruned = prune_weight_digital(R, net_glob, w_glob, outputdic)
    
    print("DONE")
    return weight_pruned

def check_pruned_weights(weights):
    for key in weights:
        if 'conv' in key or 'Conv' in key:
            weight_tensor = weights[key]
            weight_shape = weight_tensor.shape
            num_channels = weight_shape[0]

            for channel_idx in range(num_channels):
                channel_weight = weight_tensor[channel_idx]
                print(f"卷积层: {key}, 通道: {channel_idx}")
                # print(channel_weight)
