import math
from typing import Callable, Iterable, List, Optional, Tuple, Union

import torch
from botorch.test_functions.base import BaseTestProblem
from botorch.test_functions.synthetic import SyntheticTestFunction
from torch import Tensor
from torch.distributions import StudentT


################################ Experimental utilities
class CorruptedTestProblem(BaseTestProblem):
    def __init__(
        self,
        base_test_problem: BaseTestProblem,
        outlier_generator: Callable[[BaseTestProblem, Tensor, Tensor], Tensor],
        outlier_fraction: float,
        bounds: Optional[List[Tuple[float, float]]] = None,
        seeds: Optional[Union[Iterable[int], List[int]]] = None,
    ) -> None:
        """A problem with outliers.

        NOTE: Both noise_std and negate will be taken from the base test problem.

        Args:
            base_test_problem: The base function to be corrupted.
            outlier_generator: A function that generates outliers. It will be called
                with arguments `base_function`, `X`, and `bounds`, where `X` is the
                argument passed to the `forward` method and `base_function` and `bounds`
                are as here, and it returns the values of outliers.
            outlier_fraction: The fraction of outliers.
            bounds: The bounds of the function.
            seeds: The seeds to use for the outlier generator. If seeds are provided,
                the problem will iterate through the list of seeds, changing the seed
                with a call to `next(seeds)` with every `forward` call. If a list is
                provided, it will first be converted to an iterator.
        """
        self.dim = base_test_problem.dim
        self._bounds = bounds if bounds is not None else base_test_problem._bounds
        super().__init__(
            noise_std=base_test_problem.noise_std,
            negate=base_test_problem.negate,
        )
        self.base_test_problem = base_test_problem
        self.outlier_generator = outlier_generator
        self.outlier_fraction = outlier_fraction
        self._current_seed: Optional[int] = None
        self._seeds = None if seeds is None else iter(seeds)

    @property
    def _optimal_value(self) -> float:
        return self.base_test_problem._optimal_value

    def evaluate_true(self, X: Tensor) -> Tensor:
        return self.base_test_problem.evaluate_true(X)

    @property
    def has_seeds(self) -> bool:
        return self._seeds is not None

    def increment_seed(self) -> int:
        self._current_seed = next(self._seeds)
        return self._current_seed  # pyre-ignore

    @property
    def seed(self) -> Optional[int]:
        return self._current_seed

    def forward(self, X: Tensor, noise: bool = True):
        Y = super().forward(X, noise=noise)
        if noise:
            if self.has_seeds:
                self.increment_seed()
                torch.manual_seed(self.seed)
            corrupt = torch.rand(X.shape[:-1]) < self.outlier_fraction
            Y = torch.where(
                corrupt,
                self.outlier_generator(  # pyre-ignore
                    f=self.base_test_problem, X=X, bounds=self.bounds
                ),
                Y,
            )
        return Y


def constant_outlier_generator(f, X, bounds, constant: float):
    return torch.full(X.shape[:-1], constant, dtype=X.dtype, device=X.device)


def uniform_input_corruption(
    f: BaseTestProblem,
    X: Tensor,
    bounds: Tensor,
) -> Tensor:
    """A outlier generator function that generates outliers by uniformly sampling
    inputs in the domain of the function given by bounds. This ensures that the
    outliers are embedded within the range of the uncorrupted function, thereby
    making it difficult to detect them.

    Args:
        f: The function to be corrupted.
        X: A tensor whose shape to match when generating the uniformly sampled inputs.
        bounds: The bounds of the function.

    Returns:
        A tensor of outliers of the same shape as X.
    """
    a, b = bounds
    R = (b - a) * torch.rand_like(X) + a
    return f(R)


def normal_outlier_corruption(f, X, bounds, noise_std: float) -> Tensor:
    assert noise_std > 0
    Y = f(X)
    return Y + torch.randn_like(Y) * noise_std


