import jax 
import tree

import haiku as hk
import jax.numpy as jnp

from jax import eval_shape
from functools import partial

from models.GFSVI.model_utils.mlp import MLP


ACTIVATION_DICT = {
    "tanh": jnp.tanh,
    "relu": jax.nn.relu,
    "lrelu": jax.nn.leaky_relu,
    "elu": jax.nn.elu
}


class BNN:
    """
    Bayesian neural network (BNN) model.
    """
    def __init__(
        self, 
        architecture, 
        stochastic_layers,
        activation_fn, 
        ll_scale,
        init_rho_minval, 
        init_rho_maxval
    ):
        """
        Initialize BNN.

        params:
        - architecture (List[int]): number of layers and hidden units for MLP.
            For example, `[100, 100, 1]` means an MLP of two layers of 100 hidden
            units each with output dim 1.
        - stochastic_layers (List[bool]): list mapping layers to booleans 
            indicating the use of stochastic or dense layers.
        - activation_fn (str): type of activation function.
        - ll_scale (float): scale of the Gaussian likelihood.
        - init_rho_minval (float): minimum value of range from which to uniformly
            sample the initial value of pre-activated variational variance parameters.
        - init_rho_maxval (float): maximum value of range from which to uniformly
            sample the initial value of pre-activated variational variance parameters.
        """
        self.activation_fn = ACTIVATION_DICT[activation_fn]
        self.architecture = architecture
        self.stochastic_layers = stochastic_layers
        self.init_rho_minval = init_rho_minval
        self.init_rho_maxval = init_rho_maxval
        self.ll_scale = ll_scale
        self.forward = hk.transform(self.make_forward_fn())
        self.training_steps = 0


    @property
    def apply_fn(self):
        """
        Build vectorized apply function.

        returns:
        - apply (callable): vectorized apply function.
        """
        return jax.vmap(
            self.forward.apply, in_axes=(None, 0, None)
        )


    def make_forward_fn(self):
        """
        Build forward function.

        returns:
        - forward (callable): forward function.
        """
        def forward_fn(x):
            _forward_fn = MLP(
                activation_fn=self.activation_fn,
                architecture=self.architecture,
                stochastic_layers=self.stochastic_layers,
                init_rho_minval=self.init_rho_minval, 
                init_rho_maxval=self.init_rho_maxval
            )
            return _forward_fn(x)

        return forward_fn


    @partial(jax.jit, static_argnums=(0,4))
    def predict_f(
        self,
        params, 
        x, 
        key, 
        mc_samples
    ):
        """
        Sample from the linearized function distribution.

        params:
        - params (jax.tree_util.pytree): parameters of the BNN.
        - x (jnp.ndarray): input data.
        - key (jax.random.PRNGKey): random key.
        - mc_samples (int): number of Monte Carlo samples.

        returns:
        - f_lin_nn (jnp.ndarray): function samples.
        """
        key1, key2, key3, key4 = jax.random.split(key, num=4)

        # Partition parameters
        params_stochastic, params_non_stochastic = self.partition_stochastic_params(params)
        params_mean, params_rho = self.partition_mu_rho_params(params_stochastic)
        params_sig = jax.tree_util.tree_map(lambda p: jax.nn.softplus(p), params_rho)

        # Sample parameters
        theta = hk.data_structures.to_mutable_dict(params_sig)
        for module in theta:
            key1, key2 = jax.random.split(key1)
            theta[module] = {
                f"{n.split('_')[0]}_mu": w*jax.random.normal(key2, shape=(mc_samples,)+w.shape) 
                for n, w in theta[module].items()
            }
        leaves, _ = jax.tree_util.tree_flatten(theta)
        theta = jax.tree_util.tree_unflatten(jax.tree_util.tree_structure(params_mean), leaves)

        # Neural network predict function 
        def f(p_mean):
            p = hk.data_structures.merge(p_mean, params_rho, params_non_stochastic)
            return self.forward.apply(p, key3, x)

        # Compute the linearized neural network function
        f_lin_nn = jax.vmap(jax.jvp, in_axes=(None, None, 0), out_axes=(None, 0))(f, (params_mean,), (theta,))[1] 
        f_lin_nn += self.forward.apply(params, key4, x)

        return f_lin_nn
        
        
    @partial(jax.jit, static_argnums=(0,4))
    def predict_y(
        self, 
        params, 
        x, 
        key, 
        mc_samples
    ):
        """
        Sample from the linearized predictive distribution.
        
        params:
        - params (jax.tree_util.pytree): parameters of the BNN.
        - x (jnp.ndarray): input data.
        - key (jax.random.PRNGKey): random key.
        - mc_samples (int): number of Monte Carlo samples.

        returns:
        - y (jnp.ndarray): function samples.
        """    
        key1, key2 = jax.random.split(key)
        
        f_lin = self.predict_f_linearized(params, x, key1, mc_samples)

        y = f_lin + self.ll_scale*jax.random.normal(key2, shape=f_lin.shape) 

        return y
    

    @partial(jax.jit, static_argnums=(0,))
    def f_distribution(
        self, 
        params, 
        x, 
        key
    ):
        """
        Return the mean and covariance the linearized functional distribution.

        params:
        - params (jax.tree_util.pytree): parameters of the BNN.
        - x (jnp.ndarray): input data.
        - key (jax.random.PRNGKey): random key.
        
        returns:
        - mean (jnp.ndarray): mean of the distribution.
        - cov (jnp.ndarray): covariance of the distribution.
        """
        key1, key2 = jax.random.split(key)

        # Compute the mean of the function distribution 
        mean = self.forward.apply(params, key1, x).reshape(-1)

        # Compute the covariance of the function distribution 
        cov = self.f_distribution_kernel(params, x, x, key2)

        return mean, cov        
    

    @partial(jax.jit, static_argnums=(0,))
    def f_diag_distribution(
        self, 
        params, 
        x, 
        key
    ):
        """
        Return the mean and diagonalized covariance the linearized functional distribution.

        params:
        - params (jax.tree_util.pytree): parameters of the BNN.
        - x (jnp.ndarray): input data.
        - key (jax.random.PRNGKey): random key.
        
        returns:
        - mean (jnp.ndarray): mean of the distribution.
        - diag_cov (jnp.ndarray): diagonal covariance of the distribution.
        """
        key1, key2 = jax.random.split(key)

        # Compute the mean of the distribution 
        mean = self.forward.apply(params, key1, x).reshape(-1)

        # Compute the covariance of the distribution 
        x = jnp.expand_dims(x, axis=1)
        kernel_fn = lambda z: self.f_distribution_kernel(params, z, z, key2)
        diag_cov = jax.vmap(kernel_fn, in_axes=0)(x).reshape(-1) 

        return mean, diag_cov


    @partial(jax.jit, static_argnums=(0,))
    def f_distribution_kernel(
        self, 
        params, 
        x1, 
        x2, 
        key
    ):
        """
        Compute the kernel induced by the linearized BNN.

        params:
        - params (jax.tree_util.pytree): parameters of the BNN.
        - x1 (jnp.ndarray): input data.
        - x2 (jnp.ndarray): input data.
        - key (jax.random.PRNGKey): random key.

        returns
        - kernel (jnp.ndarray): kernel matrix.
        """
        # Parition parameters 
        params_stochastic, params_non_stochastic = self.partition_stochastic_params(params)
        params_mean, params_rho = self.partition_mu_rho_params(params_stochastic)
        params_var = tree.map_structure(lambda p: jax.nn.softplus(p)**2, params_rho)

        # Predict function of the neural network
        def predict_fn(p_mean, z):
            p = hk.data_structures.merge(p_mean, params_rho, params_non_stochastic)
            return self.forward.apply(p, key, z).reshape(-1)

        f1 = lambda p: predict_fn(p, x1)
        f2 = lambda p: predict_fn(p, x2)

        def delta_vjp_jvp(delta):
            delta_vjp = lambda delta: jax.vjp(f2, params_mean)[1](delta)
            renamed_params_var = self.map_variable_name(
                params_var, lambda n: f"{n.split('_')[0]}_mu"
            )
            vj_prod = tree.map_structure(
                lambda x1, x2: x1 * x2, renamed_params_var, delta_vjp(delta)[0]
            )
            leaves, _ = jax.tree_util.tree_flatten(vj_prod)
            tmp = jax.tree_util.tree_unflatten(jax.tree_util.tree_structure(params_mean), leaves)

            return jax.jvp(f1, (params_mean,), (tmp,))[1]

        # Compute the kernel
        fx2 = eval_shape(f2, params_mean)
        eye = jnp.eye(x1.shape[0])
        kernel = jax.vmap(jax.linear_transpose(delta_vjp_jvp, fx2))(eye)[0]

        return kernel
        

    def map_variable_name(
        self, 
        params, 
        fn
    ):
        """
        Rename parameters.

        params:
        - params (jax.tree_util.pytree): parameters of the BNN.
        - fn (callable): function to apply to the parameters names.

        returns:
        - params (jax.tree_util.pytree): parameters of the BNN.
        """
        params = hk.data_structures.to_mutable_dict(params)
        for module in params:
            params[module] = {
                fn(var_name): array for var_name, array in params[module].items()
            }

        return hk.data_structures.to_immutable_dict(params)

    
    @partial(jax.jit, static_argnums=(0,))
    def partition_mu_rho_params(
        self, 
        params
    ):
        """
        Split parameters into variational mean and variance.

        params:
        - params (jax.tree_util.pytree): parameters of the BNN.

        returns:
        - mean_params (jax.tree_util.pytree): mean parameters of the BNN.
        - var_params (jax.tree_util.pytree): variance parameters of the BNN.
        """
        return hk.data_structures.partition(
            lambda m, n, p: "mu" in n, params
        )
    

    @partial(jax.jit, static_argnums=(0,))
    def partition_stochastic_params(
        self, 
        params
    ):
        """
        Split parameters into stochastic and non-stochastic.
        
        params:
        - params (jax.tree_util.pytree): parameters of the BNN.

        returns:
        - stochastic_params (jax.tree_util.pytree): stochastic parameters of the BNN.
        - non_stochastic_params (jax.tree_util.pytree): non-stochastic parameters of the BNN.
        """
        return hk.data_structures.partition(
            lambda m, n, p: ("mu" in n) or ("rho" in n), params
        )