import torch
import fire
from lenet import get_layer_params_lenet, set_layer_params_lenet
from vgg import get_layer_params_vgg, set_layer_params_vgg


def reset_mask(model_dir, out_dir, model_type):

    if model_type == "lenet":
        get_params, set_params = get_layer_params_lenet, set_layer_params_lenet
    elif model_type == "vgg":
        get_params, set_params = get_layer_params_vgg, set_layer_params_vgg
    else:
        raise Exception(f'Unknown model type {model_type}')

    params = torch.load(model_dir, map_location=torch.device('cpu'))

    layer = 1
    while True:
        try:
            weights, biases, mask = get_params(params, layer)
            set_params(params, layer, weights * mask, biases, mask=None)
            layer += 1
        except KeyError:
            break

    torch.save(params, out_dir)


if __name__ == '__main__':
    fire.Fire(reset_mask)
