"""AdaptiveTanhGaussianPolicy."""
import akro
import numpy as np
import torch

from garage.torch import global_device
from garage.torch.distributions import TanhNormal
from garage.torch.policies.stochastic_policy import StochasticPolicy
from torch import nn


class RandomCausalDoPolicy(StochasticPolicy):

    def __init__(
        self,
        env_spec,
        **kwargs
    ):
        super().__init__(env_spec, name="AdaptiveTanhGaussianPolicy")

        self._obs_dim = env_spec.observation_space.flat_dim
        self._action_dim = env_spec.action_space.flat_dim
        

    def get_actions(self, observations, mask=None):
        r"""Get actions given observations.

        Args:
            observations (np.ndarray): Observations from the environment.
                Shape is :math:`batch_dim \bullet env_spec.observation_space`.

        Returns:
            tuple:
                * np.ndarray: Predicted actions.
                    :math:`batch_dim \bullet env_spec.action_space`.
                * dict:
                    * np.ndarray[float]: Mean of the distribution.
                    * np.ndarray[float]: Standard deviation of logarithmic
                        values of the distribution.

        """

        return dist.sample().detach(), {k: v.detach() for (k, v) in info.items()}
