import jax 

import haiku as hk
import jax.numpy as jnp

from functools import partial

from models.FVI.model_utils.mlp import MLP


ACTIVATION_DICT = {
    "tanh": jnp.tanh,
    "relu": jax.nn.relu,
    "lrelu": jax.nn.leaky_relu,
    "elu": jax.nn.elu,
}


class BNN:
    """
    Bayesian neural network (BNN) model.
    """
    def __init__(
        self, 
        architecture, 
        stochastic_layers,
        activation_fn, 
        ll_scale
    ):
        """
        Initialize BNN.

        params:
        - architecture (List[int]): number of layers and hidden units for MLP.
            For example, `[100, 100, 1]` means an MLP of two layers of 100 hidden
            units each with output dim 1.
        - stochastic_layers (List[bool]): list mapping layers to booleans 
            indicating the use of stochastic or dense layers.
        - activation_fn (str): type of activation function.
        - ll_scale (float): scale of the Gaussian likelihood.
        """
        self.activation_fn = ACTIVATION_DICT[activation_fn]
        self.architecture = architecture
        self.stochastic_layers = stochastic_layers
        self.ll_scale = ll_scale
        self.forward = hk.transform(self.make_forward_fn())
        self.training_steps = 0


    @property
    def apply_fn(self):
        """
        Build vectorized apply function.

        returns:
        - apply (callable): vectorized apply function.
        """
        return jax.vmap(
            self.forward.apply, in_axes=(None, 0, None, 0)
        )


    def make_forward_fn(self):
        """
        Build forward function.

        returns:
        - forward (callable): forward function.
        """
        def forward_fn(x, key):
            _forward_fn = MLP(
                activation_fn=self.activation_fn,
                architecture=self.architecture,
                stochastic_layers=self.stochastic_layers
            )
            return _forward_fn(x, key)

        return forward_fn


    @partial(jax.jit, static_argnums=(0,4))
    def predict_f(
        self,
        params, 
        x, 
        key, 
        mc_samples
    ):
        """
        Sample from the linearized function distribution.

        params:
        - params (jax.tree_util.pytree): parameters of the BNN.
        - x (jnp.ndarray): input data.
        - key (jax.random.PRNGKey): random key.
        - mc_samples (int): number of Monte Carlo samples.

        returns:
        - f (jnp.ndarray): function samples.
        """
        keys = jax.random.split(key, num=mc_samples)
        f = self.apply_fn(params, keys, x, keys)
        
        return f
        
        
    @partial(jax.jit, static_argnums=(0,4))
    def predict_y(
        self, 
        params, 
        x, 
        key, 
        mc_samples
    ):
        """
        Sample from the linearized predictive distribution.
        
        params:
        - params (jax.tree_util.pytree): parameters of the BNN.
        - x (jnp.ndarray): input data.
        - key (jax.random.PRNGKey): random key.
        - mc_samples (int): number of Monte Carlo samples.

        returns:
        - y (jnp.ndarray): function samples.
        """    
        key1, key2 = jax.random.split(key)
        
        f_lin = self.predict_f(params, x, key1, mc_samples)

        y = f_lin + self.ll_scale*jax.random.normal(key2, shape=f_lin.shape) 

        return y
