from coreset import compress_fc_layer
import torch
import fire
from lenet import get_layer_params_lenet, set_layer_params_lenet
from vgg import get_layer_params_vgg, set_layer_params_vgg


def prune_dataind(model_dir, out_dir, epsilon: float, sparsity: float, layer_norms, model_type, layers=[1], delta=0.5):

    if model_type == "lenet":
        get_params, set_params = get_layer_params_lenet, set_layer_params_lenet
    elif model_type == "vgg":
        get_params, set_params = get_layer_params_vgg, set_layer_params_vgg
    else:
        raise Exception(f'Unknown model type {model_type}')

    params = torch.load(model_dir, map_location=torch.device('cpu'))
    norms = torch.load(layer_norms)

    for layer in layers:
        weights1, biases1, mask1 = get_params(params, layer)
        weights2, biases2, mask2 = get_params(params, layer+1)
        assert torch.equal(torch.ones(mask1.shape), mask1)
        assert torch.equal(torch.ones(mask2.shape), mask2)

        mask, new_weights2, samples_needed, implied_epsilon = compress_fc_layer(
            (weights1, biases1),
            weights2,
            epsilon,
            sparsity,
            delta,
            torch.nn.ReLU(),
            max(norms[layer - 1]))

        expanded_mask = torch.tensor(mask).reshape(-1, 1).repeat([1, weights1.shape[1]])
        set_params(params, layer, weights1, biases1 * mask, expanded_mask)
        set_params(params, layer+1, new_weights2, biases2)

        print(f'{samples_needed}')
        print(f'{(mask == 0).sum() / mask.size}')
        print(f'{implied_epsilon}')

    torch.save(params, out_dir)


if __name__ == '__main__':
    fire.Fire(prune_dataind)
