"""GaussianMLPModule."""

import abc

import torch
from garage.torch.distributions import TanhNormal
from garage.torch.modules.mlp_module import MLPModule
from garage.torch.modules.multi_headed_mlp_module import MultiHeadedMLPModule
from torch import nn
from torch.distributions import Normal
from torch.distributions.independent import Independent

from src.GumbelSoftmax import GumbelSoftmax


class GaussianMLPBaseModuleDoCausal(nn.Module):
    """Base of GaussianMLPModel.

    Args:
        input_dim (int): Input dimension of the model.
        output_dim (int): Output dimension of the model.
        hidden_sizes (list[int]): Output dimension of dense layer(s) for
            the MLP for mean. For example, (32, 32) means the MLP consists
            of two hidden layers, each with 32 hidden units.
        hidden_nonlinearity (callable): Activation function for intermediate
            dense layer(s). It should return a torch.Tensor. Set it to
            None to maintain a linear activation.
        hidden_w_init (callable): Initializer function for the weight
            of intermediate dense layer(s). The function should return a
            torch.Tensor.
        hidden_b_init (callable): Initializer function for the bias
            of intermediate dense layer(s). The function should return a
            torch.Tensor.
        output_nonlinearity (callable): Activation function for output dense
            layer. It should return a torch.Tensor. Set it to None to
            maintain a linear activation.
        output_w_init (callable): Initializer function for the weight
            of output dense layer(s). The function should return a
            torch.Tensor.
        output_b_init (callable): Initializer function for the bias
            of output dense layer(s). The function should return a
            torch.Tensor.
        learn_std (bool): Is std trainable.
        init_std (float): Initial value for std.
            (plain value - not log or exponentiated).
        std_hidden_sizes (list[int]): Output dimension of dense layer(s) for
            the MLP for std. For example, (32, 32) means the MLP consists
            of two hidden layers, each with 32 hidden units.
        min_std (float): If not None, the std is at least the value of min_std,
            to avoid numerical issues (plain value - not log or exponentiated).
        max_std (float): If not None, the std is at most the value of max_std,
            to avoid numerical issues (plain value - not log or exponentiated).
        std_hidden_nonlinearity (callable): Nonlinearity for each hidden layer
            in the std network.
        std_hidden_w_init (callable):  Initializer function for the weight
            of hidden layer (s).
        std_hidden_b_init (callable): Initializer function for the bias
            of intermediate dense layer(s).
        std_output_nonlinearity (callable): Activation function for output
            dense layer in the std network. It should return a torch.Tensor.
            Set it to None to maintain a linear activation.
        std_output_w_init (callable): Initializer function for the weight
            of output dense layer(s) in the std network.
        std_parameterization (str): How the std should be parametrized. There
            are two options:
            - exp: the logarithm of the std will be stored, and applied a
               exponential transformation.
            - softplus: the std will be computed as log(1+exp(x)).
        layer_normalization (bool): Bool for using layer normalization or not.
        normal_distribution_cls (torch.distribution): normal distribution class
            to be constructed and returned by a call to forward. By default, is
            `torch.distributions.Normal`.

    """

    def __init__(
        self,
        input_dim,
        output_dim,
        hidden_sizes=(32, 32),
        hidden_nonlinearity=torch.tanh,
        hidden_w_init=nn.init.xavier_uniform_,
        hidden_b_init=nn.init.zeros_,
        output_nonlinearity=None,
        output_w_init=nn.init.xavier_uniform_,
        output_b_init=nn.init.zeros_,
        learn_std=True,
        init_std=1.0,
        min_std=1e-6,
        max_std=None,
        std_hidden_sizes=(32, 32),
        std_hidden_nonlinearity=torch.tanh,
        std_hidden_w_init=nn.init.xavier_uniform_,
        std_hidden_b_init=nn.init.zeros_,
        std_output_nonlinearity=None,
        std_output_w_init=nn.init.xavier_uniform_,
        std_parameterization="exp",
        layer_normalization=False,
        normal_distribution_cls=Normal,
        no_value=False,
    ):
        super().__init__()

        self._input_dim = input_dim
        self._hidden_sizes = hidden_sizes
        self._action_dim = output_dim
        self._learn_std = learn_std
        self._std_hidden_sizes = std_hidden_sizes
        self._min_std = min_std
        self._max_std = max_std
        self._std_hidden_nonlinearity = std_hidden_nonlinearity
        self._std_hidden_w_init = std_hidden_w_init
        self._std_hidden_b_init = std_hidden_b_init
        self._std_output_nonlinearity = std_output_nonlinearity
        self._std_output_w_init = std_output_w_init
        self._std_parameterization = std_parameterization
        self._hidden_nonlinearity = hidden_nonlinearity
        self._hidden_w_init = hidden_w_init
        self._hidden_b_init = hidden_b_init
        self._output_nonlinearity = output_nonlinearity
        self._output_w_init = output_w_init
        self._output_b_init = output_b_init
        self._layer_normalization = layer_normalization
        self._norm_dist_class = normal_distribution_cls
        self.no_value = no_value
        if self._std_parameterization not in ("exp", "softplus"):
            raise NotImplementedError

        init_std_param = torch.Tensor([init_std]).log()
        if self._learn_std:
            self._init_std = torch.nn.Parameter(init_std_param)
        else:
            self._init_std = init_std_param
            self.register_buffer("init_std", self._init_std)

        self._min_std_param = self._max_std_param = None
        if min_std is not None:
            self._min_std_param = torch.Tensor([min_std]).log()
            self.register_buffer("min_std_param", self._min_std_param)
        if max_std is not None:
            self._max_std_param = torch.Tensor([max_std]).log()
            self.register_buffer("max_std_param", self._max_std_param)

    def to(self, *args, **kwargs):
        """Move the module to the specified device.

        Args:
            *args: args to pytorch to function.
            **kwargs: keyword args to pytorch to function.

        """
        super().to(*args, **kwargs)
        buffers = dict(self.named_buffers())
        if not isinstance(self._init_std, torch.nn.Parameter):
            self._init_std = buffers["init_std"]
        self._min_std_param = buffers["min_std_param"]
        self._max_std_param = buffers["max_std_param"]

    @abc.abstractmethod
    def _get_mean_and_log_std(self, *inputs):
        pass

    def forward(self, *inputs):
        """Forward method.

        Args:
            *inputs: Input to the module.

        Returns:
            torch.distributions.independent.Independent: Independent
                distribution.

        """
        mean, log_std_uncentered = self._get_mean_and_log_std(*inputs)
        *batch_dims, _, _ = mean.shape
        try:
            mult_size = 2 if not self.no_value else 1
        except AttributeError:
            mult_size = 2
        mean, log_std_uncentered = mean.reshape(
            *batch_dims, -1, self.batch_size, mult_size
        ).transpose(-2, -3), log_std_uncentered.reshape(
            *batch_dims, -1, self.batch_size, mult_size
        ).transpose(
            -2, -3
        )
        mean, log_std_uncentered = (
            mean.transpose(-1, -2).reshape(*batch_dims, self.batch_size, -1),
            log_std_uncentered.transpose(-1, -2).reshape(
                *batch_dims, self.batch_size, -1
            ),
        )
        if self._min_std_param or self._max_std_param:
            log_std_uncentered = log_std_uncentered.clamp(
                min=(
                    None if self._min_std_param is None else self._min_std_param.item()
                ),
                max=(
                    None if self._max_std_param is None else self._max_std_param.item()
                ),
            )

        if self._std_parameterization == "exp":
            std = log_std_uncentered.exp()
        else:
            std = log_std_uncentered.exp().exp().add(1.0).log()
        dist = self._norm_dist_class(mean, std)
        # This control flow is needed because if a TanhNormal distribution is
        # wrapped by torch.distributions.Independent, then custom functions
        # such as rsample_with_pretanh_value of the TanhNormal distribution
        # are not accessable.
        if not isinstance(dist, TanhNormal):
            # Makes it so that a sample from the distribution is treated as a
            # single sample and not dist.batch_shape samples.
            dist = Independent(dist, 1)

        return dist


