#!/usr/bin/env python3

from __future__ import annotations

from typing import Callable, Any

import pandas as pd
import torch
from torch import Tensor
from wandb.sdk.wandb_run import Run

from botorch.test_functions.base import (
    ConstrainedBaseTestProblem,
    MultiObjectiveTestProblem,
)
from botorch.models.model import Model
from botorch.sampling.normal import SobolQMCNormalSampler
from botorch import fit_gpytorch_mll
from botorch.optim import optimize_acqf, optimize_acqf_mixed
from botorch.utils.transforms import unnormalize, normalize
from botorch.acquisition.multi_objective.objective import (
    IdentityMCMultiOutputObjective
)
from botorch.utils.multi_objective.box_decompositions.non_dominated import (
    FastNondominatedPartitioning,
)
from botorch.acquisition.fixed_feature import FixedFeatureAcquisitionFunction
from botorch.models.deterministic import GenericDeterministicModel
from botorch.acquisition.cost_aware import InverseCostWeightedUtility
from botorch.acquisition.multi_objective.objective import MCMultiOutputObjective

from gpytorch.kernels import Kernel

from rescue.algorithms.base import (
    BaseRescueAlgorithmMultifidelity, 
    CausalGPFn, 
    GenInitDataFn
)
from rescue.acquisition.causal_knowledge_gradient import (
    qMultiFidelityCausalHypervolumeKnowledgeGradient,
    causal_hv_value_function
)
from rescue.acquisition.causal_momf import CausalMOMF
from rescue.models import fit_causal_gp
from rescue.acquisition.optim import gen_one_shot_hvkg_initial_conditions
from rescue.metrics.optimization_pref import (
    MultiObjectiveOptimizationPref,
    nsga2_posterior_pareto
)
from rescue.metrics.utils import term_print
from rescue.algorithms.state import RescueStateMultifidelity
from rescue.utils.utils import status


