#!/usr/bin/env python3

from __future__ import annotations

from typing import Callable

import pandas as pd
import torch
from torch import Tensor
from wandb.sdk.wandb_run import Run

from botorch.test_functions.base import (
    ConstrainedBaseTestProblem,
    MultiObjectiveTestProblem,
)
from botorch.models.model import Model
from botorch.sampling.normal import SobolQMCNormalSampler
from botorch import fit_gpytorch_mll
from botorch.optim import optimize_acqf
from botorch.utils.transforms import unnormalize, normalize
from botorch.acquisition.multi_objective.objective import (
    IdentityMCMultiOutputObjective
)
from botorch.acquisition.multi_objective.objective import MCMultiOutputObjective

from gpytorch.kernels import Kernel

from rescue.acquisition.causal_knowledge_gradient import (
    qCausalHypervolumeKnowledgeGradient,
    causal_hv_value_function
)
from botorch.acquisition.multi_objective import (
    qLogExpectedHypervolumeImprovement,
)
from botorch.utils.multi_objective.box_decompositions.non_dominated import (
    FastNondominatedPartitioning,
)
from rescue.models import fit_causal_gp
from rescue.acquisition.optim import gen_one_shot_hvkg_initial_conditions
from rescue.algorithms.base import BaseRescueAlgorithm
from rescue.metrics.optimization_pref import (
    MultiObjectiveOptimizationPref,
    nsga2_posterior_pareto
)
from rescue.metrics.utils import term_print
from rescue.algorithms.state import RescueState
from rescue.utils.utils import status
from rescue.algorithms.base import CausalGPFn, GenInitDataFn


