# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# or in the "license" file accompanying this file. This file is distributed
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.

from typing import Dict, List, Tuple

import mxnet as mx

from gluonts.core.component import validated
from gluonts.mx import Tensor

from .bijection import Bijection, InverseBijection
from .bijection_output import BijectionOutput
from .distribution import getF, softplus


class BoxCoxTransform(Bijection):
    r"""
    Implements Box-Cox transformation of a uni-variate random variable.
    The Box-Cox transformation of an observation :math:`z` is given by

    .. math::

        BoxCox(z; \lambda_1, \lambda_2) = \begin{cases}
                                              ((z + \lambda_2)^{\lambda_1} - 1) / \lambda_1, \quad & \text{if }
                                              \lambda_1 \neq 0, \\
                                              \log (z + \lambda_2), \quad & \text{otherwise.}
                                          \end{cases}

    Here, :math:`\lambda_1` and :math:`\lambda_2` are learnable parameters. Note that the domain
    of the transformation is not restricted.

    For numerical stability, instead of checking :math:`\lambda_1` is exactly zero, we use the condition

    .. math::

        |\lambda_1| < tol\_lambda\_1

    for a pre-specified tolerance `tol_lambda_1`.

    Inverse of the Box-Cox Transform is given by

    .. math::

        BoxCox^{-1}(y; \lambda_1, \lambda_2) = \begin{cases}
                                                (y \lambda_1 + 1)^{(1/\lambda_1)} - \lambda_2, \quad & \text{if }
                                                \lambda_1 \neq 0, \\
                                                \exp (y) - \lambda_2, \quad & \text{otherwise.}
                                               \end{cases}

    **Notes on numerical stability:**

    1.  For the forward transformation, :math:`\lambda_2` must always be chosen such that

        .. math::

            z + \lambda_2 > 0.

        To achieve this one needs to know a priori the lower bound on the observations.
        This is set in `BoxCoxTransformOutput`, since :math:`\lambda_2` is learnable.

    2.  Similarly for the inverse transformation to work reliably, a sufficient condition is

        .. math::

            y \lambda_1 + 1 \geq 0,

        where :math:`y` is the input to the inverse transformation.

        This cannot always be guaranteed especially when :math:`y` is a sample from a transformed distribution.
        Hence we always truncate :math:`y \lambda_1 + 1` at zero.

        An example showing why this could happen in our case:
        consider transforming observations from the unit interval (0, 1) with parameters

        .. math::

            \begin{align}
                \lambda_1 = &\ 1.1, \\
                \lambda_2 = &\ 0.
            \end{align}

        Then the range of the transformation is (-0.9090, 0.0).
        If Gaussian is fit to the transformed observations and a sample is drawn from it,
        then it is likely that the sample is outside this range, e.g., when the mean is close to -0.9.
        The subsequent inverse transformation of the sample is not a real number anymore.

        >>> y = -0.91
        >>> lambda_1 = 1.1
        >>> lambda_2 = 0.0
        >>> (y * lambda_1 + 1) ** (1 / lambda_1) + lambda_2
        (-0.0017979146510711471+0.0005279153735965289j)

    Parameters
    ----------
    lambda_1
    lambda_2
    tol_lambda_1
        For numerical stability, treat `lambda_1` as zero if it is less than
        `tol_lambda_1`
    F
    """
    arg_names = ["box_cox.lambda_1", "box_cox.lambda_2"]

    @validated()
    def __init__(
        self,
        lambda_1: Tensor,
        lambda_2: Tensor,
        tol_lambda_1: float = 1e-2,
        F=None,
    ) -> None:
        self.lambda_1 = lambda_1
        self.lambda_2 = lambda_2
        self.tol_lambda_1 = tol_lambda_1
        self.F = F if F else getF(lambda_1)

        # Addressing mxnet madness
        self._power = self.F.power if self.F == mx.nd else self.F.pow

    @property
    def args(self) -> List:
        r"""
        List: current values of the parameters
        """
        return [self.lambda_1, self.lambda_2]

    @property
    def event_dim(self) -> int:
        return 0

    @property
    def sign(self) -> Tensor:
        return 1.0

    def f(self, z: Tensor) -> Tensor:
        r"""
        Forward transformation of observations `z`

        Parameters
        ----------
        z
            observations

        Returns
        -------
        Tensor
            Transformed observations
        """
        F = self.F
        lambda_1 = self.lambda_1
        lambda_2 = self.lambda_2
        tol_lambda_1 = self.tol_lambda_1
        _power = self._power

        return F.where(
            condition=(F.abs(lambda_1).__ge__(tol_lambda_1).broadcast_like(z)),
            x=(_power(z + lambda_2, lambda_1) - 1.0) / lambda_1,
            y=F.log(z + lambda_2),
            name="Box_Cox_trans",
        )

    def f_inv(self, y: Tensor) -> Tensor:
        r"""Inverse of the Box-Cox Transform

        Parameters
        ----------
        y
            Transformed observations

        Returns
        -------
        Tensor
            Observations

        """
        F = self.F
        lambda_1 = self.lambda_1
        lambda_2 = self.lambda_2
        tol_lambda_1 = self.tol_lambda_1
        _power = self._power

        # For numerical stability we truncate :math:`y * \lambda_1 + 1.0` at zero.
        base = F.relu(y * lambda_1 + 1.0)

        return F.where(
            condition=(F.abs(lambda_1).__ge__(tol_lambda_1)).broadcast_like(y),
            x=_power(base, 1.0 / lambda_1) - lambda_2,
            y=F.exp(y) - lambda_2,
            name="Box_Cox_inverse_trans",
        )

    def log_abs_det_jac(self, z: Tensor, y: Tensor = None) -> Tensor:
        r"""
        Logarithm of the absolute value of the Jacobian determinant corresponding to the Box-Cox Transform
        is given by

        .. math::
            \log \frac{d}{dz} BoxCox(z; \lambda_1, \lambda_2) = \begin{cases}
                                \log (z + \lambda_2) (\lambda_1 - 1), \quad & \text{if } \lambda_1 \neq 0, \\
                                -\log (z + \lambda_2), \quad & \text{otherwise.}
                                \end{cases}

        Note that the derivative of the transformation is always non-negative.

        Parameters
        ----------
        z
            observations
        y
            not used

        Returns
        -------
        Tensor

        """
        F = self.F
        lambda_1 = self.lambda_1
        lambda_2 = self.lambda_2
        tol_lambda_1 = self.tol_lambda_1

        return F.where(
            condition=F.abs(lambda_1).__ge__(tol_lambda_1),
            x=F.log(z + lambda_2) * (lambda_1 - 1.0),
            y=-F.log(z + lambda_2),
            name="Box_Cox_trans_log_det_jac",
        )


