import mlp
import argparse
import torch

parser = argparse.ArgumentParser()
parser.add_argument('input', type=str)

args = parser.parse_args()
input_path = args.input
output_path = input_path.replace('.npz', '.pt')

model = mlp.func_as_torch(mlp.load(input_path))
scripted = torch.jit.script(model)
scripted.save(output_path)
