import time
from typing import Tuple

import torch
from mlwiz.evaluation.util import return_class_and_args
from torch.nn import Parameter, Module
from torch.nn.functional import softplus, relu, sigmoid


def inv_softplus(bias: float | torch.Tensor) -> float | torch.Tensor:
    """Inverse softplus function.

    Args:
        bias (float or tensor): the value to be softplus-inverted.
    """
    is_tensor = True
    if not isinstance(bias, torch.Tensor):
        is_tensor = False
        bias = torch.tensor(bias)
    out = bias.expm1().clamp_min(1e-6).log()
    if not is_tensor and out.numel() == 1:
        return out.item()
    return out


class ContinuousDistribution(Module):
    """
    Implements an interface for this package
    """

    def __init__(self):
        super().__init__()
        self.device = "cpu"

    def to(self, device):
        super().to(device)
        self.device = device

    def _validate_args(self, value):
        assert isinstance(
            value, torch.Tensor
        ), f"expected torch tensor, found {type(value)}"

        # assert isinstance(value, torch.FloatTensor) or (
        #     value.dtype == torch.float32
        # ), f"expected float tensor, found {value.dtype}"

        assert (
            len(value.shape) == 2
        ), f"expected shape: (N,1), found {value.shape}"

        assert (
            value.shape[1] == 1
        ), f"expected one-dimensional values, found {value.shape}"

    def log_prob(self, value: torch.Tensor) -> torch.Tensor:
        """
        Computes the log pdf of the distribution

        :param value: a tensor of shape Nx1, where N is the number of samples

        :return: a tensor of shape Nx1
        """
        raise NotImplementedError(
            "You should subclass Distribution and " "implement this method."
        )

    def cdf(self, value):
        """
        Computes the cdf of the distribution

        :param value: a tensor of shape Nx1, where N is the number of samples

        :return: a tensor of shape Nx1
        """
        raise NotImplementedError(
            "You should subclass Distribution and " "implement this method."
        )

    def quantile(self, p: float = 0.95) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Computes the p-quantile for the distribution.

        :param p: the parameter p of the quantile

        :return: lower and upper bounds for the p-quantile. If the p-quantile
            can be computed exactly then they are the same

        """
        raise NotImplementedError(
            "You should subclass Distribution and " "implement this method."
        )

    @property
    def parameter(self) -> torch.Tensor:
        raise NotImplementedError(
            "You should subclass Distribution and " "implement this method."
        )


class Exponential(ContinuousDistribution):
    def __init__(self, rate: float, boundary: float = None):
        super().__init__()
        if boundary is not None:
            self.boundary = Parameter(
                torch.tensor([boundary]).to(torch.get_default_dtype()),
                requires_grad=False,
            )
        else:
            self.boundary = None

        self._rate = Parameter(
            inv_softplus(torch.tensor([rate]).to(torch.get_default_dtype())),
            requires_grad=True,
        )

    def _validate_args(self, value):
        super()._validate_args(value)

        assert torch.all(value >= 0), (
            f"Input values cannot be smaller"
            f" than 0. Rate is {self.rate} and values are {value}"
        )

    @property
    def rate(self) -> torch.Tensor:
        r = softplus(self._rate) + 1e-32

        if self.boundary is None:
            return r
        else:
            if self.boundary > r:
                self._rate.data = inv_softplus(self.boundary.data)
                r = softplus(self._rate) + 1e-32
                return r
            else:
                return r

    @property
    def parameter(self) -> torch.Tensor:
        return self.rate

    def log_prob(self, value: torch.Tensor) -> torch.Tensor:
        return torch.log(self.rate) - self.rate * value

    def cdf(self, value):
        one = torch.tensor([1.0]).to(self.device)
        return one - torch.exp(-self.rate * value)

    def quantile(self, p: float = 0.95) -> Tuple[torch.Tensor, torch.Tensor]:
        t = time.time()

        one = torch.tensor([1.0]).to(self.device)
        q = -torch.log(one - p) / self.rate

        q_time = time.time() - t
        # print(f'Compute quantile of original distribution took {q_time:.5f}')
        return q, q


class Uniform(ContinuousDistribution):
    def __init__(
        self,
        b: float,
    ):
        super().__init__()
        self.b = Parameter(
            torch.tensor([b]).to(torch.get_default_dtype()),
            requires_grad=False,  # DO NOT LEARN THIS
        )

    def _validate_args(self, value):
        super()._validate_args(value)

        assert torch.all(value >= 0), (
            f"Input values cannot be smaller" f" than 0."
        )

    @property
    def parameter(self) -> torch.Tensor:
        return self.b

    def log_prob(self, value: torch.Tensor) -> torch.Tensor:
        return -torch.log(self.b - 1.0)

    def cdf(self, value):
        one = torch.tensor([1.0]).to(self.device)
        return (value - one) / (self.b - one)

    def quantile(self, p: float = 0.95) -> Tuple[torch.Tensor, torch.Tensor]:
        b = self.b
        return b, b


class Pareto(ContinuousDistribution):
    def __init__(
        self, alpha: float, scale: float = 1.0
    ):  # DO NOT CHANGE, MUST STAY 1
        super().__init__()
        self.xm = torch.tensor([scale]).to(torch.get_default_dtype())
        self._alpha = Parameter(
            torch.tensor([alpha]).to(torch.get_default_dtype()),
            requires_grad=True,
        )

    def _validate_args(self, value):
        super()._validate_args(value)

        one = torch.tensor([1.0]).to(self.device)
        assert torch.all(value + one >= self.xm), (
            f"Input values cannot be smaller" f" than the scale {self.xm}."
        )

    @property
    def alpha(self) -> torch.Tensor:
        return relu(self._alpha)

    @property
    def parameter(self) -> torch.Tensor:
        return self.alpha

    def log_prob(self, value: torch.Tensor) -> torch.Tensor:
        one = torch.tensor([1.0]).to(self.device)

        # shift values so that everything starts from 1
        value = value + one

        return (
            torch.log(self.alpha)
            + self.alpha * torch.log(self.xm)
            - (self.alpha + one) * torch.log(value)
        )

    def cdf(self, value):
        one = torch.tensor([1.0]).to(self.device)

        # shift values so that everything starts from 1
        value = value + one

        return one - torch.pow((self.xm / value), self.alpha)

    def quantile(self, p: float = 0.95) -> Tuple[torch.Tensor, torch.Tensor]:
        one = torch.tensor([1.0]).to(self.device)
        q = self.xm / torch.pow((one - p), one / self.alpha)
        return q, q


class Lomax(ContinuousDistribution):
    def __init__(self, alpha: float, scale: float):
        super().__init__()
        self._scale = torch.tensor([scale]).to(torch.get_default_dtype())
        self._alpha = Parameter(
            torch.tensor([alpha]).to(torch.get_default_dtype()),
            requires_grad=True,
        )

    def _validate_args(self, value):
        super()._validate_args(value)
        assert torch.all(value >= 0.0), (
            f"Input values cannot be smaller" f" than the scale {self.xm}."
        )

    @property
    def alpha(self) -> torch.Tensor:
        return relu(self._alpha)

    @property
    def parameter(self) -> torch.Tensor:
        return self.alpha

    @property
    def scale(self) -> torch.Tensor:
        return relu(self._scale)

    def log_prob(self, value: torch.Tensor) -> torch.Tensor:
        one = torch.tensor([1.0]).to(self.device)

        return (
            torch.log(self.alpha)
            - torch.log(self.scale)
            - (self.alpha + one) * torch.log(one + value / self.scale)
        )

    def cdf(self, value):
        one = torch.tensor([1.0]).to(self.device)
        return one - one / torch.pow(one + (value / self.scale), self.alpha)

    def quantile(self, p: float = 0.95) -> Tuple[torch.Tensor, torch.Tensor]:
        one = torch.tensor([1.0]).to(self.device)
        q = self.scale * (one / torch.pow((one - p), one / self.alpha) - one)
        return q, q


class PowerLaw(ContinuousDistribution):
    def __init__(
        self, gamma: float, xmin: float = 1.0, low_deg_sat: float = 0.0,
            boundary: float = None
    ):
        super().__init__()
        assert xmin >= 0

        if boundary is not None:
            self.boundary = Parameter(
                torch.tensor([boundary]).to(torch.get_default_dtype()),
                requires_grad=False,
            )
        else:
            self.boundary = None

        self._xmin = Parameter(
            torch.tensor([xmin]).to(torch.get_default_dtype()),
            requires_grad=False,
        )
        self._gamma = Parameter(
            torch.tensor([gamma]).to(torch.get_default_dtype()),
            requires_grad=True,
        )
        self._low_deg_sat = Parameter(
            torch.tensor([low_deg_sat]).to(torch.get_default_dtype()),
            requires_grad=False,
        )

    def _validate_args(self, value):
        super()._validate_args(value)
        assert torch.all(value >= 0.0), (
            f"Input values cannot be smaller" f" than the scale {self.xm}."
        )

    @property
    def gamma(self) -> torch.Tensor:
        g = relu(self._gamma)
        if self.boundary is None:
            return g
        else:
            if self.boundary > g:
                self._gamma.data = self.boundary.data.clone()
                g = relu(self._gamma) + 1e-32
                return g
            else:
                return g


    @property
    def low_deg_sat(self) -> torch.Tensor:
        return self._low_deg_sat

    @property
    def parameter(self) -> torch.Tensor:
        return self.gamma

    @property
    def xmin(self) -> torch.Tensor:
        return self._xmin

    def log_prob(self, value: torch.Tensor) -> torch.Tensor:
        # p(k) = C k^{-gamma}, C = (gamma-1)k_min^{gamma-1}
        assert torch.all(value >= 0.0)
        one = torch.tensor([1.0]).to(self.device)
        xmin = one * self.xmin
        # shift input of self.xmin so that passing value=0 corresponds to xmin
        value = value + xmin

        gamma = self.gamma
        gmo = self.gamma - one
        log_C = torch.log(gmo)

        log_p = (
            log_C
            + gmo * torch.log(self.xmin + self.low_deg_sat)
            - gamma * torch.log(value + self.low_deg_sat)
        )
        return log_p

    def cdf(self, value):
        assert torch.all(value >= 0.0)
        one = torch.tensor([1.0]).to(self.device)
        xmin = one * self.xmin
        low_deg_sat = self.low_deg_sat
        # shift input of self.xmin so that passing value=0 corresponds to xmin
        value = value + xmin

        return one - torch.pow(
            (xmin + low_deg_sat) / (value + low_deg_sat), self.gamma - one
        )

    def quantile(self, p: float = 0.95) -> Tuple[torch.Tensor, torch.Tensor]:
        one = torch.tensor([1.0]).to(self.device)
        exp_term = -(one / (self.gamma - one))
        q = (self.xmin + self.low_deg_sat) * torch.pow(
            (one - p), exp_term
        ) - self.low_deg_sat
        return q, q


class SigmoidDistribution(ContinuousDistribution):
    def __init__(self, b: float, k: float = 1.0):
        super().__init__()
        # make both k and b trainable
        self._k = Parameter(torch.tensor([k]).to(torch.get_default_dtype()),
            requires_grad=True,
        )
        self._b = Parameter(torch.tensor([b]).to(torch.get_default_dtype()),
            requires_grad=True,
        )

    @property
    def b(self):
        return softplus(self._b)

    @property
    def k(self):
        return sigmoid(self._k) * 2  # keep k in [0,2]

    def _validate_args(self, value):
        super()._validate_args(value)

    def log_prob(self, value: torch.Tensor) -> torch.Tensor:
        value = value
        return torch.log(1 - (1 / (1 +
                                   torch.exp(-self.k * (value - self.b)))))

    def cdf(self, value):
        return (
            value
            - torch.log(1 +
                        torch.exp(self.k * (-self.b + value))) / self.k
        ) - (0 - torch.log(1 +
                           torch.exp(self.k * (-self.b + 0))) / self.k)

    def quantile(self, p: float = 0.9) -> Tuple[torch.Tensor, torch.Tensor]:
        p = torch.Tensor([p])
        x = (self.b * self.k + torch.log(1 / (1-p) - 1)) / self.k
        # -torch.log(-torch.tensor([1])/(p+torch.tensor([1]))-1)/self.k+self.b
        return x + 2, x + 2

    @property
    def parameter(self) -> torch.Tensor:
        return self.b


class DiscretizedDistribution(Module):
    def __init__(self, **kwargs):
        """
        Creates a discretized version of a continuous distribution such that

            p(x) = phi(x+1) - phi(x)

        where phi is the cdf of the original distribution.

        :param kwargs: a dictionary with a key 'base_distribution' that
            allows us to instantiate a discretized distribution
        """
        super().__init__()
        base_d_cls, base_d_args = return_class_and_args(
            kwargs, "base_distribution"
        )
        self.base_distribution = base_d_cls(**base_d_args)
        self.device = "cpu"

    def to(self, device):
        super().to(device)
        self.device = device
        self.base_distribution.to(device)

    def get_q_named_parameters(self) -> dict:
        return self.base_distribution.get_q_named_parameters()

    def _validate_args(self, value):
        self.base_distribution._validate_args(value)

        # check values are integers
        assert torch.allclose(
            value, value.int().to(torch.get_default_dtype())
        ), f"expected float tensor with integer values, got {value}."

    def log_prob(self, value: torch.Tensor) -> torch.Tensor:
        """
        Computes the log pdf of the distribution

        :param value: a tensor of shape Nx1, where N is the number of samples

        :return: a tensor of shape Nx1
        """
        # self._validate_args(value)

        one = torch.ones(1).to(value.device)
        # avoids a degenerate case where the base distribution has the
        # same cdf for both value and value+1
        # which leads to nan. Also, a too small value can cause some
        # distributions to have prob 1 for a single neuron, and the model
        # gets trapped in there
        tmp = torch.ones(1).to(value.device) * 1e-6
        return torch.log(
            self.base_distribution.cdf(value + one)
            - self.base_distribution.cdf(value)
            + tmp
        )

    def cdf(self, value: torch.Tensor) -> torch.Tensor:
        """
        Computes the cdf of the distribution

        :param value: a tensor of shape Nx1, where N is the number of samples

        :return: a tensor of shape Nx1
        """
        # self._validate_args(value)
        one = torch.ones(1).to(value.device)

        return self.base_distribution.cdf(value + one)

    def quantile(self, p: float = 0.95) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Computes the approximated p-quantile for the discrete distribution.
        The lower and upper bounds returned by the method will coincide, since
        we provide the smallest integer x such that cdf(x) >= p

        :param p: the parameter p

        :return: lower and upper bounds for the p-quantile. If the p-quantile
            can be computed exactly then they are the same
        """
        lower_bound, upper_bound = self.base_distribution.quantile(p)

        # TODO TEST: see if we can avoid the call to CDF
        # corner case or case when quantile is known exactly
        if lower_bound == upper_bound:
            u = upper_bound.to(self.device)
            return u + 2, u + 2  # + 2 to avoid degenerate cases

        # Now perform binary search over the integers to find the smallest x
        # such that cdf(x) >= p. The boundaries of the search are given by the
        # bounds, and we use the fact that the cdf forms an ordered sequence
        l = torch.floor(lower_bound).to(self.device)
        u = torch.ceil(upper_bound).to(self.device)

        # ------------------------------------------------------------------ #
        # code that saves the day in case you implemented wrong bounds
        if self.cdf((u).unsqueeze(1)) < p:
            print(
                "WARNING: your upper bound on the quantile is "
                f"not working as expected:{(u, self.cdf((u).unsqueeze(1)))}"
            )
            ok = False
            while not ok:
                u += 1
                if self.cdf((u).unsqueeze(1)) >= p:
                    ok = True
            return u, u + 2  # + 2 to avoid degenerate cases
        # ------------------------------------------------------------------ #

        # if lower bound is already sufficient, stop, the normal and folded
        # normal curves are very similar at the desired quantile
        if self.cdf(l.unsqueeze(1)) >= p:
            # we could test l-1 (because of discretization), but in the end
            # it does not make a big difference
            assert self.cdf(l.unsqueeze(1)) >= p
            return l, l + 2  # + 2 to avoid degenerate cases

        # adapt the search: U will always have cdf(U) >= p, so we need to
        # check when we move from cdf(U) > p to cdf(U-1) <= p
        while l < u:
            if l == (u - 1.0):
                # assert self.cdf((u + 1).unsqueeze(1)) >= p
                # return u + 1, u + 1
                assert self.cdf(u.unsqueeze(1)) >= p
                return u, u + 2  # + 2 to avoid degenerate cases

            m = torch.floor((l + u) / 2.0)
            cdf_m = self.cdf(m.unsqueeze(1))

            if cdf_m < p:
                # move L to the right, closing the gap
                l = m + 1.0
            elif cdf_m > p:
                # move U to the left, closing the gap
                u = m - 1.0

    def compute_probability_vector(self, x) -> torch.Tensor:
        """
        Computes the **renormalized** vector of probabilities on the fly

        :return: a vector of arbitrary length with the probabilities
        """
        log_probs = self.log_prob(x).squeeze(1)
        probs = log_probs.exp()
        probs = probs / probs.sum()
        return probs

    @property
    def mean(self) -> torch.Tensor:
        return self.base_distribution.mean

    @property
    def variance(self) -> torch.Tensor:
        return self.base_distribution.variance


