import torch
import fire
from lenet import get_layer_params_lenet, set_layer_params_lenet
from vgg import get_layer_params_vgg, set_layer_params_vgg


def scale_weights(model_dir, out_dir, scale: float, model_type, layers=[1]):

    if model_type == "lenet":
        get_params, set_params = get_layer_params_lenet, set_layer_params_lenet
    elif model_type == "vgg":
        get_params, set_params = get_layer_params_vgg, set_layer_params_vgg
    else:
        raise Exception(f'Unknown model type {model_type}')

    params = torch.load(model_dir)

    for layer in layers:
        weights, biases, mask = get_params(params, layer)
        set_params(params, layer, weights * scale, biases * scale, mask)

    torch.save(params, out_dir)


if __name__ == '__main__':
    fire.Fire(scale_weights)