class GaussianMLPTwoHeadedModuleDoCausal(GaussianMLPBaseModuleDoCausal):
    """GaussianMLPModule which has only one mean network.

    Args:
        input_dim (int): Input dimension of the model.
        output_dim (int): Output dimension of the model.
        hidden_sizes (list[int]): Output dimension of dense layer(s) for
            the MLP for mean. For example, (32, 32) means the MLP consists
            of two hidden layers, each with 32 hidden units.
        hidden_nonlinearity (callable): Activation function for intermediate
            dense layer(s). It should return a torch.Tensor. Set it to
            None to maintain a linear activation.
        hidden_w_init (callable): Initializer function for the weight
            of intermediate dense layer(s). The function should return a
            torch.Tensor.
        hidden_b_init (callable): Initializer function for the bias
            of intermediate dense layer(s). The function should return a
            torch.Tensor.
        output_nonlinearity (callable): Activation function for output dense
            layer. It should return a torch.Tensor. Set it to None to
            maintain a linear activation.
        output_w_init (callable): Initializer function for the weight
            of output dense layer(s). The function should return a
            torch.Tensor.
        output_b_init (callable): Initializer function for the bias
            of output dense layer(s). The function should return a
            torch.Tensor.
        learn_std (bool): Is std trainable.
        init_std (float): Initial value for std.
            (plain value - not log or exponentiated).
        min_std (float): If not None, the std is at least the value of min_std,
            to avoid numerical issues (plain value - not log or exponentiated).
        max_std (float): If not None, the std is at most the value of max_std,
            to avoid numerical issues (plain value - not log or exponentiated).
        std_parameterization (str): How the std should be parametrized. There
            are two options:
            - exp: the logarithm of the std will be stored, and applied a
               exponential transformation
            - softplus: the std will be computed as log(1+exp(x))
        layer_normalization (bool): Bool for using layer normalization or not.
        normal_distribution_cls (torch.distribution): normal distribution class
            to be constructed and returned by a call to forward. By default, is
            `torch.distributions.Normal`.

    """

    def __init__(
        self,
        input_dim,
        output_dim,
        hidden_sizes=(32, 32),
        hidden_nonlinearity=torch.tanh,
        hidden_w_init=nn.init.xavier_uniform_,
        hidden_b_init=nn.init.zeros_,
        output_nonlinearity=None,
        output_w_init=nn.init.xavier_uniform_,
        output_b_init=nn.init.zeros_,
        learn_std=True,
        init_mean=0.0,
        init_std=1.0,
        min_std=1e-6,
        max_std=None,
        std_parameterization="exp",
        layer_normalization=False,
        normal_distribution_cls=Normal,
        batch_size=1,
        no_value=False,
    ):
        super(GaussianMLPTwoHeadedModuleDoCausal, self).__init__(
            input_dim=input_dim,
            output_dim=output_dim,
            hidden_sizes=hidden_sizes,
            hidden_nonlinearity=hidden_nonlinearity,
            hidden_w_init=hidden_w_init,
            hidden_b_init=hidden_b_init,
            output_nonlinearity=output_nonlinearity,
            output_w_init=output_w_init,
            output_b_init=output_b_init,
            learn_std=learn_std,
            init_std=init_std,
            min_std=min_std,
            max_std=max_std,
            std_parameterization=std_parameterization,
            layer_normalization=layer_normalization,
            normal_distribution_cls=normal_distribution_cls,
            no_value=no_value,
        )
        self.batch_size = batch_size
        self._shared_mean_log_std_network = nn.DataParallel(
            MultiHeadedMLPModule(
                n_heads=2,
                input_dim=self._input_dim,
                output_dims=self._action_dim,
                hidden_sizes=self._hidden_sizes,
                hidden_nonlinearity=self._hidden_nonlinearity,
                hidden_w_init=self._hidden_w_init,
                hidden_b_init=self._hidden_b_init,
                output_nonlinearities=self._output_nonlinearity,
                output_w_inits=self._output_w_init,
                output_b_inits=[
                    # nn.init.zeros_,
                    lambda x: nn.init.constant_(x, init_mean),
                    lambda x: nn.init.constant_(x, self._init_std.item()),
                ],
                layer_normalization=self._layer_normalization,
            )
        )

    def _get_mean_and_log_std(self, *inputs):
        """Get mean and std of Gaussian distribution given inputs.

        Args:
            *inputs: Input to the module.

        Returns:
            torch.Tensor: The mean of Gaussian distribution.
            torch.Tensor: The variance of Gaussian distribution.

        """
        return self._shared_mean_log_std_network(*inputs)


