# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.


"""
The hypervolume knowledge gradient acquisition function (HVKG).

References:

.. [Daulton2023hvkg]
    S. Daulton, M. Balandat, E. Bakshy. Hypervolume Knowledge Gradient: A
    Lookahead Approach for Multi-Objective Bayesian Optimization with Partial
    Information. Proceedings of the 40th International Conference on Machine
    Learning, 2023.
"""

import warnings
from collections.abc import Callable
from copy import deepcopy
from typing import Any

import torch
from torch import Tensor
from botorch import settings
from botorch.acquisition.acquisition import (
    AcquisitionFunction,
    OneShotAcquisitionFunction,
)

from botorch.acquisition.cost_aware import CostAwareUtility
from botorch.acquisition.knowledge_gradient import ProjectedAcquisitionFunction
from botorch.acquisition.multi_objective.base import MultiObjectiveMCAcquisitionFunction
from botorch.acquisition.multi_objective.monte_carlo import (
    qExpectedHypervolumeImprovement,
)
from botorch.acquisition.multi_objective.objective import MCMultiOutputObjective
from botorch.exceptions.errors import UnsupportedError
from botorch.exceptions.warnings import NumericsWarning
from botorch.models.deterministic import PosteriorMeanModel
from botorch.sampling.base import MCSampler
from botorch.sampling.list_sampler import ListSampler
from botorch.sampling.normal import SobolQMCNormalSampler
from botorch.sampling.stochastic_samplers import StochasticSampler
from botorch.utils.multi_objective.box_decompositions.non_dominated import (
    FastNondominatedPartitioning,
)
from botorch.utils.transforms import (
    average_over_ensemble_models,
    t_batch_mode_transform,
)

from rescue.models.causal_gp.multitask import CausalMultitaskGP
from rescue.models.causal_gp.multitask_multifidelity import CausalMultitaskMultifidelityGP
from rescue.models.causal_gp.fantasy_model import fantasize