def uniform_corruption(
    f: BaseTestProblem, X: Tensor, bounds: Tensor, lower: float, upper: float
) -> Tensor:
    assert lower < upper
    return lower + (upper - lower) * torch.rand_like(f(X))


def student_t_corruption(
    f: BaseTestProblem, X: Tensor, bounds: Tensor, df: float, scale: float
) -> Tensor:
    Y = f(X)
    dis = StudentT(
        df=torch.as_tensor(df, dtype=Y.dtype, device=Y.device),
        loc=torch.tensor(0.0, dtype=Y.dtype, device=Y.device),
        scale=torch.as_tensor(scale, dtype=Y.dtype, device=Y.device),
    )
    return Y + dis.sample(Y.shape).to(Y)


class Friedman(SyntheticTestFunction):

    def __init__(
        self,
        dim: int = 5,
        noise_std: Optional[float] = None,
        negate: bool = False,
        bounds: Optional[List[Tuple[float, float]]] = None,
    ) -> None:
        r"""
        Args:
            dim: The (input) dimension. Should be at least 5. If more than 5, the last
                dim - 5 dimensions will be ignored, similar to the 10d experiments in
                "Robust Regression with Twinned Gaussian Processes".
            noise_std: Standard deviation of the observation noise.
            negate: If True, negate the function.
            bounds: Custom bounds for the function specified as (lower, upper) pairs.
        """
        self.dim = dim
        if bounds is None:
            bounds = [(0, 1) for _ in range(self.dim)]
        super().__init__(noise_std=noise_std, negate=negate, bounds=bounds)

    def evaluate_true(self, X: Tensor) -> Tensor:
        return (
            torch.sin(torch.pi * X[:, 0] * X[:, 1]) * 10.0
            + 20.0 * torch.square(X[:, 2] - 0.5)
            + 10.0 * X[:, 3]
            + 5.0 * X[:, 4]
        )


class Bow(SyntheticTestFunction):

    def __init__(
        self,
        dim: int = 1,
        noise_std: Optional[float] = None,
        negate: bool = False,
        bounds: Optional[List[Tuple[float, float]]] = None,
    ) -> None:
        """Adapted from Andrade & Takeda (same function)"""
        self.dim = dim
        if bounds is None:
            bounds = [(0, 1) for _ in range(self.dim)]
        super().__init__(noise_std=noise_std, negate=negate, bounds=bounds)

    def evaluate_true(self, X: Tensor) -> Tensor:
        assert X.shape[-1] == 1
        sqrt_12 = math.sqrt(12.0)
        # Copied from here:
        # https://github.com/andrade-stats/TrimmedMarginalLikelihoodGP/blob/f810d4384454682a758fb79fd7352c4b3d4db7f4/simDataGeneration.py#L226
        X = (X - 0.5) * sqrt_12
        # Copied from here:
        # https://github.com/andrade-stats/TrimmedMarginalLikelihoodGP/blob/f810d4384454682a758fb79fd7352c4b3d4db7f4/simDataGeneration.py#L226https://github.com/andrade-stats/TrimmedMarginalLikelihoodGP/blob/f810d4384454682a758fb79fd7352c4b3d4db7f4/simDataGeneration.py#L226
        Y = 3.235 * (torch.sin(((X / sqrt_12) + 0.5) * torch.pi)) - 2.058
        return Y.squeeze(-1)


class SumOfSines(SyntheticTestFunction):

    def __init__(
        self,
        dim: int = 1,
        noise_std: Optional[float] = None,
        negate: bool = False,
        bounds: Optional[List[Tuple[float, float]]] = None,
        frequency: float = 2.0,
    ) -> None:
        self.dim = dim
        if bounds is None:
            bounds = [(0, 1) for _ in range(self.dim)]
        self.frequency = frequency
        super().__init__(noise_std=noise_std, negate=negate, bounds=bounds)

    def evaluate_true(self, X: Tensor) -> Tensor:
        return (2 * torch.pi * self.frequency * X).sin().sum(dim=-1)
