import jax
import jax.numpy as jnp
from typing import Callable

def linear_module(config: dict) -> tuple[Callable[..., list[jax.Array]], Callable[..., jax.Array]]:
    input_dim = config['input_dim']
    output_dim = config['output_dim']
    layers = config['layers']
    layer_sizes = [input_dim] + layers + [output_dim]

    def init(rng: jax.Array, dtype: jnp.dtype) -> list[jax.Array]:
        θ = []
        for i in range(len(layer_sizes) - 1):
            key, rng = jax.random.split(rng)
            w_shape = (layer_sizes[i], layer_sizes[i + 1])
            b_shape = (layer_sizes[i + 1],)
            W = jax.random.normal(key, w_shape, dtype=dtype) * jnp.sqrt(2.0 / layer_sizes[i])
            b = jax.random.normal(key, b_shape, dtype=dtype) * 0.01
            θ.append((W, b))
        
        return θ # list of (W, b) tuples
    
    def apply(θ: list[jax.Array], u: jax.Array) -> jax.Array:
        """

        Args:
            θ (list[jax.Array]): Emulator parameters
            u (jax.Array): Batch of input state [B, T, d]

        Returns:
            u_hat: Output state
        """
        for (W, b) in θ:
            u = jnp.dot(u, W) + b
        return u
    return init, apply