class qHypervolumeKnowledgeGradient(
    MultiObjectiveMCAcquisitionFunction,
    OneShotAcquisitionFunction,
):
    """Batch Hypervolume Knowledge Gradient using one-shot optimization.

    The hypervolume knowledge gradient seeks to maximize the difference in
    hypervolume of the hypervolume-maximizing set of a fixed size after
    conditioning the unknown observation(s) that would be recevied if X where
    evalauted. See [Daulton2023hvkg]_ for details.

    This computes the batch Hypervolume Knowledge Gradient using fantasies for
    the outer expectation and MC-sampling for the inner expectation.

    In addition to the design variables, the input `X` also includes variables
    for the optimal designs for each of the fantasy models (Note this is
    `N x N_pareto` optimal designs). For a fixed number of fantasies, all points
    in `X` can be optimized in a "one-shot" fashion.
    """

    def __init__(
        self,
        model: CausalMultitaskGP | CausalMultitaskMultifidelityGP,
        ref_point: Tensor,
        num_fantasies: int = 8,
        num_pareto: int = 10,
        sampler: ListSampler | None = None,
        objective: MCMultiOutputObjective | None = None,
        constraints: list[Callable[[Tensor], Tensor]] | None = None,
        inner_sampler: MCSampler | None = None,
        current_value: Tensor | None = None,
        use_posterior_mean: bool = True,
        cost_aware_utility: CostAwareUtility | None = None,
    ) -> None:
        r"""
        Initialize qHypervolumeKnowledgeGradient.

        Args:
            model (CausalMultitaskGP | CausalMultitaskMultifidelityGP): 
                The GP model.
            ref_point (Tensor): Reference point for hypervolume.
            num_fantasies (int): Number of fantasy samples.
            num_pareto (int): Number of Pareto optimal points.
            sampler (ListSampler | None): MC sampler for fantasies.
            objective (MCMultiOutputObjective | None): Objective.
            constraints (list[Callable[[Tensor], Tensor]] | None): 
                Constraints.
            inner_sampler (MCSampler | None): Sampler for inner 
                acquisition.
            current_value (Tensor | None): Current value.
            use_posterior_mean (bool): Whether to use posterior mean.
            cost_aware_utility (CostAwareUtility | None): Cost-aware 
                utility.
        """
        if sampler is None:
            sampler = SobolQMCNormalSampler(sample_shape=torch.Size([num_fantasies]))
        else:
            sample_shape = sampler.samplers[0].sample_shape
            if sample_shape != torch.Size([num_fantasies]):
                raise ValueError(
                    f"The sampler shape must match num_fantasies={num_fantasies}."
                )
        super().__init__(model=model)

        if inner_sampler is None:
            inner_sampler = SobolQMCNormalSampler(sample_shape=torch.Size([32]))
        if current_value is None and cost_aware_utility is not None:
            raise UnsupportedError(
                "Cost-aware HVKG requires current_value to be specified."
            )
        self.register_buffer("ref_point", ref_point)
        self.sampler = sampler
        self.objective = objective
        self.constraints = constraints
        self.inner_sampler = inner_sampler
        self.num_fantasies = num_fantasies
        self.num_pareto = num_pareto
        self.num_pseudo_points = num_fantasies * num_pareto
        self.current_value = current_value
        self.use_posterior_mean = use_posterior_mean
        self.cost_aware_utility = cost_aware_utility
        self._cost_sampler = None

    @property
    def cost_sampler(self):
        if self._cost_sampler is None:
            # Note: Using the deepcopy here is essential. Removing this poses a
            # problem if the base model and the cost model have a different number
            # of outputs or test points (this would be caused by expand), as this
            # would trigger re-sampling the base samples in the fantasy sampler.
            # By cloning the sampler here, the right thing will happen if the
            # the sizes are compatible, if they are not this will result in
            # samples being drawn using different base samples, but it will at
            # least avoid changing state of the fantasy sampler.
            self._cost_sampler = deepcopy(self.sampler)
        return self._cost_sampler

    @t_batch_mode_transform()
    @average_over_ensemble_models
    def forward(self, X: Tensor) -> Tensor:
        r"""Evaluate qKnowledgeGradient on the candidate set `X`.

        Args:
            X: A `b x (q + num_fantasies) x d` Tensor with `b` t-batches of
                `q + num_fantasies` design points each. We split this X tensor
                into two parts in the `q` dimension (`dim=-2`). The first `q`
                are the q-batch of design points and the last num_fantasies are
                the current solutions of the inner optimization problem.

                `X_fantasies = X[..., -num_fantasies:, :]`
                `X_fantasies.shape = b x num_fantasies x d`

                `X_actual = X[..., :-num_fantasies, :]`
                `X_actual.shape = b x q x d`

        Returns:
            A Tensor of shape `b`. For t-batch b, the q-KG value of the design
                `X_actual[b]` is averaged across the fantasy models, where
                `X_fantasies[b, i]` is chosen as the final selection for the
                `i`-th fantasy model.
                NOTE: If `current_value` is not provided, then this is not the
                true KG value of `X_actual[b]`, and `X_fantasies[b, : ]` must be
                maximized at fixed `X_actual[b]`.
        """
        X_actual, X_fantasies = _split_hvkg_fantasy_points(
            X=X, n_f=self.num_fantasies, num_pareto=self.num_pareto
        )
        q = X_actual.shape[-2]

        fantasy_model = fantasize(
            model=self.model,
            X=X_actual,
            sampler=self.sampler,
        )

        # get the value function
        value_function = _get_hv_value_function(
            model=fantasy_model,
            ref_point=self.ref_point,
            objective=self.objective,
            constraints=self.constraints,
            sampler=self.inner_sampler,
            use_posterior_mean=self.use_posterior_mean,
        )

        # make sure to propagate gradients to the fantasy model train inputs
        with settings.propagate_grads(True):
            # X_fantasies is num_pseudo_points x batch_shape x 1 x d
            # Reshape it into num_fantasies x batch_shape x num_pareto x d
            shape = torch.Size(
                [
                    self.num_fantasies,
                    *X_fantasies.shape[1:-2],
                    self.num_pareto,
                    X_fantasies.shape[-1],
                ]
            )
            values = value_function(X=X_fantasies.reshape(shape))  # num_fantasies x b

        if self.current_value is not None:
            values = values - self.current_value

        if self.cost_aware_utility is not None:
            values = self.cost_aware_utility(
                X=X_actual[..., :q, :],
                deltas=values,
                sampler=self.cost_sampler,
            )

        # return average over the fantasy samples
        return values.mean(dim=0)

    def get_augmented_q_batch_size(self, q: int) -> int:
        r"""Get augmented q batch size for one-shot optimization.

        Args:
            q: The number of candidates to consider jointly.

        Returns:
            The augmented size for one-shot optimization (including variables
            parameterizing the fantasy solutions).
        """
        return q + self.num_pseudo_points

    def extract_candidates(self, X_full: Tensor) -> Tensor:
        r"""We only return X as the set of candidates post-optimization.

        Args:
            X_full: A `b x (q + num_fantasies) x d`-dim Tensor with `b`
                t-batches of `q + num_fantasies` design points each.

        Returns:
            A `b x q x d`-dim Tensor with `b` t-batches of `q` design points each.
        """
        return X_full[..., : -self.num_pseudo_points, :]


