import math
from typing import Optional, Union

import torch
from gpytorch import settings
from gpytorch.constraints import Interval

from torch import Tensor

EPS = 1e-8


# IDEA: On a high level, the cleanest solution for this would be to separate out the
# 1) definition and book-keeping of parameter constraints on the one hand, and
# 2) the re-parameterization of the variables with some monotonic transformation,
# since the two steps are orthogonal.


class NonTransformedInterval(Interval):
    """Modification of the GPyTorch interval class.

    The Interval class in GPyTorch will map the parameter to the range [0, 1] before
    applying the inverse transform. We don't want to do this when using log as an
    inverse transform. This class will not apply any transformations to the parameters
    and instead pass the bounds constraint to the scipy L-BFGS optimizer. Crucially,
    this allows for the occurance of exact zeros for sparse optimization algorithms.
    """

    def __init__(
        self,
        lower_bound: Union[float, Tensor],
        upper_bound: Union[float, Tensor],
        initial_value: Optional[Union[float, Tensor]] = None,
    ):
        super().__init__(
            lower_bound=lower_bound,
            upper_bound=upper_bound,
            transform=None,
            inv_transform=None,
            initial_value=initial_value,
        )

    def transform(self, tensor: Tensor) -> Tensor:
        return tensor

    def inverse_transform(self, transformed_tensor: Tensor) -> Tensor:
        return transformed_tensor


class NonTransformedGreaterThan(NonTransformedInterval):
    """Modification of the GPyTorch Positive class, with similar reasoning to the
    NonTransformedInterval class.
    """

    def __init__(
        self,
        lower_bound: Union[float, Tensor],
        initial_value: Optional[Union[float, Tensor]] = None,
    ):
        super().__init__(
            lower_bound=lower_bound,
            upper_bound=torch.inf,
            initial_value=initial_value,
        )


class NonTransformedNonNegative(NonTransformedGreaterThan):
    """Modification of the GPyTorch Positive class, with similar reasoning to the
    NonTransformedInterval class.
    """

    def __init__(self, initial_value: Optional[Union[float, Tensor]] = None):
        super().__init__(
            lower_bound=0.0,
            initial_value=initial_value,
        )


class LogTransformedInterval(Interval):
    """Modification of the GPyTorch interval class.

    The Interval class in GPyTorch will map the parameter to the range [0, 1] before
    applying the inverse transform. We don't want to do this when using log as an
    inverse transform. This class will skip this step and apply the log transform
    directly to the parameter values so we can optimize log(parameter) under the bound
    constraints log(lower) <= log(parameter) <= log(upper).
    """

    def __init__(self, lower_bound, upper_bound, initial_value=None):
        super().__init__(
            lower_bound=lower_bound,
            upper_bound=upper_bound,
            transform=torch.exp,
            inv_transform=torch.log,
            initial_value=initial_value,
        )

        # Save the untransformed initial value
        self.register_buffer(
            "initial_value_untransformed",
            (
                torch.tensor(initial_value).to(self.lower_bound)
                if initial_value is not None
                else None
            ),
        )

        if settings.debug.on():
            max_bound = torch.max(self.upper_bound)
            min_bound = torch.min(self.lower_bound)
            if max_bound == math.inf or min_bound == -math.inf:
                raise RuntimeError(
                    "Cannot make an Interval directly with non-finite bounds. Use a "
                    "derived class like GreaterThan or LessThan instead."
                )

    def transform(self, tensor):
        if not self.enforced:
            return tensor

        transformed_tensor = self._transform(tensor)
        return transformed_tensor

    def inverse_transform(self, transformed_tensor):
        if not self.enforced:
            return transformed_tensor

        tensor = self._inv_transform(transformed_tensor)
        return tensor
