"""AdaptiveTanhGaussianPolicy."""

import akro
import numpy as np
import torch
from garage.torch import global_device
from garage.torch.distributions import TanhNormal
from garage.torch.policies.stochastic_policy import StochasticPolicy
from torch import nn

from src.modules import (
    AlternateAttention,
    GaussianMLPTwoHeadedModuleDoCausal,
    SingleTargetSixHeadedModule,
    SingleTargetTwoHeadedModule,
)


class AdaptiveTransformerTanhGaussianPolicy(StochasticPolicy):
    """Multiheaded MLP with an encoder and pooling operation.

    A policy that takes as input entire histories and maps them to a Gaussian
    distribution with a tanh transformation. Inputs to the network should be of
    the shape (batch_dim, history_length, obs_dim)

    Args:
        env_spec (garage.envs.env_spec.EnvSpec): Environment specification.
        encoder_sizes (list[int]): Output dimension of dense layer(s) for
            the MLP for encoder. For example, (32, 32) means the MLP consists
            of two hidden layers, each with 32 hidden units.
        encoder_nonlinearity (callable): Activation function for intermediate
            dense layer(s) of encoder. It should return a torch.Tensor. Set it
            to None to maintain a linear activation.
        encoder_output_nonlinearity (callable): Activation function for encoder
            output dense layer. It should return a torch.Tensor. Set it to None
            to maintain a linear activation.
        encoding_dim (int): Output dimension of output dense layer for encoder.
        emitter_sizes (list[int]): Output dimension of dense layer(s) for
            the MLP for emitter.
        emitter_nonlinearity (callable): Activation function for intermediate
            dense layer(s) of emitter.
        emitter_output_nonlinearity (callable): Activation function for emitter
            output dense layer.
        hidden_w_init (callable): Initializer function for the weight
            of intermediate dense layer(s). The function should return a
            torch.Tensor.
        hidden_b_init (callable): Initializer function for the bias
            of intermediate dense layer(s). The function should return a
            torch.Tensor.
        output_w_init (callable): Initializer function for the weight
            of output dense layer(s). The function should return a
            torch.Tensor.
        output_b_init (callable): Initializer function for the bias
            of output dense layer(s). The function should return a
            torch.Tensor.
        init_std (float): Initial value for std.
            (plain value - not log or exponentiated).
        min_std (float): If not None, the std is at least the value of min_std,
            to avoid numerical issues (plain value - not log or exponentiated).
        max_std (float): If not None, the std is at most the value of max_std,
            to avoid numerical issues (plain value - not log or exponentiated).
        std_parameterization (str): How the std should be parametrized. There
            are two options:
            - exp: the logarithm of the std will be stored, and applied a
               exponential transformation
            - softplus: the std will be computed as log(1+exp(x))
        layer_normalization (bool): Bool for using layer normalization or not.

    """

    def __init__(
        self,
        env_spec,
        dropout=0.1,
        widening_factor=4,
        pooling="max",
        embedding_dim=16,
        n_attention_layers=8,
        n_attention_heads=8,
        emitter_sizes=(32, 32),
        emitter_nonlinearity=nn.ReLU,
        emitter_output_nonlinearity=None,
        hidden_w_init=nn.init.xavier_uniform_,
        hidden_b_init=nn.init.zeros_,
        output_w_init=nn.init.xavier_uniform_,
        output_b_init=nn.init.zeros_,
        init_mean=0.0,
        init_std=1.0,
        min_std=np.exp(-20.0),
        max_std=np.exp(2.0),
        std_parameterization="exp",
        layer_normalization=False,
        batch_size=1,
        device="cpu",
        is_single_target=False,
        no_value=False,
    ):
        super().__init__(env_spec, name="AdaptiveTanhGaussianPolicy")

        self._pooling = pooling
        self.device = device
        self.batch_size = batch_size
        self.is_single_target = is_single_target
        self.no_value = no_value

        self._encoder = AlternateAttention(
            dim=embedding_dim,
            feedforward_dim=widening_factor * embedding_dim,
            dropout=dropout,
            n_layers=n_attention_layers,
            num_heads=n_attention_heads,
        )
        self._encoder = nn.DataParallel(self._encoder)
        if is_single_target and no_value:
            self._emitter = SingleTargetTwoHeadedModule(
                input_dim=embedding_dim,
                output_dim=self.batch_size,
                hidden_sizes=emitter_sizes,
                hidden_nonlinearity=emitter_nonlinearity,
                hidden_w_init=hidden_w_init,
                hidden_b_init=hidden_b_init,
                output_nonlinearity=emitter_output_nonlinearity,
                output_w_init=output_w_init,
                output_b_init=output_b_init,
                init_mean=init_mean,
                init_std=init_std,
                min_std=min_std,
                max_std=max_std,
                std_parameterization=std_parameterization,
                layer_normalization=layer_normalization,
                normal_distribution_cls=TanhNormal,
                batch_size=batch_size,
            )
        elif is_single_target and not no_value:
            self._emitter = SingleTargetSixHeadedModule(
                input_dim=embedding_dim,
                output_dim=self.batch_size,
                hidden_sizes=emitter_sizes,
                hidden_nonlinearity=emitter_nonlinearity,
                hidden_w_init=hidden_w_init,
                hidden_b_init=hidden_b_init,
                output_nonlinearity=emitter_output_nonlinearity,
                output_w_init=output_w_init,
                output_b_init=output_b_init,
                layer_normalization=layer_normalization,
                batch_size=batch_size,
            )
        else:
            self._emitter = GaussianMLPTwoHeadedModuleDoCausal(
                input_dim=embedding_dim,
                output_dim=2 * self.batch_size if not no_value else self.batch_size,
                hidden_sizes=emitter_sizes,
                hidden_nonlinearity=emitter_nonlinearity,
                hidden_w_init=hidden_w_init,
                hidden_b_init=hidden_b_init,
                output_nonlinearity=emitter_output_nonlinearity,
                output_w_init=output_w_init,
                output_b_init=output_b_init,
                init_mean=init_mean,
                init_std=init_std,
                min_std=min_std,
                max_std=max_std,
                std_parameterization=std_parameterization,
                layer_normalization=layer_normalization,
                normal_distribution_cls=TanhNormal,
                batch_size=batch_size,
                no_value=no_value,
            )

    def get_actions(self, observations, mask=None):
        r"""Get actions given observations.

        Args:
            observations (np.ndarray): Observations from the environment.
                Shape is :math:`batch_dim \bullet env_spec.observation_space`.

        Returns:
            tuple:
                * np.ndarray: Predicted actions.
                    :math:`batch_dim \bullet env_spec.action_space`.
                * dict:
                    * np.ndarray[float]: Mean of the distribution.
                    * np.ndarray[float]: Standard deviation of logarithmic
                        values of the distribution.

        """
        if not isinstance(observations[0], np.ndarray) and not isinstance(
            observations[0], torch.Tensor
        ):
            observations = self._env_spec.observation_space.flatten_n(observations)

        # frequently users like to pass lists of torch tensors or lists of
        # numpy arrays. This handles those conversions.
        if isinstance(observations, list):
            if isinstance(observations[0], np.ndarray):
                observations = np.stack(observations)
            elif isinstance(observations[0], torch.Tensor):
                observations = torch.stack(observations)

        if isinstance(self._env_spec.observation_space, akro.Image) and len(
            observations.shape
        ) < len(self._env_spec.observation_space.shape):
            observations = self._env_spec.observation_space.unflatten_n(observations)
        with torch.no_grad():
            if not isinstance(observations, torch.Tensor):
                observations = torch.as_tensor(observations).float().to(global_device())
                if mask is not None:
                    mask = torch.as_tensor(mask).float().to(global_device())
            if self.is_single_target and self.no_value:
                dist, dist_value, info = self.forward(
                    observations.to(self.device), mask.to(self.device)
                )
                targets = dist.rsample().detach().cpu()
                if self.no_value:
                    values = dist_value.cpu()
                else:
                    values = dist_value.sample().detach().cpu().squeeze(-2)
                actions = torch.cat([targets, values], dim=-1).unsqueeze(-2)
            elif self.is_single_target and not self.no_value:
                dist, dist_value, dist_obs, info = self.forward(
                    observations.to(self.device), mask.to(self.device)
                )
                targets = dist.rsample().detach().cpu()
                values = dist_value.rsample().detach().cpu()
                obs = dist_obs.rsample().detach().cpu()
                values = targets * torch.matmul(
                    values, torch.tensor([[-1.0], [1.0]]).to(values.device)
                )
                targets = targets * torch.matmul(
                    obs, torch.tensor([[0.0], [1.0]]).to(targets.device)
                )
                actions = torch.cat([targets, values], dim=-1).unsqueeze(-2)
            else:
                dist, info = self.forward(
                    observations.to(self.device), mask.to(self.device)
                )
                actions = dist.sample().detach().cpu()
        return actions, {k: v.detach().cpu() for (k, v) in info.items()}

    def forward(self, observations, mask=None):
        """Compute the action distributions from the observations.

        Args:
            observations (torch.Tensor): Batch of observations on default
                torch device.
            mask (torch.Tensor): a mask to account for 0-padded inputs

        Returns:
            torch.distributions.Distribution: Batch distribution of actions.
            dict[str, torch.Tensor]: Additional agent_info, as torch Tensors

        """
        encoding = self._encoder(observations)
        if self._pooling:
            if self._pooling == "max":
                if mask is not None:
                    min_value = torch.min(encoding).detach()
                    encoding[~mask.expand(*encoding.shape)] = min_value
                encoding = torch.max(encoding, -3).values
            elif self._pooling == "sum":
                if mask is not None:
                    encoding = encoding * mask
                encoding = torch.sum(encoding, -3)
            else:
                raise NotImplementedError(f"{self._pooling} not implemented")
        if self.is_single_target and not self.no_value:
            dist_target, dist_value, dist_obs = self._emitter(encoding)
            ret_logits = dist_target.logits.clone()
            ret_log_temp = dist_target.temperature.log().clone()
            ret_logits_value = dist_value.logits.clone()
            ret_log_temp_value = dist_value.temperature.log().clone()
            ret_logits_obs = dist_obs.logits.clone()
            ret_log_temp_obs = dist_obs.temperature.log().clone()
            info = dict(
                logits=ret_logits,
                log_temp=ret_log_temp,
                logits_value=ret_logits_value,
                log_temp_value=ret_log_temp_value,
                logits_obs=ret_logits_obs,
                log_temp_obs=ret_log_temp_obs,
            )
            return dist_target, dist_value, dist_obs, info
        elif self.is_single_target and self.no_value:
            dist, dist_value = self._emitter(encoding)
            ret_logits = dist.logits.clone()
            ret_log_temp = dist.temperature.log().clone()
            info = dict(logits=ret_logits, log_temp=ret_log_temp)
            return dist, dist_value, info
        else:
            dist = self._emitter(encoding)
            ret_mean = dist.mean.clone()
            ret_log_std = (dist.variance.sqrt()).log().clone()
            info = dict(mean=ret_mean, log_std=ret_log_std)
            return dist, info
