from __future__ import annotations

import math
import warnings
from copy import deepcopy

from typing import Any, Callable, List, Optional, Tuple, Union

import torch
from botorch.acquisition.cached_cholesky import CachedCholeskyMCAcquisitionFunction
from botorch.acquisition.multi_objective import MultiObjectiveMCAcquisitionFunction
from botorch.acquisition.multi_objective.objective import MCMultiOutputObjective
from botorch.acquisition.multi_objective.utils import (
    prune_inferior_points_multi_objective,
)
from botorch.exceptions.errors import UnsupportedError
from botorch.exceptions.warnings import BotorchWarning
from botorch.models.model import Model
from botorch.sampling.base import MCSampler
from botorch.utils.multi_objective.box_decompositions.box_decomposition_list import (
    BoxDecompositionList,
)
from botorch.utils.multi_objective.box_decompositions.dominated import (
    DominatedPartitioning,
)
from botorch.utils.multi_objective.box_decompositions.non_dominated import (
    FastNondominatedPartitioning,
    NondominatedPartitioning,
)
from botorch.utils.multi_objective.box_decompositions.utils import (
    _pad_batch_pareto_frontier,
)
from botorch.utils.safe_math import logdiffexp
from botorch.utils.torch import BufferDict
from botorch.utils.transforms import (
    concatenate_pending_points,
    is_fully_bayesian,
    match_batch_shape,
    t_batch_mode_transform,
)
from .qlogei import (
    cauchy,
    compute_log_constraint_indicator,
    fatmax,
    log1pexp,
    log_fatplus,
    logmeanexp,
)
from torch import Tensor

TAU_RELU = 1e-3
TAU_MAX = 1e-1
STANDARDIZE_TAU = False


