import torch
from models import *
from copy import deepcopy


EPSILON = 10e-7
DENSE = 'dense'
NST = 'nst'
PDST = 'pdst'
JMWST = 'jmwst'
SPDST = 'spdst'
RESNET18 = 'resnet18'
MNISTNET = 'mnistnet'
CNN = 'cnn'
TRUE = 'true'
FALSE = 'false'
CIFAR10 = 'cifar10'
CIFAR100 = 'cifar100'
MNIST = 'mnist'
FEMNIST = 'femnist'
EXP_NAME = 'alpha:{}|frac:{}'
experiment_choices = [DENSE, NST, PDST, SPDST, JMWST]
model_choices = [RESNET18, MNISTNET, CNN]
true_false = [TRUE, FALSE]
density_levels = [0.75, 0.5]
hetero = [0.5, 0.5, 0.5, 0.75, 0.75, 0.75, 1, 1, 1, 1]
hetero_62 = [0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5,
             0.75, 0.75, 0.75, 0.75, 0.75, 0.75, 0.75, 0.75, 0.75, 0.75, 0.75, 0.75,
             1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]


def to_dict(my_dict, device):
    for key in my_dict:
        my_dict[key] = my_dict[key].to(device)


def nan_to_zero(a):
    a[a != a] = 0
    return a


def get_density(model):
    d = 0
    w = 0
    for _, weight in model.items():
        w += torch.numel(weight)
        d += torch.sum(weight == 0).item()
    return 1 - d / w


def get_model(args):
    if args.model_type == RESNET18:
        global_model = ResNet18(num_classes=args.num_classes, track=bool(args.trackBN))
    elif args.model_type == MNISTNET:
        global_model = MNISTNet()
    elif args.model_type == CNN:
        global_model = leaf_cnn(args.num_classes)
    else:
        raise NotImplementedError()
    return global_model


def get_avg_mask(masks, target_density, mask_path):
    mask_body = torch.load(mask_path)
    num_clients = len(masks)
    layer_density = {}
    sum_ones = 0
    sum_valid_one = 0
    for name in mask_body:
        density = 0
        numel = torch.numel(masks[0][name])
        sum_valid_one += numel

        for client_mask in masks:
            density += torch.sum(client_mask[name] != 0.0).item() / numel

        density = density / num_clients
        sum_ones += density * numel
        layer_density[name] = density
    ratio = sum_valid_one / sum_ones * target_density
    mask = {}
    for name, shape in mask_body.items():
        density = min(layer_density[name] * ratio, 1)
        mask[name] = (torch.rand(shape) < density).float().data
    return mask


def get_aggregation_based_sensitivity(weights, mask_path):
    mask_body = torch.load(mask_path)
    layer_density = {}
    total_ones = 0
    total_elem = 0
    for name in mask_body:
        numel = torch.numel(weights[name])
        total_elem += numel
        layer_ones = torch.sum(weights[name] != 0.0).item()
        total_ones += layer_ones
        layer_density[name] = layer_ones
    return total_elem, total_ones, mask_body, layer_density


def get_avg_based_sensivity(clients_mask, mask_path):
    layer_density = {}
    num_clients = len(clients_mask)
    mask_body = torch.load(mask_path)
    total_ones, total_elem = 0, 0
    for name in mask_body:
        numel = torch.numel(clients_mask[0][name])
        total_elem += numel
        density = 0
        for mask in clients_mask:
            density += torch.sum(mask[name] != 0.0).item() / numel
        density = density / num_clients
        total_ones += density * numel
        layer_density[name] = density * numel
    return total_elem, total_ones, mask_body, layer_density


def next_mask(clients_mask, server_weights, target, mask_path, method):
    if method == 'avg':
        total_elem, total_ones, mask_body, layer_density = get_avg_based_sensivity(clients_mask, mask_path)
    elif method == 'aggregate':
        total_elem, total_ones, mask_body, layer_density = get_aggregation_based_sensitivity(server_weights, mask_path)
    else:
        raise NotImplementedError()

    ratio = (total_elem / total_ones) * target
    mask = {}
    for name, shape in mask_body.items():
        num_ones = int(layer_density[name] * ratio)
        mask[name] = torch.zeros(shape)
        _, idx = torch.sort(torch.abs(server_weights[name].view(-1)), descending=True)
        mask[name].view(-1)[idx[:num_ones]] = 1.0
    return mask


def get_rand_mask(density, path):
    mask_body = torch.load(path)
    mask = {}
    for name, shape in mask_body.items():
        mask[name] = (torch.rand(shape) < density).float().data
    return mask


def apply_mask(model, masks):
    with torch.no_grad():
        for name, tensor in model.named_parameters():
            if name in masks:
                tensor.data = tensor.data * masks[name]
    return model


def aggregate(weights, masks, data_size, update_mode, dataset):
    if dataset == FEMNIST:
        ratio = sum(data_size)
        data_size = [i / ratio for i in data_size]
        for i in range(len(weights)):
            w = weights[i]
            if i == 0:
                new_weights = deepcopy(w)
                for key in new_weights:
                    new_weights[key] = new_weights[key] * data_size[0]
            else:
                for key in new_weights:
                    new_weights[key] += w[key] * data_size[i]

    if update_mode == 0:
        return average_weights(weights)
    elif update_mode == 1:
        return masked_average_weights(weights, masks)
    raise NotImplementedError


def average_weights(w):
    w_avg = deepcopy(w[0])
    for key in w_avg.keys():
        for i in range(1, len(w)):
            w_avg[key] += w[i][key]
        w_avg[key] = torch.div(w_avg[key], len(w))
    return w_avg


def masked_average_weights(w, masks):
    w_avg = deepcopy(w[0])
    sum_mask = deepcopy(masks[0])
    with torch.no_grad():
        for key in w_avg.keys():
            for i in range(1, len(w)):
                w_avg[key] += w[i][key]
                if key in sum_mask:
                    sum_mask[key] += masks[i][key]
            if key not in sum_mask:
                w_avg[key] = torch.div(w_avg[key], len(w))
            else:
                w_avg[key] = torch.div(w_avg[key], sum_mask[key])
                w_avg[key] = nan_to_zero(w_avg[key])
    return w_avg
