import torch
import fire
import numpy as np
from lenet import get_layer_params_lenet
from vgg import get_layer_params_vgg


def randomized_compare(pruned_model, orig_model, model_type, layer_norms, sample_size=50000, eps_type="additive"):

    if model_type == "lenet":
        get_params, layer = get_layer_params_lenet, 3
    elif model_type == "vgg":
        get_params, layer = get_layer_params_vgg, 4
    else:
        raise Exception(f'Unknown model type {model_type}')

    pruned_params = torch.load(pruned_model)
    orig_params = torch.load(orig_model)
    norms = torch.load(layer_norms)

    def compute_hidden(vec, params):
        weights1, biases1, mask1 = get_params(params, layer-1)
        weights2, biases2, mask2 = get_params(params, layer)
        vec = vec.type(mask1.dtype) # Sometimes mask1 is float64, sometimes it's float32
        hidden = torch.matmul(weights1 * mask1, vec) + biases1
        return torch.matmul(weights2 * mask2, hidden) + biases2

    orig_hiddens = []
    pruned_hiddens = []
    input_dim = get_params(orig_params, layer-1)[0].shape[1]
    max_norm = max(norms[layer - 2])
    for _ in range(sample_size):
        vec = np.random.randn(input_dim)
        vec = vec / np.linalg.norm(vec) * max_norm.item()
        vec = torch.tensor(vec, dtype=torch.float64)

        pruned_hiddens.append(compute_hidden(vec, pruned_params).numpy())
        orig_hiddens.append(compute_hidden(vec, orig_params).numpy())

    orig_hiddens = np.array(orig_hiddens)
    pruned_hiddens = np.array(pruned_hiddens)

    diff = np.abs(orig_hiddens - pruned_hiddens)
    if eps_type == 'multiplicative':
        diff[(orig_hiddens < 0.01) & (orig_hiddens > -0.01)] = 0
        diff = diff / np.abs(orig_hiddens)

    print(f'{diff.max()}')
    print(f'{diff.mean()}')
    print(f'{diff.std()}')


if __name__ == '__main__':
    fire.Fire(randomized_compare)