class qLogExpectedHypervolumeImprovement(MultiObjectiveMCAcquisitionFunction):
    def __init__(
        self,
        model: Model,
        ref_point: Union[List[float], Tensor],
        partitioning: NondominatedPartitioning,
        sampler: Optional[MCSampler] = None,
        objective: Optional[MCMultiOutputObjective] = None,
        constraints: Optional[List[Callable[[Tensor], Tensor]]] = None,
        X_pending: Optional[Tensor] = None,
        eta: Optional[Union[Tensor, float]] = 1e-2,
        tau_relu: float = TAU_RELU,
        tau_max: float = TAU_MAX,
        standardize_tau_relu: bool = STANDARDIZE_TAU,
        standardize_tau_max: bool = STANDARDIZE_TAU,
    ) -> None:
        r"""q-Expected Hypervolume Improvement supporting m>=2 outcomes.

        See [Daulton2020qehvi]_ for details.

        Example:
            >>> model = SingleTaskGP(train_X, train_Y)
            >>> ref_point = [0.0, 0.0]
            >>> qEHVI = qExpectedHypervolumeImprovement(model, ref_point, partitioning)
            >>> qehvi = qEHVI(test_X)

        Args:
            model: A fitted model.
            ref_point: A list or tensor with `m` elements representing the reference
                point (in the outcome space) w.r.t. to which compute the hypervolume.
                This is a reference point for the objective values (i.e. after
                applying`objective` to the samples).
            partitioning: A `NondominatedPartitioning` module that provides the non-
                dominated front and a partitioning of the non-dominated space in hyper-
                rectangles. If constraints are present, this partitioning must only
                include feasible points.
            sampler: The sampler used to draw base samples. If not given,
                a sampler is generated using `get_sampler`.
            objective: The MCMultiOutputObjective under which the samples are evaluated.
                Defaults to `IdentityMultiOutputObjective()`.
            constraints: A list of callables, each mapping a Tensor of dimension
                `sample_shape x batch-shape x q x m` to a Tensor of dimension
                `sample_shape x batch-shape x q`, where negative values imply
                feasibility. The acqusition function will compute expected feasible
                hypervolume.
            X_pending: A `batch_shape x m x d`-dim Tensor of `m` design points that have
                points that have been submitted for function evaluation but have not yet
                been evaluated. Concatenated into `X` upon forward call. Copied and set
                to have no gradient.
            eta: The temperature parameter for the sigmoid function used for the
                differentiable approximation of the constraints. In case of a float the
                same eta is used for every constraint in constraints. In case of a
                tensor the length of the tensor must match the number of provided
                constraints. The i-th constraint is then estimated with the i-th
                eta value.
            standardize: Whether to standardize the temperature parameters.
        """
        if len(ref_point) != partitioning.num_outcomes:
            raise ValueError(
                "The length of the reference point must match the number of outcomes. "
                f"Got ref_point with {len(ref_point)} elements, but expected "
                f"{partitioning.num_outcomes}."
            )
        ref_point = torch.as_tensor(
            ref_point,
            dtype=partitioning.pareto_Y.dtype,
            device=partitioning.pareto_Y.device,
        )
        super().__init__(
            model=model,
            sampler=sampler,
            objective=objective,
            constraints=constraints,
            eta=eta,
            X_pending=X_pending,
        )
        self.register_buffer("ref_point", ref_point)
        cell_bounds = partitioning.get_hypercell_bounds()
        self.register_buffer("cell_lower_bounds", cell_bounds[0])
        self.register_buffer("cell_upper_bounds", cell_bounds[1])
        self.q_out = -1
        self.q_subset_indices = BufferDict()
        self._tau_relu = tau_relu
        self.tau_max = tau_max
        self.standardize_tau_relu = standardize_tau_relu
        self.standardize_tau_max = standardize_tau_max

    def tau_relu(self, Y: Tensor) -> Tensor:
        tau = self._tau_relu
        if self.standardize_tau_relu:  # scaling standard deviation over mc dimension
            tau = tau * Y.std(dim=0, keepdim=True).clamp_max(1)
        return tau

    def _cache_q_subset_indices(self, q_out: int) -> None:
        r"""Cache indices corresponding to all subsets of `q_out`.

        This means that consecutive calls to `forward` with the same
        `q_out` will not recompute the indices for all (2^q_out - 1) subsets.

        Note: this will use more memory than regenerating the indices
        for each i and then deleting them, but it will be faster for
        repeated evaluations (e.g. during optimization).

        Args:
            q_out: The batch size of the objectives. This is typically equal
                to the q-batch size of `X`. However, if using a set valued
                objective (e.g., MVaR) that produces `s` objective values for
                each point on the q-batch of `X`, we need to properly account
                for each objective while calculating the hypervolume contributions
                by using `q_out = q * s`.
        """
        if q_out != self.q_out:
            tkwargs = {"dtype": torch.long, "device": self.ref_point.device}
            indices = torch.arange(q_out, **tkwargs)
            self.q_subset_indices = BufferDict(
                {
                    f"q_choose_{i}": torch.combinations(indices, i)
                    for i in range(1, q_out + 1)
                }
            )
            self.q_out = q_out

    def _compute_log_qehvi(self, samples: Tensor, X: Optional[Tensor] = None) -> Tensor:
        r"""Compute the expected (feasible) hypervolume improvement given MC samples.

        Args:
            samples: A `n_samples x batch_shape x q' x m`-dim tensor of samples.
            X: A `batch_shape x q x d`-dim tensor of inputs.

        Returns:
            A `batch_shape x (model_batch_shape)`-dim tensor of expected hypervolume
            improvement for each batch.
        """
        # Note that the objective may subset the outcomes (e.g. this will usually happen
        # if there are constraints present).
        obj = self.objective(samples, X=X)  # mc_samples x batch_shape x q x m
        q = obj.shape[-2]
        if self.constraints is not None:
            log_feas_weights = compute_log_constraint_indicator(
                constraints=self.constraints,
                samples=samples,
                eta=self.eta,
                fatten=True,
            )
        self._cache_q_subset_indices(q_out=q)
        batch_shape = obj.shape[:-2]  # mc_samples x batch_shape
        # areas tensor is `mc_samples x batch_shape x num_cells x 2`-dim
        log_areas_per_segment = torch.full(
            size=(
                *batch_shape,
                self.cell_lower_bounds.shape[-2],  # num_cells
                2,  # for even and odd terms
            ),
            fill_value=-torch.inf,
            dtype=obj.dtype,
            device=obj.device,
        )

        cell_batch_ndim = self.cell_lower_bounds.ndim - 2
        # conditionally adding mc_samples dim if cell_batch_ndim > 0
        # adding ones to shape equal in number to to batch_shape_ndim - cell_batch_ndim
        # adding cell_bounds batch shape w/o 1st dimension
        sample_batch_view_shape = torch.Size(
            [
                batch_shape[0] if cell_batch_ndim > 0 else 1,
                *[1 for _ in range(len(batch_shape) - max(cell_batch_ndim, 1))],
                *self.cell_lower_bounds.shape[1:-2],
            ]
        )
        view_shape = (
            *sample_batch_view_shape,
            self.cell_upper_bounds.shape[-2],  # num_cells
            1,  # adding for q_choose_i dimension
            self.cell_upper_bounds.shape[-1],  # num_objectives
        )

        for i in range(1, self.q_out + 1):
            # TODO: we could use batches to compute (q choose i) and (q choose q-i)
            # simultaneously since subsets of size i and q-i have the same number of
            # elements. This would decrease the number of iterations, but increase
            # memory usage.
            q_choose_i = self.q_subset_indices[f"q_choose_{i}"]  # q_choose_i x i
            # this tensor is mc_samples x batch_shape x i x q_choose_i x m
            obj_subsets = obj.index_select(dim=-2, index=q_choose_i.view(-1))
            obj_subsets = obj_subsets.view(
                obj.shape[:-2] + q_choose_i.shape + obj.shape[-1:]
            )  # mc_samples x batch_shape x q_choose_i x i x m

            # NOTE: the order of operations in non-log _compute_qehvi is 3), 1), 2).
            # since 3) moved above 1), _log_improvement adds another Tensor dimension
            # that keeps track of num_cells.

            # 1) computes log smoothed improvement over the cell lower bounds.
            # mc_samples x batch_shape x num_cells x q_choose_i x i x m
            log_improvement_i = self._log_improvement(obj_subsets, view_shape)

            # 2) take the minimum log improvement over all i subsets.
            # since all hyperrectangles share one vertex, the opposite vertex of the
            # overlap is given by the component-wise minimum.
            # negative of maximum of negative log_improvement is approximation to min.
            log_improvement_i = fatmin(
                log_improvement_i,
                dim=-2,
                tau=self.tau_max,
                standardize=self.standardize_tau_max,
            )  # mc_samples x batch_shape x num_cells x q_choose_i x m

            # 3) compute the log lengths of the cells' sides.
            # mc_samples x batch_shape x num_cells x q_choose_i x m
            log_lengths_i = self._log_cell_lengths(log_improvement_i, view_shape)
            # 4) take product over hyperrectangle side lengths to compute area (m-dim).
            # after, log_areas_i is mc_samples x batch_shape x num_cells x q_choose_i
            log_areas_i = log_lengths_i.sum(dim=-1)  # = lengths_i.prod(dim=-1)

            # 5) if constraints are present, apply a differentiable approximation of
            # the indicator function.
            if self.constraints is not None:
                log_feas_subsets = log_feas_weights.index_select(
                    dim=-1, index=q_choose_i.view(-1)
                ).view(log_feas_weights.shape[:-1] + q_choose_i.shape)
                log_areas_i = log_areas_i + log_feas_subsets.unsqueeze(-3).sum(dim=-1)

            # 6) sum over all subsets of size i, i.e. reduce over q_choose_i-dim
            # after, log_areas_i is mc_samples x batch_shape x num_cells
            log_areas_i = torch.logsumexp(log_areas_i, dim=-1)  #  areas_i.sum(dim=-1)

            # 7) Using the inclusion-exclusion principle, set the sign to be positive
            # for subsets of odd sizes and negative for subsets of even size
            # in non-log space: areas_per_segment += (-1) ** (i + 1) * areas_i,
            # but here in log space, we need to keep track of sign:
            log_areas_per_segment[..., i % 2] = logplusexp(
                log_areas_per_segment[..., i % 2],
                log_areas_i,
            )

        # 8) subtract even from odd log area terms
        log_areas_per_segment = logdiffexp(
            log_a=log_areas_per_segment[..., 0], log_b=log_areas_per_segment[..., 1]
        )

        # 9) sum over segments (n_cells-dim) and average over MC samples
        return logmeanexp(torch.logsumexp(log_areas_per_segment, dim=-1), dim=0)

    def _log_improvement(
        self, obj_subsets: Tensor, view_shape: Union[Tuple, torch.Size]
    ) -> Tensor:
        # smooth out the clamp and take the log (previous step 3)
        # substract cell lower bounds, clamp min at zero, but first
        # make obj_subsets broadcastable with cell bounds:
        # mc_samples x batch_shape x (num_cells = 1) x q_choose_i x i x m
        obj_subsets = obj_subsets.unsqueeze(-4)
        # making cell bounds broadcastable with obj_subsets:
        # (mc_samples = 1) x (batch_shape = 1) x num_cells x 1 x (i = 1) x m
        cell_lower_bounds = self.cell_lower_bounds.view(view_shape).unsqueeze(-3)
        Z = obj_subsets - cell_lower_bounds
        log_Zi = log_fatplus(Z, tau=self.tau_relu(Z))
        return log_Zi  # mc_samples x batch_shape x num_cells x q_choose_i x i x m

    def _log_cell_lengths(
        self, log_improvement_i: Tensor, view_shape: Union[Tuple, torch.Size]
    ) -> Tensor:
        cell_upper_bounds = self.cell_upper_bounds.clamp_max(
            1e10 if log_improvement_i.dtype == torch.double else 1e8
        )  # num_cells x num_objectives
        # add batch-dim to compute area for each segment (pseudo-pareto-vertex)
        log_cell_lengths = (
            (cell_upper_bounds - self.cell_lower_bounds).log().view(view_shape)
        )  # (mc_samples = 1) x (batch_shape = 1) x n_cells x (q_choose_i = 1) x m
        # mc_samples x batch_shape x num_cells x q_choose_i x m
        return fatminimum(
            log_improvement_i,
            log_cell_lengths,
            tau=self.tau_max,
            standardize=self.standardize_tau_max,
        )

    @concatenate_pending_points
    @t_batch_mode_transform()
    def forward(self, X: Tensor) -> Tensor:
        posterior = self.model.posterior(X)
        samples = self.get_posterior_samples(posterior)
        return self._compute_log_qehvi(samples=samples, X=X)


