import flax.nnx as nnx
import optax


def create_opt(
    gfn: nnx.Module,
    epochs: int,
    logz_lr: float = 1e-1,
    lr: float = 1e-2,
    use_scheduler: bool = True,
    should_clip: bool = False,
):
    params = nnx.state(gfn, nnx.Param)

    def label_params(path, _):
        if "logz" in path:
            return "logz"
        return "default"

    param_labels = nnx.map_state(label_params, params)

    if use_scheduler:
        default_schedule = optax.schedules.warmup_cosine_decay_schedule(
            init_value=0,
            warmup_steps=epochs // 10,
            peak_value=lr,
            end_value=lr * 1e-2,
            decay_steps=epochs,
        )
    else:
        default_schedule = lr

    clip_default = optax.clip_by_global_norm(1.0) if should_clip else optax.identity()

    tx = optax.multi_transform(
        {
            "logz": optax.chain(
                optax.adam(learning_rate=logz_lr),
            ),
            "default": optax.chain(
                clip_default,
                optax.contrib.muon(learning_rate=default_schedule),
            ),
        },
        param_labels,
    )

    return nnx.Optimizer(gfn, tx=tx, wrt=nnx.Param)
