import fire
import torch
from lenet import load_lenet, get_layer_output_lenet, mnist_dataset, get_device
from vgg import load_vgg, get_layer_output_vgg, cifar10_dataset
import numpy as np


def get_epsilons(model, data_loc, model_type, epsilon_type):
    if model_type == "lenet":
        dataset, load, get_layer_output, output_layer = mnist_dataset, load_lenet, get_layer_output_lenet, 3
    elif model_type == "vgg":
        dataset, load, get_layer_output, output_layer = cifar10_dataset, load_vgg, get_layer_output_vgg, 4
    else:
        raise Exception(f"Unknown model type {model_type}")

    _, valloader = dataset(data_loc)

    device = get_device()
    model = load(model).to(device)

    max_activations = []
    second_largest = []

    for images, labels in valloader:
        images = images if model_type == "vgg" else images.view(images.shape[0], -1)
        images, labels = images.to(device), labels.to(device)

        with torch.no_grad():
            # Pre-Relu activations
            activations = get_layer_output(model, images, output_layer)

        pred = torch.argmax(activations, dim=1)
        # Keep only the correct activations and predictions
        activations = activations[pred == labels, :]
        pred = pred[pred == labels]

        max_activations.append(activations.gather(1, pred.view(-1,1)).to('cpu').numpy().flatten())
        activations = activations.scatter(1, pred.view(-1,1), float('-inf'))
        second_largest.append(activations.max(dim=1)[0].to('cpu').numpy().flatten())

    max_activations = np.concatenate(max_activations)
    second_largest = np.concatenate(second_largest)
    if epsilon_type == 'additive':
        return (max_activations - second_largest) / 2
    elif epsilon_type == 'multiplicative':
        return (max_activations - second_largest) / (np.abs(max_activations) + np.abs(second_largest))
    else:
        raise Exception(f'Unknown epsilon type {epsilon_type}')


def acc_to_eps(model, data_loc, model_type, acc: float, epsilon_type='additive'):

    eps = get_epsilons(model, data_loc, model_type, epsilon_type)

    # To guarantee 95% of correct labels stay the same, we need the epsilon which is smaller than 95% of epsilons required.
    correct_count = len(eps)
    threshold = int(np.ceil(correct_count * acc))
    eps_descending = -np.sort(-eps)
    print(eps_descending[min(threshold, len(eps) - 1)])


if __name__ == "__main__":
    fire.Fire(acc_to_eps)
