from argparse import ArgumentParser, REMAINDER
from typing import List
from omegaconf import DictConfig, OmegaConf

__all__ = [
    "default_parser",
    "load_config",
    "override_config_by_cli",
]

"""Common usage case

if __name__ == "__main__":
    parser = default_parser()
    args = parser.parse_args()
    d_config = load_config(args.config)
    d_config = override_config_by_cli(d_config, args.script_args)
    run(d_config)
    
"""


def default_parser() -> ArgumentParser:
    parser = ArgumentParser()
    parser.add_argument("--config", type=str, help="Run configuration", required=True)
    parser.add_argument("script_args", nargs=REMAINDER, help="Override config by CLI")
    return parser


def load_config(yaml_path: str) -> DictConfig:
    cfg = OmegaConf.load(yaml_path)
    return cfg


def override_config_by_cli(base_cfg: DictConfig, script_args: List[str]) -> DictConfig:
    """Usage of script_args:
    A.B=C D=E F.G.H=K
    [A.B=C, D=E, F.G.H=K]
    """
    cli_cfg = OmegaConf.from_dotlist(script_args)
    cfg = OmegaConf.merge(base_cfg, cli_cfg)
    return cfg