############################################### utils #################################################
def logplusexp(a: Tensor, b: Tensor) -> Tensor:
    """Computes log(exp(a) + exp(b)) similar to logsumexp."""
    rev_cond = b < a  # condition for reversal of inputs
    if rev_cond.any():
        c = torch.where(rev_cond, b, a)
        b = torch.where(rev_cond, a, b)
        a = c  # after we updated b, can assign c to a
    return b + log1pexp(a - b)  # reversal ensures we get highest relative accuracy


def fatmaximum(
    a: Tensor, b: Tensor, tau: Union[float, Tensor] = 1.0, standardize: bool = False
) -> Tensor:
    """Computes a smooth approximation to torch.minimum(a, b) with a fat tail."""
    rev_cond = b < a  # condition for reversal of inputs
    if rev_cond.any():
        c = torch.where(rev_cond, b, a)
        b = torch.where(rev_cond, a, b)
        a = c  # after we updated b, can assign c to a
    # we now have all(a <= b), so a - b <= 0
    amb = a - b
    if standardize:
        tau = tau * (amb.abs() / math.sqrt(2)).clamp_max(1)
    is_inf = b.isinf()
    amb = amb.masked_fill(is_inf, 0)
    return torch.where(
        is_inf,
        b,  # can't compute gradients when b is inf anyway.
        b + tau * cauchy(amb / tau).log(),
    )


def fatminimum(
    a: Tensor, b: Tensor, tau: Union[float, Tensor] = 1.0, standardize: bool = False
) -> Tensor:
    return -fatmaximum(-a, -b, tau=tau, standardize=standardize)


def fatmin(
    x: Tensor, dim: int, tau: Union[float, Tensor] = 1.0, standardize: bool = False
) -> Tensor:
    return -fatmax(x=-x, dim=dim, tau=tau, standardize=standardize)
