#!/usr/bin/env python3

r"""
Multi-objective multi-fidelity optimization benchmark problems.

References

.. [bhatija2025]
    Bhatija, S., Zuercher, P. D., Thumm, J., & Bohné, T. (2025). 
    Multi-Objective Causal Bayesian Optimization. arXiv preprint arXiv:2502.14755.
    
.. [aglietti2020]
    Aglietti, V., Lu, X., Paleyes, A., & González, J. (2020, June). 
    Causal bayesian optimization. In International Conference on Artificial 
    Intelligence and Statistics (pp. 3155-3164). PMLR.
"""

from __future__ import annotations

from torch import Tensor
import torch
from botorch.test_functions.base import (
    MultiObjectiveTestProblem,
    ConstrainedBaseTestProblem
)

class MFHealth(MultiObjectiveTestProblem, ConstrainedBaseTestProblem):
    """
    Multi-fidelity version of multi-objective causal health model.
    From [bhatija2025],
    
    Manipulable variables (2D):
    - x[0]: BMI (Body Mass Index) ∈ [20, 30]
    - x[1]: Aspirin (daily aspirin regimen) ∈ [0, 1]
    
    Non-manipulable variables:
    - age = 65 (fixed at mean)
    
    Objectives (to minimize):
    - f1: statin (medication) = 
        σ(S * (-13 + 0.10*age + 0.20*BMI))

        where σ(X) = 1/(1+e^(-X)) is the sigmoid function

    - f2: PSA (prostate specific antigen) = 
        S * (6.8 + 0.04*age - 0.15*BMI - 0.6*statin + 0.55*aspirin + cancer)
    
        where cancer = σ(S * (2.2 - 0.05*age + 0.01*BMI - 0.04*statin + 0.02*aspirin))
    """
    
    dim = 3  # Changed from 2 to 3
    num_objectives = 2
    num_constraints = 1  # Cancer risk constraint

    design_var_names = [
        'BMI',        # Body Mass Index
        'Aspirin',    # Daily aspirin regimen
        'S'  # Fidelity parameter
    ]
    objective_var_names = [
        'Statin',  # Statin medication
        'PSA'      # Prostate Specific Antigen
    ]
    constraint_var_names = [
        'Cancer'  # Constraint: cancer risk < 0.3
    ]
    fidelity_param_name = 'S'  # Fidelity parameter

    continuous_inds = list(range(3))
    discrete_inds = []
    categorical_inds = []

    _bounds = [
        (20.0, 30.0),      # BMI
        (0.0, 1.0),        # Aspirin
        (0.0, 1.0),        # Fidelity S
    ]
    _ref_point = [0.3775, 7.1964]  # Reference point for hypervolume
    _max_hv = 0.45637065528355614  # Approximate using NSGA-II, seed 334
    
    def __init__(
        self,
        noise_std: None | float | list[float] = None,
        negate: bool = False,
        dtype: torch.dtype = torch.double,
    ) -> None:
        r"""
        Args:
            noise_std: Standard deviation of the observation noise.
            negate: If True, negate the objectives.
            dtype: The dtype that is used for the bounds of the function.
        """
        super().__init__(noise_std=noise_std, negate=negate, dtype=dtype)
        
    def _sigmoid(self, x: Tensor) -> Tensor:
        """Sigmoid function σ(x) = 1/(1+e^(-x))"""
        return 1.0 / (1.0 + torch.exp(-x))
    
    def _statin(self, BMI: Tensor, age: Tensor, S: Tensor) -> Tensor:
        """Compute statin objective."""
        statin_logit = S * (-13.0 + 0.10 * age + 0.20 * BMI)
        return self._sigmoid(statin_logit)

    def _cancer(self, BMI: Tensor, age: Tensor, aspirin: Tensor, S: Tensor) -> Tensor:
        """Compute cancer objective."""
        cancer_logit = (S 
                            * (
                                2.2 - 0.05 * age + 0.01 * BMI - 0.04 * self._statin(BMI, age, S) 
                                + 0.02 * aspirin
                            )
                        )
        return self._sigmoid(cancer_logit)

    def _PSA(self, BMI: Tensor, age: Tensor, aspirin: Tensor, S: Tensor) -> Tensor:
        """Compute PSA objective."""
        statin = self._statin(BMI=BMI, age=age, S=S)
        cancer = self._cancer(BMI=BMI, age=age, aspirin=aspirin, S=S)
        PSA = (6.8 + S
                * (0.04 * age - 0.15 * BMI - 0.6 * statin + 0.55 * aspirin + cancer)
            )
        return PSA

    def _evaluate_true(self, X: Tensor) -> Tensor:
        """
        Evaluate the multi-fidelity causal health model.
        
        Args:
            X: (batch_shape) x 3 tensor of inputs [BMI, Aspirin, S]
            
        Returns:
            (batch_shape) x 2 tensor of objectives [Statin, PSA]
        """
        # Extract manipulable variables
        BMI = X[..., 0]
        aspirin = X[..., 1]
        S = X[..., 2]  # Fidelity parameter
        
        # Non-manipulable variable
        age = torch.tensor(65.0, device=X.device, dtype=X.dtype)
        
        # Objective 1: Statin
        statin = self._statin(BMI, age, S)
        
        # Objective 2: PSA (baseline + scaled effects)
        PSA = self._PSA(BMI=BMI, age=age, aspirin=aspirin, S=S)
        return torch.stack([statin, PSA], dim=-1)
    
    def _evaluate_slack_true(self, X: Tensor) -> Tensor:
        r"""
        Constrained on cancer risk being below 0.3.
        """
        # Extract manipulable variables
        BMI = X[..., 0]
        aspirin = X[..., 1]
        S = X[..., 2]  # Fidelity parameter

        cancer = self._cancer(
            BMI=BMI, 
            age=torch.tensor(65.0, device=X.device, dtype=X.dtype), 
            aspirin=aspirin, 
            S=S
        )
        # Constraint: cancer risk < 0.35
        slack = 0.35 - cancer
        return slack.unsqueeze(-1)