import jax
import jax.numpy as jnp
from typing import Callable

def mlp_module(config: dict) -> tuple[Callable[..., list[jax.Array]], Callable[..., jax.Array]]:
    input_dim = config['input_dim']
    output_dim = config['output_dim']
    hidden_layers = config['hidden_layers']
    activation_name = config['activation']
    layer_sizes = [input_dim] + hidden_layers + [output_dim]

    activation = {'relu': jax.nn.relu,
                'tanh': jax.nn.tanh,
                'gelu': jax.nn.gelu,
                'sigmoid': jax.nn.sigmoid,
                'identity': lambda x: x}[activation_name]

    def init(rng: jax.Array, dtype: jnp.dtype) -> list[jax.Array]:
        θ = []
        for i in range(len(layer_sizes) - 2):
            key, rng = jax.random.split(rng)
            w_shape = (layer_sizes[i], layer_sizes[i + 1])
            b_shape = (layer_sizes[i + 1],)
            W = jax.random.normal(key, w_shape, dtype=dtype) * jnp.sqrt(2.0 / layer_sizes[i])
            b = jax.random.normal(key, b_shape, dtype=dtype) * 0.01
            θ.append((W, b))

        # Last layer initialize at zero
        key, rng = jax.random.split(rng)
        w_shape = (layer_sizes[-2], layer_sizes[-1])
        b_shape = (layer_sizes[-1],)
        W = jax.random.normal(key, w_shape, dtype=dtype) * jnp.sqrt(2.0 / layer_sizes[-2]) * 1e-4
        b = jax.random.normal(key, b_shape, dtype=dtype) * 1e-5
        θ.append((W, b))

        return θ # list of (W, b) tuples
    
    def apply(θ: list[jax.Array], u: jax.Array) -> jax.Array:
        """

        Args:
            θ (list[jax.Array]): Emulator parameters
            u (jax.Array): Batch of input state [B, T, d]

        Returns:
            u_hat: Output state
        """
        for (W, b) in θ[:-1]:
            u = activation(jnp.dot(u, W) + b)
        W, b = θ[-1]
        u_hat = jnp.dot(u, W) + b
        return u_hat
        
    return init, apply