import optax

OPTIMIZER_REGISTRY = {
    "adam": lambda lr, **kw: optax.adam(learning_rate=lr, **kw),
}