#!/usr/bin/env python3

r"""
    Extension of hypervolume knowledge gradient to causal hypervolume knowledge gradient
    This is similar to: 
    https://github.com/pytorch/botorch/blob/main/botorch/acquisition/multi_objective/multi_fidelity.py

"""

from __future__ import annotations

from collections.abc import Callable
from typing import Any

import torch
from copy import deepcopy

import warnings

from botorch.acquisition.multi_objective.objective import MCMultiOutputObjective
from botorch.sampling.base import MCSampler
from botorch.utils.transforms import (
    average_over_ensemble_models,
    t_batch_mode_transform,
)
from torch import Tensor
from botorch.models.deterministic import PosteriorMeanModel
from botorch.sampling.stochastic_samplers import StochasticSampler
from botorch.exceptions.warnings import NumericsWarning
from botorch.utils.multi_objective.box_decompositions.non_dominated import (
    FastNondominatedPartitioning,
)
from botorch.sampling.base import MCSampler
from botorch.acquisition.acquisition import (
    AcquisitionFunction,
    OneShotAcquisitionFunction,
)
from botorch.acquisition.knowledge_gradient import ProjectedAcquisitionFunction
from botorch.sampling.list_sampler import ListSampler
from botorch.acquisition.multi_objective.base import MultiObjectiveMCAcquisitionFunction
from botorch.acquisition.cost_aware import CostAwareUtility
from botorch.sampling.normal import SobolQMCNormalSampler
from botorch.exceptions.errors import UnsupportedError
from botorch import settings

from rescue.models.causal_gp.multitask import CausalMultitaskGP
from rescue.models.causal_gp.multitask_multifidelity import CausalMultitaskMultifidelityGP
from rescue.models.causal_gp.fantasy_model import fantasize
from rescue.acquisition.causal_ehvi import (
    CausalqExpectedHypervolumeImprovement
)


