import haiku as hk

from models.FVI.model_utils.linear import LinearStochastic, Linear

class MLP(hk.Module):
    """
    Multilayer perceptron (MLP) with stochastic layers.
    """
    def __init__(
        self, 
        activation_fn, 
        architecture, 
        stochastic_layers
    ):
        """
        Initialize a MLP.

        params:
        - activation_fn (callable): activation function.
        - architecture (List[int]): number of layers and hidden units for MLP.
            For example, `[100, 100, 1]` means an MLP of two layers of 100 hidden
            units each with output dim 1.
        - stochastic_layers (List[bool]): list indicating if a layer is stochastic (True)
            or deterministic (False).
            For example, `[True, False, False]` means that the input layer of the MLP is 
            stochastic and the hidden and output layers are deterministic.
        """
        super().__init__()

        self.activation_fn = activation_fn
        
        # Build MLP
        self.layers = []
        for (unit, stochastic) in zip(architecture, stochastic_layers):
            if stochastic:
                self.layers.append(
                    LinearStochastic(unit, bias=True)
                )
            else:
                self.layers.append(
                    Linear(unit, bias=True)
                )


    def __call__(
        self, 
        x, 
        key
    ):
        """
        Forward pass on the MLP.

        params:
        - x (jnp.ndarray): input.
        - key (jax.random.PRNGKey): JAX random key.
        
        returns: 
        - out (jnp.ndarray): output of the MLP.
        """
        out = x
        for l in range(len(self.layers)-1):
            out = self.layers[l](out, key)
            out = self.activation_fn(out)

        # Output layer
        out = self.layers[-1](out, key)

        return out


