import jax 
import tree

import haiku as hk
import jax.numpy as jnp

from jax import eval_shape
from functools import partial

from models.TFSVI.model_utils.mlp import MLP


ACTIVATION_DICT = {
    "tanh": jnp.tanh,
    "relu": jax.nn.relu,
    "lrelu": jax.nn.leaky_relu,
    "elu": jax.nn.elu
}


class BNN:
    """
    Bayesian neural network (BNN) model.
    """
    
    def __init__(
        self, 
        architecture, 
        stochastic_layers,
        activation_fn, 
        ll_scale, 
        prior_scale
    ):
        """
        Initialize BNN.

        params:
        - architecture (List[int]): number of layers and hidden units for MLP.
            For example, `[100, 100, 1]` means an MLP of two layers of 100 hidden
            units each with output dim 1.
        - stochastic_layers (List[bool]): list mapping layers to booleans 
            indicating the use of stochastic or dense layers.
        - activation_fn (str): type of activation function.
        - ll_scale (float): scale of the Gaussian likelihood.
        """
        self.activation_fn = ACTIVATION_DICT[activation_fn]
        self.architecture = architecture
        self.stochastic_layers = stochastic_layers
        self.ll_scale = ll_scale
        self.forward = hk.transform(self.make_forward_fn())
        self.prior_scale = prior_scale
        self.training_steps = 0


    @property
    def apply_fn(self):
        """
        Build vectorized apply function.

        returns:
        - apply (callable): vectorized apply function.
        """
        return jax.vmap(
            self.forward.apply, 
            in_axes=(None, 0, None, 0, None, None)
        )


    def make_forward_fn(self):
        """
        Build forward function.

        returns:
        - forward (callable): forward function.
        """
        def forward_fn(x, key, is_training, stochastic):
            _forward_fn = MLP(
                activation_fn=self.activation_fn,
                architecture=self.architecture,
                stochastic_layers=self.stochastic_layers
            )
            return _forward_fn(x, key, is_training, stochastic)

        return forward_fn


    @partial(jax.jit, static_argnums=(0,4,5,6))
    def predict_f(
        self,
        params, 
        x, 
        key, 
        mc_samples, 
        is_training, 
        stochastic
    ):
        """
        Sample from function distribution.

        params:
        - params (jax.tree_util.pytree): parameters of the BNN.
        - x (jnp.array): input data.
        - key (jax.random.PRNGKey): random key.
        - mc_samples (int): number of Monte Carlo samples.
        - is_training (bool): whether the model is in training mode.
        - stochastic (bool): whether to use stochastic layers.

        returns:
        - f (jnp.array): function samples.
        """
        keys = jax.random.split(key, mc_samples)
        f = self.apply_fn(params, keys, x, keys, is_training, stochastic)
        
        return f
        
        
    @partial(jax.jit, static_argnums=(0,4,5,6))
    def predict_y(
        self, 
        params, 
        x, 
        key, 
        mc_samples, 
        is_training, 
        stochastic
    ):
        """
        Sample from the predictive distribution.

        params:
        - params (jax.tree_util.pytree): parameters of the BNN.
        - x (jnp.array): input data.
        - key (jax.random.PRNGKey): random key.
        - mc_samples (int): number of Monte Carlo samples.
        - is_training (bool): whether the model is in training mode.
        - stochastic (bool): whether to use stochastic layers.

        returns:
        - y (jnp.array): function samples.
        """
        key1, key2 = jax.random.split(key)
        
        f = self.predict_f(params, x, key1, mc_samples, is_training, stochastic)

        y = f + self.ll_scale*jax.random.normal(key2, shape=f.shape) 

        return y
    

    @partial(jax.jit, static_argnums=(0,))
    def f_distribution(
        self, 
        params, 
        x, 
        key
    ):
        """
        Return the mean and covariance the linearized functional distribution.

        params:
        - params (jax.tree_util.pytree): parameters of the BNN.
        - x (jnp.array): input data.
        - key (jax.random.PRNGKey): random key.
        
        returns:
        - mean (jnp.array): mean of the function distribution.
        - cov (jnp.array): covariance of the function distribution.
        """
        key1, key2 = jax.random.split(key)

        # Compute the mean of the distribution 
        mean = self.forward.apply(
            params, 
            key1, 
            x, 
            key1, 
            is_training=False, 
            stochastic=False
        ).reshape(-1)

        # Compute the covariance of the distribution 
        cov = self.f_distribution_kernel(params, x, x, key2)

        return mean, cov        
    

    @partial(jax.jit, static_argnums=(0,))
    def f_diag_distribution(
        self, 
        params, 
        x, 
        key
    ):
        """
        Return the mean and diagonalized covariance the linearized functional distribution.

        params:
        - params (jax.tree_util.pytree): parameters of the BNN.
        - x (jnp.array): input data.
        - key (jax.random.PRNGKey): random key.
        
        returns:
        - mean (jnp.array): mean of the distribution.
        - diag_cov (jnp.array): diagonal covariance of the distribution.
        """
        key1, key2 = jax.random.split(key)

        # Compute the mean of the distribution 
        mean = self.forward.apply(
            params, 
            key1, 
            x, 
            key1, 
            is_training=False, 
            stochastic=False
        ).reshape(-1)

        # Compute the covariance of the distribution 
        x = jnp.expand_dims(x, axis=1)
        kernel_fn = lambda z: self.f_distribution_kernel(params, z, z, key2)
        diag_cov = jax.vmap(kernel_fn, in_axes=0)(x).reshape(-1) 

        return mean, diag_cov


    @partial(jax.jit, static_argnums=(0,))
    def f_distribution_kernel(
        self, 
        params, 
        x1, 
        x2, 
        key
    ):
        """
        Compute the kernel induced by the linearized BNN.

        params:
        - params (jax.tree_util.pytree): parameters of the BNN.
        - x1 (jnp.array): input data.
        - x2 (jnp.array): input data.
        - key (jax.random.PRNGKey): random key.

        returns
        - kernel (jnp.array): kernel matrix.
        """
        # Parition parameters 
        params_stochastic, params_non_stochastic = self.partition_stochastic_params(params)
        params_mean, params_rho = self.partition_mu_rho_params(params_stochastic)
        params_var = tree.map_structure(lambda p: jax.nn.softplus(p)**2, params_rho)

        def predict_fn(p_mean, z):
            p = hk.data_structures.merge(p_mean, params_rho, params_non_stochastic)
            return self.forward.apply(p, key, z, key, is_training=False, stochastic=False).reshape(-1)

        f1 = lambda p: predict_fn(p, x1)
        f2 = lambda p: predict_fn(p, x2)

        def delta_vjp_jvp(delta):
            delta_vjp = lambda delta: jax.vjp(f2, params_mean)[1](delta)
            renamed_params_var = self.map_variable_name(
                params_var, lambda n: f"{n.split('_')[0]}_mu"
            )
            vj_prod = tree.map_structure(
                lambda x1, x2: x1 * x2, renamed_params_var, delta_vjp(delta)[0]
            )
            leaves, _ = jax.tree_util.tree_flatten(vj_prod)
            tmp = jax.tree_util.tree_unflatten(jax.tree_util.tree_structure(params_mean), leaves)

            return jax.jvp(f1, (params_mean,), (tmp,))[1]

        fx2 = eval_shape(f2, params_mean)
        eye = jnp.eye(x1.shape[0])
        kernel = jax.vmap(jax.linear_transpose(delta_vjp_jvp, fx2))(eye)[0]

        return kernel
    

    @partial(jax.jit, static_argnums=(0,))
    def prior_f_distribution(
        self, 
        params, 
        x, 
        key
    ):
        """
        Compute the prior functional distribution.

        params:
        - params (jax.tree_util.pytree): parameters of the BNN.
        - x (jnp.array): input data.
        - key (jax.random.PRNGKey): random key.

        returns
        - p_mean (jnp.array): prior function mean.
        - p_cov (jnp.array): prior function cov.
        """
        # Parition parameters 
        params_stochastic, params_non_stochastic = self.partition_stochastic_params(params)
        params_mean, params_rho = self.partition_mu_rho_params(params_stochastic)
        
        def fwd(p_mean):
            p = hk.data_structures.merge(p_mean, params_rho, params_non_stochastic)
            return self.forward.apply(p, key, x, key, is_training=False, stochastic=False).reshape(-1)

        # Compute mean
        p_mean = fwd(params_mean)
        p_mean -= jax.jvp(fwd, (params_mean,), (params_mean,))[1] 

        def NTK(p_mean):
            JJ_T = lambda v: jax.jvp(fwd, (p_mean,), jax.vjp(fwd, p_mean)[1](v))[1]
            return jax.vmap(JJ_T)(jnp.eye(x.shape[0])).T
        
        # Compute covariance
        p_cov = self.prior_scale**2 * NTK(params_mean)
        
        return p_mean, p_cov
        

    def map_variable_name(
        self, 
        params, 
        fn
    ):
        """
        Rename parameters.

        params:
        - params (jax.tree_util.pytree): parameters of the BNN.
        - fn (callable): function to apply to the parameters names.

        returns:
        - params (jax.tree_util.pytree): parameters of the BNN.
        """
        params = hk.data_structures.to_mutable_dict(params)
        for module in params:
            params[module] = {
                fn(var_name): array for var_name, array in params[module].items()
            }

        return hk.data_structures.to_immutable_dict(params)

    
    @partial(jax.jit, static_argnums=(0,))
    def partition_mu_rho_params(
        self, 
        params
    ):
        """
        Split parameters into mean and variance.

        params:
        - params (jax.tree_util.pytree): parameters of the BNN.

        returns:
        - mean_params (jax.tree_util.pytree): variance parameters of the BNN.
        - var_params (jax.tree_util.pytree): variance parameters of the BNN.
        """
        return hk.data_structures.partition(
            lambda m, n, p: "mu" in n, params
        )
    

    @partial(jax.jit, static_argnums=(0,))
    def partition_stochastic_params(
        self, 
        params
    ):
        """
        Split parameters into stochastic and non-stochastic.
        
        params:
        - params (jax.tree_util.pytree): parameters of the BNN.

        returns:
        - stochastic_params (jax.tree_util.pytree): stochastic parameters of the BNN.
        - non_stochastic_params (jax.tree_util.pytree): non-stochastic parameters of the BNN.
        """
        return hk.data_structures.partition(
            lambda m, n, p: ("mu" in n) or ("rho" in n), params
        )