class qCausalHypervolumeKnowledgeGradient(
    MultiObjectiveMCAcquisitionFunction,
    OneShotAcquisitionFunction,
):
    r"""Batch Causal Hypervolume Knowledge Gradient.
    This is a causal variant of `qHypervolumeKnowledgeGradient` that uses
    the `CausalqExpectedHypervolumeImprovement` acquisition function.
    """
    def __init__(
        self,
        model: CausalMultitaskGP | CausalMultitaskMultifidelityGP,
        ref_point: Tensor,
        num_fantasies: int = 8,
        num_pareto: int = 10,
        sampler: ListSampler | None = None,
        objective: MCMultiOutputObjective | None = None,
        constraints: list[Callable[[Tensor], Tensor]] | None = None,
        inner_sampler: MCSampler | None = None,
        current_value: Tensor | None = None,
        use_posterior_mean: bool = True,
        cost_aware_utility: CostAwareUtility | None = None,
    ) -> None:
        if sampler is None:
            sampler = SobolQMCNormalSampler(sample_shape=torch.Size([num_fantasies]))
        else:
            sample_shape = sampler.samplers[0].sample_shape
            if sample_shape != torch.Size([num_fantasies]):
                raise ValueError(
                    f"The sampler shape must match num_fantasies={num_fantasies}."
                )
        super().__init__(model=model)

        if inner_sampler is None:
            inner_sampler = SobolQMCNormalSampler(sample_shape=torch.Size([32]))
        if current_value is None and cost_aware_utility is not None:
            raise UnsupportedError(
                "Cost-aware HVKG requires current_value to be specified."
            )
        self.register_buffer("ref_point", ref_point)
        self.sampler = sampler
        self.objective = objective
        self.constraints = constraints
        self.inner_sampler = inner_sampler
        self.num_fantasies = num_fantasies
        self.num_pareto = num_pareto
        self.num_pseudo_points = num_fantasies * num_pareto
        self.current_value = current_value
        self.use_posterior_mean = use_posterior_mean
        self.cost_aware_utility = cost_aware_utility
        self._cost_sampler = None

    @property
    def cost_sampler(self):
        if self._cost_sampler is None:
            # Note: Using the deepcopy here is essential. Removing this poses a
            # problem if the base model and the cost model have a different number
            # of outputs or test points (this would be caused by expand), as this
            # would trigger re-sampling the base samples in the fantasy sampler.
            # By cloning the sampler here, the right thing will happen if the
            # the sizes are compatible, if they are not this will result in
            # samples being drawn using different base samples, but it will at
            # least avoid changing state of the fantasy sampler.
            self._cost_sampler = deepcopy(self.sampler)
        return self._cost_sampler

    @t_batch_mode_transform()
    @average_over_ensemble_models
    def forward(self, X: Tensor) -> Tensor:
        r"""Evaluate qKnowledgeGradient on the candidate set `X`.

        Args:
            X: A `b x (q + num_fantasies) x d` Tensor with `b` t-batches of
                `q + num_fantasies` design points each. We split this X tensor
                into two parts in the `q` dimension (`dim=-2`). The first `q`
                are the q-batch of design points and the last num_fantasies are
                the current solutions of the inner optimization problem.

                `X_fantasies = X[..., -num_fantasies:, :]`
                `X_fantasies.shape = b x num_fantasies x d`

                `X_actual = X[..., :-num_fantasies, :]`
                `X_actual.shape = b x q x d`

        Returns:
            A Tensor of shape `b`. For t-batch b, the q-KG value of the design
                `X_actual[b]` is averaged across the fantasy models, where
                `X_fantasies[b, i]` is chosen as the final selection for the
                `i`-th fantasy model.
                NOTE: If `current_value` is not provided, then this is not the
                true KG value of `X_actual[b]`, and `X_fantasies[b, : ]` must be
                maximized at fixed `X_actual[b]`.
        """
        X_actual, X_fantasies = _split_hvkg_fantasy_points(
            X=X, n_f=self.num_fantasies, num_pareto=self.num_pareto
        )
        q = X_actual.shape[-2]

        fantasy_model = fantasize(
            model=self.model,
            X=X_actual,
            sampler=self.sampler,
        )

        # get the value function
        value_function = causal_hv_value_function(
            model=fantasy_model,
            ref_point=self.ref_point,
            objective=self.objective,
            constraints=self.constraints,
            sampler=self.inner_sampler,
            use_posterior_mean=self.use_posterior_mean,
        )

        # make sure to propagate gradients to the fantasy model train inputs
        with settings.propagate_grads(True):
            # X_fantasies is num_pseudo_points x batch_shape x 1 x d
            # Reshape it into num_fantasies x batch_shape x num_pareto x d
            shape = torch.Size(
                [
                    self.num_fantasies,
                    *X_fantasies.shape[1:-2],
                    self.num_pareto,
                    X_fantasies.shape[-1],
                ]
            )
            values = value_function(X=X_fantasies.reshape(shape))  # num_fantasies x b

        if self.current_value is not None:
            values = values - self.current_value

        if self.cost_aware_utility is not None:
            values = self.cost_aware_utility(
                X=X_actual[..., :q, :],
                deltas=values,
                sampler=self.cost_sampler,
            )

        # return average over the fantasy samples
        return values.mean(dim=0)

    def get_augmented_q_batch_size(self, q: int) -> int:
        r"""Get augmented q batch size for one-shot optimization.

        Args:
            q: The number of candidates to consider jointly.

        Returns:
            The augmented size for one-shot optimization (including variables
            parameterizing the fantasy solutions).
        """
        return q + self.num_pseudo_points

    def extract_candidates(self, X_full: Tensor) -> Tensor:
        r"""We only return X as the set of candidates post-optimization.

        Args:
            X_full: A `b x (q + num_fantasies) x d`-dim Tensor with `b`
                t-batches of `q + num_fantasies` design points each.

        Returns:
            A `b x q x d`-dim Tensor with `b` t-batches of `q` design points each.
        """
        return X_full[..., : -self.num_pseudo_points, :]