class MLPSingleTargetBaseModuleDoCausal(nn.Module):
    """Base of GaussianMLPModel.

    Args:
        input_dim (int): Input dimension of the model.
        output_dim (int): Output dimension of the model.
        hidden_sizes (list[int]): Output dimension of dense layer(s) for
            the MLP for mean. For example, (32, 32) means the MLP consists
            of two hidden layers, each with 32 hidden units.
        hidden_nonlinearity (callable): Activation function for intermediate
            dense layer(s). It should return a torch.Tensor. Set it to
            None to maintain a linear activation.
        hidden_w_init (callable): Initializer function for the weight
            of intermediate dense layer(s). The function should return a
            torch.Tensor.
        hidden_b_init (callable): Initializer function for the bias
            of intermediate dense layer(s). The function should return a
            torch.Tensor.
        output_nonlinearity (callable): Activation function for output dense
            layer. It should return a torch.Tensor. Set it to None to
            maintain a linear activation.
        output_w_init (callable): Initializer function for the weight
            of output dense layer(s). The function should return a
            torch.Tensor.
        output_b_init (callable): Initializer function for the bias
            of output dense layer(s). The function should return a
            torch.Tensor.
        learn_std (bool): Is std trainable.
        init_std (float): Initial value for std.
            (plain value - not log or exponentiated).
        std_hidden_sizes (list[int]): Output dimension of dense layer(s) for
            the MLP for std. For example, (32, 32) means the MLP consists
            of two hidden layers, each with 32 hidden units.
        min_std (float): If not None, the std is at least the value of min_std,
            to avoid numerical issues (plain value - not log or exponentiated).
        max_std (float): If not None, the std is at most the value of max_std,
            to avoid numerical issues (plain value - not log or exponentiated).
        std_hidden_nonlinearity (callable): Nonlinearity for each hidden layer
            in the std network.
        std_hidden_w_init (callable):  Initializer function for the weight
            of hidden layer (s).
        std_hidden_b_init (callable): Initializer function for the bias
            of intermediate dense layer(s).
        std_output_nonlinearity (callable): Activation function for output
            dense layer in the std network. It should return a torch.Tensor.
            Set it to None to maintain a linear activation.
        std_output_w_init (callable): Initializer function for the weight
            of output dense layer(s) in the std network.
        std_parameterization (str): How the std should be parametrized. There
            are two options:
            - exp: the logarithm of the std will be stored, and applied a
               exponential transformation.
            - softplus: the std will be computed as log(1+exp(x)).
        layer_normalization (bool): Bool for using layer normalization or not.
        normal_distribution_cls (torch.distribution): normal distribution class
            to be constructed and returned by a call to forward. By default, is
            `torch.distributions.Normal`.

    """

    def __init__(
        self,
        input_dim,
        output_dim,
        hidden_sizes=(32, 32),
        hidden_nonlinearity=torch.tanh,
        hidden_w_init=nn.init.xavier_uniform_,
        hidden_b_init=nn.init.zeros_,
        output_nonlinearity=None,
        output_w_init=nn.init.xavier_uniform_,
        output_b_init=nn.init.zeros_,
        layer_normalization=False,
        init_temp=1.0,
        min_temp=1e-6,
        max_temp=None,
        learn_temp=True,
        temp_parameterization="exp",
    ):
        super().__init__()

        self._input_dim = input_dim
        self._hidden_sizes = hidden_sizes
        self._action_dim = output_dim
        self._hidden_nonlinearity = hidden_nonlinearity
        self._hidden_w_init = hidden_w_init
        self._hidden_b_init = hidden_b_init
        self._output_nonlinearity = output_nonlinearity
        self._output_w_init = output_w_init
        self._output_b_init = output_b_init
        self._layer_normalization = layer_normalization
        self._init_temp = init_temp
        self._min_temp = min_temp
        self._max_temp = max_temp
        self._learn_temp = learn_temp
        self._temp_parameterization = temp_parameterization

        if self._temp_parameterization not in ("exp", "softplus"):
            raise NotImplementedError

        init_temp_param = torch.Tensor([init_temp]).log()
        if self._learn_temp:
            self._init_temp = torch.nn.Parameter(init_temp_param)
        else:
            self._init_temp = init_temp_param
            self.register_buffer("init_temp", self._init_temp)

        self._min_temp_param = self._max_temp_param = None
        if min_temp is not None:
            self._min_temp_param = torch.Tensor([min_temp]).log()
            self.register_buffer("min_temp_param", self._min_temp_param)
        if max_temp is not None:
            self._max_temp_param = torch.Tensor([max_temp]).log()
            self.register_buffer("max_temp_param", self._max_temp_param)

    def to(self, *args, **kwargs):
        """Move the module to the specified device.

        Args:
            *args: args to pytorch to function.
            **kwargs: keyword args to pytorch to function.

        """
        super().to(*args, **kwargs)
        buffers = dict(self.named_buffers())
        if not isinstance(self._init_temp, torch.nn.Parameter):
            self._init_temp = buffers["init_temp"]
        self._min_temp_param = buffers["min_temp_param"]
        self._max_temp_param = buffers["max_temp_param"]

    @abc.abstractmethod
    def _get_mean_and_log_std(self, *inputs):
        pass

    def forward(self, *inputs):
        """Forward method.

        Args:
            *inputs: Input to the module.

        Returns:
            torch.distributions.independent.Independent: Independent
                distribution.

        """
        logits, log_temp, logits_value, log_temp_value, logits_obs, log_temp_obs = (
            self._get_mean_and_log_std(*inputs)
        )

        probs_value = torch.sigmoid(torch.sum(logits_value, dim=-2, keepdim=True))
        probs_obs = torch.sigmoid(torch.sum(logits_obs, dim=-2, keepdim=True))
        probs_value = torch.cat([probs_value, 1 - probs_value], dim=-2)
        probs_obs = torch.cat([probs_obs, 1 - probs_obs], dim=-2)
        log_temp = torch.min(log_temp, dim=-2, keepdim=True)[0]
        log_temp_value = torch.min(log_temp_value, dim=-2, keepdim=True)[0]
        log_temp_obs = torch.min(log_temp_obs, dim=-2, keepdim=True)[0]

        if self._min_temp_param or self._max_temp_param:
            log_temp = log_temp.clamp(
                min=(
                    None
                    if self._min_temp_param is None
                    else self._min_temp_param.item()
                ),
                max=(
                    None
                    if self._max_temp_param is None
                    else self._max_temp_param.item()
                ),
            )
            log_temp_value = log_temp_value.clamp(
                min=(
                    None
                    if self._min_temp_param is None
                    else self._min_temp_param.item()
                ),
                max=(
                    None
                    if self._max_temp_param is None
                    else self._max_temp_param.item()
                ),
            )
            log_temp_obs = log_temp_obs.clamp(
                min=(
                    None
                    if self._min_temp_param is None
                    else self._min_temp_param.item()
                ),
                max=(
                    None
                    if self._max_temp_param is None
                    else self._max_temp_param.item()
                ),
            )

        if self._learn_temp:
            if self._temp_parameterization == "exp":
                temp = log_temp.exp()
                temp_value = log_temp_value.exp()
                temp_obs = log_temp_obs.exp()
            else:
                temp = log_temp.exp().exp().add(1.0).log()
                temp_value = log_temp_value.exp().exp().add(1.0).log()
                temp_obs = log_temp_obs.exp().exp().add(1.0).log()
        else:
            temp = self._init_temp.exp() * torch.ones_like(log_temp)
            temp_value = self._init_temp.exp() * torch.ones_like(log_temp_value)
            temp_obs = self._init_temp.exp() * torch.ones_like(log_temp_obs)
        dist_target = GumbelSoftmax(
            temperature=temp.squeeze(-1), logits=logits.squeeze()
        )
        dist_value = GumbelSoftmax(
            temperature=temp_value.squeeze(-1), probs=probs_value.squeeze()
        )
        dist_obs = GumbelSoftmax(
            temperature=temp_obs.squeeze(-1), probs=probs_obs.squeeze()
        )

        return dist_target, dist_value, dist_obs


