from easy_tpp.config_factory import Config
from easy_tpp.runner import Runner
import argparse

def print_result(res):
    for k, val in res.items():
        print(f'Model {k}:')
        print(val)
    return

def main(args):
    """ Run the models on one dataset - take taxi dataset for example """
    results = {}
    file_path = args.config_dir

    # Run RMTPP
    config = Config.build_from_yaml_file(file_path, experiment_id='RMTPP_train')
    model_runner = Runner.build_from_config(config)
    res = model_runner.run()
    results['RMTPP'] = res
    print_result(results)

    # Run NHP
    config = Config.build_from_yaml_file(file_path, experiment_id='NHP_train')
    model_runner = Runner.build_from_config(config)
    res = model_runner.run()
    results['NHP'] = res
    print_result(results)


    # Run SAHP
    config = Config.build_from_yaml_file(file_path, experiment_id='SAHP_train')
    model_runner = Runner.build_from_config(config)
    res = model_runner.run()
    results['SAHP'] = res
    print_result(results)


    # Run THP
    config = Config.build_from_yaml_file(file_path, experiment_id='THP_train')
    model_runner = Runner.build_from_config(config)
    res = model_runner.run()
    results['THP'] = res
    print_result(results)



    # Run DLHP
    config = Config.build_from_yaml_file(file_path, experiment_id='DLHP_train')
    model_runner = Runner.build_from_config(config)
    res = model_runner.run()
    results['DLHP'] = res

    print_result(results)




    # Run IntensityFree
    config = Config.build_from_yaml_file(file_path, experiment_id='IntensityFree_train')
    model_runner = Runner.build_from_config(config)
    res = model_runner.run()  # TODO: have mark_ll=time_ll=0 for now
    results['IntensityFree'] = res
    print_result(results)


    # Run FullyNN
    config = Config.build_from_yaml_file(file_path, experiment_id='FullyNN_train')
    model_runner = Runner.build_from_config(config)
    res = model_runner.run()
    results['FullyNN'] = res

    print_result(results)


    #Run AttNHP
    #converge slow
    config = Config.build_from_yaml_file(file_path, experiment_id='AttNHP_train')
    model_runner = Runner.build_from_config(config)
    res = model_runner.run()
    results['AttNHP'] = res
    print_result(results)


    # Run ODETPP - move to back because of nan
    config = Config.build_from_yaml_file(file_path, experiment_id='ODETPP_train')
    model_runner = Runner.build_from_config(config)
    res = model_runner.run()
    results['ODETPP'] = res
    print_result(results)

    return results


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--config_dir', type=str, required=False, default='configs/exp_config_test.yaml',
                        help='Dir of configuration yaml to train and evaluate the model.')
    args = parser.parse_args()
    results = main(args)
