import jax 

import haiku as hk

from models.MFVI.model_utils.linear import LinearStochastic, Linear


class MLP(hk.Module):
    """
    Multilayer perceptron (MLP) with stochastic layers.
    """
    def __init__(
        self, 
        activation_fn, 
        architecture, 
        stochastic_layers, 
        prior_scale
    ):
        """
        Build an MLP with stochastic layers.

        params:
        - activation_fn (callable): activation function.
        - architecture (List[int]): number of layers and hidden units for MLP.
            For example, `[100, 100, 1]` means an MLP of two layers of 100 hidden
            units each with output dim 1.
        - stochastic_layers (List[bool]): list indicating if a layer is stochastic (True)
            or deterministic (False).
            For example, `[True, False, False]` means that the input layer of the MLP is 
            stochastic and the hidden and output layers are deterministic.
        - prior_scale (float): scale of the prior distribution.
        """
        super().__init__()
        self.activation_fn = activation_fn

        self.layers = []
        for unit, stochastic in zip(architecture, stochastic_layers):
            if stochastic:
                self.layers.append(
                    LinearStochastic(
                        output_size=unit,
                        prior_scale=prior_scale, 
                        bias=True
                    )
                )
            else:
                self.layers.append(Linear(unit, bias=True))


    def __call__(
        self, 
        x, 
        key, 
        is_training, 
        stochastic
    ):
        """
        Forward pass on the MLP.

        params:
        - x (jnp.array): input.
        - key (jax.random.PRNGKey): JAX random key.
        - is_training (bool): if True, apply local reparameterization trick.
        - stochastic (bool): if True, sample weights otherwise use mean parameters.
        
        returns:
        - out (jnp.array): output of the MLP.
        """
        out = x

        kl_div = 0
        for l in range(len(self.layers) - 1):
            key, sub_key = jax.random.split(key)
            out, kl_div_layer = self.layers[l](out, sub_key, is_training, stochastic)
            out = self.activation_fn(out)
            kl_div += kl_div_layer

        # Final layer
        out, kl_div_layer = self.layers[-1](out, key, is_training, stochastic)
        kl_div += kl_div_layer

        return out, kl_div


