
from __future__ import annotations

from typing import Callable

import torch
from torch import Tensor

from botorch.models.model import Model
from botorch.acquisition.multi_objective import qExpectedHypervolumeImprovement
from botorch.utils.multi_objective.box_decompositions.non_dominated import (
    NondominatedPartitioning,
)
from botorch.sampling.base import MCSampler
from botorch.acquisition.multi_objective.objective import (
    MCMultiOutputObjective
)
from botorch.utils.transforms import (
    average_over_ensemble_models,
    concatenate_pending_points,
    t_batch_mode_transform,
)

class CausalqExpectedHypervolumeImprovement(
    qExpectedHypervolumeImprovement):
    def __init__(
        self,
        model: Model,
        ref_point: Tensor,
        partitioning: NondominatedPartitioning,
        causal_model: torch.nn.Module | None = None,
        causal_weight: float | None = None,
        sampler: MCSampler | None = None,
        objective: MCMultiOutputObjective | None = None,
        constraints: list[Callable[[Tensor], Tensor]] | None = None,
        eta: Tensor | float = 1e-3,
        X_pending: Tensor | None = None,
    ) -> None:
        r"""
        Args:
            model: A fitted model. There are two default assumptions in the training
                data. `train_X` should have fidelity parameter `s` as the last dimension
                of the input and `train_Y` contains a trust objective as its last
                dimension.
            ref_point: A list or tensor with `m+1` elements representing the reference
                point (in the outcome space) w.r.t. to which compute the hypervolume.
                The '+1' takes care of the trust objective appended to `train_Y`.
                This is a reference point for the objective values (i.e. after
                applying`objective` to the samples).
            partitioning: A `NondominatedPartitioning` module that provides the non-
                dominated front and a partitioning of the non-dominated space in hyper-
                rectangles. If constraints are present, this partitioning must only
                include feasible points.
            causal_model: A torch.nn.Module representing the causal model. If provided,
                the acquisition function will also consider the causal model's output
                in the computation of the hypervolume.
            causal_weight: The weight of the causal model's output in the computation of the
                hypervolume. Must be in the interval [0, 1]. If not provided, the
                acquisition function will raise an error if `causal_model` is provided.
            sampler: The sampler used to draw base samples. If not given,
                a sampler is generated using `get_sampler`.
            objective: The MCMultiOutputObjective under which the samples are evaluated.
                Defaults to `IdentityMCMultiOutputObjective()`.
            constraints: A list of callables, each mapping a Tensor of dimension
                `sample_shape x batch-shape x q x m` to a Tensor of dimension
                `sample_shape x batch-shape x q`, where negative values imply
                feasibility. The acquisition function will compute expected feasible
                hypervolume.
            eta: The temperature parameter for the sigmoid function used for the
                differentiable approximation of the constraints. In case of a float the
                same eta is used for every constraint in constraints. In case of a
                tensor the length of the tensor must match the number of provided
                constraints. The i-th constraint is then estimated with the i-th
                eta value.
            x_pending: A `batch_shape x q' x d`-dim tensor of pending points that have
                been evaluated but the results are not yet available. These points are
                added to the model during evaluation.
        """
        if causal_model is not None:
            if causal_weight is None:
                raise ValueError(
                    "If `causal_model` is provided, `causal_weight` must be given."
                )
            if not (0.0 <= causal_weight <= 1.0):
                raise ValueError(
                    "`causal_weight` must be in the interval [0, 1]."
                )
        super().__init__(
            model=model,
            ref_point=ref_point,
            partitioning=partitioning,
            sampler=sampler,
            objective=objective,
            constraints=constraints,
            eta=eta,
            X_pending=X_pending,
        )
        self.causal_model = causal_model
        self.causal_weight = causal_weight

    @concatenate_pending_points
    @t_batch_mode_transform()
    @average_over_ensemble_models
    def forward(self, X: Tensor) -> Tensor:    
        posterior = self.model.posterior(X)  
        gp_samples = self.get_posterior_samples(posterior)
        hv_gain_cgp = self._compute_qehvi(samples=gp_samples, X=X)

        if self.causal_model is None:
            return hv_gain_cgp
        else:
            # Get Monte Carlo samples from the causal model
            # Since we map the SCM to torch.nn, we can directly call it            
            num_samples = gp_samples.shape[0]  # s
            original_shape = X.shape
            sample_dims = (num_samples,) + original_shape
            X_expanded = X.unsqueeze(0).expand(sample_dims)
            causal_samples, _ = self.causal_model(X_expanded)
            hv_gain_cm = self._compute_qehvi(samples=causal_samples, X=X)

            return hv_gain_cgp + self.causal_weight * hv_gain_cm