import math
import warnings

import torch
import torch.nn.functional as F
from torch import distributions as pyd
from torch.distributions.utils import _standard_normal

from lambda_ac.rl_types import Distribution


class Gaussian(Distribution):
    def __init__(self, mean, log_std):
        std = log_std.exp() + 1e-8

        self._pdf = pyd.Normal(mean, std)

    def rsample(self, n: int = 1) -> torch.Tensor:
        sample: torch.Tensor = self._pdf.rsample()  # type: ignore
        return sample

    def log_prob(self, actions: torch.Tensor) -> torch.Tensor:
        log_prob = self._pdf.log_prob(actions)
        return log_prob.sum(dim=-1, keepdim=True)

    @property
    def mean(self):
        return self._pdf.mean

    @property
    def log_std(self):
        return torch.log(self._pdf.scale)

    @property
    def std(self):
        return self._pdf.scale

    @property
    def entropy(self):
        return self._pdf.entropy()

    def get_pdf(self):
        return self._pdf


class BoundedGaussian(Distribution):
    def __init__(self, mean, std):
        mean = 30 * torch.tanh(mean / 30)
        # std = torch.exp(log_std)
        std = 2.0 - F.softplus(2.0 - std)
        std = 1e-5 + F.softplus(std - 1e-5)

        self._pdf = pyd.Normal(mean, std)

    def rsample(self, n: int = 1) -> torch.Tensor:
        sample: torch.Tensor = self._pdf.rsample()
        return sample

    def log_prob(self, actions: torch.Tensor) -> torch.Tensor:
        log_prob = self._pdf.log_prob(actions)
        return log_prob.sum(dim=-1, keepdim=True)

    @property
    def mean(self):
        return self._pdf.mean

    @property
    def log_std(self):
        return torch.log(self._pdf.scale)

    @property
    def std(self):
        return self._pdf.scale

    def get_pdf(self):
        return self._pdf

    @property
    def entropy(self):
        return self._pdf.entropy()


class TanhGaussian(Distribution):
    def __init__(self, mean, log_std):
        # self._mean = torch.clamp(mean, min=-5, max=5)
        log_std = torch.clamp(log_std, min=-20, max=2)
        self._log_std = log_std
        std = log_std.exp()

        self._pdf = SquashedNormal(mean, std)

    def rsample(self, n: int = 1) -> torch.Tensor:
        sample: torch.Tensor = self._pdf.rsample()  # type: ignore
        return sample

    def log_prob(self, actions: torch.Tensor) -> torch.Tensor:
        log_prob = self._pdf.log_prob(actions)
        return log_prob.sum(dim=-1, keepdim=True)

    @property
    def mean(self):
        return self._pdf.mean

    @property
    def log_std(self):
        return torch.log(self._pdf.scale)

    @property
    def std(self):
        return self._pdf.scale

    def get_pdf(self):
        return self._pdf

    @property
    def entropy(self):
        warnings.warn("Entropy not accurate for TanhGaussian")
        return self._pdf.entropy()


class TanhTransform(pyd.transforms.Transform):
    domain = pyd.constraints.real  # type: ignore
    codomain = pyd.constraints.interval(-1.0, 1.0)  # type: ignore
    bijective = True
    sign = +1  # type: ignore

    def __init__(self, cache_size=1):
        super().__init__(cache_size=cache_size)

    @staticmethod
    def atanh(x):
        return 0.5 * (x.log1p() - (-x).log1p())

    def __eq__(self, other):
        return isinstance(other, TanhTransform)

    def _call(self, x):
        return x.tanh()

    def _inverse(self, y):
        return self.atanh(y)

    def log_abs_det_jacobian(self, x, y):
        # We use a formula that is more numerically stable, see details in the following link
        # https://github.com/tensorflow/probability/commit/ef6bb176e0ebd1cf6e25c6b5cecdd2428c22963f#diff-e120f70e92e6741bca649f04fcd907b7
        return 2.0 * (math.log(2.0) - x - F.softplus(-2.0 * x))


class SquashedNormal(pyd.transformed_distribution.TransformedDistribution):
    def __init__(self, loc, scale):
        self.loc = loc
        self.scale = scale

        self.base_dist = pyd.Normal(loc, scale)
        transforms = [TanhTransform()]
        super().__init__(self.base_dist, transforms)

    @property
    def mean(self):
        mu = self.loc
        for tr in self.transforms:
            mu = tr(mu)
        return mu


class DiracDistribution(Distribution):
    def __init__(self, mean: torch.Tensor):
        self._mean = mean

    def rsample(self, n: int = 1) -> torch.Tensor:
        if n == 1:
            return self._mean
        else:
            return self._mean.expand(n, *self._mean.shape)

    def log_prob(self, actions: torch.Tensor) -> torch.Tensor:
        return torch.sum(torch.zeros_like(actions), dim=-1, keepdim=True)

    @property
    def mean(self):
        return self._mean

    def get_pdf(self):
        return None

    @property
    def entropy(self):
        return torch.zeros_like(self._mean)

    @property
    def log_std(self):
        return torch.zeros((1,))

    @property
    def std(self):
        return torch.zeros((1,))


class TruncatedNormal(pyd.Normal):
    def __init__(self, loc, scale, low=-1.0, high=1.0, eps=1e-6):
        super().__init__(loc, scale, validate_args=False)
        self.low = low
        self.high = high
        self.eps = eps

    def _clamp(self, x):
        clamped_x = torch.clamp(x, self.low + self.eps, self.high - self.eps)
        x = x - x.detach() + clamped_x.detach()
        return x

    def sample(self, clip=None, sample_shape=torch.Size()):
        shape = self._extended_shape(sample_shape)
        eps = _standard_normal(shape, dtype=self.loc.dtype, device=self.loc.device)
        eps *= self.scale
        if clip is not None:
            eps = torch.clamp(eps, -clip, clip)
        x = self.loc + eps
        return self._clamp(x)
