import numpy as np
from omegaconf import OmegaConf
import torch

OmegaConf.register_new_resolver("eval", eval)
OmegaConf.register_new_resolver("np.pi", lambda: np.pi)

def load_train_config():
    cli_config = OmegaConf.from_cli()

    if "base" in cli_config.config:
        raise ValueError("must not specify a base config file")

    base_config = OmegaConf.load("configs/base.yml")

    task = cli_config.config.split('/')[0]
    task_config = OmegaConf.load(f"configs/{task}/base.yml")

    expt_config = OmegaConf.load(f"configs/{cli_config.config}.yml")
    expt_config.device = "cuda" if torch.cuda.is_available() else "cpu"

    return OmegaConf.merge(base_config, task_config, expt_config, cli_config)
