import jax 

import haiku as hk
import numpy as np
import jax.numpy as jnp


class LinearStochastic(hk.Module):
    """
    Stochastic linear layer.
    """
    def __init__(
        self,
        output_size,
        init_rho_minval=-10, 
        init_rho_maxval=-8,
        bias=True
    ):
        """
        Build a stochastic linear layer.

        params:
        - output_size (int): output size of the layer.
        - init_rho_minval (float): lower bound of the range from which to
            uniformly sample the initial value of pre-activated variational variance
            parameters.
        - init_rho_maxval (float): upper bound of the range from which to
            uniformly sample the initial value of pre-activated variational variance
            parameters.
        - bias (bool): if True, include bias parameters.
        - prior_scale (float): scale of the prior distribution.
        """
        super().__init__("LinearStochastic")
        
        self.input_size = None
        self.output_size = output_size
        self.bias = bias
        self.uniform_init_minval = init_rho_minval
        self.uniform_init_maxval = init_rho_maxval


    def __call__(
        self, 
        x, 
        key, 
        is_training, 
        stochastic
    ):
        """
        Forward pass on the stochastic linear layer.

        params:
        - x (jnp.array): input array.
        - key (jax.random.PRNGKey): JAX random key.
        - is_training (bool): if True, use the local reparameterization trick. 
        - stochastic (bool): if True, use sampled parameters, otherwise, use mean
            parameters.
        
        returns:
        - out (jnp.array): output of the stochastic linear layer.
        """
        j, k = x.shape[-1], self.output_size

        # Get parameters
        stddev = 1.0 / np.sqrt(j)
        w_mu = hk.get_parameter(
            "w_mu", 
            shape=[j, k], 
            dtype=x.dtype, 
            init=hk.initializers.RandomUniform(minval=-stddev, maxval=stddev)
        )
        w_rho = hk.get_parameter(
            "w_rho", 
            shape=[j, k], 
            dtype=x.dtype,
            init=hk.initializers.RandomUniform(
                minval=self.uniform_init_minval, maxval=self.uniform_init_maxval
            )
        )
        if self.bias:
            b_mu = hk.get_parameter(
                "b_mu", 
                shape=[k], 
                dtype=x.dtype, 
                init=hk.initializers.RandomUniform(minval=-stddev, maxval=stddev)
            )
            b_rho = hk.get_parameter(
                "b_rho", 
                shape=[k], 
                dtype=x.dtype,
                init=hk.initializers.RandomUniform(
                    minval=self.uniform_init_minval, maxval=self.uniform_init_maxval
                )
            )

        # Forward
        if not stochastic:
            # Use mean parameters
            logits = jnp.einsum("bi,io->bo", x, w_mu)
            if self.bias:
                logits += b_mu
        elif is_training:
            # Use local reparameterization
            logits_mu = jnp.einsum("bi,io->bo", x, w_mu)
            logits_var = jnp.einsum("bi,io->bo", x**2, jax.nn.softplus(w_rho)**2)    
            if self.bias:
                logits_mu += b_mu
                logits_var += jax.nn.softplus(b_rho)**2
            logits = self.gaussian_sample(logits_mu, logits_var**0.5, key)
        else:
            # Directly sample parameters
            key_1, key_2 = jax.random.split(key)
            w = self.gaussian_sample(w_mu, jax.nn.softplus(w_rho), key_1)
            logits = jnp.einsum("bi,io->bo", x, w)
            if self.bias:
                logits += self.gaussian_sample(b_mu, jax.nn.softplus(b_rho), key_2)
            
        return logits


    def gaussian_sample(
        self, 
        mu, 
        sig, 
        key
    ):
        """
        Return a sample from a fully-factorized Gaussian.

        params:
        - mu (jnp.array): mean of Gaussian.
        - sig (jnp.array): standard deviation of Gaussian.
        - key (jax.random.PRNGKey): JAX random key.

        returns:
        - z (jnp.array): sample from Gaussian.
        """
        eps = jax.random.normal(key, shape=sig.shape)
        z = mu + sig * eps
        
        return z    
    


class Linear(hk.Module):
    """
    Linear layer.
    """
    def __init__(
        self, 
        output_size, 
        bias=True
    ):
        """
        Build a linear layer.

        params:
        - output_size (int): output size of the layer.
        - bias (bool): if True, include bias parameters.
        """
        super().__init__("Linear")
        self.input_size = None
        self.output_size = output_size
        self.bias = bias


    def __call__(self, 
        x, 
        key, 
        is_training, 
        stochastic
    ):
        """
        Forward pass on the linear layer.

        params:
        - x (jnp.array): input array.
        - key (jax.random.PRNGKey): dummy argument.
        - is_training (bool): dummy argument. 
        - stochastic (bool): dummy argument. 

        returns:
        - out (jnp.array): output of the linear layer.
        """
        j, k = x.shape[-1], self.output_size

        # Get weights
        stddev = 1.0 / np.sqrt(j)
        w = hk.get_parameter(
            "w", 
            shape=[j, k], 
            dtype=x.dtype, 
            init=hk.initializers.RandomUniform(minval=-stddev, maxval=stddev)
        )
        if self.bias:
            b = hk.get_parameter(
                "b", 
                shape=[k], 
                dtype=x.dtype, 
                init=hk.initializers.RandomUniform(minval=-stddev, maxval=stddev)
            )

        # Forward
        logits = jnp.einsum("bi,io->bo", x, w)
        if self.bias:
            logits += b
        
        return logits