import numpy as np
from absl import flags, logging
from jax import numpy as jnp

FLAGS = flags.FLAGS
# Work in progress


def get_restoration_values():
    return {"attention_scale": FLAGS.attn_scale_init}


def is_dict_like(elem):
    return hasattr(elem, "keys") and hasattr(elem, "items")


def restore_alike(array_spec, value_preference):
    return jnp.full(
        shape=array_spec.shape, fill_value=value_preference, dtype=array_spec.dtype
    )


def fix_loaded(pattern, loaded):
    if not is_dict_like(pattern):
        assert not is_dict_like(loaded)
        return loaded
    assert is_dict_like(loaded)
    res = {}
    for key, value in pattern.items():
        if key not in loaded:
            assert not is_dict_like(value)
            assert key in get_restoration_values()
            restored_value = restore_alike(value, get_restoration_values()[key])
            logging.info(f"Performing restoration of {key} to {restored_value}")
            res[key] = restored_value
        else:
            res[key] = fix_loaded(value, loaded[key])
    return res
