import argparse
import json
from pprint import pprint

from numerical_mz.model.property_prediction import model_factory

def main(args):
    with open(args.config, 'rb') as fid:
        config = json.load(fid)

    pprint(config)

    replace_params = dict()
    if args.num_workers is not None:
        replace_params['num_workers'] = args.num_workers
    if args.prefetch_factor is not None:
        replace_params['prefetch_factor'] = args.prefetch_factor
    if args.batch_size is not None:
        replace_params['batch_size'] = args.batch_size
    if args.num_batch_per_update is not None:
        replace_params['num_batch_per_update'] = args.num_batch_per_update
    if args.devices is not None:
        replace_params['devices'] = args.devices
    if args.accelerator is not None:
        replace_params['accelerator'] = args.accelerator
    if args.strategy is not None:
        replace_params['strategy'] = args.strategy
    if args.gpus is not None:
        replace_params['gpus'] = args.gpus
    if args.auto_select_gpus is not None:
        replace_params['auto_select_gpus'] = args.auto_select_gpus

    # print(replace_params)
    model_factory(config, **replace_params).train()

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Train a model')
    parser.add_argument('--config', default='', help='path to config file')
    parser.add_argument('--num-workers', type=int, default=None)
    parser.add_argument('--prefetch-factor', type=int, default=None)
    parser.add_argument('--batch-size', type=int, default=None)
    parser.add_argument('--num-batch-per-update', type=int, default=None)
    parser.add_argument('--devices', default=None)
    parser.add_argument('--accelerator', default=None)
    parser.add_argument('--strategy', default=None)
    parser.add_argument('--gpus', default=None)
    parser.add_argument('--auto-select-gpus', type=bool, default=None)
    args = parser.parse_args()

    main(args)
