import argparse
import os
from core.trainer import Trainer
from core.utils.config_utils import merge_configs, make_dataclass_from_dict
from mpi4py import MPI
import json

comm = MPI.COMM_WORLD
rank = comm.Get_rank()
num_workers = comm.Get_size()


def _get_parser_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--base_config", type=str, default="base.json", help="config")
    parser.add_argument("--extension_config", type=str, default="", help="config")
    parser.add_argument("--seed", type=int, default=0, help="seed number")
    args = parser.parse_args()
    return args


if __name__ == "__main__":

    os.environ['OMP_NUM_THREADS'] = '1'
    os.environ['MKL_NUM_THREADS'] = '1'
    os.environ['IN_MPI'] = '1'

    args = _get_parser_args()

    # Read and merge configs
    base_config = json.load(open(os.path.join("experiments/configs/", args.base_config), "r"))
    extension_config = json.load(open(os.path.join("experiments/configs/", args.extension_config), "r"))
    dict_config = merge_configs(base_config, extension_config)
    # Config dict to dataclass
    config = make_dataclass_from_dict(dict_config)

    config.xp_params.seed = args.seed
    config.name = "qd_"+str(num_workers)+"_workers"

    trainer = Trainer(config)
    trainer.train()

