import haiku as hk
import numpy as np
import jax.numpy as jnp


class LinearStochastic(hk.Module):
    """
    Stochastic linear layer.
    """
    def __init__(
        self,
        output_size,
        idx,
        init_rho_minval, 
        init_rho_maxval,
        bias=True,
    ):
        """
        Build a stochastic linear layer.

        params:
        - output_size (int): output size of the layer.
        - idx (int): index of the layer.
        - init_rho_minval (float): minimum value of range from which to uniformly 
            sample the initial value of pre-activated variational variance parameters.
        - init_rho_maxval (float): maximum value of range from which to uniformly 
            sample the initial value of pre-activated variational variance parameters.
        - bias (bool): if True, include bias parameters.
        """
        super().__init__("LinearStochastic")

        self.idx = idx
        self.input_size = None
        self.bias = bias
        self.output_size = output_size
        self.uniform_init_minval = init_rho_minval
        self.uniform_init_maxval = init_rho_maxval
        
        
    def __call__(
        self, 
        x
    ):
        """
        Forward pass on the stochastic linear layer.

        params:
        - x (jnp.ndarray): input features.
        
        returns:
        - out (jnp.ndarray): output of the stochastic linear layer.
        """
        # Get input size
        j, k = x.shape[-1], self.output_size

        # Define stddev of initailization
        stddev = 1.0 / np.sqrt(j)

        # Get parameters
        w_mu = hk.get_parameter(
            "w_mu", 
            shape=[j, k], 
            dtype=x.dtype, 
            init=hk.initializers.RandomUniform(minval=-stddev, maxval=stddev)
        )
        w_rho = hk.get_parameter(
            "w_rho", 
            shape=[j, k], 
            dtype=x.dtype,
            init=hk.initializers.RandomUniform(
                minval=self.uniform_init_minval, maxval=self.uniform_init_maxval
            )
        )
        if self.bias:
            b_mu = hk.get_parameter(
                "b_mu", 
                shape=[k], 
                dtype=x.dtype, 
                init=hk.initializers.RandomUniform(minval=-stddev, maxval=stddev)
            )
            b_rho = hk.get_parameter(
                "b_rho", 
                shape=[k], 
                dtype=x.dtype,
                init=hk.initializers.RandomUniform(
                    minval=self.uniform_init_minval, maxval=self.uniform_init_maxval
                )
            )

        # Forward pass 
        logits = jnp.einsum("bi,io->bo", x, w_mu)
        if self.bias:
            logits += b_mu
        
        return logits


class Linear(hk.Module):
    """
    Linear layer.
    """
    def __init__(
        self, 
        output_size, 
        idx,
        bias=True
    ):
        """
        Build a linear layer.

        params:
        - output_size (int): output size.
        - idx (int): index of the layer.
        - sine_activation (bool): if True, use sine activation function.
        - bias (bool): if True, include bias parameters.
        """
        super().__init__("Linear")

        self.idx = idx
        self.input_size = None
        self.bias = bias
        self.output_size = output_size
        

    def __call__(
        self, 
        x
    ):
        """
        Forward pass on the linear layer.

        params:
        - x (jnp..ndarray): input.

        returns:
        - out (jnp.ndarray): output of the linear layer.
        """
        # Get input size
        j = x.shape[-1]
        
        # Define stddev of initailization
        stddev = 1.0 / np.sqrt(j)

        # Get parameters
        w = hk.get_parameter(
            "w", 
            shape=[j, self.output_size], 
            dtype=x.dtype, 
            init=hk.initializers.RandomUniform(minval=-stddev, maxval=stddev)
        )
        if self.bias:
            b = hk.get_parameter(
                "b", 
                shape=[self.output_size], 
                dtype=x.dtype, 
                init=hk.initializers.RandomUniform(minval=-stddev, maxval=stddev)
            )

        # Forward
        logits = jnp.einsum("bi,io->bo", x, w)
        if self.bias:
            logits += b
        
        return logits