import torch
import torch.nn.functional as F
import fire
import numpy as np
from lenet import get_layer_params_lenet, mnist_dataset, get_device
from vgg import load_vgg, get_layer_params_vgg, cifar10_dataset


def datadep_sensitivity(model_dir, data_loc, model_type, out_file, delta=0.5):

    device = get_device()

    params = torch.load(model_dir, map_location=device)

    if model_type == "lenet":
        dataset, get_params, output_layer = mnist_dataset, get_layer_params_lenet, 3
        model = None
    elif model_type == "vgg":
        dataset, get_params, output_layer = cifar10_dataset, get_layer_params_vgg, 4
        model = load_vgg(model_dir, device=device).to(device)
    else:
        raise Exception(f"Unknown model type {model_type}")

    trainloader, _ = dataset(data_loc)

    weights = []
    pos_weights = []
    neg_weights = []
    pos_sensitivities = []
    neg_sensitivities = []
    hidden_sizes = []
    layer_deltas = []
    for i in range(1, output_layer + 1):
        # output x input
        weight, bias, mask = get_params(params, i)
        assert torch.equal(torch.ones(mask.shape).to(device), mask)
        weight = torch.tensor(torch.cat((weight, bias.reshape(-1, 1)), dim=1))
        weights.append(weight)
        hidden_sizes.append(weight.shape[0])

        pos_weight = weight.detach().clone()
        pos_weight[weight < 0] = 0
        pos_weights.append(pos_weight)

        neg_weight = weight.detach().clone()
        neg_weight[weight > 0] = 0
        neg_weights.append(-neg_weight)

        pos_sensitivities.append(torch.zeros(weight.shape).to(device))
        neg_sensitivities.append(torch.zeros(weight.shape).to(device))

        layer_deltas.append([])

    total_hidden = sum(hidden_sizes)
    matrix_sizes = []
    for i in range(len(hidden_sizes) - 1):
        matrix_sizes.append(hidden_sizes[i] * hidden_sizes[i+1])
    K_p = 10
    sample_size = K_p * np.log(2 * max(matrix_sizes) / delta)

    def calculate_sensitivities():
        count = 0
        for images, labels in trainloader:
            images, labels = images.to(device), labels.to(device)
            if model_type == "vgg":
                images = model.avgpool(model.features(images)).view(-1, 25088)
            else:
                images = images.view(images.shape[0], -1)

            for j in range(len(labels)):
                hidden = images[j].unsqueeze(0) if model_type == "vgg" else images[j].view(1, 784)
                for i in range(len(weights)):
                    # Add a 1 to the activations so the biases are used appropriately
                    hidden = torch.cat((hidden, torch.ones((1,1)).to(device)), dim=1)

                    # inputs x outputs
                    pos_unnormalized = hidden.t() * pos_weights[i].t()
                    neg_unnormalized = hidden.t() * neg_weights[i].t()

                    pos_denom = torch.sum(pos_unnormalized, dim=0)
                    pos_sensitivity = pos_unnormalized / pos_denom
                    neg_denom = torch.sum(neg_unnormalized, dim=0)
                    neg_sensitivity = neg_unnormalized / neg_denom

                    # If all the values are 0, just let the sensitivities be 0 rather than NaN.
                    pos_sensitivity[:,pos_denom == 0] = 0
                    neg_sensitivity[:,neg_denom == 0] = 0

                    pos_sensitivities[i] = torch.max(pos_sensitivities[i], pos_sensitivity.t())
                    neg_sensitivities[i] = torch.max(neg_sensitivities[i], neg_sensitivity.t())

                    pos_hidden = torch.matmul(hidden, pos_weights[i].t())
                    neg_hidden = torch.matmul(hidden, neg_weights[i].t())
                    hidden = torch.matmul(hidden, weights[i].t())

                    deltas = (torch.abs(pos_hidden) + torch.abs(neg_hidden)) / torch.abs(hidden)
                    layer_deltas[i].append(deltas.squeeze())

                    hidden = F.relu(hidden)

                count += 1
                if count >= sample_size:
                    return

    print(f"Calculating data dependent values for {sample_size} examples...")
    with torch.no_grad():
        calculate_sensitivities()

    filtered_layer_deltas = []
    for i in range(len(layer_deltas)):
        deltas = torch.stack(layer_deltas[i])
        new_vals, _ = torch.topk(deltas, int(deltas.shape[0] * (1 - delta)), largest=False, dim=0)
        filtered_layer_deltas.append(new_vals.max(dim=0)[0])

    torch.save({'delta_prob': delta,
                'pos_sensitivities': pos_sensitivities,
                'neg_sensitivities': neg_sensitivities,
                'hidden_sizes': hidden_sizes,
                'sample_size': sample_size,
                'layer_deltas': filtered_layer_deltas,
                }, out_file)


if __name__ == "__main__":
    fire.Fire(datadep_sensitivity)
