import torch
import fire
from lenet import get_layer_params_lenet, set_layer_params_lenet
from vgg import get_layer_params_vgg, set_layer_params_vgg


def prune_magnitude(model_dir, out_dir, sparsity: float, model_type):

    if model_type == "lenet":
        get_params, set_params = get_layer_params_lenet, set_layer_params_lenet
    elif model_type == "vgg":
        get_params, set_params = get_layer_params_vgg, set_layer_params_vgg
    else:
        raise Exception(f'Unknown model type {model_type}')

    params = torch.load(model_dir, map_location=torch.device('cpu'))

    layer = 1
    while True:
        try:
            weights, biases, mask = get_params(params, layer)
        except KeyError:
            break
        assert torch.equal(torch.ones(mask.shape), mask)

        prune_size = int(weights.numel() * sparsity)
        orig_shape = weights.shape

        weights = weights.flatten()
        new_mask = torch.ones(weights.shape)
        _, prune_indices = torch.topk(torch.abs(weights), prune_size, largest=False)
        new_mask[prune_indices] = 0
        weights = weights.reshape(orig_shape)
        new_mask = new_mask.reshape(orig_shape)

        set_params(params, layer, weights * new_mask, biases, mask=new_mask)
        layer += 1

    torch.save(params, out_dir)


if __name__ == '__main__':
    fire.Fire(prune_magnitude)
