import haiku as hk

from models.GFSVI.model_utils.linear import LinearStochastic, Linear

class MLP(hk.Module):
    """
    Multilayer perceptron (MLP) with stochastic layers.
    """
    def __init__(
        self, 
        activation_fn, 
        architecture, 
        stochastic_layers, 
        init_rho_minval, 
        init_rho_maxval
    ):
        """
        Initialize a MLP.

        params:
        - activation_fn (callable): activation function.
        - architecture (List[int]): number of layers and hidden units for MLP.
            For example, `[100, 100, 1]` means an MLP of two layers of 100 hidden
            units each with output dim 1.
        - stochastic_layers (List[bool]): list indicating if a layer is stochastic (True)
            or deterministic (False).
            For example, `[True, False, False]` means that the input layer of the MLP is 
            stochastic and the hidden and output layers are deterministic.
        - init_rho_minval (float): minimum value of range from which to uniformly
            sample the initial value of pre-activated variational variance parameters.
        - init_rho_maxval (float): maximum value of range from which to uniformly
            sample the initial value of pre-activated variational variance parameters.
        """
        super().__init__()

        self.activation_fn = activation_fn
    
        # Build MLP
        self.layers = []
        for i, (unit, stochastic) in enumerate(zip(architecture, stochastic_layers)):
            if stochastic:
                self.layers.append(
                    LinearStochastic(
                        unit, 
                        i, 
                        init_rho_minval, 
                        init_rho_maxval,
                        bias=True
                    )
                )
            else:
                self.layers.append(
                    Linear(
                        unit, 
                        i, 
                        bias=True
                    )
                )


    def __call__(
        self, 
        x
    ):
        """
        Forward pass on the MLP.

        params:
        - x (jnp.ndarray): input.
        
        returns: 
        - out (jnp.ndarray): output of the MLP.
        """
        out = x
        for l in range(len(self.layers)-1):
            out = self.layers[l](out)
            out = self.activation_fn(out)

        # Output layer
        out = self.layers[-1](out)

        return out


