import jax

import haiku as hk
import numpy as np
import jax.numpy as jnp


class LinearStochastic(hk.Module):
    """
    Stochastic linear layer.
    """
    def __init__(
        self,
        output_size,
        init_rho_minval=-5, 
        init_rho_maxval=-5,
        bias=True,
    ):
        """
        Build a stochastic linear layer.

        params:
        - output_size (int): output size of the layer.
        - idx (int): index of the layer.
        - init_rho_minval (float): minimum value of range from which to uniformly 
            sample the initial value of pre-activated variational variance parameters.
        - init_rho_maxval (float): maximum value of range from which to uniformly 
            sample the initial value of pre-activated variational variance parameters.
        - bias (bool): if True, include bias parameters.
        """
        super().__init__("LinearStochastic")

        self.input_size = None
        self.bias = bias
        self.output_size = output_size
        self.uniform_init_minval = init_rho_minval
        self.uniform_init_maxval = init_rho_maxval
        
        
    def __call__(
        self, 
        x, 
        key
    ):
        """
        Forward pass on the stochastic linear layer.

        params:
        - x (jnp.ndarray): input features.
        - key (jax.random.PRNGKey): JAX random key.
        
        returns:
        - out (jnp.ndarray): output of the stochastic linear layer.
        """
        # Get input size
        j, k = x.shape[-1], self.output_size

        # Define stddev of initailization
        stddev = 1.0 / np.sqrt(j)

        # Get parameters
        w_mu = hk.get_parameter(
            "w_mu", 
            shape=[j, k], 
            dtype=x.dtype, 
            init=hk.initializers.RandomUniform(minval=-stddev, maxval=stddev)
        )
        w_rho = hk.get_parameter(
            "w_rho", 
            shape=[j, k], 
            dtype=x.dtype,
            init=hk.initializers.RandomUniform(
                minval=self.uniform_init_minval, maxval=self.uniform_init_maxval
            )
        )
        if self.bias:
            b_mu = hk.get_parameter(
                "b_mu", 
                shape=[k], 
                dtype=x.dtype, 
                init=hk.initializers.RandomUniform(minval=-stddev, maxval=stddev)
            )
            b_rho = hk.get_parameter(
                "b_rho", 
                shape=[k], 
                dtype=x.dtype,
                init=hk.initializers.RandomUniform(
                    minval=self.uniform_init_minval, maxval=self.uniform_init_maxval
                )
            )

        # Forward pass 
        key_1, key_2 = jax.random.split(key)
        w = self.gaussian_sample(w_mu, jax.nn.softplus(w_rho), key_1)
        logits = jnp.einsum("bi,io->bo", x, w)
        if self.bias:
            logits += self.gaussian_sample(b_mu, jax.nn.softplus(b_rho), key_2)
        
        return logits
    

    def gaussian_sample(
        self, 
        mu, 
        sig, 
        key
    ):
        """
        Return a sample from a fully-factorized Gaussian.

        params:
        - mu (jnp.array): mean of Gaussian.
        - sig (jnp.array): standard deviation of Gaussian.
        - key (jax.random.PRNGKey): JAX random key.

        returns:
        - z (jnp.array): sample from Gaussian.
        """
        eps = jax.random.normal(key, shape=sig.shape)
        z = mu + sig * eps
        
        return z 


class Linear(hk.Module):
    """
    Linear layer.
    """
    def __init__(
        self, 
        output_size, 
        bias=True
    ):
        """
        Build a linear layer.

        params:
        - output_size (int): output size.
        - bias (bool): if True, include bias parameters.
        """
        super().__init__("Linear")

        self.input_size = None
        self.bias = bias
        self.output_size = output_size
        

    def __call__(
        self, 
        x, 
        key
    ):
        """
        Forward pass on the linear layer.

        params:
        - x (jnp..ndarray): input.
        - key (jax.random.PRNGKey): dummy argument for compatibility.

        returns:
        - out (jnp.ndarray): output of the linear layer.
        """
        # Get input size
        j = x.shape[-1]
        
        # Define stddev of initailization
        stddev = 1.0 / np.sqrt(j)

        # Get parameters
        w = hk.get_parameter(
            "w", 
            shape=[j, self.output_size], 
            dtype=x.dtype, 
            init=hk.initializers.RandomUniform(minval=-stddev, maxval=stddev)
        )
        if self.bias:
            b = hk.get_parameter(
                "b", 
                shape=[self.output_size], 
                dtype=x.dtype, 
                init=hk.initializers.RandomUniform(minval=-stddev, maxval=stddev)
            )

        # Forward
        logits = jnp.einsum("bi,io->bo", x, w)
        if self.bias:
            logits += b
        
        return logits