class MLPSingleTargetNoValueBaseModuleDoCausal(nn.Module):
    """Base of GaussianMLPModel.

    Args:
        input_dim (int): Input dimension of the model.
        output_dim (int): Output dimension of the model.
        hidden_sizes (list[int]): Output dimension of dense layer(s) for
            the MLP for mean. For example, (32, 32) means the MLP consists
            of two hidden layers, each with 32 hidden units.
        hidden_nonlinearity (callable): Activation function for intermediate
            dense layer(s). It should return a torch.Tensor. Set it to
            None to maintain a linear activation.
        hidden_w_init (callable): Initializer function for the weight
            of intermediate dense layer(s). The function should return a
            torch.Tensor.
        hidden_b_init (callable): Initializer function for the bias
            of intermediate dense layer(s). The function should return a
            torch.Tensor.
        output_nonlinearity (callable): Activation function for output dense
            layer. It should return a torch.Tensor. Set it to None to
            maintain a linear activation.
        output_w_init (callable): Initializer function for the weight
            of output dense layer(s). The function should return a
            torch.Tensor.
        output_b_init (callable): Initializer function for the bias
            of output dense layer(s). The function should return a
            torch.Tensor.
        learn_std (bool): Is std trainable.
        init_std (float): Initial value for std.
            (plain value - not log or exponentiated).
        std_hidden_sizes (list[int]): Output dimension of dense layer(s) for
            the MLP for std. For example, (32, 32) means the MLP consists
            of two hidden layers, each with 32 hidden units.
        min_std (float): If not None, the std is at least the value of min_std,
            to avoid numerical issues (plain value - not log or exponentiated).
        max_std (float): If not None, the std is at most the value of max_std,
            to avoid numerical issues (plain value - not log or exponentiated).
        std_hidden_nonlinearity (callable): Nonlinearity for each hidden layer
            in the std network.
        std_hidden_w_init (callable):  Initializer function for the weight
            of hidden layer (s).
        std_hidden_b_init (callable): Initializer function for the bias
            of intermediate dense layer(s).
        std_output_nonlinearity (callable): Activation function for output
            dense layer in the std network. It should return a torch.Tensor.
            Set it to None to maintain a linear activation.
        std_output_w_init (callable): Initializer function for the weight
            of output dense layer(s) in the std network.
        std_parameterization (str): How the std should be parametrized. There
            are two options:
            - exp: the logarithm of the std will be stored, and applied a
               exponential transformation.
            - softplus: the std will be computed as log(1+exp(x)).
        layer_normalization (bool): Bool for using layer normalization or not.
        normal_distribution_cls (torch.distribution): normal distribution class
            to be constructed and returned by a call to forward. By default, is
            `torch.distributions.Normal`.

    """

    def __init__(
        self,
        input_dim,
        output_dim,
        hidden_sizes=(32, 32),
        hidden_nonlinearity=torch.tanh,
        hidden_w_init=nn.init.xavier_uniform_,
        hidden_b_init=nn.init.zeros_,
        output_nonlinearity=None,
        output_w_init=nn.init.xavier_uniform_,
        output_b_init=nn.init.zeros_,
        learn_std=True,
        init_std=1.0,
        min_std=1e-6,
        max_std=None,
        std_hidden_sizes=(32, 32),
        std_hidden_nonlinearity=torch.tanh,
        std_hidden_w_init=nn.init.xavier_uniform_,
        std_hidden_b_init=nn.init.zeros_,
        std_output_nonlinearity=None,
        std_output_w_init=nn.init.xavier_uniform_,
        std_parameterization="exp",
        layer_normalization=False,
        normal_distribution_cls=Normal,
        init_temp=1.0,
        min_temp=1e-6,
        max_temp=None,
        learn_temp=True,
        temp_parameterization="exp",
    ):
        super().__init__()

        self._input_dim = input_dim
        self._hidden_sizes = hidden_sizes
        self._action_dim = output_dim
        self._learn_std = learn_std
        self._std_hidden_sizes = std_hidden_sizes
        self._min_std = min_std
        self._max_std = max_std
        self._std_hidden_nonlinearity = std_hidden_nonlinearity
        self._std_hidden_w_init = std_hidden_w_init
        self._std_hidden_b_init = std_hidden_b_init
        self._std_output_nonlinearity = std_output_nonlinearity
        self._std_output_w_init = std_output_w_init
        self._std_parameterization = std_parameterization
        self._hidden_nonlinearity = hidden_nonlinearity
        self._hidden_w_init = hidden_w_init
        self._hidden_b_init = hidden_b_init
        self._output_nonlinearity = output_nonlinearity
        self._output_w_init = output_w_init
        self._output_b_init = output_b_init
        self._layer_normalization = layer_normalization
        self._norm_dist_class = normal_distribution_cls
        self._init_temp = init_temp
        self._min_temp = min_temp
        self._max_temp = max_temp
        self._learn_temp = learn_temp
        self._temp_parameterization = temp_parameterization

        if self._std_parameterization not in ("exp", "softplus"):
            raise NotImplementedError
        if self._temp_parameterization not in ("exp", "softplus"):
            raise NotImplementedError

        init_temp_param = torch.Tensor([init_temp]).log()
        if self._learn_temp:
            self._init_temp = torch.nn.Parameter(init_temp_param)
        else:
            self._init_temp = init_temp_param
            self.register_buffer("init_temp", self._init_temp)

        self._min_temp_param = self._max_temp_param = None
        if min_temp is not None:
            self._min_temp_param = torch.Tensor([min_temp]).log()
            self.register_buffer("min_temp_param", self._min_temp_param)
        if max_temp is not None:
            self._max_temp_param = torch.Tensor([max_temp]).log()
            self.register_buffer("max_temp_param", self._max_temp_param)

    def to(self, *args, **kwargs):
        """Move the module to the specified device.

        Args:
            *args: args to pytorch to function.
            **kwargs: keyword args to pytorch to function.

        """
        super().to(*args, **kwargs)
        buffers = dict(self.named_buffers())
        if not isinstance(self._init_temp, torch.nn.Parameter):
            self._init_temp = buffers["init_temp"]
        self._min_temp_param = buffers["min_temp_param"]
        self._max_temp_param = buffers["max_temp_param"]

    @abc.abstractmethod
    def _get_mean_and_log_std(self, *inputs):
        pass

    def forward(self, *inputs):
        """Forward method.

        Args:
            *inputs: Input to the module.

        Returns:
            torch.distributions.independent.Independent: Independent
                distribution.

        """
        logits, log_temp = self._get_mean_and_log_std(*inputs)
        log_temp = torch.min(log_temp, dim=-2, keepdim=True)[0]
        logits, log_temp = logits.transpose(-1, -2), log_temp.transpose(-1, -2)
        if self._min_temp_param or self._max_temp_param:
            log_temp = log_temp.clamp(
                min=(
                    None
                    if self._min_temp_param is None
                    else self._min_temp_param.item()
                ),
                max=(
                    None
                    if self._max_temp_param is None
                    else self._max_temp_param.item()
                ),
            )

        if self._learn_temp:
            if self._temp_parameterization == "exp":
                temp = log_temp.exp()
            else:
                temp = log_temp.exp().exp().add(1.0).log()
        else:
            temp = self._init_temp.exp() * torch.ones_like(log_temp)

        dist_target = GumbelSoftmax(temperature=temp.squeeze(-2), logits=logits)
        return dist_target, torch.ones_like(dist_target.logits)