class qMultiFidelityHypervolumeKnowledgeGradient(qHypervolumeKnowledgeGradient):
    r"""Batch Hypervolume Knowledge Gradient for multi-fidelity optimization.

    See [Daulton2023hvkg]_ for details.

    A version of `qHypervolumeKnowledgeGradient` that supports multi-fidelity
    optimization via a `CostAwareUtility` and the `project` and `expand`
    operators. If none of these are set, this acquisition function reduces to
    `qHypervolumeKnowledgeGradient`. Through `valfunc_cls` and `valfunc_argfac`,
    this can be changed into a custom multi-fidelity acquisition function.
    """

    def __init__(
        self,
        model: CausalMultitaskGP | CausalMultitaskMultifidelityGP,
        ref_point: Tensor,
        target_fidelities: dict[int, float],
        num_fantasies: int = 8,
        num_pareto: int = 10,
        sampler: MCSampler | None = None,
        objective: MCMultiOutputObjective | None = None,
        constraints: list[Callable[[Tensor], Tensor]] | None = None,
        inner_sampler: MCSampler | None = None,
        current_value: Tensor | None = None,
        cost_aware_utility: CostAwareUtility | None = None,
        project: Callable[[Tensor], Tensor] = lambda X: X,
        use_posterior_mean: bool = True,
        **kwargs: Any,
    ) -> None:
        r"""
        Initialize qMultiFidelityHypervolumeKnowledgeGradient.

        Args:
            model (CausalMultitaskGP | CausalMultitaskMultifidelityGP): 
                The GP model.
            ref_point (Tensor): Reference point for hypervolume.
            target_fidelities (dict[int, float]): Target fidelity 
                values.
            num_fantasies (int): Number of fantasy samples.
            num_pareto (int): Number of Pareto optimal points.
            sampler (MCSampler | None): MC sampler for fantasies.
            objective (MCMultiOutputObjective | None): Objective.
            constraints (list[Callable[[Tensor], Tensor]] | None): 
                Constraints.
            inner_sampler (MCSampler | None): Sampler for inner 
                acquisition.
            current_value (Tensor | None): Current value.
            cost_aware_utility (CostAwareUtility | None): Cost-aware 
                utility.
            project (Callable[[Tensor], Tensor]): Projection function.
            use_posterior_mean (bool): Whether to use posterior mean.
            **kwargs: Additional keyword arguments.
        """

        super().__init__(
            model=model,
            ref_point=ref_point,
            num_fantasies=num_fantasies,
            num_pareto=num_pareto,
            sampler=sampler,
            objective=objective,
            constraints=constraints,
            inner_sampler=inner_sampler,
            current_value=current_value,
            use_posterior_mean=use_posterior_mean,
            cost_aware_utility=cost_aware_utility,
        )
        self.project = project
        if kwargs.get("expand") is not None:
            raise NotImplementedError(
                "Trace observations are not currently supported "
                "by `qMultiFidelityHypervolumeKnowledgeGradient`."
            )
        self.expand = lambda X: X
        self.target_fidelities = target_fidelities

    @t_batch_mode_transform()
    @average_over_ensemble_models
    def forward(self, X: Tensor) -> Tensor:
        r"""Evaluate qMultiFidelityKnowledgeGradient on the candidate set `X`.

        Args:
            X: A `b x (q + num_fantasies) x d` Tensor with `b` t-batches of
                `q + num_fantasies` design points each. We split this X tensor
                into two parts in the `q` dimension (`dim=-2`). The first `q`
                are the q-batch of design points and the last num_fantasies are
                the current solutions of the inner optimization problem.

                `X_fantasies = X[..., -num_fantasies:, :]`
                `X_fantasies.shape = b x num_fantasies x d`

                `X_actual = X[..., :-num_fantasies, :]`
                `X_actual.shape = b x q x d`

                In addition, `X` may be augmented with fidelity parameteres as
                part of thee `d`-dimension. Projecting fidelities to the target
                fidelity is handled by `project`.

        Returns:
            A Tensor of shape `b`. For t-batch b, the q-KG value of the design
                `X_actual[b]` is averaged across the fantasy models, where
                `X_fantasies[b, i]` is chosen as the final selection for the
                `i`-th fantasy model.
                NOTE: If `current_value` is not provided, then this is not the
                true KG value of `X_actual[b]`, and `X_fantasies[b, : ]` must be
                maximized at fixed `X_actual[b]`.
        """
        X_actual, X_fantasies = _split_hvkg_fantasy_points(
            X=X, n_f=self.num_fantasies, num_pareto=self.num_pareto
        )
        q = X_actual.shape[-2]

        fantasy_model = fantasize(
            model=self.model,
            X=X_actual,
            sampler=self.sampler,
        )
        # get the value function
        value_function = _get_hv_value_function(
            model=fantasy_model,
            ref_point=self.ref_point,
            objective=self.objective,
            constraints=self.constraints,
            sampler=self.inner_sampler,
            project=self.project,
            use_posterior_mean=self.use_posterior_mean,
        )

        # make sure to propagate gradients to the fantasy model train inputs
        with settings.propagate_grads(True):
            # X_fantasies is num_pseudo_points  x batch_shape x 1 x d
            # Reshape it into num_fantasies x batch_shape x num_pareto x d
            shape = torch.Size(
                [
                    self.num_fantasies,
                    *X_fantasies.shape[1:-2],
                    self.num_pareto,
                    X_fantasies.shape[-1],
                ]
            )
            values = value_function(X=X_fantasies.reshape(shape))  # num_fantasies x b
        if self.current_value is not None:
            values = values - self.current_value
        if self.cost_aware_utility is not None:
            values = self.cost_aware_utility(
                X=X_actual[..., :q, :],
                deltas=values,
                sampler=self.cost_sampler,
            )
        # return average over the fantasy samples
        return values.mean(dim=0)


