import fire
import torch
import matplotlib.pyplot as plt
from lenet import load_lenet, activations_lenet, mnist_dataset, get_device
from vgg import load_vgg, activations_vgg, cifar10_dataset


def dataind_sensitivity(model_dir, data_loc, model_type, norms_out):

    if model_type == "lenet":
        dataset, load, activations = mnist_dataset, load_lenet, activations_lenet
    elif model_type == "vgg":
        dataset, load, activations = cifar10_dataset, load_vgg, activations_vgg
    else:
        raise Exception(f"Unknown model type {model_type}")

    trainloader, _ = dataset(data_loc)

    device = get_device()

    model = load(model_dir).to(device)

    layer_norms = [[]]
    for images, labels in trainloader:
        images = images if model_type == "vgg" else images.view(images.shape[0], -1)
        images, labels = images.to(device), labels.to(device)

        with torch.no_grad():
            if model_type == "vgg":
                inputs = model.avgpool(model.features(images)).view(-1, 25088)
            else:
                inputs = images
            layer_norms[0] += torch.norm(inputs).flatten().tolist()

            for idx, act in enumerate(activations(model, images)):
                if idx > len(layer_norms) - 1:
                    layer_norms.append([])
                layer_norms[idx] += act.norm().flatten().tolist()

    torch.save(layer_norms, norms_out)

    for idx, norms in enumerate(layer_norms):
        fig, axes = plt.subplots()
        axes.hist(norms)
        axes.set_title(f"Layer {idx} Norms")
        fig.savefig(f'layer_{idx}_norms.png')


if __name__ == "__main__":
    fire.Fire(dataind_sensitivity)
