import fire
import torch
import torch.nn.functional as F
from lenet import load_lenet, get_layer_output_lenet, mnist_dataset, get_device
from vgg import load_vgg, get_layer_output_vgg, cifar10_dataset
import numpy as np


def compare_hidden(pruned_model, orig_model, data_loc, model_type, layer, apply_softmax=False, eps_type="additive"):

    if model_type == "lenet":
        dataset, load, get_layer_output, output_layer = mnist_dataset, load_lenet, get_layer_output_lenet, 3
    elif model_type == "vgg":
        dataset, load, get_layer_output, output_layer = cifar10_dataset, load_vgg, get_layer_output_vgg, 4
    else:
        raise Exception(f"Unknown model type {model_type}")

    _, valloader = dataset(data_loc)

    device = get_device()

    orig_model = load(orig_model).to(device)
    pruned_model = load(pruned_model).to(device)

    orig_hiddens = []
    pruned_hiddens = []
    for images, labels in valloader:
        images = images if model_type == "vgg" else images.view(images.shape[0], -1)
        images, labels = images.to(device), labels.to(device)

        with torch.no_grad():
            orig_hidden = get_layer_output(orig_model, images, layer)
            pruned_hidden = get_layer_output(pruned_model, images, layer)

            if apply_softmax:
                assert layer == output_layer
                orig_hidden = F.softmax(orig_hidden, dim=1)
                pruned_hidden = F.softmax(pruned_hidden, dim=1)

            orig_hiddens.append(orig_hidden.to('cpu').numpy().flatten())
            pruned_hiddens.append(pruned_hidden.to('cpu').numpy().flatten())

    orig_hiddens = np.concatenate(orig_hiddens)
    pruned_hiddens = np.concatenate(pruned_hiddens)

    diff = np.abs(orig_hiddens - pruned_hiddens)
    if eps_type == 'multiplicative':
        diff[(orig_hiddens < 0.01) & (orig_hiddens > -0.01)] = 0
        diff = diff / np.abs(orig_hiddens)

    print(f'{diff.max()}')
    print(f'{diff.mean()}')
    print(f'{diff.std()}')


if __name__ == "__main__":
    fire.Fire(compare_hidden)
