import argparse
import os
import sys

import hienet._keys as KEY
from hienet._const import HIENET_VERSION
from hienet.parse_input import read_config_yaml
from hienet.scripts.train_lightning import train
from hienet.hienet_logger import Logger

description = (
    f'hienet version={HIENET_VERSION}, train model based on the input.yaml'
)

input_yaml_help = 'input.yaml for training'
working_dir_help = 'path to write output. Default is cwd.'
screen_help = 'print log to stdout'
distributed_help = 'set this flag if it is distributed training'
number_of_nodes_help = 'number of nodes for distributed training'

# TODO: do somthing for model type (it is not printed on log)
global_config = {
    'version': HIENET_VERSION,
    KEY.MODEL_TYPE: 'E3_equivariant_model',
}


def main(args=None):
    """
    main function of hienet
    """
    input_yaml, working_dir, screen, distributed, experiment_name, n_nodes, jobid = cmd_parse_main(args)

    if working_dir is None:
        working_dir = os.getcwd()

    jobid = jobid if jobid else os.getpid()
    experiment_name = experiment_name + '.' + str(jobid)

    Logger(
        filename=f'{os.path.abspath(working_dir)}/log.hienet.{jobid}',
        screen=screen,
        rank=os.environ.get('RANK', 0),
    )
    Logger().greeting()

    if distributed:
        Logger().writeline(
            f'Distributed training enabled'
        )

    try:
        model_config, train_config, data_config = read_config_yaml(input_yaml)
    except Exception as e:
        Logger().error(e)
        sys.exit(1)

    train_config[KEY.IS_DDP] = distributed
    train_config[KEY.N_NODES] = n_nodes

    Logger().print_config(model_config, data_config, train_config)
    # don't have to distinguish configs inside program
    global_config.update(model_config)
    global_config.update(train_config)
    global_config.update(data_config)

    # Not implemented
    if global_config[KEY.DTYPE] == 'double':
        raise Exception('double precision is not implemented yet')
        # torch.set_default_dtype(torch.double)
    
    # run train
    train(global_config, working_dir, experiment_name)


def cmd_parse_main(args=None):
    ag = argparse.ArgumentParser(description=description)
    ag.add_argument('input_yaml', help=input_yaml_help, type=str)
    ag.add_argument(
        '-w',
        '--working_dir',
        nargs='?',
        const=os.getcwd(),
        help=working_dir_help,
        type=str,
    )
    ag.add_argument('-s', '--screen', help=screen_help, action='store_true')
    ag.add_argument(
        '-d', '--distributed', help=distributed_help, action='store_true'
    )
    ag.add_argument("-e", '--experiment', type=str)
    ag.add_argument("-n", "--n_nodes", help=number_of_nodes_help, type=int, default=1)
    ag.add_argument("-j", "--jobid", type=str, default='')

    args = ag.parse_args()
    input_yaml = args.input_yaml
    experiment = args.experiment
    wd = args.working_dir
    return input_yaml, wd, args.screen, args.distributed, args.experiment, args.n_nodes, args.jobid


if __name__ == '__main__':
    main()
