import jax
import jax.numpy as jnp
from typing import Any, Callable
from modules.emulator.mlp_timestepper import make_mlp_timestepper

def make_mlp_summary(cfg: dict[str, Any]) -> dict[str, bool | Callable[..., list[jax.Array]] | Callable[..., jax.Array]]:
    init_mlp, apply_mlp = make_mlp_timestepper(cfg)

    def init(rng: jax.Array, dtype: jnp.dtype) -> list[jax.Array]:
        return init_mlp(rng, dtype)

    def apply(ψ: Any, u: jax.Array) -> jax.Array:
        return apply_mlp(ψ, u)

    return {
        "has_params": True,
        "init": init,
        "apply": apply,
    }
