import jax 

import haiku as hk
import numpy as np
import jax.numpy as jnp


class LinearStochastic(hk.Module):
    """
    Stochastic linear layer.
    """
    def __init__(
        self,
        output_size,
        init_rho_minval=-10, 
        init_rho_maxval=-8,
        bias=True,
        prior_scale=1.0,
    ):
        """
        Build a stochastic linear layer.

        params:
        - output_size (int): output size of the layer.
        - init_rho_minval (float): lower bound of the range from which to
            uniformly sample the initial value of pre-activated variational variance
            parameters.
        - init_rho_maxval (float): upper bound of the range from which to
            uniformly sample the initial value of pre-activated variational variance
            parameters.
        - bias (bool): if True, include bias parameters.
        - prior_scale (float): scale of the prior distribution.
        """
        super().__init__("LinearStochastic")

        self.input_size = None
        self.output_size = output_size
        self.bias = bias
        self.uniform_init_minval = init_rho_minval
        self.uniform_init_maxval = init_rho_maxval
        self.prior_scale = prior_scale


    def __call__(
        self, 
        x, 
        key, 
        is_training, 
        stochastic
    ):
        """
        Forward pass on the stochastic linear layer.

        params:
        - x (jnp.array): input array.
        - key (jax.random.PRNGKey): JAX random key.
        - is_training (bool): if True, use the local reparameterization trick. 
        - stochastic (bool): if True, use sampled parameters, otherwise, use mean
            parameters.
        
        returns:
        - out (jnp.array): output of the stochastic linear layer.
        - kl (float): kl divergence between the variational posterior and prior.
        """
        j, k = x.shape[-1], self.output_size

        # Get parameters
        stddev = 1.0 / np.sqrt(j)
        w_mu = hk.get_parameter(
            "w_mu", 
            shape=[j, k], 
            dtype=x.dtype, 
            init=hk.initializers.RandomUniform(minval=-stddev, maxval=stddev)
        )
        w_rho = hk.get_parameter(
            "w_rho", 
            shape=[j, k], 
            dtype=x.dtype,
            init=hk.initializers.RandomUniform(
                minval=self.uniform_init_minval, maxval=self.uniform_init_maxval
            )
        )
        if self.bias:
            b_mu = hk.get_parameter(
                "b_mu", 
                shape=[k], 
                dtype=x.dtype, 
                init=hk.initializers.RandomUniform(minval=-stddev, maxval=stddev)
            )
            b_rho = hk.get_parameter(
                "b_rho", 
                shape=[k], 
                dtype=x.dtype,
                init=hk.initializers.RandomUniform(
                    minval=self.uniform_init_minval, maxval=self.uniform_init_maxval
                )
            )

        # Forward
        if not stochastic:
            # Use mean parameters
            logits = jnp.einsum("bi,io->bo", x, w_mu)
            if self.bias:
                logits += b_mu
        elif is_training:
            # Use local reparameterization
            logits_mu = jnp.einsum("bi,io->bo", x, w_mu)
            logits_var = jnp.einsum("bi,io->bo", x**2, jax.nn.softplus(w_rho)**2)    
            if self.bias:
                logits_mu += b_mu
                logits_var += jax.nn.softplus(b_rho)**2
            logits = self.gaussian_sample(logits_mu, logits_var**0.5, key)
        else:
            # Directly sample parameters
            key_1, key_2 = jax.random.split(key)
            w = self.gaussian_sample(w_mu, jax.nn.softplus(w_rho), key_1)
            logits = jnp.einsum("bi,io->bo", x, w)
            if self.bias:
                logits += self.gaussian_sample(b_mu, jax.nn.softplus(b_rho), key_2)

        # Compute KL divergence
        kl_div = self._KL_divergence(
            w_mu.flatten(), 
            jnp.square(jax.nn.softplus(w_rho)).flatten(), 
            self.prior_scale**2
        )

        if self.bias:
            kl_div += self._KL_divergence(
                b_mu.flatten(), 
                jnp.square(jax.nn.softplus(b_rho)).flatten(), 
                self.prior_scale**2
            )
            
        return logits, kl_div


    def gaussian_sample(
        self, 
        mu, 
        sig, 
        key
    ):
        """
        Return a sample from a fully-factorized Gaussian.

        params:
        - mu (jnp.array): mean of Gaussian.
        - sig (jnp.array): standard deviation of Gaussian.
        - key (jax.random.PRNGKey): JAX random key.

        returns:
        - z (jnp.array): sample from Gaussian.
        """
        eps = jax.random.normal(key, shape=sig.shape)
        z = mu + sig * eps
        
        return z    
    

    def _KL_divergence(
        self, 
        q_mean, 
        q_cov, 
        p_cov
    ):
        """
        Compute KL divergence between fully factorized Gaussian approximate
        posterior and isotropic zero-centered Gaussian prior.

        params:
        - q_mean (jax.numpy.array): mean vector of the posterior
        - q_cov (jax.numpy.array): vector with the diagonal elements 
            of the covariance matrix
        - p_cov (jax.numpy.array): scalar with the value of the 
            variation of the isotropic Gaussian prior
        
        returns:
        - kl (float): KL divergence.
        """
        eps = 1e-15
        d = np.prod(q_mean.shape)

        kl = 0.5 * (
            jnp.sum(
                (q_cov + q_mean**2) / p_cov - jnp.log(q_cov+eps)
            ) 
            - d + d*jnp.log(p_cov)
        )

        return kl


class Linear(hk.Module):
    """
    Linear layer.
    """
    def __init__(
        self, 
        output_size, 
        bias=True
    ):
        """
        Build a linear layer.

        params:
        - output_size (int): output size of the layer.
        - bias (bool): if True, include bias parameters.
        """
        super().__init__("Linear")
        
        self.input_size = None
        self.output_size = output_size
        self.bias = bias


    def __call__(self, 
        x, 
        key, 
        is_training, 
        stochastic
    ):
        """
        Forward pass on the linear layer.

        params:
        - x (jnp.array): input array.
        - key (jax.random.PRNGKey): dummy argument.
        - is_training (bool): dummy argument. 
        - stochastic (bool): dummy argument. 

        returns:
        - out (jnp.array): output of the linear layer.
        """
        j, k = x.shape[-1], self.output_size

        # Get weights
        stddev = 1.0 / np.sqrt(j)
        w = hk.get_parameter(
            "w", 
            shape=[j, k], 
            dtype=x.dtype, 
            init=hk.initializers.RandomUniform(minval=-stddev, maxval=stddev)
        )
        if self.bias:
            b = hk.get_parameter(
                "b", 
                shape=[k], 
                dtype=x.dtype, 
                init=hk.initializers.RandomUniform(minval=-stddev, maxval=stddev)
            )

        # Forward
        logits = jnp.einsum("bi,io->bo", x, w)
        if self.bias:
            logits += b
        
        return logits, 0.