class BoxCoxTransformOutput(BijectionOutput):
    bij_cls: type = BoxCoxTransform
    args_dim: Dict[str, int] = dict(zip(BoxCoxTransform.arg_names, [1, 1]))

    @validated()
    def __init__(self, lb_obs: float = 0.0, fix_lambda_2: bool = True) -> None:
        super().__init__()
        self.lb_obs = lb_obs
        self.fix_lambda_2 = fix_lambda_2

    def domain_map(self, F, *args: Tensor) -> Tuple[Tensor, ...]:
        lambda_1, lambda_2 = args
        if self.fix_lambda_2:
            lambda_2 = -self.lb_obs * F.ones_like(lambda_2)
        else:
            # This makes sure that :math:`z +  \lambda_2 > 0`, where :math:`z > lb_obs`
            lambda_2 = softplus(F, lambda_2) - self.lb_obs * F.ones_like(
                lambda_2
            )

        # we squeeze the output since event_shape is ()
        return lambda_1.squeeze(axis=-1), lambda_2.squeeze(axis=-1)

    @property
    def event_shape(self) -> Tuple:
        return ()


class InverseBoxCoxTransform(InverseBijection):
    """
    Implements the inverse of Box-Cox transformation as a bijection.
    """

    arg_names = ["box_cox.lambda_1", "box_cox.lambda_2"]

    @validated()
    def __init__(
        self,
        lambda_1: Tensor,
        lambda_2: Tensor,
        tol_lambda_1: float = 1e-2,
        F=None,
    ) -> None:
        super().__init__(BoxCoxTransform(lambda_1, lambda_2, tol_lambda_1, F))

    @property
    def event_dim(self) -> int:
        return 0


class InverseBoxCoxTransformOutput(BoxCoxTransformOutput):
    bij_cls: type = InverseBoxCoxTransform

    args_dim: Dict[str, int] = dict(
        zip(InverseBoxCoxTransform.arg_names, [1, 1])
    )

    @property
    def event_shape(self) -> Tuple:
        return ()
