import torch
from torch.distributions import Distribution, Normal
from torch.distributions.transforms import TanhTransform
from torch.distributions.transformed_distribution import TransformedDistribution
from torch.distributions import constraints


class StableTanhNormal(Distribution):
    arg_constraints = {"loc": constraints.real, "scale": constraints.positive}
    # TODO: Add support

    def __init__(
        self,
        loc: torch.Tensor,
        scale: torch.Tensor,
    ):
        self.loc = loc
        self.scale = scale
        super().__init__()
        self.normal = Normal(loc=loc, scale=scale)

    def sample(self):
        return torch.tanh(self.normal.sample())

    def rsample(self):
        return torch.tanh(self.normal.rsample())

    def log_prob(self, value: torch.Tensor) -> torch.Tensor:
        # value = value.clamp(min=-1 + 1e-6, max=1 - 1e-6)
        inv_value = TanhBijector.inverse(value)
        normal_log_prob = self.normal.log_prob(inv_value).sum(dim=1, keepdim=True)
        correction = torch.log(1 - value**2 + 1e-6).sum(dim=1, keepdim=True)
        log_prob: torch.Tensor = normal_log_prob - correction
        return log_prob

    @property
    def mean(self):
        return torch.tanh(self.normal.mean)


class TanhBijector(object):
    """
    Bijective transformation of a probability distribution
    using a squashing function (tanh)
    :param epsilon: small value to avoid NaN due to numerical imprecision.
    """

    def __init__(self, epsilon: float = 1e-6):
        super(TanhBijector, self).__init__()
        self.epsilon = epsilon

    @staticmethod
    def forward(x: torch.Tensor) -> torch.Tensor:
        return torch.tanh(x)

    @staticmethod
    def atanh(x: torch.Tensor) -> torch.Tensor:
        """
        Inverse of Tanh
        0.5 * torch.log((1 + x ) / (1 - x))
        """
        return 0.5 * (x.log1p() - (-x).log1p())

    @staticmethod
    def inverse(y: torch.Tensor) -> torch.Tensor:
        """
        Inverse tanh.
        :param y:
        :return:
        """
        eps = torch.finfo(y.dtype).eps
        # Clip the action to avoid NaN
        return TanhBijector.atanh(y.clamp(min=-1.0 + eps, max=1.0 - eps))

    def log_prob_correction(self, x: torch.Tensor) -> torch.Tensor:
        # Squash correction (from original SAC implementation)
        return torch.log(1.0 - torch.tanh(x) ** 2 + self.epsilon)


class SquashedTanhGaussian(TransformedDistribution):
    def __init__(self, loc: torch.Tensor, scale: torch.Tensor):
        self.base_dist = Normal(loc=loc, scale=scale)
        transforms = [TanhTransform(cache_size=1)]
        super().__init__(self.base_dist, transforms)

    def log_prob(self, value: torch.Tensor) -> torch.Tensor:
        # eps = torch.finfo(value.dtype).eps
        eps = 1e-6
        value = value.clamp(min=-1.0 + eps, max=1.0 - eps)
        return super().log_prob(value).sum(dim=1, keepdim=True)

    @property
    def mean(self):
        mu = self.loc
        for tr in self.transforms:
            mu = tr(mu)
        return mu