class RescueAlgorithmSinglefidelity(BaseRescueAlgorithm):
    def __init__(
        self,
        problem: MultiObjectiveTestProblem | ConstrainedBaseTestProblem,
        design_variables: list[str],
        objective_variables: list[str],
        bounds: Tensor,
        gen_initial_data: GenInitDataFn | None = None,
        custom_model: None | CausalGPFn = None,
        verbose: bool = False,
        status_spinner: bool = True,
        cost_fn: None | Callable[[Tensor], Tensor] = None,
        constraint_variables: None | list[str] = None,
        kpi_variables: None | list[str] = None,
        device: None | torch.device = None,
        dtype: None | torch.dtype = None,
        wandb_run: None | Run = None,
        rescue_state: None | RescueState = None,
    ) -> None:
        r"""
        Initialize the single-fidelity rescue algorithm.

        Args:
            problem (MultiObjectiveTestProblem | ConstrainedBaseTestProblem): The 
                optimization problem to solve.
            design_variables (list[str]): The design variables to optimize.
            objective_variables (list[str]): The objective variables to optimize.
            bounds (Tensor): The bounds for the design variables.
            gen_initial_data (GenInitDataFn | None): A function to generate initial data.
            custom_model (None | CausalGPFn): A custom model function.
            verbose (bool): Whether to print verbose output.
            status_spinner (bool): Whether to show a status spinner.
            cost_fn (None | Callable[[Tensor], Tensor]): The cost function for evaluations.
            constraint_variables (None | list[str]): The constraint variables to track.
            kpi_variables (None | list[str]): The KPI variables to track.
            device (None | torch.device): The device to use.
            dtype (None | torch.dtype): The dtype to use.
            wandb_run (None | Run): The wandb run to log to.
            rescue_state (None | RescueState): The rescue state to use.
        """
        super().__init__(
            problem=problem,
            custom_model=custom_model,
            cost_fn=cost_fn,
            design_variables=design_variables,
            objective_variables=objective_variables,
            constraint_variables=constraint_variables,
            kpi_variables=kpi_variables,
            bounds=bounds,
            gen_initial_data=gen_initial_data,
            verbose=verbose,
            status_spinner=status_spinner,
            device=device,
            dtype=dtype,
            wandb_run=wandb_run,
            rescue_state=rescue_state,
        )
        # Future me: Naive check for rescue state
        # can be easily bypassed
        if self.rescue_state is not None:
            if self.state['is_multifidelity']:
                raise ValueError("Expected `algorithms.state.RescueState`, "
                                 "but found `algorithms.state.RescueStateMultifidelity`!"
                                )
        # *** Debugging only ***
        self.status_spinner = status_spinner  # For status decorator
        # *********************

    @status(show_func_name=True)
    def _generate_initial_data(
        self,
        n: float | int | None = None,
    ) -> tuple[Tensor, Tensor, Tensor | None]:
        r"""
        Generate initial training data.
        
        Args:
            n: Number of initial samples (or budget for multi-fidelity)
            
        Returns:
            tuple[Tensor, Tensor, Tensor | None]: 
                - train_x: Input samples with dim `n x d` where last 
                    column of `d` corresponds to fidelity levels.
                - train_obj: Objective values with dim `n x m`.
                - train_constraints: Constraint values with dim `n x c`
                    (None if no constraints).
        """
        return self.generate_initial_data(n=n)

    @status(show_func_name=True) 
    def update_causal_gp(self, mll, use_rescue_fit=True) -> None:
        """Fit the causal GP model.

        Args:
            mll: The marginal likelihood to optimize.
            use_rescue_fit (bool): Whether to use the rescue fit method. Defaults to True.
        """
        if use_rescue_fit:
            fit_causal_gp(mll, verbose=self.verbose)
        else:
            fit_gpytorch_mll(mll) 

    def get_current_value(
        self,
        model: Model,
        causal_model: torch.nn,
        causal_weight: float,
        ref_point: torch.Tensor,
        raw_samples: int,
        mc_samples: int,
        num_restarts: int,
        q: int,
        use_posterior_mean: bool = True,
        objective: MCMultiOutputObjective | None = None,
        constraints: list[Callable[[Tensor], Tensor]] | None = None,
        options: dict[str, bool | float | int | str] | None = None,
    ):
        """Helper to get the hypervolume of the current hypervolume
        maximizing set.
        """
        if options is None:
            options = {"nonnegative": True}
        else:
            options["nonnegative"] = True 
        if not use_posterior_mean and mc_samples is not None:
            sampler = SobolQMCNormalSampler(sample_shape=torch.Size([mc_samples]))
        else:
            sampler = None
        curr_val_acqf=causal_hv_value_function(
            model=model,
            causal_model=causal_model,
            causal_weight=causal_weight,
            ref_point=ref_point,
            sampler=sampler,
            use_posterior_mean=use_posterior_mean,
            objective=objective,
            constraints=constraints,
        )
        # optimize
        _, current_value = optimize_acqf(
            acq_function=curr_val_acqf,
            bounds=self.standard_bounds,
            q=q,
            num_restarts=num_restarts,
            raw_samples=raw_samples,
            return_best_only=True,
            options=options,
        )
        return current_value

    @status(show_func_name=True) 
    def optimize_causal_hvkg(
        self,
        q: int,
        model: Model,
        causal_model: torch.nn,
        causal_weight: float,
        ref_point: Tensor,
        curr_val_raw_samples: int,
        curr_val_mc_samples: int,
        curr_val_num_restarts: int,
        inner_mc_samples: int,
        num_pareto: int,
        num_fantasies: int,
        num_restarts: int,
        raw_samples: int,
        use_posterior_mean: bool = True,
        objective_indices: None | list[int] = None,
        options: dict[str, bool | float | int | str] | None = None,
        optim_valuef_options: dict[str, bool | float | int | str] | None = None
    ) -> tuple[Tensor, Tensor]:
        r"""
        Optimize the acquisition function to get next candidate and fidelity.
        NOTE: Assumes last `c` dim of `train_y` corresponds to constraints.

        Args:
            q (int): The batch size for acquisition.
            model (Model): The surrogate model to optimize.
            causal_model (torch.nn): The causal model.
            causal_weight (float): The weight for the causal component.
            ref_point (Tensor): The reference point for the optimization.
            curr_val_raw_samples (int): The number of raw samples for current value optimization.
            curr_val_mc_samples (int): The number of MC samples for current value.
            curr_val_num_restarts (int): The number of restarts for current value optimization.
            inner_mc_samples (int): The number of inner MC samples.
            num_pareto (int): The number of Pareto points to consider.
            num_fantasies (int): The number of fantasy points for the knowledge gradient.
            num_restarts (int): The number of restarts for the optimization.
            raw_samples (int): The number of raw samples to generate.
            use_posterior_mean (bool): Whether to use the posterior mean. Defaults to True.
            objective_indices (None | list[int]): The indices of the objectives to 
                optimize. Defaults to None.
            options (dict[str, bool | float | int | str] | None): Additional optimization 
                options. Defaults to None.
            optim_valuef_options (dict[str, bool | float | int | str] | None): Options for 
                value function optimization. Defaults to None.

        Returns:
            Tuple (Tensor, Tensor): (new_x, acq_value).
        """
        if not use_posterior_mean and inner_mc_samples is not None:
            inner_sampler = SobolQMCNormalSampler(
                            sample_shape=torch.Size([inner_mc_samples])
                        )
        else:
            inner_sampler = None
        if self.has_constraints:
            # Constraints are in the model after all objectives
            # Model structure: [obj0, obj1, ..., objN, cons0, cons1, ...]
            num_objectives = len(objective_indices)
            constraints = [
                lambda Z, idx=num_objectives+j: Z[..., idx] 
                for j in range(self.num_constraints)
            ]
        objective = IdentityMCMultiOutputObjective(
                        outcomes=objective_indices,
        )
        current_value = self.get_current_value(
            model=model,
            causal_model=causal_model,
            causal_weight=causal_weight,
            ref_point=ref_point,
            objective=objective,
            constraints=constraints if self.has_constraints else None,
            raw_samples=curr_val_raw_samples,
            mc_samples=curr_val_mc_samples,
            q=num_pareto,
            use_posterior_mean=use_posterior_mean,
            num_restarts=curr_val_num_restarts,
            options=optim_valuef_options
        )
        acq_func = qCausalHypervolumeKnowledgeGradient(
                model=model,
                ref_point=ref_point,  # use known reference point
                num_fantasies=num_fantasies,
                num_pareto=num_pareto,
                current_value=current_value,
                inner_sampler=inner_sampler,
                use_posterior_mean=use_posterior_mean,
                objective=objective,
                constraints=constraints if self.has_constraints else None,
            )
        # Optimization
        candidates, vals = optimize_acqf(
            acq_function=acq_func,
            bounds=self.standard_bounds,
            q=q,
            num_restarts=num_restarts,
            raw_samples=raw_samples,  # used for intialization heuristic
            ic_generator=gen_one_shot_hvkg_initial_conditions,
            options=options,
        )
        # if the AF val is 0, set the fidelity parameter to zero
        if vals.item() == 0.0:
            candidates[:, -1] = 0.0
        # observe new values
        new_x = unnormalize(candidates.detach(), bounds=self.bounds)
        return new_x, vals     

    @status(show_func_name=True)
    def optimize_qLogEHVI(
        self,
        model: Model,
        train_obj: torch.Tensor,
        ref_point: torch.Tensor,
        q: int,
        num_restarts: int,
        raw_samples: int,
        sampler: SobolQMCNormalSampler,
        objective_indices: None | list[int] = None,
        constraints_indices: None | list[int] = None,
        options: None | dict[str, bool | float | int | str] = None
    ) -> tuple[Tensor, Tensor]:
        r"""
        Optimize using qLogExpectedHypervolumeImprovement acquisition function.

        Assumptions:
            - last `c` dim corresponds to constraints
        
        Args:
            model (Model): The surrogate model to optimize.
            train_obj (torch.Tensor): The training objectives.
            ref_point (torch.Tensor): The reference point for hypervolume computation.
            q (int): The batch size for acquisition.
            num_restarts (int): The number of restarts for the optimization.
            raw_samples (int): The number of raw samples to generate.
            sampler (SobolQMCNormalSampler): The sampler to use for MC sampling.
            objective_indices (None | list[int]): The indices of the objectives. Defaults to None.
            constraints_indices (None | list[int]): The indices of the constraints. Defaults to None.
            options (None | dict[str, bool | float | int | str]): Additional optimization options. Defaults to None.

        Returns:
            Tuple (Tensor, Tensor): (new_x, acq_value).
        """
        partitioning = FastNondominatedPartitioning(ref_point=ref_point, Y=train_obj)
        if self.has_constraints:
            # Constraints are in the model after all objectives
            # Model structure: [obj0, obj1, ..., objN, cons0, cons1, ...]
            num_objectives = len(objective_indices)
            constraints = [
                lambda Z, idx=num_objectives+j: Z[..., idx] 
                for j in range(len(constraints_indices))
            ]
        acq_func = qLogExpectedHypervolumeImprovement(
            model=model,
            ref_point=ref_point,  # use known reference point
            partitioning=partitioning,
            sampler=sampler,
            objective=IdentityMCMultiOutputObjective(
                outcomes=objective_indices
            ),
            constraints=constraints if self.has_constraints else None,
        )
        # Optimization
        candidates, vals = optimize_acqf(
            acq_function=acq_func,
            bounds=self.standard_bounds,
            q=q,
            num_restarts=num_restarts,
            raw_samples=raw_samples,  # used for intialization heuristic
            sequential=True,
            options=options
        )
        new_x = unnormalize(candidates.detach(), bounds=self.bounds)
        return new_x, vals

    def run(
        self,
        budget: int | float,
        ref_point: Tensor,
        objective_indices: list[int],
        causal_observational_data: pd.DataFrame,
        init_budget: int | None = None,
        include_initcost_to_budget: bool = True,
        use_rescue_fit: bool = False,
        causal_intervention_samples: int = 1000,
        causal_discovery: str = 'PC',
        causal_num_interventions: int = 200,
        causal_net_epochs: int = 100,
        causal_inf_backend: str = "loky",
        causal_inf_batch_size: int = 10,
        causal_is_design_var_independent: bool = True,
        causal_inf_n_jobs: int = 1,  
        causal_discovery_alpha: None | float = None,      
        cgp_base_covar_module: None | Kernel = None,
        fit_causal_model_after_each: None | int = None,
        causal_save_graph: bool = False,  
        acqfn_q: int = 1,   
        acqfn_num_restarts: int = 10,
        acqfn_raw_samples: int = 512,
        acqfn_curr_val_causal_weight: float = 1.0,
        acqfn_curr_val_raw_samples: int = 2 * 512,
        acqfn_curr_val_mc_samples: int = 32,
        acqfn_curr_val_num_restarts: int = 1,
        acfn_inner_mc_samples: int = 64,
        acqfn_num_fantasies: int = 8,
        acqfn_num_pareto: int = 32,
        acqfn_use_qlogehvi: bool = True,
        acqfn_causalehvi_mc_samples: int = 128,
        acqfn_use_posterior_mean: bool = True,
        optim_acqfn_options: None | dict[str, bool | float | int | str] = None,
        optim_inner_acqfn_options: dict[str, bool | float | int | str] | None = None,
        max_hv: float | Tensor | None = None,
        constraints_indices: None | list[int] = None,             
        seed: int | None = None,
        compute_metrics: bool = True,
        show_status: bool = True
    ) -> tuple[RescueState, dict[str, float]]:
        r""" 
        Run the single-fidelity RESCUE optimization algorithm.

        NOTE: Assumes last dim of `train_y` corresponds to constraints.

        Args:
            budget (int | float): The budget for the optimization run.
            ref_point (Tensor): The reference point for the optimization.
            causal_observational_data (pd.DataFrame): Observational data for causal discovery.
            objective_indices (list[int]): The indices of the objective variables.
            init_budget (int | None): Initial budget for data generation. Defaults to None.
            include_initcost_to_budget (bool): Whether to include the initial cost in the budget.
            use_rescue_fit (bool): Whether to use the rescue fit method. Defaults to False.
            causal_intervention_samples (int): The number of causal intervention samples to generate.
            causal_discovery (str): The causal discovery method to use.
            causal_num_interventions (int): The number of interventions to perform.
            causal_net_epochs (int): The number of epochs to train the causal network.
            causal_inf_backend (str): The backend for causal inference.
            causal_inf_batch_size (int): The batch size for causal inference.
            causal_is_design_var_independent (bool): Whether design variables are independent.
            causal_inf_n_jobs (int): The number of jobs for causal inference.
            causal_discovery_alpha (None | float): The significance level for causal discovery.
            cgp_base_covar_module (None | Kernel): The base covariance module for the causal GP.
            fit_causal_model_after_each (None | int): The frequency of fitting the causal model.
            causal_save_graph (bool): Whether to save the causal graph.
            acqfn_q (int): The batch size for acquisition.
            acqfn_num_restarts (int): The number of restarts for the acquisition function.
            acqfn_raw_samples (int): The number of raw samples for the acquisition function.
            acqfn_curr_val_causal_weight (float): The causal weight for current value computation.
            acqfn_curr_val_raw_samples (int): The number of raw samples for current value.
            acqfn_curr_val_mc_samples (int): The number of MC samples for current value.
            acqfn_curr_val_num_restarts (int): The number of restarts for current value.
            acfn_inner_mc_samples (int): The number of inner MC samples.
            acqfn_num_fantasies (int): The number of fantasies for knowledge gradient.
            acqfn_num_pareto (int): The number of Pareto points.
            acqfn_use_qlogehvi (bool): Whether to use qLogEHVI acquisition function.
            acqfn_causalehvi_mc_samples (int): The number of MC samples for causal EHVI.
            acqfn_use_posterior_mean (bool): Whether to use posterior mean.
            optim_acqfn_options (None | dict[str, bool | float | int | str]): Options for 
                acquisition optimization.
            optim_inner_acqfn_options (dict[str, bool | float | int | str] | None): Options for 
                inner acquisition optimization.
            max_hv (float | Tensor | None): The maximum hypervolume to achieve.
                If provided, the optimization method will be evaluated at runtime.
                Which can be extremely costly as evaluation is done at the
                target fidelity. NOTE: Should be only used for research purposes.
            constraints_indices (None | list): The indices of the constraint variables.
            seed (int | None): The `torch.manual_seed` seed for reproducibility.
            compute_metrics (bool): Whether to evaluate the optimization preferences at runtime.
            show_status (bool): Whether to show status updates.

        Returns:
            dict[str, Any]: A dictionary containing the results of the optimization run.
                - "pareto_X": The Pareto points.
                - "pareto_Y": The objective values of the Pareto points.
                - "constraints": The constraint values of the Pareto points.
                - "pymoo_res": The pymoo Result object.
                - "cgp_model": The causal GP model.
                - "cpm": The causal performance model.
        """
        if causal_discovery == "FCI":
            raise NotImplementedError(
                "FCI produces Partial Ancestral Graphs (PAGs) "
                "which is resolved using information theoretic approach to "
                "result in a Acyclic Directed Mixed Graph (ADMG)."
                "However, `dowhy` does not support ADMGs"
            )        
        self.validate_run_inputs(compute_metrics, objective_indices, constraints_indices)
        has_maxhv = max_hv is not None
        # === Helper Functions for run ===
        def _get_total_cost(train_x: Tensor):
            if self.cost_fn is not None:
                return self.cost_fn(train_x).sum().item() 
            else:
                return train_x.shape[0]

        def _get_new_cost(new_x: Tensor):
            if self.cost_fn is not None:
                return self.cost_fn(new_x).item()
            else:
                return 1
        # === End Helper Functions for run ===     
    
        # +===========================================+
        # |           RESCUE Initializations          |
        # +===========================================+
        if self.rescue_state is not None:        
            (
                iteration,
                get_seed,
                initial_cost,
                current_cost,           
                ref_point,
                max_hv,
                met_best_nsga2_regret,
                met_best_nsga2_hv,
                x_intervention,
                train_x,
                train_obj,
                train_constraints,
                objective_indices,
                constraints_indices
            ) = self.load_state()
            self.validated_budget(budget, current_cost)

            causal_graph, causal_net, causal_net_loss, scm = self.load_causal_model()         

            mll, model = self.update_causal_GP(
                is_multifidelity=False,
                state_dict=self.state['model'],
                train_x=normalize(train_x, self.bounds),
                train_objectives=train_obj,
                train_constraints=train_constraints,
                causal_net=causal_net,
                data_covar_module=cgp_base_covar_module,
            )
            if compute_metrics:
                optimization_pref = MultiObjectiveOptimizationPref(
                    ref_point=ref_point,
                    max_hv=max_hv,
                    **self.tkwargs
                )              
        else:
            met_best_nsga2_regret = None
            met_best_nsga2_hv = None    
            met_curr_nsga2_regret = None
            met_curr_nsga2_hv = None
            met_curr_nsga2_violation = None
            met_observed_regret = None
            met_observed_hv = None
            met_observed_violation = None        
            iteration = 0
            current_cost = 0

            if seed is not None:
                torch.manual_seed(seed)
            get_seed = torch.initial_seed()

            # Generate initial data
            train_x, train_obj, train_constraints = self._generate_initial_data(init_budget)

            # Track optimization
            initial_cost = _get_total_cost(train_x)
            print("Initial cumulative cost:", initial_cost) if show_status else None
            if include_initcost_to_budget:
                current_cost = initial_cost
            self.validated_budget(budget, current_cost)

            # Generate intervention values
            x_intervention = self.generate_x_intervention_val(causal_intervention_samples)
            # Initialize causal model
            causal_graph, causal_net, causal_net_loss, scm = self.update_causal_model(
                is_multifidelity=False,
                observational_data=causal_observational_data,  # normalized inside base.py
                x_intervention_val=x_intervention, # normalized inside base.py
                causal_discovery=causal_discovery,
                alpha=causal_discovery_alpha,
                is_design_var_independent=causal_is_design_var_independent,
                num_interventions=causal_num_interventions,
                epochs=causal_net_epochs,
                save_graph=causal_save_graph,
                causal_inf_backend=causal_inf_backend,
                causal_inf_batch_size=causal_inf_batch_size,
                causal_inf_n_jobs=causal_inf_n_jobs,
            )
                
            # Initialize causal GP
            mll, model = self.update_causal_GP(
                is_multifidelity=False,
                train_x=normalize(train_x, self.bounds),
                train_objectives=train_obj,
                train_constraints=train_constraints,
                causal_net=causal_net,
                data_covar_module=cgp_base_covar_module
            )
            self.update_causal_gp(
                mll=mll,
                use_rescue_fit=use_rescue_fit
            )
            # +===========================================+
            # |       END RESCUE Initializations          |
            # +===========================================+ 
                        
            if compute_metrics:        
                met_best_nsga2_regret = float("inf")
                met_best_nsga2_hv = float("-inf")                       
                optimization_pref = MultiObjectiveOptimizationPref(
                    ref_point=ref_point,
                    max_hv=max_hv,
                    **self.tkwargs
                )             
            # Resuce state
            self.update_state(
                is_multifidelity=False,
                get_seed=get_seed,
                budget=budget,
                ref_point=ref_point,
                max_hv=max_hv,
                iteration=iteration,
                met_best_nsga2_regret=met_best_nsga2_regret,
                met_best_nsga2_hv=met_best_nsga2_hv,
                acquisition_value=None,
                new_fidelity=None,
                initial_cost=initial_cost,
                current_cost=current_cost,
                train_x=train_x,
                train_obj=train_obj,
                train_constraints=train_constraints,
                objective_indices=objective_indices,
                constraints_indices=constraints_indices,
                x_intervention=x_intervention,
                causal_graph=causal_graph,
                scm=scm,
                causal_net=causal_net,
                model=model,
                causal_net_loss=causal_net_loss,
            )
            self.log_artifacts_to_wandb()

        # +=================================+
        # |            BO loop              |
        # +=================================+
        # Generate Sampler
        sampler = SobolQMCNormalSampler(
            sample_shape=torch.Size([acqfn_causalehvi_mc_samples])
        )   
        while current_cost < budget:
            if acqfn_use_qlogehvi:
                new_x, acq_value = self.optimize_qLogEHVI(
                    model=model,
                    train_obj=train_obj,
                    ref_point=ref_point,
                    q=acqfn_q,
                    num_restarts=acqfn_num_restarts,
                    raw_samples=acqfn_raw_samples,
                    sampler=sampler,
                    objective_indices=objective_indices,
                    constraints_indices=constraints_indices,
                    options=optim_acqfn_options
                )
            # Causal HVKG
            else:
                new_x, acq_value = self.optimize_causal_hvkg(
                    q=acqfn_q,
                    model=model,
                    causal_model=causal_net,
                    causal_weight=acqfn_curr_val_causal_weight,
                    ref_point=ref_point,
                    num_restarts=acqfn_num_restarts,
                    curr_val_mc_samples=acqfn_curr_val_mc_samples,
                    curr_val_raw_samples=acqfn_curr_val_raw_samples,
                    curr_val_num_restarts=acqfn_curr_val_num_restarts,
                    inner_mc_samples=acfn_inner_mc_samples,
                    use_posterior_mean=acqfn_use_posterior_mean,
                    num_fantasies=acqfn_num_fantasies,
                    num_pareto=acqfn_num_pareto,
                    raw_samples=acqfn_raw_samples,
                    objective_indices=objective_indices,
                    options=optim_acqfn_options,
                    optim_valuef_options=optim_inner_acqfn_options
                )
            # Real intervention
            new_obj, new_constraints = self.obtain_new_y(new_x)

            self.validate_new_constraints(new_constraints)
            # Update training data
            train_x = torch.cat([train_x, new_x], dim=0)
            train_obj = torch.cat([train_obj, new_obj], dim=0)
            if new_constraints is not None:
                train_constraints = torch.cat([train_constraints, new_constraints])

            # Update causal model periodically
            if fit_causal_model_after_each is not None:
                if iteration % fit_causal_model_after_each == 0:
                    causal_graph, causal_net, causal_net_loss, scm = self.update_causal_model(
                        is_multifidelity=False,
                        train_x=train_x,
                        train_objectives=train_obj,
                        observational_data=causal_observational_data,
                        train_constraints=train_constraints,
                        x_intervention_val=x_intervention,
                        causal_discovery=causal_discovery,
                        alpha=causal_discovery_alpha,
                        num_interventions=causal_num_interventions,
                        is_design_var_independent=causal_is_design_var_independent,
                        epochs=causal_net_epochs,
                        save_graph=causal_save_graph,
                        causal_inf_backend=causal_inf_backend,
                        causal_inf_batch_size=causal_inf_batch_size,
                        causal_inf_n_jobs=causal_inf_n_jobs,                
                    )

            # Update causal GP
            mll, model = self.update_causal_GP(
                is_multifidelity=False,
                train_x=normalize(train_x, self.bounds),
                train_objectives=train_obj,
                train_constraints=train_constraints,
                causal_net=causal_net,
                data_covar_module=cgp_base_covar_module
            )
            self.update_causal_gp(
                mll=mll,
                use_rescue_fit=use_rescue_fit
            )

            # Update budget
            current_cost += _get_new_cost(new_x)
            iteration += 1
            # +=================================+
            # |          End BO loop            |
            # +=================================+

            # +=================================+
            # |       Compute HV and Regret     |
            # +=================================+
            if compute_metrics:
                (met_curr_nsga2_hv, met_curr_nsga2_regret, met_curr_nsga2_violation,
                met_observed_hv, met_observed_regret, met_observed_violation) = (
                    self.get_metrics(
                        is_mf_model=False,
                        input_dim_without_fid=self.input_dim,
                        compute_metrics=compute_metrics,
                        model=model,
                        train_objectives=train_obj,
                        train_constraints=train_constraints,
                        optimization_pref=optimization_pref,
                        objective_indices=objective_indices,
                        constraints_indices=constraints_indices,
                        seed=get_seed,
                    )
                )
                if has_maxhv:
                    if met_curr_nsga2_regret is not None:
                        if met_curr_nsga2_regret < met_best_nsga2_regret:
                            met_best_nsga2_regret = met_curr_nsga2_regret
                if met_curr_nsga2_hv is not None:
                    if met_curr_nsga2_hv > met_best_nsga2_hv:
                        met_best_nsga2_hv = met_curr_nsga2_hv     

            stats = self.exp_stats(
                budget=budget,
                initial_cost=initial_cost,
                seed=get_seed,
                compute_metrics=compute_metrics,
                has_maxhv=has_maxhv,
                iteration=iteration,
                current_cost=current_cost,
                acq_value=acq_value,
                causal_net_loss=causal_net_loss,
                best_nsga2_regret=met_best_nsga2_regret,
                curr_nsga2_regret=met_curr_nsga2_regret,
                curr_nsga2_violation=met_curr_nsga2_violation,
                observed_regret=met_observed_regret,
                best_nsga2_hv=met_best_nsga2_hv,
                curr_nsga2_hv=met_curr_nsga2_hv,
                observed_violation=met_observed_violation,
                observed_hv=met_observed_hv,
            )           
            self.log_exp_stats_to_wandb(
                iter=iteration, 
                exp_stats=stats,
                compute_metrics=compute_metrics,
            )
            self.save_stats_to_json(stats)
            # +=================================+
            # |    End Compute HV and Regret    |
            # +=================================+

            # Update state
            self.update_state(
                is_multifidelity=False,
                get_seed=get_seed,
                budget=budget,
                ref_point=ref_point,
                max_hv=max_hv,
                iteration=iteration,
                met_best_nsga2_regret=met_best_nsga2_regret,
                met_best_nsga2_hv=met_best_nsga2_hv,
                acquisition_value=acq_value,
                new_fidelity=1.0,   # Single fidelity
                initial_cost=initial_cost,
                current_cost=current_cost,
                train_x=train_x,
                train_obj=train_obj,
                train_constraints=train_constraints,
                objective_indices=objective_indices,
                constraints_indices=constraints_indices,
                x_intervention=x_intervention,
                causal_graph=causal_graph,
                scm=scm,
                causal_net=causal_net,
                model=model,
                causal_net_loss=causal_net_loss,
            )
            self.log_artifacts_to_wandb()

            # === Terminal log ===
            term_print(
                show_stats=show_status, 
                exp_stats=stats,
                budget=budget,
                has_constraints=self.has_constraints
            )  
            torch.cuda.empty_cache()
        if self.wandb_run is not None:
            self.wandb_run.finish()
        res_pymoo = nsga2_posterior_pareto(
            model=model,
            input_dim_without_fid=self.input_dim,
            num_objectives=len(objective_indices),
            objective_indices=objective_indices,
            is_mf_model=False,
            device=self.tkwargs['device'],
            dtype=self.tkwargs['dtype'],
            constraints_indices=constraints_indices,
        )
        res = {
            "pareto_X": res_pymoo.X,
            "pareto_Y": res_pymoo.F,
            "constraints": res_pymoo.G,
            "pymoo_res": res_pymoo,
            "cgp_model": model,
            "cpm": causal_graph,
        }
        return self.state, stats






