import torch
import copy
import itertools
import numpy as np
from src.models.model import get_model


def calculate_diff_norms_maxclip(client_updates, local_clip_values, rho):
    max_clip_value = max(local_clip_values)
    diffs = [local_clip - max_clip_value for local_clip in local_clip_values]
    diffs_norm2 = [pow(pow(diff_clip, 2), 0.5) for diff_clip in diffs]
    return diffs_norm2

def calculate_diff_norms_avgclip(local_clip_values, global_clip_value):
    diffs = [local_clip - global_clip_value for local_clip in local_clip_values]
    diffs_norm2 = [pow(pow(diff_clip, 2), 0.5) for diff_clip in diffs]
    return diffs_norm2

def calculate_diff_norms_refclient(client_updates):
        diffs_norm2=[]
        ref_client0_update = client_updates[0]
        for client_id, client_update in enumerate(client_updates):
            diff = {layer_id: layer_update - ref_client0_update[layer_id] for layer_id, layer_update in
                     client_update.items()}
            diff_norm2 = sum(pow(torch.norm(diff_layer, p=2), 2) for diff_layer in diff.values())
            diffs_norm2.append(pow(diff_norm2.item(), 0.5))
        return diffs_norm2

def calculate_diff_norms(sampled_clients, client_updates, aggregated_updates_dict, global_sigma,
                         global_clip_value, num_clients, learning_rate, device):
    diffs_norm2 = []
    for client_id, client_update in enumerate(client_updates):
        diff = {layer_id: layer_update - aggregated_updates_dict[layer_id].to(device)
                for layer_id, layer_update in client_update.items()}
        diff_norm2 = sum(pow(torch.norm(diff_layer, p=2), 2) for diff_layer in diff.values())
        diffs_norm2.append(pow(diff_norm2.item(), 0.5))
    return diffs_norm2


def calculate_influence(model_updates,
                        global_state_dict,
                        sampled_clients,
                        sample_rates,
                        learning_rate,
                        global_clip_value,
                        global_sigma,
                        num_clients,
                        testset):
    influence = []

    model = averaging_func(model_updates,
                           global_state_dict,
                           sampled_clients,
                           sample_rates,
                           learning_rate,
                           global_clip_value,
                           global_sigma,
                           num_clients)
    total_value = model_evaluation_func(model, testset)

    for client_id in range(num_clients):
        chosen_number = list(np.arange(num_clients))
        chosen_number.remove(client_id)
        chosen_clients = [clients[index] for index in chosen_number]
        chosen_model_updates = model_updates[chosen_clients]

        model = averaging_func(chosen_model_updates,
                               global_state_dict,
                               sampled_clients,
                               sample_rates,
                               learning_rate,
                               global_clip_value,
                               global_sigma,
                               num_clients)
        chosen_value = model_evaluation_func(model, testset)


        influence.append(chosen_value - total_value)
        print( f'CID: {client_id}, inf: {influence[client_id]:.3f}')




def calculate_svalue(model_updates,
                     global_state_dict,
                     sampled_clients,
                     sample_rates,
                     learning_rate,
                     global_clip_value,
                     global_sigma,
                     num_clients,
                     testset,
                     rate,
                     dataset,
                     bs):
    """
    https://github.com/lamnd09/cds/blob/main/src/shapleyfl.py
    Computes the Shapley Value for clients
    Parameters:
    models (dict): Key value pair of client identifiers and model updates.
    model_evaluation_func (func) : Function to evaluate model update.
    averaging_func (func) : Function to used to average the model updates.
    Returns:
    svalue: Key value pair of client identifiers and the computed shapley values.
    """
    # generate possible permutations
    model_updates_dic = {cid: cupdate for cid, cupdate in enumerate(model_updates)}
    all_perms = list(itertools.permutations(list(model_updates_dic.keys())))
    print(len(all_perms))
    marginal_contributions = []
    # history map to avoid retesting the models
    history = {}

    for perm in all_perms:
        perm_values = {}
        perm_model_updates = {}

        for client_id in perm:
            model_update = copy.deepcopy(model_updates_dic[client_id])
            perm_model_updates[client_id] = model_update

            if len(perm_values.keys()) == 0:
                index = (client_id,)
            else:
                index = tuple(sorted(list(tuple(perm_values.keys()) + (client_id,))))

            if index in history.keys():
                current_value = history[index]
            else:
                model = averaging_func(perm_model_updates,
                                       global_state_dict,
                                       sampled_clients,
                                       sample_rates,
                                       learning_rate,
                                       global_clip_value,
                                       global_sigma,
                                       num_clients,
                                       rate,
                                       dataset,
                                       bs)
                current_value = model_evaluation_func(model, testset)
                history[index] = current_value

            perm_values[client_id] = max(0, current_value - sum(perm_values.values()))
        marginal_contributions.append(perm_values)

    svalues = [0 for _ in model_updates_dic.keys()]
    # sum the marginal contributions
    for perm in marginal_contributions:
        for key, value in perm.items():
            svalues[key] += value / len(marginal_contributions)

    return svalues

def model_evaluation_func(model, testset):
    ''' Computes the overall accuracy of the server on the test dataset '''
    model.eval()
    test_dataloader = torch.utils.data.DataLoader(testset, batch_size=32, drop_last=False, shuffle=True)

    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in test_dataloader:
            outputs = model(images)
            predicted = torch.argmax(outputs, dim=1)
            correct += (predicted == labels).sum().item()
            total += labels.size(0)

    accuracy = 100 * correct / total
    return accuracy


def averaging_func(client_updates,
                   global_state_dict,
                   sampled_clients,
                   sample_rates,
                   learning_rate,
                   global_clip_value,
                   global_sigma,
                   num_clients,
                   rate,
                   dataset,
                   bs):
    ''' Aggregating updates for sampled clients and adding noise for unsampled clients '''
    arbitrary_key = next(iter(client_updates))
    aggregated_updates_dict = {k: torch.zeros_like(v) for k, v in client_updates[arbitrary_key].items()}
    averaging_coefficient = sum(sample_rates)
    for client_id, client_update in client_updates.items():
        if client_id in sampled_clients:
            for k in aggregated_updates_dict.keys():
                aggregated_updates_dict[k] += client_update[k] / averaging_coefficient
        else:
            for k in aggregated_updates_dict.keys():
                noise_std = global_sigma * global_clip_value / np.sqrt(num_clients)
                noise = torch.normal(0, noise_std, size=aggregated_updates_dict[k].shape)
                aggregated_updates_dict[k] += noise * learning_rate * rate / averaging_coefficient
    for k in global_state_dict.keys():
        global_state_dict[k] = global_state_dict[k] + aggregated_updates_dict[k]
    model = get_model(dataset,bs)
    model.load_state_dict(global_state_dict)
    return model