class RescueAlgorithmMultifidelity(BaseRescueAlgorithmMultifidelity):
    def __init__(
        self,
        problem: MultiObjectiveTestProblem | ConstrainedBaseTestProblem,
        design_variables: list[str],
        objective_variables: list[str],
        bounds: Tensor,        
        is_discrete_fidelities: bool,
        target_fidelities: dict[int, float],
        fidelity_param_name: str,
        cost_fn: Callable[[Tensor], Tensor],
        gen_initial_data: GenInitDataFn | None = None,
        custom_model: None | CausalGPFn = None, 
        verbose: bool = False,
        status_spinner: bool = True,
        fidelity_levels: None | Tensor = None,        
        kpi_variables: None | list[str] = None,
        constraint_variables: None | list[str] = None,
        device: None | torch.device = None,
        dtype: None | torch.dtype = None,        
        wandb_run: None | Run = None,
        rescue_state: None | RescueStateMultifidelity = None,
    ) -> None:
        r"""
        Initialize the multi-fidelity rescue algorithm.

        Assumptions:
            - The design variables include a fidelity parameter.
            - The fidelity parameter is the last column in the design variable tensor.

        Args:
            problem (MultiObjectiveTestProblem | ConstrainedBaseTestProblem): The 
                optimization problem to solve.
            design_variables (list[str]): The design variables to optimize.
            objective_variables (list[str]): The objective variables to optimize.
            bounds (Tensor): The bounds for the design variables.
            is_discrete_fidelities (bool): Whether the fidelities are discrete.
            target_fidelities (dict[int, float]): The target fidelity values for each fidelity dimension.
            fidelity_param_name (str): The name of the fidelity parameter.
            cost_fn (Callable[[Tensor], Tensor]): The cost function for evaluations.
            gen_initial_data (GenInitDataFn | None): A function to generate initial data.
            custom_model (None | CausalGPFn): A custom model function.
            verbose (bool): Whether to print verbose output.
            status_spinner (bool): Whether to show a status spinner.
            fidelity_levels (None | Tensor): The fidelity levels to use.
            kpi_variables (None | list[str]): The KPI variables to track.
            constraint_variables (None | list[str]): The constraint variables to track.
            device (None | torch.device): The device to use.
            dtype (None | torch.dtype): The dtype to use.
            wandb_run (None | Run): The wandb run to log to.
            rescue_state (None | RescueStateMultifidelity): The rescue state to use.
        """
        super().__init__(
            problem=problem,
            custom_model=custom_model,
            cost_fn=cost_fn,
            design_variables=design_variables,
            objective_variables=objective_variables,
            kpi_variables=kpi_variables,
            constraint_variables=constraint_variables,
            bounds=bounds,
            gen_initial_data=gen_initial_data,
            is_discrete_fidelities=is_discrete_fidelities,
            target_fidelities=target_fidelities,
            fidelity_param_name=fidelity_param_name,
            fidelity_levels=fidelity_levels,
            verbose=verbose,
            status_spinner=status_spinner,
            device=device,
            dtype=dtype,
            wandb_run=wandb_run,
            rescue_state=rescue_state,
        )

        # *** Debugging only ***
        self.status_spinner = status_spinner  # For status decorator
        # *********************

    @status(show_func_name=True)
    def _generate_initial_data(
        self,
        n_full_fidelity_equiv: float | int | None = None,
    ) -> tuple[Tensor, Tensor, Tensor | None]:
        r"""
        Generate initial training data.
        
        Args:
            n_full_fidelity_equiv: Number of initial samples (or budget for multi-fidelity)
            
        Returns:
            tuple[Tensor, Tensor, Tensor | None]: 
                - train_x: Input samples with dim `n x d` where last 
                    column of `d` corresponds to fidelity levels.
                - train_obj: Objective values with dim `n x m`.
                - train_constraints: Constraint values with dim `n x c`
                    (None if no constraints).
        """
        return self.generate_initial_data_mf(
                    n_full_fidelity_equiv=n_full_fidelity_equiv
                )  
    
    @status(show_func_name=True) 
    def update_causal_gp(self, mll, use_rescue_fit=False) -> None:
        r"""Fit the causal GP model.

        Args:
            mll: The marginal likelihood to optimize.
            use_rescue_fit (bool): Whether to use the rescue fit method. Defaults to False.
        """
        if use_rescue_fit:
            fit_causal_gp(mll, verbose=self.verbose)
        else:
            fit_gpytorch_mll(mll)
        
    def get_current_value(
        self,
        model: Model,
        causal_model: torch.nn,
        causal_weight: float,
        ref_point: torch.Tensor,
        raw_samples: int,
        num_restarts: int,
        q: int,
        mc_samples: int | None = None,
        use_posterior_mean: bool = True,
        objective: MCMultiOutputObjective | None = None,
        constraints: list[Callable[[Tensor], Tensor]] | None = None,
        options: dict[str, bool | float | int | str] | None = None,
    ):
        r"""Helper to get the hypervolume of the current hypervolume
        maximizing set.
        """
        if options is None:
            options = {"nonnegative": True}
        else:
            options["nonnegative"] = True 
        dim_x = self.problem.dim
        fidelity_dims, fidelity_targets = zip(
            *self.normalized_target_fidelities.items()
            )
        if not use_posterior_mean and mc_samples is not None:
            sampler = SobolQMCNormalSampler(sample_shape=torch.Size([mc_samples]))
        else:
            sampler = None
        # optimize
        non_fidelity_dims = list(set(range(dim_x)) - set(fidelity_dims))
        curr_val_acqf = FixedFeatureAcquisitionFunction(
            acq_function=causal_hv_value_function(
                model=model,
                causal_model=causal_model,
                causal_weight=causal_weight,
                ref_point=ref_point,
                sampler=sampler,
                use_posterior_mean=use_posterior_mean,
                objective=objective,
                constraints=constraints,
            ),
            d=self.problem.dim,
            columns=fidelity_dims,
            values=fidelity_targets,
        )
        # optimize
        _, current_value = optimize_acqf(
            acq_function=curr_val_acqf,
            bounds=self.standard_bounds[:, non_fidelity_dims],
            q=q,
            num_restarts=num_restarts,
            raw_samples=raw_samples,
            return_best_only=True,
            options=options,
        )
        return current_value

    @status(show_func_name=True) 
    def optimize_causal_hvkg(
        self,
        q: int,
        model: Model,
        causal_model: torch.nn.Module,
        causal_weight: float,
        ref_point: Tensor,
        curr_val_raw_samples: int,
        curr_val_num_restarts: int,
        num_pareto: int,
        num_fantasies: int,
        num_restarts: int,
        raw_samples: int,
        use_posterior_mean: bool = True,
        curr_val_mc_samples: int | None = None,
        inner_mc_samples: int | None = None,
        objective_indices: None | list[int] = None,
        fixed_features_list: None | list[dict[int, float]] = None,
        options: dict[str, bool | float | int | str] | None = None,
        optim_valuef_options: dict[str, bool | float | int | str] | None = None
    ) -> tuple[Tensor, Tensor]:
        r"""
        Optimize the acquisition function to get next candidate and fidelity.
        NOTE: Assumes last `c` dim of `train_y` corresponds to constraints.

        Args:
            q (int): The batch size for acquisition.
            model (Model): The surrogate model to optimize.
            causal_model (torch.nn.Module): The causal model.
            causal_weight (float): The weight for the causal component.
            ref_point (Tensor): The reference point for the optimization.
            curr_val_raw_samples (int): The number of raw samples for 
                current value optimization.
            curr_val_num_restarts (int): The number of restarts for 
                current value optimization.
            num_pareto (int): The number of Pareto points to consider.
            num_fantasies (int): The number of fantasy points for the knowledge gradient.
            num_restarts (int): The number of restarts for the optimization.
            raw_samples (int): The number of raw samples to generate.
            use_posterior_mean (bool): Whether to use the posterior mean. 
                Defaults to True.
            curr_val_mc_samples (int | None): The number of MC samples for 
                current value. Defaults to None.
            inner_mc_samples (int | None): The number of inner MC samples. 
                Defaults to None.
            objective_indices (None | list[int]): The indices of the objectives to optimize. 
                Defaults to None.
            fixed_features_list (None | list[dict[int, float]]): The fixed features 
                for discrete fidelity. Defaults to None.
            options (dict[str, bool | float | int | str] | None): Additional optimization 
                options. Defaults to None.
            optim_valuef_options (dict[str, bool | float | int | str] | None): Options for 
                value function optimization. Defaults to None.

        Returns:
            Tuple (Tensor, Tensor): (new_x, acq_value).
        """
        if not use_posterior_mean and inner_mc_samples is not None:
            inner_sampler = SobolQMCNormalSampler(
                            sample_shape=torch.Size([inner_mc_samples])
                        )
        else:
            inner_sampler = None
        if self.has_constraints:
            # Constraints are in the model after all objectives
            # Model structure: [obj0, obj1, ..., objN, cons0, cons1, ...]
            num_objectives = len(objective_indices)
            constraints = [
                lambda Z, idx=num_objectives+j: Z[..., idx] 
                for j in range(self.num_constraints)
            ]
        objective = IdentityMCMultiOutputObjective(outcomes=objective_indices)
        cost_model = GenericDeterministicModel(self.cost_fn)
        cost_aware_utility = InverseCostWeightedUtility(cost_model=cost_model)
        current_value = self.get_current_value(
            model=model,
            causal_model=causal_model,
            causal_weight=causal_weight,
            ref_point=ref_point,
            objective=objective,
            constraints=constraints if self.has_constraints else None,
            raw_samples=curr_val_raw_samples,
            mc_samples=curr_val_mc_samples,
            q=num_pareto,
            use_posterior_mean=use_posterior_mean,
            num_restarts=curr_val_num_restarts,
            options=optim_valuef_options
        )
        acq_func = qMultiFidelityCausalHypervolumeKnowledgeGradient(
                model=model,
                ref_point=ref_point,  # use known reference point
                num_fantasies=num_fantasies,
                num_pareto=num_pareto,
                inner_sampler=inner_sampler,
                current_value=current_value,
                use_posterior_mean=use_posterior_mean,
                cost_aware_utility=cost_aware_utility,
                target_fidelities=self.normalized_target_fidelities,
                project=self._project_to_target_fidelity,
                objective=objective,
                constraints=constraints if self.has_constraints else None,
            )
        # Optimization
        if self.is_discrete_fidelities:
            candidates, vals = optimize_acqf_mixed(
                acq_function=acq_func,
                bounds=self.standard_bounds,
                q=q,
                fixed_features_list=fixed_features_list,
                num_restarts=num_restarts,
                raw_samples=raw_samples,  # used for intialization heuristic
                ic_generator=gen_one_shot_hvkg_initial_conditions,
                options=options,
            )
        else:
            candidates, vals = optimize_acqf(
                acq_function=acq_func,
                bounds=self.standard_bounds,
                q=q,
                num_restarts=num_restarts,
                raw_samples=raw_samples,  # used for intialization heuristic
                ic_generator=gen_one_shot_hvkg_initial_conditions,
                options=options,
                sequential=True
            )
        # if the AF val is 0, set the fidelity parameter to zero
        if vals.item() == 0.0:
            if self.is_discrete_fidelities:
                min_fidelity = self.fidelity_levels.min().item()
                candidates[:, -1] = min_fidelity
            else:
                candidates[:, -1] = 0.0
        # observe new values
        if self.is_discrete_fidelities:
            new_x = candidates.detach()
            new_x[:, :-1] = unnormalize(candidates[:, :-1].detach(), bounds=self.bounds[:, :-1])
        else:
            new_x = unnormalize(candidates.detach(), bounds=self.bounds)
        return new_x, vals    

    @status(show_func_name=True)
    def optimize_CausalMOMF(
        self,
        model: Model,
        causal_model: torch.nn,
        causal_weight: float,
        train_obj: torch.Tensor,
        sampler: SobolQMCNormalSampler,
        ref_point: torch.Tensor,
        q: int,
        num_restarts: int,
        raw_samples: int,
        objective_indices: None | list[int] = None,
        constraints_indices: None | list = None,
        fixed_features_list: None | list[dict[int, float]] = None,
        options: dict[str, bool | float | int | str] | None = None
    ) -> tuple[Tensor, Tensor]:
        r"""
        Wrapper to call MOMF and optimizes it in a sequential greedy
        fashion returning a new candidate and evaluation.

        Assumptions:
            - last `c` dim corresponds to constraints
        """       
        if options is None:
            options = {"nonnegative": True}
        else:
            options["nonnegative"] = True 
        partitioning = FastNondominatedPartitioning(
            ref_point=ref_point, Y=train_obj)
        if self.has_constraints:
            # Constraints are in the model after all objectives
            # Model structure: [obj0, obj1, ..., objN, cons0, cons1, ...]
            num_objectives = len(objective_indices)
            constraints = [
                lambda Z, idx=num_objectives+j: Z[..., idx] 
                for j in range(len(constraints_indices))
            ]
        acq_func = CausalMOMF(
            model=model,
            causal_model=causal_model,
            causal_weight=causal_weight,
            ref_point=ref_point,  # use known reference point
            partitioning=partitioning,
            sampler=sampler,
            cost_call=self.cost_fn,
            objective=IdentityMCMultiOutputObjective(
                outcomes=objective_indices
            ),
            constraints=constraints if self.has_constraints else None,
        )
        # Optimization
        if self.is_discrete_fidelities:
            candidates, vals = optimize_acqf_mixed(
                acq_function=acq_func,
                bounds=self.standard_bounds,
                q=q,
                fixed_features_list=fixed_features_list,
                num_restarts=num_restarts,
                raw_samples=raw_samples,  # used for intialization heuristic
                options=options
            )
        else:
            candidates, vals = optimize_acqf(
                acq_function=acq_func,
                bounds=self.standard_bounds,
                q=q,
                num_restarts=num_restarts,
                raw_samples=raw_samples,  # used for intialization heuristic
                sequential=True,
                options=options
            )
        # if the AF val is 0, set the fidelity parameter to zero
        if vals.item() == 0.0:
            if self.is_discrete_fidelities:
                min_fidelity = self.fidelity_levels.min().item()
                candidates[:, -1] = min_fidelity
            else:
                candidates[:, -1] = 0.0
        # observe new values
        if self.is_discrete_fidelities:
            new_x = candidates.detach()
            new_x[:, :-1] = unnormalize(candidates[:, :-1].detach(), bounds=self.bounds[:, :-1])
        else:
            new_x = unnormalize(candidates.detach(), bounds=self.bounds)
        return new_x, vals

    def run(
        self,
        budget: int | float,
        ref_point: Tensor,
        causal_observational_data: pd.DataFrame,
        objective_indices: list[int],
        use_rescue_fit: bool = False,
        init_budget: int | float | None = None,
        include_initcost_to_budget: bool = True,
        causal_intervention_samples: int = 1000,
        causal_discovery: str = 'PC',
        causal_num_interventions: int = 200,
        causal_net_epochs: int = 100,
        causal_inf_backend: str = "loky",
        causal_inf_batch_size: int = 10,
        causal_inf_n_jobs: int = 1,  
        causal_discovery_alpha: None | float = None,    
        causal_is_design_var_independent: bool = True,  
        cgp_data_covar_module: None | Kernel = None,
        cgp_fidelity_covar_module: None | Kernel = None,
        fit_causal_model_after_each: None | int = None,
        causal_save_graph: bool = False,  
        acqfn_q: int = 1,  
        acqfn_use_causalmomf: bool = False,
        acqfn_causalmomf_mc_samples: int = 128,
        acqfn_num_restarts: int = 10,
        acqfn_raw_samples: int = 512,
        acqfn_num_pareto: int = 10,
        acqfn_num_fantasies: int = 8,
        acqfn_curr_val_causal_weight: float = 1.0,
        acqfn_curr_val_mc_samples: int = 32,
        acqfn_curr_val_raw_samples: int = 2 * 512,
        acqfn_curr_val_num_restarts: int = 1,
        acqfn_inner_mc_samples: int = 32,
        acqfn_use_posterior_mean: bool = True,
        optim_acqfn_options: dict[str, bool | float | int | str] | None = None,
        optim_inner_acqfn_options: dict[str, bool | float | int | str] | None = None,
        max_hv: float | Tensor | None = None,
        constraints_indices: None | list[int] = None,             
        seed: int | None = None,
        compute_metrics: bool = True,
        show_status: bool = True        
    ) -> dict[str, Any]:
        r""" 
        NOTE: Assumes last dim of `train_y` corresponds to constraints.

        Args:
            budget (int | float): The budget for the optimization run.
            ref_point (Tensor): The reference point for the optimization.
            causal_observational_data (pd.DataFrame): Observational data for causal discovery.
            objective_indices (None | list[int]): The indices of the objective variables.
            use_rescue_fit (bool): Whether to use the rescue fit method. Defaults to False.
            init_budget (int | float | None): Budget in terms of number of 
                full-fidelity evaluations.  Defined as `n x cost(target_fidelity)`.
            include_initcost_to_budget (bool): Whether to include the initial cost in the budget.
            causal_intervention_samples (int): The number of causal intervention samples to generate.
            causal_discovery (str): The causal discovery method to use.
            causal_num_interventions (int): The number of interventions to perform.
            causal_net_epochs (int): The number of epochs to train the causal network.
            causal_inf_backend (str): The backend for causal inference.
            causal_inf_batch_size (int): The batch size for causal inference.
            causal_inf_n_jobs (int): The number of jobs for causal inference.
            causal_discovery_alpha (None | float): The significance level for causal discovery.
            causal_is_design_var_independent (bool): Whether design variables are independent.
            cgp_data_covar_module (None | Kernel): The data covariance module for the causal GP.
            cgp_fidelity_covar_module (None | Kernel): The fidelity covariance module for the causal GP.
            fit_causal_model_after_each (None | int): The frequency of fitting the causal model.
            causal_save_graph (bool): Whether to save the causal graph.
            acqfn_q (int): The batch size for acquisition.
            acqfn_use_causalmomf (bool): Whether to use CausalMOMF acquisition function.
            acqfn_causalmomf_mc_samples (int): The number of MC samples for CausalMOMF.
            acqfn_num_restarts (int): The number of restarts for the acquisition function.
            acqfn_raw_samples (int): The number of raw samples for the acquisition function.
            acqfn_num_pareto (int): The number of Pareto points.
            acqfn_num_fantasies (int): The number of fantasies for knowledge gradient.
            acqfn_curr_val_causal_weight (float): The causal weight for current value computation.
            acqfn_curr_val_mc_samples (int): The number of MC samples for current value.
            acqfn_curr_val_raw_samples (int): The number of raw samples for current value.
            acqfn_curr_val_num_restarts (int): The number of restarts for current value.
            acqfn_inner_mc_samples (int): The number of inner MC samples.
            acqfn_use_posterior_mean (bool): Whether to use posterior mean.
            optim_acqfn_options (dict[str, bool | float | int | str] | None): Options for 
                acquisition optimization.
            optim_inner_acqfn_options (dict[str, bool | float | int | str] | None): Options for 
                inner acquisition optimization.
            max_hv (float | Tensor | None): The maximum hypervolume to achieve.
                If provided, the optimization method will be evaluated at runtime.
                Which can be extremely costly as evaluation is done at the
                target fidelity. NOTE: Should be only used for research purposes.
            constraints_indices (None | list[int]): The indices of the constraint variables.
            seed (int | None): The `torch.manual_seed` seed for reproducibility.
            compute_metrics (bool): Whether to evaluate the optimization preferences at runtime.
            show_status (bool): Whether to show status updates.

        Returns:
            dict[str, Any]: A dictionary containing the results of the optimization run.
                - "pareto_X": The Pareto points.
                - "pareto_Y": The objective values of the Pareto points.
                - "constraints": The constraint values of the Pareto points.
                - "pymoo_res": The pymoo Result object.
                - "cgp_model": The causal GP model.
                - "cpm": The causal performance model.

        """
        if causal_discovery == "FCI":
            raise NotImplementedError(
                "FCI produces Partial Ancestral Graphs (PAGs) "
                "which is resolved using information theoretic approach to "
                "result in a Acyclic Directed Mixed Graph (ADMG)."
                "However, `dowhy` does not support ADMGs"
            )
        
        self.validate_run_inputs(compute_metrics, objective_indices, constraints_indices)
        has_maxhv = max_hv is not None      
        # === Helper Functions for run ===
        def _get_total_cost(train_x: Tensor) -> float:
            return self.cost_fn(train_x).sum().item() 
        
        def _get_new_cost(new_x: Tensor) -> float:
            return self.cost_fn(new_x).item()
        # === End Helper Functions for run ===

        # +===========================================+
        # |           RESCUE Initializations          |
        # +===========================================+
        if self.rescue_state is not None:
            (
                iteration,
                get_seed,    
                initial_cost,
                current_cost,  
                ref_point,
                max_hv,           
                met_best_nsga2_regret,
                met_best_nsga2_hv,            
                x_intervention,
                train_x,
                train_obj,
                train_constraints,
                objective_indices,
                constraints_indices,
                fixed_features_list,
                target_train_x,
                target_fid_obj,
                target_fid_constraints,
            )   = self.load_state_mf()          
            self.validated_budget(budget, current_cost)
            
            (causal_graph, causal_net, 
                causal_net_loss, scm) = self.load_causal_model()         

            mll, model = self.update_causal_GP(
                is_multifidelity=True,
                state_dict=self.state['model'],
                train_x=normalize(train_x, self.bounds),
                train_objectives=train_obj,
                train_constraints=train_constraints,
                causal_net=causal_net,
                data_covar_module=cgp_data_covar_module,
                fidelity_covar_module=cgp_fidelity_covar_module,
            )
            if compute_metrics:
                optimization_pref = MultiObjectiveOptimizationPref(
                    ref_point=ref_point,
                    max_hv=max_hv,
                    **self.tkwargs
                )              
        else:
            target_train_x = None
            target_fid_obj = None
            target_fid_constraints = None
            fixed_features_list = None
            met_curr_nsga2_regret = None
            met_curr_nsga2_hv = None
            met_curr_nsga2_violation = None
            met_observed_regret = None
            met_observed_hv = None
            met_observed_violation = None
            met_best_nsga2_regret = None
            met_best_nsga2_hv = None  
            iteration = 0 
            current_cost = 0  

            if seed is not None:
                torch.manual_seed(seed)
            get_seed = torch.initial_seed()

            # Generate initial data
            train_x, train_obj, train_constraints = self._generate_initial_data(init_budget) 
            
            # Track optimization
            initial_cost = _get_total_cost(train_x)
            print("Initial cumulative cost:", initial_cost) if show_status else None
            if include_initcost_to_budget:
                current_cost = initial_cost
            self.validated_budget(budget, current_cost)

            # Generate fixed features for each fidelity level
            # when is_discrete_fidelities is True
            if self.is_discrete_fidelities:
                fixed_features_list = self.fidelity_to_fixed_features_list()

            # Generate intervention values
            x_intervention = self.generate_x_intervention_val(causal_intervention_samples)
            
            # Initialize causal model
            causal_graph, causal_net, causal_net_loss, scm = self.update_causal_model(
                is_multifidelity=True,
                fidelity_param_name=self.fidelity_param_name,
                observational_data=causal_observational_data,  # normalized inside base.py
                x_intervention_val=x_intervention, # normalized inside base.py
                causal_discovery=causal_discovery,
                is_design_var_independent=causal_is_design_var_independent,
                alpha=causal_discovery_alpha,
                num_interventions=causal_num_interventions,
                epochs=causal_net_epochs,
                save_graph=causal_save_graph,
                causal_inf_backend=causal_inf_backend,
                causal_inf_batch_size=causal_inf_batch_size,
                causal_inf_n_jobs=causal_inf_n_jobs,
            )
                
            # Initialize causal GP        
            mll, model = self.update_causal_GP(
                is_multifidelity=True,
                train_x=normalize(train_x, self.bounds),
                train_objectives=train_obj,
                train_constraints=train_constraints,
                causal_net=causal_net,
                data_covar_module=cgp_data_covar_module,
                fidelity_covar_module=cgp_fidelity_covar_module,
            )
            self.update_causal_gp(
                mll=mll, 
                use_rescue_fit=use_rescue_fit
            ) 
            # +===========================================+
            # |       END RESCUE Initializations          |
            # +===========================================+    

            # Do not contribute to the optimization performance
            # For computing exp_stats          
            if compute_metrics:
                met_best_nsga2_regret = float("inf")
                met_best_nsga2_hv = float("-inf")                             
                target_train_x, target_fid_obj, target_fid_constraints = (
                    self.get_target_fid_observations(train_x=train_x)
                )  
                optimization_pref = MultiObjectiveOptimizationPref(
                    ref_point=ref_point,
                    max_hv=max_hv,
                    **self.tkwargs
                )    
            # END Do not contribute to the optimization performance 

            # Rescue state
            self.update_state(
                is_multifidelity=True,
                get_seed=get_seed,
                budget=budget,
                ref_point=ref_point,
                max_hv=max_hv,
                iteration=iteration,
                met_best_nsga2_regret=met_best_nsga2_regret,
                met_best_nsga2_hv=met_best_nsga2_hv,
                acquisition_value=None,
                new_fidelity=None,
                initial_cost=initial_cost,
                current_cost=current_cost,
                train_x=train_x,
                train_obj=train_obj,
                train_constraints=train_constraints,
                fixed_features_list=fixed_features_list,
                target_train_x=target_train_x,
                target_fid_obj=target_fid_obj,
                target_fid_constraints=target_fid_constraints,
                objective_indices=objective_indices,
                constraints_indices=constraints_indices,
                x_intervention=x_intervention,
                causal_graph=causal_graph,
                scm=scm,
                causal_net=causal_net,
                model=model,
                causal_net_loss=causal_net_loss,
            )
            self.log_artifacts_to_wandb()

        # +=================================+
        # |            BO loop              |
        # +=================================+   
        if acqfn_use_causalmomf:
            causal_momf_sampler = SobolQMCNormalSampler(
                sample_shape=torch.Size([acqfn_causalmomf_mc_samples])
            )   
        while current_cost < budget:
            if acqfn_use_causalmomf:
                new_x, acq_value = self.optimize_CausalMOMF(
                    model=model,
                    causal_model=causal_net,
                    causal_weight=acqfn_curr_val_causal_weight,
                    train_obj=train_obj,
                    sampler=causal_momf_sampler,
                    ref_point=ref_point,
                    q=acqfn_q,
                    num_restarts=acqfn_num_restarts,
                    raw_samples=acqfn_raw_samples,
                    objective_indices=objective_indices,
                    constraints_indices=constraints_indices,
                    fixed_features_list=fixed_features_list,
                    options=optim_acqfn_options
                )
            else:
                new_x, acq_value = self.optimize_causal_hvkg(
                    q=acqfn_q,
                    model=model,
                    causal_model=causal_net,
                    causal_weight=acqfn_curr_val_causal_weight,
                    ref_point=ref_point,
                    num_restarts=acqfn_num_restarts,
                    curr_val_mc_samples=acqfn_curr_val_mc_samples,
                    curr_val_raw_samples=acqfn_curr_val_raw_samples,
                    curr_val_num_restarts=acqfn_curr_val_num_restarts,
                    inner_mc_samples=acqfn_inner_mc_samples,
                    num_fantasies=acqfn_num_fantasies,
                    num_pareto=acqfn_num_pareto,
                    use_posterior_mean=acqfn_use_posterior_mean,
                    raw_samples=acqfn_raw_samples,
                    objective_indices=objective_indices,
                    fixed_features_list=fixed_features_list,
                    options=optim_acqfn_options,
                    optim_valuef_options=optim_inner_acqfn_options
                )
            new_fidelity = new_x.detach()[..., -1].item()
            # Real intervention
            new_obj, new_constraints = self.obtain_new_y(new_x)

            self.validate_new_constraints(new_constraints)
            # Update training data
            train_x = torch.cat([train_x, new_x], dim=0)
            train_obj = torch.cat([train_obj, new_obj], dim=0)
            if new_constraints is not None:
                train_constraints = torch.cat([train_constraints, new_constraints])

            # Update causal model periodically
            if fit_causal_model_after_each is not None:
                if iteration % fit_causal_model_after_each == 0:
                    causal_graph, causal_net, causal_net_loss, scm = self.update_causal_model(
                        is_multifidelity=True,
                        fidelity_param_name=self.fidelity_param_name,
                        train_x=train_x,
                        train_objectives=train_obj,
                        observational_data=causal_observational_data,
                        train_constraints=train_constraints,
                        x_intervention_val=x_intervention,
                        causal_discovery=causal_discovery,
                        alpha=causal_discovery_alpha,
                        num_interventions=causal_num_interventions,
                        epochs=causal_net_epochs,
                        is_design_var_independent=causal_is_design_var_independent,
                        save_graph=causal_save_graph,
                        causal_inf_backend=causal_inf_backend,
                        causal_inf_batch_size=causal_inf_batch_size,
                        causal_inf_n_jobs=causal_inf_n_jobs,
                    )

            # Update causal GP
            mll, model = self.update_causal_GP(
                is_multifidelity=True,
                train_x=normalize(train_x, self.bounds),
                train_objectives=train_obj,
                train_constraints=train_constraints,
                causal_net=causal_net,
                data_covar_module=cgp_data_covar_module,
                fidelity_covar_module=cgp_fidelity_covar_module,
            )
            self.update_causal_gp(
                mll=mll,
                use_rescue_fit=use_rescue_fit
            )

            # Update budget
            current_cost += _get_new_cost(new_x)
            iteration += 1
            # +=================================+
            # |          End BO loop            |
            # +=================================+

            # +===========================================+
            # |           Do not contribute               |
            # |     to the optimization performance       |
            # +===========================================+
            # This can be time-consuming if target fidelity is costly
            # Should be only used for research purposes.
            # Logging and computing metrics can be
            # done offine or parallel to the optimization
            # in that case, compute_metrics should be False.
            if compute_metrics:
                target_new_x = self._project_to_target_fidelity(new_x)
                target_train_x = torch.cat([target_train_x, target_new_x], dim=0)
                # We need obtain the new observations at the target fidelity
                new_target_fid_obj = self.problem(target_new_x)
                if self.has_constraints:
                    new_target_fid_constraints = -self.problem.evaluate_slack(target_new_x)
                    target_fid_constraints = torch.cat(
                        [target_fid_constraints, new_target_fid_constraints], 
                        dim=0
                    )
                target_fid_obj = torch.cat([target_fid_obj, new_target_fid_obj], dim=0)
                ## Evaluate optimization performance
                # Compute HV and regret using NSGA-II on the posterior 
                (met_curr_nsga2_hv, met_curr_nsga2_regret, met_curr_nsga2_violation,
                    met_observed_hv, met_observed_regret, met_observed_violation) = (
                        self.get_metrics(
                            is_mf_model=True,
                            # future me, we need to drop the fidelity dim for
                            # evaluating the problem
                            input_dim_without_fid=self.input_dim - 1,
                            project=self._project_to_target_fidelity,
                            compute_metrics=compute_metrics,
                            model=model,
                            train_objectives=target_fid_obj,
                            train_constraints=target_fid_constraints,
                            optimization_pref=optimization_pref,
                            objective_indices=objective_indices,
                            constraints_indices=constraints_indices,
                            seed=get_seed,
                        )
                    )
                if has_maxhv:
                    if met_curr_nsga2_regret is not None:
                        if met_curr_nsga2_regret < met_best_nsga2_regret:
                            met_best_nsga2_regret = met_curr_nsga2_regret
                if met_curr_nsga2_hv is not None:
                    if met_curr_nsga2_hv > met_best_nsga2_hv:
                        met_best_nsga2_hv = met_curr_nsga2_hv 

            stats = self.exp_stats(
                budget=budget,
                initial_cost=initial_cost,
                seed=get_seed,
                compute_metrics=compute_metrics,
                has_maxhv=has_maxhv,
                iteration=iteration,
                current_cost=current_cost,
                acq_value=acq_value,
                causal_net_loss=causal_net_loss,
                best_nsga2_regret=met_best_nsga2_regret,
                curr_nsga2_regret=met_curr_nsga2_regret,
                curr_nsga2_violation=met_curr_nsga2_violation,
                observed_regret=met_observed_regret,
                best_nsga2_hv=met_best_nsga2_hv,
                curr_nsga2_hv=met_curr_nsga2_hv,
                observed_hv=met_observed_hv,
                observed_violation=met_observed_violation,
                new_fidelity=new_fidelity
            )           
            self.log_exp_stats_to_wandb(
                iter=iteration, 
                exp_stats=stats,
                compute_metrics=compute_metrics,
            )
            self.save_stats_to_json(stats)
            # +===========================================+
            # |          End Do not contribute            |
            # |     to the optimization performance       |
            # +===========================================+    
               
            # Update state
            self.update_state(
                is_multifidelity=True,
                get_seed=get_seed,
                budget=budget,
                ref_point=ref_point,
                max_hv=max_hv,
                iteration=iteration,
                met_best_nsga2_regret=met_best_nsga2_regret,
                met_best_nsga2_hv=met_best_nsga2_hv,
                acquisition_value=acq_value,
                new_fidelity=new_fidelity,
                initial_cost=initial_cost,
                current_cost=current_cost,
                train_x=train_x,
                train_obj=train_obj,
                train_constraints=train_constraints,
                fixed_features_list=fixed_features_list,
                target_train_x=target_train_x,
                target_fid_obj=target_fid_obj,
                target_fid_constraints=target_fid_constraints,
                objective_indices=objective_indices,
                constraints_indices=constraints_indices,
                x_intervention=x_intervention,
                causal_graph=causal_graph,
                scm=scm,
                causal_net=causal_net,
                model=model,
                causal_net_loss=causal_net_loss,
            )
            self.log_artifacts_to_wandb()

            # === Terminal log ===
            term_print(
                show_stats=show_status, 
                exp_stats=stats,
                budget=budget,
                has_constraints=self.has_constraints
            )
            torch.cuda.empty_cache()
        if self.wandb_run is not None:
            self.wandb_run.finish()
        res_pymoo = nsga2_posterior_pareto(
            model=model,
            input_dim_without_fid=self.input_dim - 1,
            num_objectives=len(objective_indices),
            objective_indices=objective_indices,
            is_mf_model=True,
            device=self.tkwargs['device'],
            dtype=self.tkwargs['dtype'],
            project_to_target_fidelity=self._project_to_target_fidelity,
            constraints_indices=constraints_indices,
        )
        res = {
            "pareto_X": res_pymoo.X,
            "pareto_Y": res_pymoo.F,
            "constraints": res_pymoo.G,
            "pymoo_res": res_pymoo,
            "cgp_model": model,
            "cpm": causal_graph,
        }
        return res
   