import jax
from typing import Callable
from modules.architectures.mlp import mlp_module

def make_mlp_timestepper(config: dict) -> tuple[Callable[..., list[jax.Array]], Callable[..., jax.Array]]:
    return mlp_module(config)

