import jax 

import haiku as hk

from models.TFSVI.model_utils.linear import LinearStochastic, Linear


class MLP(hk.Module):
    """
    Multilayer perceptron (MLP) with stochastic layers.
    """
    def __init__(
        self, 
        activation_fn, 
        architecture, 
        stochastic_layers
    ):
        """
        Initialize the MLP.

        params:
        - activation_fn (callable): activation function.
        - architecture (List[int]): number of layers and hidden units for MLP.
            For example, `[100, 100, 1]` means an MLP of two layers of 100 hidden
            units each with output dim 1.
        - stochastic_layers (List[bool]): list indicating if a layer is stochastic (True)
            or deterministic (False).
            For example, `[True, False, False]` means that the input layer of the MLP is 
            stochastic and the hidden and output layers are deterministic.
        """
        super().__init__()
        self.activation_fn = activation_fn

        self.layers = []
        for i, (unit, stochastic) in enumerate(zip(architecture, stochastic_layers)):
            if stochastic:
                self.layers.append(LinearStochastic(unit, bias=True))
            else:
                self.layers.append(Linear(unit, bias=True))


    def __call__(
        self, 
        x, 
        key, 
        is_training, 
        stochastic
    ):
        """
        Forward pass on the MLP.

        params:
        - x (jnp.array): input.
        - key (jax.random.PRNGKey): JAX random key.
        - is_training (bool): if True, apply local reparameterization trick.
        - stochastic (bool): if True, sample weights otherwise use mean parameters.
        
        returns: 
        - out (jnp.array): output of the MLP.
        """
        out = x
        for l in range(len(self.layers)-1):
            key, sub_key = jax.random.split(key)
            out = self.layers[l](out, sub_key, is_training, stochastic)
            out = self.activation_fn(out)

        # Output layer
        out = self.layers[-1](out, key, is_training, stochastic)

        return out
