import argparse

import torch

import hienet._const as _const
import hienet.util
from hienet.scripts.deploy import deploy, deploy_parallel

description_get_model = (
    f'hienet version={_const.HIENET_VERSION}, hienet_get_model.'
    + ' Deploy model for LAMMPS from the checkpoint'
)
checkpoint_help = 'checkpoint path'
output_name_help = 'filename prefix'
get_parallel_help = 'deploy parallel model'


def main(args=None):
    checkpoint, output_prefix, get_parallel = cmd_parse_get_model(args)
    get_serial = not get_parallel
    cp_file = torch.load(checkpoint, map_location=torch.device('cpu'))

    if output_prefix is None:
        output_prefix = (
            'deployed_parallel' if not get_serial else 'deployed_serial'
        )

    model, config = hienet.util.model_from_checkpoint(checkpoint)
    stct_dct = model.state_dict()

    if get_serial:
        deploy(stct_dct, config, output_prefix)
    else:
        deploy_parallel(stct_dct, config, output_prefix)


def cmd_parse_get_model(args=None):
    ag = argparse.ArgumentParser(description=description_get_model)
    ag.add_argument('checkpoint', help=checkpoint_help, type=str)
    ag.add_argument(
        '-o', '--output_prefix', nargs='?', help=output_name_help, type=str
    )
    ag.add_argument(
        '-p', '--get_parallel', help=get_parallel_help, action='store_true'
    )
    args = ag.parse_args()
    checkpoint = args.checkpoint
    output_prefix = args.output_prefix
    get_parallel = args.get_parallel
    return checkpoint, output_prefix, get_parallel