class qMultiFidelityCausalHypervolumeKnowledgeGradient(qCausalHypervolumeKnowledgeGradient):
    r"""Batch Multi-Fidelity Causal Hypervolume Knowledge Gradient.
    This is a variant of `qCausalHypervolumeKnowledgeGradient` that supports
    multi-fidelity optimization by projecting the fidelities of the fantasy points
    to the target fidelities.
    """

    def __init__(
        self,
        model: CausalMultitaskGP | CausalMultitaskMultifidelityGP,
        ref_point: Tensor,
        target_fidelities: dict[int, float],
        num_fantasies: int = 8,
        num_pareto: int = 10,
        sampler: MCSampler | None = None,
        objective: MCMultiOutputObjective | None = None,
        constraints: list[Callable[[Tensor], Tensor]] | None = None,
        inner_sampler: MCSampler | None = None,
        current_value: Tensor | None = None,
        cost_aware_utility: CostAwareUtility | None = None,
        project: Callable[[Tensor], Tensor] = lambda X: X,
        use_posterior_mean: bool = True,
        **kwargs: Any,
    ) -> None:

        super().__init__(
            model=model,
            ref_point=ref_point,
            num_fantasies=num_fantasies,
            num_pareto=num_pareto,
            sampler=sampler,
            objective=objective,
            constraints=constraints,
            inner_sampler=inner_sampler,
            current_value=current_value,
            use_posterior_mean=use_posterior_mean,
            cost_aware_utility=cost_aware_utility,
        )
        self.project = project
        if kwargs.get("expand") is not None:
            raise NotImplementedError(
                "Trace observations are not currently supported "
                "by `qMultiFidelityHypervolumeKnowledgeGradient`."
            )
        self.expand = lambda X: X
        self.target_fidelities = target_fidelities


    @t_batch_mode_transform()
    @average_over_ensemble_models
    def forward(self, X: Tensor) -> Tensor:
        r"""Evaluate qMultiFidelityKnowledgeGradient on the candidate set `X`.

        Args:
            X: A `b x (q + num_fantasies) x d` Tensor with `b` t-batches of
                `q + num_fantasies` design points each. We split this X tensor
                into two parts in the `q` dimension (`dim=-2`). The first `q`
                are the q-batch of design points and the last num_fantasies are
                the current solutions of the inner optimization problem.

                `X_fantasies = X[..., -num_fantasies:, :]`
                `X_fantasies.shape = b x num_fantasies x d`

                `X_actual = X[..., :-num_fantasies, :]`
                `X_actual.shape = b x q x d`

                In addition, `X` may be augmented with fidelity parameteres as
                part of thee `d`-dimension. Projecting fidelities to the target
                fidelity is handled by `project`.

        Returns:
            A Tensor of shape `b`. For t-batch b, the q-KG value of the design
                `X_actual[b]` is averaged across the fantasy models, where
                `X_fantasies[b, i]` is chosen as the final selection for the
                `i`-th fantasy model.
                NOTE: If `current_value` is not provided, then this is not the
                true KG value of `X_actual[b]`, and `X_fantasies[b, : ]` must be
                maximized at fixed `X_actual[b]`.
        """
        X_actual, X_fantasies = _split_hvkg_fantasy_points(
            X=X, n_f=self.num_fantasies, num_pareto=self.num_pareto
        )
        q = X_actual.shape[-2]

        fantasy_model = fantasize(
            model=self.model,
            X=X_actual,
            sampler=self.sampler,
        )
        # get the value function
        value_function = causal_hv_value_function(
            model=fantasy_model,
            ref_point=self.ref_point,
            objective=self.objective,
            constraints=self.constraints,
            sampler=self.inner_sampler,
            project=self.project,
            use_posterior_mean=self.use_posterior_mean,
        )

        # make sure to propagate gradients to the fantasy model train inputs
        with settings.propagate_grads(True):
            # X_fantasies is num_pseudo_points  x batch_shape x 1 x d
            # Reshape it into num_fantasies x batch_shape x num_pareto x d
            shape = torch.Size(
                [
                    self.num_fantasies,
                    *X_fantasies.shape[1:-2],
                    self.num_pareto,
                    X_fantasies.shape[-1],
                ]
            )
            values = value_function(X=X_fantasies.reshape(shape))  # num_fantasies x b
        if self.current_value is not None:
            values = values - self.current_value
        if self.cost_aware_utility is not None:
            values = self.cost_aware_utility(
                X=X_actual[..., :q, :],
                deltas=values,
                sampler=self.cost_sampler,
            )
        # return average over the fantasy samples
        return values.mean(dim=0)