def _get_hv_value_function(
    model: CausalMultitaskGP | CausalMultitaskMultifidelityGP,
    ref_point: Tensor,
    objective: MCMultiOutputObjective | None = None,
    constraints: list[Callable[[Tensor], Tensor]] | None = None,
    sampler: MCSampler | None = None,
    project: Callable[[Tensor], Tensor] | None = None,
    use_posterior_mean: bool = False,
) -> AcquisitionFunction:
    r"""Construct value function (i.e. inner acquisition function).

    Args:
        model (CausalMultitaskGP | CausalMultitaskMultifidelityGP): 
            The GP model.
        ref_point (Tensor): Reference point for hypervolume.
        objective (MCMultiOutputObjective | None): Objective.
        constraints (list[Callable[[Tensor], Tensor]] | None): 
            Constraints.
        sampler (MCSampler | None): MC sampler.
        project (Callable[[Tensor], Tensor] | None): Projection 
            function.
        use_posterior_mean (bool): Whether to use posterior mean.

    Returns:
        AcquisitionFunction: The value function.

    Note:
        This is a method for computing hypervolume.
    """
    if use_posterior_mean:
        model = PosteriorMeanModel(model=model)
        sampler = StochasticSampler(sample_shape=torch.Size([1]))  # dummy sampler
    with warnings.catch_warnings():
        warnings.filterwarnings(
            message="qExpectedHypervolumeImprovement has known",
            action="ignore",
            category=NumericsWarning,
        )
        base_value_function = qExpectedHypervolumeImprovement(
            model=model,
            ref_point=ref_point,
            partitioning=FastNondominatedPartitioning(
                ref_point=ref_point,
                Y=torch.empty(
                    (0, ref_point.shape[0]),
                    dtype=ref_point.dtype,
                    device=ref_point.device,
                ),
            ),  # create empty partitioning
            sampler=sampler,
            objective=objective,
            constraints=constraints,
        )
    # ProjectedAcquisitionFunction requires this
    base_value_function.posterior_transform = None

    if project is None:
        return base_value_function
    else:
        return ProjectedAcquisitionFunction(
            base_value_function=base_value_function,
            project=project,
        )


def _split_hvkg_fantasy_points(
    X: Tensor, n_f: int, num_pareto: int
) -> tuple[Tensor, Tensor]:
    r"""Split a one-shot HV-KG optimization input into actual and 
    fantasy points.

    Args:
        X (Tensor): A `batch_shape x (q + n_f*num_pareto) x d`-dim 
            tensor of actual and fantasy points.
        n_f (int): Number of fantasies.
        num_pareto (int): Number of Pareto optimal points.

    Returns:
        tuple[Tensor, Tensor]: 2-element tuple containing:
            - A `batch_shape x q x d`-dim tensor `X_actual` of input 
                candidates.
            - A `n_f x batch_shape x num_pareto x d`-dim tensor 
                `X_fantasies` of fantasy points, where 
                `X_fantasies[i, batch_idx]` is the i-th fantasy point 
                associated with the batch indexed by `batch_idx`.
    """
    if n_f * num_pareto > X.size(-2):
        raise ValueError(
            f"`n_f*num_pareto` ({n_f * num_pareto}) must be less than"
            f" the `q`-batch dimension of `X` ({X.size(-2)})."
        )
    split_sizes = [X.size(-2) - n_f * num_pareto, n_f * num_pareto]
    X_actual, X_fantasies = torch.split(X, split_sizes, dim=-2)
    # X_fantasies is b x n_f * num_pareto x d, needs to be n_f x b x num_pareto x d
    # reshape into num_fantasies x b x num_pareto x d
    new_shape = torch.Size(
        [n_f, *X_fantasies.shape[:-2], num_pareto, X_fantasies.shape[-1]]
    )
    X_fantasies = X_fantasies.reshape(new_shape)
    # n_f x b x num_pareto x d
    return X_actual, X_fantasies


