"""
Add custom distributions in addition to th existing ones
"""
import torch
from torch.distributions import Normal
from torch.distributions import Distribution
from torch.distributions import Independent
import numpy as np


class TorchDistributionWrapper(Distribution):
    def __init__(self, distribution: Distribution):
        self.distribution = distribution

    @property
    def batch_shape(self):
        return self.distribution.batch_shape

    @property
    def event_shape(self):
        return self.distribution.event_shape

    @property
    def arg_constraints(self):
        return self.distribution.arg_constraints

    @property
    def support(self):
        return self.distribution.support

    @property
    def mean(self):
        return self.distribution.mean

    @property
    def variance(self):
        return self.distribution.variance

    @property
    def stddev(self):
        return self.distribution.stddev

    def sample(self, sample_size=torch.Size()):
        return self.distribution.sample(sample_shape=sample_size)

    def rsample(self, sample_size=torch.Size()):
        return self.distribution.rsample(sample_shape=sample_size)

    def log_prob(self, value):
        return self.distribution.log_prob(value)

    def cdf(self, value):
        return self.distribution.cdf(value)

    def icdf(self, value):
        return self.distribution.icdf(value)

    def enumerate_support(self, expand=True):
        return self.distribution.enumerate_support(expand=expand)

    def entropy(self):
        return self.distribution.entropy()

    def perplexity(self):
        return self.distribution.perplexity()

    def __repr__(self):
        return "Wrapped " + self.distribution.__repr__()


class MultivariateDiagonalNormal(TorchDistributionWrapper):
    from torch.distributions import constraints

    arg_constraints = {"loc": constraints.real, "scale": constraints.positive}

    def __init__(self, loc, scale_diag, reinterpreted_batch_ndims=1):
        dist = Independent(
            Normal(loc=loc, scale=scale_diag),
            reinterpreted_batch_ndims=reinterpreted_batch_ndims,
        )
        super().__init__(dist)


class TanhNormal(TorchDistributionWrapper):
    """
    Represent distribution of X where
        X ~ tanh(Z)
        Z ~ N(mean, std)
    Note: this is not very numerically stable.
    """

    def __init__(self, normal_mean, normal_std, epsilon=1e-6):
        """
        :param normal_mean: Mean of the normal distribution
        :param normal_std: Std of the normal distribution
        :param epsilon: Numerical stability epsilon when computing log-prob.
        """
        self.normal_mean = normal_mean
        self.normal_std = normal_std
        self.normal = MultivariateDiagonalNormal(normal_mean, normal_std)
        self.epsilon = epsilon

    def sample_n(self, n):
        z = self.normal.sample_n(n)

        return torch.tanh(z)

    def _log_prob_from_pre_tanh(self, pre_tanh_value: torch.Tensor):
        """
        Adapted from
        This formula is mathematically equivalent to log(1 - tanh(x)^2).
        Derivation:
        log(1 - tanh(x)^2)
         = log(sech(x)^2)
         = 2 * log(sech(x))
         = 2 * log(2e^-x / (e^-2x + 1))
         = 2 * (log(2) - x - log(e^-2x + 1))
         = 2 * (log(2) - x - softplus(-2x))
        :param value: some value, x
        :param pre_tanh_value: arctanh(x)
        :return:
        """
        pre_tanh_value_device = pre_tanh_value.device
        log_prob = self.normal.log_prob(pre_tanh_value)
        correction = -2.0 * (
            torch.from_numpy(np.log([2.0])).float().to(pre_tanh_value_device)
            - pre_tanh_value
            - torch.nn.functional.softplus(-2.0 * pre_tanh_value)
        ).sum(dim=1)
        return log_prob + correction

    def log_prob(self, value):
        # errors or instability at values near 1
        value = torch.clamp(value, -0.999999, 0.999999)
        pre_tanh_value = torch.log(1 + value) / 2 - torch.log(1 - value) / 2
        return self._log_prob_from_pre_tanh(pre_tanh_value)

    def rsample_with_pretanh(self):
        device = self.normal_mean.device
        dtype = self.normal_mean.dtype
        z = (
            self.normal_mean
            + self.normal_std
            * MultivariateDiagonalNormal(
                torch.zeros_like(self.normal_mean, dtype=dtype, device=device),
                torch.ones_like(self.normal_std, dtype=dtype, device=device),
            ).sample()
        )
        return torch.tanh(z), z

    def sample(self):
        """
        Gradients will and should *not* pass through this operation.
        """
        value, pre_tanh_value = self.rsample_with_pretanh()
        return value.detach()

    def rsample(self):
        """
        Sampling in the reparameterization case.
        """
        value, pre_tanh_value = self.rsample_with_pretanh()
        return value

    @property
    def mean(self):
        return torch.tanh(self.normal_mean)

    @property
    def stddev(self):
        return self.normal_std


class TanhNormalCherry(Distribution):
    def __init__(self, loc, scale):
        super().__init__()
        self.normal = Normal(loc, scale)
        self.bijector = TanhBijector()

    def sample(self):
        return torch.tanh(self.normal.sample())

    def rsample(self):
        return torch.tanh(self.normal.rsample())

    def log_prob(self, value):
        # errors or instability at values near 1
        # value = torch.clamp(value, -0.999999, 0.999999)
        inv_value = self.bijector.inverse(value)
        log_probs_nb_action = self.normal.log_prob(
            inv_value
        ) - self.bijector.log_prob_correction(value)
        return torch.sum(log_probs_nb_action, dim=1)

    @property
    def mean(self):
        return torch.tanh(self.normal.mean)


class TanhBijector(object):
    """
    Bijective transformation of a probability distribution
    using a squashing function (tanh)
    :param epsilon: small value to avoid NaN due to numerical imprecision.
    """

    def __init__(self, epsilon: float = 1e-6):
        super(TanhBijector, self).__init__()
        self.epsilon = epsilon

    @staticmethod
    def forward(x: torch.Tensor) -> torch.Tensor:
        return torch.tanh(x)

    @staticmethod
    def atanh(x: torch.Tensor) -> torch.Tensor:
        """
        Inverse of Tanh
        0.5 * torch.log((1 + x ) / (1 - x))
        """
        return 0.5 * (x.log1p() - (-x).log1p())

    @staticmethod
    def inverse(y: torch.Tensor) -> torch.Tensor:
        """
        Inverse tanh.
        :param y:
        :return:
        """
        eps = torch.finfo(y.dtype).eps
        # Clip the action to avoid NaN
        return TanhBijector.atanh(y.clamp(min=-1.0 + eps, max=1.0 - eps))

    def log_prob_correction(self, x: torch.Tensor) -> torch.Tensor:
        # Squash correction (from original SAC implementation)
        return torch.log(1.0 - torch.tanh(x) ** 2 + self.epsilon)