class TruncatedDistribution(Module):
    def __init__(self, truncation_quantile: float, **kwargs):
        """
        Truncates a discretized distribution to a given quantile and
        renormalizes its probability.

        :param truncation_quantile: the quantile in [0,1] at which we want
            to truncate the discrete distribution.
        :param kwargs: a dictionary with a key 'discretized_distribution' that
            allows us to instantiate a discretized distribution
        """
        super().__init__()

        dist_d_cls, dist_d_args = return_class_and_args(
            kwargs, "discretized_distribution"
        )
        self.discretized_distribution = dist_d_cls(**dist_d_args)
        self.truncation_quantile = truncation_quantile

    def to(self, device):
        super().to(device)
        self.device = device
        self.discretized_distribution.to(device)

    def get_q_named_parameters(self) -> dict:
        return self.discretized_distribution.get_q_named_parameters()

    def compute_truncation_number(self) -> int:
        """
        Computes the truncation number at the specified quantile.

        :return: a positive integer holding the truncation number

        """

        # exploits the implementation of quantile() for the
        # DiscretizedDistribution, which returns

        t = time.time()
        _, truncation_number = self.discretized_distribution.quantile(
            p=self.truncation_quantile
        )
        trunc_time = time.time() - t

        # print(f'Compute quantile of discretized distribution took {trunc_time:.5f}')

        # detach: this must not be part of the gradient computation in any way
        truncation_number = int(truncation_number.detach())

        # assert truncation_number > 0
        return truncation_number

    def compute_probability_vector(self) -> torch.Tensor:
        """
        Computes the **renormalized** vector of probabilities on the fly

        :return: a vector of arbitrary length with the probabilities
        """
        truncation_number = self.compute_truncation_number()

        # no gradient so far, we detach on purpose
        x = torch.arange(
            truncation_number,
            dtype=torch.get_default_dtype(),
            device=self.device,
        ).unsqueeze(1)

        probs = self.discretized_distribution.compute_probability_vector(x)

        assert torch.allclose(probs.sum(), torch.ones(1, device=self.device)),(
            probs, self.discretized_distribution.base_distribution.k, self.discretized_distribution.base_distribution.b)
        return probs

    @property
    def mean(self) -> torch.Tensor:
        proba = self.compute_probability_vector()
        return (proba * torch.arange(len(proba)).to(proba.device)).sum()

    def quantile(self, p: float = 0.95) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Computes the approximated p-quantile for the discrete distribution.
        The lower and upper bounds returned by the method will coincide, since
        we provide the smallest integer x such that cdf(x) >= p

        :param p: the parameter p

        :return: lower and upper bounds for the p-quantile. If the p-quantile
            can be computed exactly then they are the same
        """
        return self.discretized_distribution.quantile(p)
