from coreset import compress_fc_layer
import torch
import fire
import sys
from lenet import get_layer_params_lenet, set_layer_params_lenet
from vgg import get_layer_params_vgg, set_layer_params_vgg


def prune_L0(model_dir, out_dir, sparsity: float, model_type, layers=[1]):

    if model_type == "lenet":
        get_params, set_params = get_layer_params_lenet, set_layer_params_lenet
    elif model_type == "vgg":
        get_params, set_params = get_layer_params_vgg, set_layer_params_vgg
    else:
        raise Exception(f'Unknown model type {model_type}')

    params = torch.load(model_dir, map_location=torch.device('cpu'))

    for layer in layers:
        weights, biases, mask = get_params(params, layer)
        weights2, biases2, mask2, qz_loga = get_params(params, layer+1, include_qzloga=True)
        assert torch.equal(torch.ones(mask.shape), mask)
        assert torch.equal(torch.ones(mask2.shape), mask2)

        # The gate values for the first layer neurons are stored with the second layer weight matrix
        assert qz_loga.shape[0] == weights.shape[0]

        prune_size = int(sparsity * qz_loga.shape[0])
        neuron_mask = torch.ones(qz_loga.shape[0])
        _, prune_indices = torch.topk(qz_loga, prune_size, largest=False)
        neuron_mask[prune_indices] = 0

        # We have to mask out both the incoming and out-going weights for the pruned neurons.
        # This is a cheap workaround for not being able to mask out biases.
        new_mask = (torch.ones(mask.shape).t() * neuron_mask).t()
        new_mask2 = torch.ones(mask2.shape) * neuron_mask

        set_params(params, layer, weights * new_mask, biases, mask=new_mask)
        set_params(params, layer+1, weights2 * new_mask2, biases2, mask=new_mask2)

    torch.save(params, out_dir)


if __name__ == '__main__':
    fire.Fire(prune_L0)