class SingleTargetSixHeadedModule(MLPSingleTargetBaseModuleDoCausal):
    """Base of GaussianMLPModel.

    Args:
        input_dim (int): Input dimension of the model.
        output_dim (int): Output dimension of the model.
        hidden_sizes (list[int]): Output dimension of dense layer(s) for
            the MLP for mean. For example, (32, 32) means the MLP consists
            of two hidden layers, each with 32 hidden units.
        hidden_nonlinearity (callable): Activation function for intermediate
            dense layer(s). It should return a torch.Tensor. Set it to
            None to maintain a linear activation.
        hidden_w_init (callable): Initializer function for the weight
            of intermediate dense layer(s). The function should return a
            torch.Tensor.
        hidden_b_init (callable): Initializer function for the bias
            of intermediate dense layer(s). The function should return a
            torch.Tensor.
        output_nonlinearity (callable): Activation function for output dense
            layer. It should return a torch.Tensor. Set it to None to
            maintain a linear activation.
        output_w_init (callable): Initializer function for the weight
            of output dense layer(s). The function should return a
            torch.Tensor.
        output_b_init (callable): Initializer function for the bias
            of output dense layer(s). The function should return a
            torch.Tensor.
        learn_std (bool): Is std trainable.
        init_std (float): Initial value for std.
            (plain value - not log or exponentiated).
        std_hidden_sizes (list[int]): Output dimension of dense layer(s) for
            the MLP for std. For example, (32, 32) means the MLP consists
            of two hidden layers, each with 32 hidden units.
        min_std (float): If not None, the std is at least the value of min_std,
            to avoid numerical issues (plain value - not log or exponentiated).
        max_std (float): If not None, the std is at most the value of max_std,
            to avoid numerical issues (plain value - not log or exponentiated).
        std_hidden_nonlinearity (callable): Nonlinearity for each hidden layer
            in the std network.
        std_hidden_w_init (callable):  Initializer function for the weight
            of hidden layer (s).
        std_hidden_b_init (callable): Initializer function for the bias
            of intermediate dense layer(s).
        std_output_nonlinearity (callable): Activation function for output
            dense layer in the std network. It should return a torch.Tensor.
            Set it to None to maintain a linear activation.
        std_output_w_init (callable): Initializer function for the weight
            of output dense layer(s) in the std network.
        std_parameterization (str): How the std should be parametrized. There
            are two options:
            - exp: the logarithm of the std will be stored, and applied a
               exponential transformation.
            - softplus: the std will be computed as log(1+exp(x)).
        layer_normalization (bool): Bool for using layer normalization or not.
        normal_distribution_cls (torch.distribution): normal distribution class
            to be constructed and returned by a call to forward. By default, is
            `torch.distributions.Normal`.

    """

    def __init__(
        self,
        input_dim,
        output_dim,
        hidden_sizes=(32, 32),
        hidden_nonlinearity=torch.nn.ReLU,
        hidden_w_init=nn.init.xavier_uniform_,
        hidden_b_init=nn.init.zeros_,
        output_nonlinearity=None,
        output_w_init=nn.init.xavier_uniform_,
        output_b_init=nn.init.zeros_,
        layer_normalization=False,
        init_temp=1.0,
        min_temp=1e-6,
        max_temp=None,
        learn_temp=True,
        temp_parameterization="exp",
        batch_size=1,
    ):
        super().__init__(
            input_dim=input_dim,
            output_dim=output_dim,
            hidden_sizes=hidden_sizes,
            hidden_nonlinearity=hidden_nonlinearity,
            hidden_w_init=hidden_w_init,
            hidden_b_init=hidden_b_init,
            output_nonlinearity=output_nonlinearity,
            output_w_init=output_w_init,
            output_b_init=output_b_init,
            layer_normalization=layer_normalization,
            init_temp=init_temp,
            min_temp=min_temp,
            max_temp=max_temp,
            learn_temp=learn_temp,
            temp_parameterization=temp_parameterization,
        )
        self.batch_size = batch_size
        self._shared_mean_log_std_network = nn.DataParallel(
            MultiHeadedMLPModule(
                n_heads=6,
                input_dim=self._input_dim,
                output_dims=[
                    self._action_dim,
                    1,
                    self._action_dim,
                    1,
                    self._action_dim,
                    1,
                ],
                hidden_sizes=self._hidden_sizes,
                hidden_nonlinearity=self._hidden_nonlinearity,
                hidden_w_init=self._hidden_w_init,
                hidden_b_init=self._hidden_b_init,
                output_nonlinearities=self._output_nonlinearity,
                output_w_inits=self._output_w_init,
                output_b_inits=[
                    # nn.init.zeros_,
                    nn.init.zeros_,
                    lambda x: nn.init.constant_(x, self._init_temp.item()),
                    nn.init.zeros_,
                    lambda x: nn.init.constant_(x, self._init_temp.item()),
                    nn.init.zeros_,
                    lambda x: nn.init.constant_(x, self._init_temp.item()),
                ],
                layer_normalization=self._layer_normalization,
            )
        )

    def _get_mean_and_log_std(self, *inputs):
        return self._shared_mean_log_std_network(*inputs)


