import jax 

import haiku as hk
import jax.numpy as jnp

from functools import partial

from models.MFVI.model_utils.mlp import MLP


ACTIVATION_DICT = {
    "tanh": jnp.tanh,
    "relu": jax.nn.relu,
    "lrelu": jax.nn.leaky_relu,
    "elu": jax.nn.elu,
}

class BNN:
    """
    Bayesian neural network (BNN) model.
    """
    def __init__(
        self, 
        architecture, 
        stochastic_layers,
        activation_fn, 
        ll_scale,
        prior_scale
    ):
        """
        Initialize the BNN model.

        params:
        - architecture (List[int]): number of layers and hidden units for MLP.
            For example, `[100, 100, 1]` means an MLP of two layers of 100 hidden
            units each with output dim 1.
        - stochastic_layers (List[bool]): list mapping layers to booleans 
            indicating the use of stochastic or dense layers.
        - activation_fn (Str): type of activation function.
        - ll_scale (Float): scale of the Gaussian likelihood.
        - prior_scale (Float): scale of the Prior distribution.
        """
        self.activation_fn = ACTIVATION_DICT[activation_fn]
        self.architecture = architecture
        self.stochastic_layers = stochastic_layers
        self.ll_scale = ll_scale
        self.prior_scale = prior_scale
        self.forward = hk.transform(self.make_forward_fn())
        self.training_steps = 0


    @property
    def apply_fn(self):
        """
        Build vectorized apply function.
        
        returns:
        - apply (callable): vectorized apply function.
        """
        return jax.vmap(
            self.forward.apply, 
            in_axes=(None, 0, None, 0, None, None), 
            out_axes=(0,None)
        )


    def make_forward_fn(self):
        """
        Build forward function.

        returns:
        - forward (callable): forward function.
        """
        def forward_fn(x, key, is_training, stochastic):
            _forward_fn = MLP(
                activation_fn=self.activation_fn,
                architecture=self.architecture,
                stochastic_layers=self.stochastic_layers,
                prior_scale=self.prior_scale
            )
            return _forward_fn(x, key, is_training, stochastic)

        return forward_fn


    @partial(jax.jit, static_argnums=(0,4,5,6))
    def predict_f(
        self, 
        params, 
        x, 
        key, 
        mc_samples, 
        is_training, 
        stochastic
    ):
        """
        Sample from function distribution.

        params:
        - params (jax.tree_util.pytree): parameters of the BNN.
        - x (jnp.array): input data.
        - key (jax.random.PRNGKey): random key.
        - mc_samples (int): number of Monte Carlo samples.
        - is_training (bool): whether the model is in training mode.
        - stochastic (bool): whether to use stochastic layers.

        returns:
        - f (jnp.array): function samples.
        - kl (float): kl divergence between weight distributions.
        """
        keys = jax.random.split(key, num=mc_samples)
        f, kl_div = self.apply_fn(params, keys, x, keys, is_training, stochastic)
        
        return f, kl_div


    @partial(jax.jit, static_argnums=(0,4,5,6))
    def predict_y(
        self, 
        params, 
        x, 
        key, 
        mc_samples, 
        is_training, 
        stochastic
    ):
        """
        Sample from the predictive distribution.

        params:
        - params (jax.tree_util.pytree): parameters of the BNN.
        - x (jnp.array): input data.
        - key (jax.random.PRNGKey): random key.
        - mc_samples (int): number of Monte Carlo samples.
        - is_training (bool): whether the model is in training mode.
        - stochastic (bool): whether to use stochastic layers.

        returns:
        - y (jnp.array): function samples.
        - kl (float): kl divergence between weight distributions.
        """
        key1, key2 = jax.random.split(key)
        
        f, kl_div = self.predict_f(params, x, key1, mc_samples, is_training, stochastic)

        y = f + self.ll_scale*jax.random.normal(key2, shape=f.shape) 

        return y, kl_div