import hydra
from omegaconf import DictConfig, OmegaConf
from transformers import AutoConfig


def get_num_layers(model_id: str):
    model_config = AutoConfig.from_pretrained(model_id)
    return model_config.num_hidden_layers

OmegaConf.register_new_resolver("get_num_layers", get_num_layers)

def get_log_range(start, end, num):
    import numpy as np
    return list(np.logspace(np.log10(start), np.log10(end), num))

def resolve_freeze_layers(coeffs_tuple_list, model_id):
    nl = get_num_layers(model_id)
    lst = []
    for t in coeffs_tuple_list:
        lst.append((int(float(t[0])*nl), int(float(t[1])*nl)))
    return lst

OmegaConf.register_new_resolver("resolve_freeze_layers", resolve_freeze_layers)

@hydra.main(config_path=".", config_name="pipeline_default", version_base=None)
def test(cfg: DictConfig):
    print(OmegaConf.to_yaml(cfg))
    nm = cfg.num_layers
    fz = cfg.unlearn.freeze_layers
    print(f"{nm=}\n{fz=}")

if __name__ == "__main__":
    test()