class SingleTargetTwoHeadedModule(MLPSingleTargetNoValueBaseModuleDoCausal):
    """Base of GaussianMLPModel.

    Args:
        input_dim (int): Input dimension of the model.
        output_dim (int): Output dimension of the model.
        hidden_sizes (list[int]): Output dimension of dense layer(s) for
            the MLP for mean. For example, (32, 32) means the MLP consists
            of two hidden layers, each with 32 hidden units.
        hidden_nonlinearity (callable): Activation function for intermediate
            dense layer(s). It should return a torch.Tensor. Set it to
            None to maintain a linear activation.
        hidden_w_init (callable): Initializer function for the weight
            of intermediate dense layer(s). The function should return a
            torch.Tensor.
        hidden_b_init (callable): Initializer function for the bias
            of intermediate dense layer(s). The function should return a
            torch.Tensor.
        output_nonlinearity (callable): Activation function for output dense
            layer. It should return a torch.Tensor. Set it to None to
            maintain a linear activation.
        output_w_init (callable): Initializer function for the weight
            of output dense layer(s). The function should return a
            torch.Tensor.
        output_b_init (callable): Initializer function for the bias
            of output dense layer(s). The function should return a
            torch.Tensor.
        learn_std (bool): Is std trainable.
        init_std (float): Initial value for std.
            (plain value - not log or exponentiated).
        std_hidden_sizes (list[int]): Output dimension of dense layer(s) for
            the MLP for std. For example, (32, 32) means the MLP consists
            of two hidden layers, each with 32 hidden units.
        min_std (float): If not None, the std is at least the value of min_std,
            to avoid numerical issues (plain value - not log or exponentiated).
        max_std (float): If not None, the std is at most the value of max_std,
            to avoid numerical issues (plain value - not log or exponentiated).
        std_hidden_nonlinearity (callable): Nonlinearity for each hidden layer
            in the std network.
        std_hidden_w_init (callable):  Initializer function for the weight
            of hidden layer (s).
        std_hidden_b_init (callable): Initializer function for the bias
            of intermediate dense layer(s).
        std_output_nonlinearity (callable): Activation function for output
            dense layer in the std network. It should return a torch.Tensor.
            Set it to None to maintain a linear activation.
        std_output_w_init (callable): Initializer function for the weight
            of output dense layer(s) in the std network.
        std_parameterization (str): How the std should be parametrized. There
            are two options:
            - exp: the logarithm of the std will be stored, and applied a
               exponential transformation.
            - softplus: the std will be computed as log(1+exp(x)).
        layer_normalization (bool): Bool for using layer normalization or not.
        normal_distribution_cls (torch.distribution): normal distribution class
            to be constructed and returned by a call to forward. By default, is
            `torch.distributions.Normal`.

    """

    def __init__(
        self,
        input_dim,
        output_dim,
        hidden_sizes=(32, 32),
        hidden_nonlinearity=torch.tanh,
        hidden_w_init=nn.init.xavier_uniform_,
        hidden_b_init=nn.init.zeros_,
        output_nonlinearity=None,
        output_w_init=nn.init.xavier_uniform_,
        output_b_init=nn.init.zeros_,
        learn_std=True,
        init_std=1.0,
        min_std=1e-6,
        max_std=None,
        std_hidden_sizes=(32, 32),
        std_hidden_nonlinearity=torch.tanh,
        std_hidden_w_init=nn.init.xavier_uniform_,
        std_hidden_b_init=nn.init.zeros_,
        std_output_nonlinearity=None,
        std_output_w_init=nn.init.xavier_uniform_,
        std_parameterization="exp",
        layer_normalization=False,
        normal_distribution_cls=Normal,
        init_mean=0.0,
        init_temp=1.0,
        min_temp=1e-6,
        max_temp=None,
        learn_temp=True,
        temp_parameterization="exp",
        batch_size=1,
    ):
        super().__init__(
            input_dim=input_dim,
            output_dim=output_dim,
            hidden_sizes=hidden_sizes,
            hidden_nonlinearity=hidden_nonlinearity,
            hidden_w_init=hidden_w_init,
            hidden_b_init=hidden_b_init,
            output_nonlinearity=output_nonlinearity,
            output_w_init=output_w_init,
            output_b_init=output_b_init,
            learn_std=learn_std,
            init_std=init_std,
            min_std=min_std,
            max_std=max_std,
            std_hidden_sizes=std_hidden_sizes,
            std_hidden_nonlinearity=std_hidden_nonlinearity,
            std_hidden_w_init=std_hidden_w_init,
            std_hidden_b_init=std_hidden_b_init,
            std_output_nonlinearity=std_output_nonlinearity,
            std_output_w_init=std_output_w_init,
            std_parameterization=std_parameterization,
            layer_normalization=layer_normalization,
            normal_distribution_cls=normal_distribution_cls,
            init_temp=init_temp,
            min_temp=min_temp,
            max_temp=max_temp,
            learn_temp=learn_temp,
            temp_parameterization=temp_parameterization,
        )
        self.batch_size = batch_size
        self._shared_mean_log_std_network = MultiHeadedMLPModule(
            n_heads=2,
            input_dim=self._input_dim,
            output_dims=[self._action_dim, self._action_dim],
            hidden_sizes=self._hidden_sizes,
            hidden_nonlinearity=self._hidden_nonlinearity,
            hidden_w_init=self._hidden_w_init,
            hidden_b_init=self._hidden_b_init,
            output_nonlinearities=self._output_nonlinearity,
            output_w_inits=self._output_w_init,
            output_b_inits=[
                nn.init.zeros_,
                lambda x: nn.init.constant_(x, self._init_temp.item()),
            ],
            layer_normalization=self._layer_normalization,
        )

    def _get_mean_and_log_std(self, *inputs):
        return self._shared_mean_log_std_network(*inputs)