def causal_hv_value_function(
    model: CausalMultitaskGP | CausalMultitaskMultifidelityGP,
    ref_point: Tensor,
    causal_model: torch.nn.Module | None = None,
    causal_weight: float | None = None,
    objective: MCMultiOutputObjective | None = None,
    constraints: list[Callable[[Tensor], Tensor]] | None = None,
    sampler: MCSampler | None = None,
    project: Callable[[Tensor], Tensor] | None = None,
    use_posterior_mean: bool = False,
) -> AcquisitionFunction:
    r"""Construct value function (i.e. inner acquisition function).
    This is a method for computing hypervolume.
    """
    if use_posterior_mean:
        model = PosteriorMeanModel(model=model)
        sampler = StochasticSampler(sample_shape=torch.Size([1]))  # dummy sampler
    with warnings.catch_warnings():
        warnings.filterwarnings(
            message="qExpectedHypervolumeImprovement has known",
            action="ignore",
            category=NumericsWarning,
        )
        base_value_function = CausalqExpectedHypervolumeImprovement(
            model=model,
            causal_model=causal_model,
            ref_point=ref_point,
            partitioning=FastNondominatedPartitioning(
                ref_point=ref_point,
                Y=torch.empty(
                    (0, ref_point.shape[0]),
                    dtype=ref_point.dtype,
                    device=ref_point.device,
                ),
            ),  # create empty partitioning
            sampler=sampler,
            objective=objective,
            constraints=constraints,
            causal_weight=causal_weight,
        )
    # ProjectedAcquisitionFunction requires this
    base_value_function.posterior_transform = None

    if project is None:
        return base_value_function
    else:
        return ProjectedAcquisitionFunction(
            base_value_function=base_value_function,
            project=project,
        )

def _split_hvkg_fantasy_points(
    X: Tensor, n_f: int, num_pareto: int
) -> tuple[Tensor, Tensor]:
    r"""Split a one-shot HV-KGoptimization input into actual and fantasy points

    Args:
        X: A `batch_shape x (q + n_f*num_pareto) x d`-dim tensor of actual
            and fantasy points

    Returns:
        2-element tuple containing

        - A `batch_shape x q x d`-dim tensor `X_actual` of input candidates.
        - A `n_f x batch_shape x num_pareto x d`-dim tensor `X_fantasies` of
            fantasy points, where `X_fantasies[i, batch_idx]` is the i-th
            fantasy point associated with the batch indexed by `batch_idx`.
    """
    if n_f * num_pareto > X.size(-2):
        raise ValueError(
            f"`n_f*num_pareto` ({n_f * num_pareto}) must be less than"
            f" the `q`-batch dimension of `X` ({X.size(-2)})."
        )
    split_sizes = [X.size(-2) - n_f * num_pareto, n_f * num_pareto]
    X_actual, X_fantasies = torch.split(X, split_sizes, dim=-2)
    # X_fantasies is b x n_f * num_pareto x d, needs to be n_f x b x num_pareto x d
    # reshape into num_fantasies x b x num_pareto x d
    new_shape = torch.Size(
        [n_f, *X_fantasies.shape[:-2], num_pareto, X_fantasies.shape[-1]]
    )
    X_fantasies = X_fantasies.reshape(new_shape)
    # n_f x b x num_pareto x d
    return X_actual, X_fantasies