from omegaconf import DictConfig, OmegaConf

from smlm.utils.git import get_head_commit_hash


def initialize_config(cfg: DictConfig):
    OmegaConf.set_struct(cfg, False)
    if not OmegaConf.has_resolver("pow2"):
        OmegaConf.register_new_resolver("pow2", lambda n: 2 ** int(n))
    if not OmegaConf.has_resolver("eval"):
        OmegaConf.register_new_resolver("eval", eval)
    cfg.git_commit_hash = get_head_commit_hash()
    return cfg


def add_git_commit_hash(cfg: DictConfig):
    cfg.git_commit_hash = get_head_commit_hash()
    return cfg


def add_eff_batch_size(cfg: DictConfig, world_size: int):
    cfg.eff_batch_size = cfg.batch_size * world_size * cfg.n_accum_steps
    return cfg


def add_total_steps(cfg: DictConfig, step_per_epoch: int):
    if cfg.n_epochs > 0:
        cfg.total_steps = cfg.n_epochs * step_per_epoch // cfg.eff_batch_size
    else:
        cfg.total_steps = -1
    return cfg
