#!/usr/bin/env python3

from __future__ import annotations

from typing import Callable, Protocol, Any

import os
import json
import uuid
from pathlib import Path
from datetime import datetime
from abc import ABC
import torch
from torch import Tensor
from dowhy.gcm import StructuralCausalModel

import networkx as nx
import pandas as pd
from wandb.sdk.wandb_run import Run

from botorch.test_functions.base import (
    ConstrainedBaseTestProblem,
    MultiObjectiveTestProblem,
)
from botorch.utils.sampling import draw_sobol_samples
from botorch.acquisition.utils import project_to_target_fidelity
from botorch.models.model import Model
from botorch.models.transforms import Standardize
from botorch.utils.transforms import normalize

from gpytorch.likelihoods import Likelihood
from gpytorch.mlls import ExactMarginalLogLikelihood
from gpytorch.kernels import Kernel, RBFKernel
from gpytorch.likelihoods import MultitaskGaussianLikelihood
from gpytorch.models import ExactGP

from rescue.models.causal_model.causal_model import CausalPerformanceModel
from rescue.models.causal_model.map_to_NN import CausalMeanVarSurrogateNN
from rescue.models.causal_gp.multitask import CausalMultitaskGP
from rescue.models.causal_gp.multitask_multifidelity import (
    CausalMultitaskMultifidelityGP
)
from rescue.models import fit_causal_model
from rescue.utils.load_GraphandData import make_dataframe
from rescue.metrics.optimization_pref import MultiObjectiveOptimizationPref
from rescue.algorithms.state import RescueState, RescueStateMultifidelity
from rescue.algorithms.initial_sampling import GenerateInitialSample
from rescue.utils.utils import status

import warnings


class CausalGPFn(Protocol):
    def __call__(
        self,
        *,
        train_x: torch.Tensor,
        train_objectives: torch.Tensor,
        train_constraints: torch.Tensor,
        causal_net: CausalMeanVarSurrogateNN,
        state_dict: dict,
        device: torch.device,
        dtype: torch.dtype,
        extra: object,
    ) -> tuple[ExactGP, Likelihood]:
        ...

class GenInitDataFn(Protocol):
    def __call__(
        self,
        *,
        n: float | int
    ) -> tuple[Tensor, Tensor, Tensor | None]:
        ...


class BaseRescueAlgorithm(ABC):
    r""" 
    Base class for rescue algorithms.
    """
    def __init__(
        self,
        problem: MultiObjectiveTestProblem | ConstrainedBaseTestProblem,
        design_variables: list[str],
        objective_variables: list[str],
        bounds: Tensor,
        gen_initial_data: GenInitDataFn | None = None,
        custom_model: None | CausalGPFn = None,        
        verbose: bool = False,
        status_spinner: bool = True,
        cost_fn: None | Callable[[Tensor], Tensor] = None,
        kpi_variables: None | list[str] = None,
        constraint_variables: None | list[str] = None,
        device: None | torch.device = None,
        dtype: None | torch.dtype = None,
        wandb_run: None | Run = None,
        rescue_state: None | RescueState = None,
    ) -> None:
        r""" 
        Initialize the base rescue algorithm.

        Args:
            problem (MultiObjectiveTestProblem | ConstrainedBaseTestProblem): The optimization 
                problem to solve.
            design_variables (list[str]): List of design variable names.
            objective_variables (list[str]): List of objective variable names.
            bounds (Tensor): Tensor defining the bounds of the design space.
            gen_initial_data (Callable[[float | int], tuple[Tensor, Tensor, Tensor | None]]): 
                Function to generate initial training data.
            custom_model (Callable):
                Optional custom function to create the causal GP model.
            verbose (bool): Whether to print verbose output.
            status_spinner (bool): Whether to show a status spinner.
            cost_fn (Callable[[Tensor], Tensor]): Optional cost function to evaluate the design.
            kpi_variables (list[str]): Optional list of KPI variable names.
            constraint_variables (list[str]): Optional list of constraint variable names.
            device (torch.device): Optional device to run the algorithm on.
            dtype (torch.dtype): Optional data type for the tensors.
            wandb_run (Run): Optional Weights & Biases run object.
            rescue_state (RescueState): Optional rescue state object.
        """      

        if device is None:
            device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
        dtype = torch.double if dtype is None else dtype
        # Tkwargs is a dictionary contaning data about data type and data device
        self.tkwargs = {  
            "dtype": dtype,
            "device": device,
        }       
        self.problem = problem
        self.cost_fn = cost_fn
        self.rescue_state = rescue_state
        if self.rescue_state is not None:
            self.state = rescue_state

            if not self.state['is_multifidelity']:
                warnings.warn("`rescue_state` is provided, "
                        "the algorithm will use `algorithms.state.RescueState` to initialize itself.",
                        UserWarning
                )
            if self.state['is_multifidelity']:
                warnings.warn("`rescue_state` is provided, "
                        "the algorithm will use `algorithms.state.RescueStateMultifidelity` "
                        "to initialize itself.",
                        UserWarning
                )
            
            self.design_variables = self.state['design_variables']
            self.objective_variables = self.state['objective_variables']
            self.constraint_variables = self.state['constraint_variables'] or []
            self.kpi_variables = self.state['kpi_variables'] or []
            self.bounds = self.state['bounds'].to(**self.tkwargs)
        else:
            self.design_variables = design_variables
            self.objective_variables = objective_variables
            self.constraint_variables = constraint_variables or []
            self.kpi_variables = kpi_variables or []
            self.bounds = bounds.to(**self.tkwargs)

            self.state = {}
            self.state['is_multifidelity'] = False
            self.state['design_variables'] = design_variables
            self.state['objective_variables'] = objective_variables
            self.state['kpi_variables'] = kpi_variables
            self.state['constraint_variables'] = constraint_variables
            self.state['bounds'] = self.bounds
        self.outcome_variables = (objective_variables 
                                + self.kpi_variables
                                + self.constraint_variables 
        )
        self.outcome_dim = len(self.outcome_variables)

        # Problem dimensions
        self.input_dim = len(design_variables)
        self.num_objectives = len(objective_variables)
        self.num_constraints = len(self.constraint_variables)
        self.standard_bounds = torch.zeros(
            2, self.bounds.shape[1], **self.tkwargs)
        self.standard_bounds[1] = 1
        
        self.verbose = verbose

        self.has_constraints = self.num_constraints > 0  
        self._validate_inputs()     
        self.wandb_run = wandb_run   

        self.custom_model = custom_model
        self.gen_initial_data = gen_initial_data

        # Generate unique results filepath
        self.problem_name = self.problem.__class__.__name__
        self.timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        self.unique_id = str(uuid.uuid4())[:8]

        # *** Debug only ***
        self.status_spinner = status_spinner 
        # *** Debug only ***


    def generate_initial_data(
        self, 
        n: int | float | None = None,
    ) -> tuple[Tensor, Tensor, Tensor | None]:
        r"""
        Generate initial training data.
        
        Args:
            n: Number of initial samples (or budget for multi-fidelity)
            
        Returns:
            tuple[Tensor, Tensor, Tensor | None]: 
                - train_x: Input samples with dim `n x d` where last 
                    column of `d` corresponds to fidelity levels.
                - train_obj: Objective values with dim `n x m`.
                - train_constraints: Constraint values with dim `n x c`
                    (None if no constraints).
        """
        if self.gen_initial_data is None:
            if n is None:
                raise ValueError("`initial_budget` must be specified.")
            sampling = GenerateInitialSample(
                problem=self.problem,
                bounds=self.bounds,
                has_constraints=self.has_constraints
            )
            train_x, train_obj, train_cons = sampling.generate_initial_data(n=n)
        else:
            train_x, train_obj, train_cons = self.gen_initial_data(n)
        self._validate_traindata(train_x, train_obj, train_cons)
        return train_x, train_obj, train_cons


    def generate_x_intervention_val(self, causal_intervention_samples: int):
        r"""
        Generate initial training data.
        
        Args:
            n: Number of initial samples (or budget for multi-fidelity)
            
        Returns:
            Tensor: An `n x d` tensor representing the design variables.
        """
        inter_x = draw_sobol_samples(
            bounds=self.bounds, 
            n=causal_intervention_samples, 
            q=1).squeeze(1)
        return inter_x
    
    @status(show_func_name=True)    
    def update_causal_model(
        self,
        observational_data: pd.DataFrame,
        x_intervention_val: Tensor,
        causal_discovery: str,
        num_interventions: int,
        epochs: int,
        save_graph: bool,
        is_multifidelity: bool,
        is_design_var_independent: bool,
        causal_inf_backend: str = "loky",
        causal_inf_batch_size: int = 10,
        causal_inf_n_jobs: int = 4,  
        alpha: None | float = None,      
        fidelity_param_name: None | str = None,
        train_x: Tensor | None = None,
        train_objectives: Tensor | None = None,
        train_constraints: None | Tensor = None,
    ) -> tuple[nx.DiGraph, torch.nn.Module, float, StructuralCausalModel]:
        r"""
        Update the causal model with new training data.

        Args:
            x_intervention_val (Tensor): Intervention values.
            causal_discovery (str): Causal discovery method.
            alpha (float): Significance level.
            num_interventions (int): Number of interventions.
            epochs (int): Number of training epochs.
            save_graph (bool): Whether to save the causal graph.
            observational_data (None | pd.DataFrame): Observational data.
            save_graph (bool): Whether to save the causal graph.
            causal_inf_backend (str): Backend for causal inference.
            causal_inf_batch_size (int): Batch size for causal inference.
            causal_inf_n_jobs (int): Number of jobs for causal inference.
            fidelity_param_name (None | str): Fidelity parameter name.
            is_multifidelity (bool): Whether the model is multi-fidelity.
            is_design_var_independent (bool): Whether design variables are independent.
            train_x (Tensor | None): Online Training input features.
            train_objectives (Tensor | None): Online training objectives.
            train_constraints (Tensor | None): Online training constraints.

        Returns:
            tuple[nx.DiGraph, torch.nn.Module, float, StructuralCausalModel]:
                - Causal graph (nx.DiGraph)
                - Causal net (torch.nn.Module)
                - Causal net loss (float)
                - Structural causal model (StructuralCausalModel)
        """     
        df_train_runtime = None
        train_y = None
        if train_x is not None and train_objectives is not None:
            train_y = train_objectives.clone()
            # Combine objectives and constraints
            if self.has_constraints and train_constraints is not None:
                train_y = torch.cat([train_y, train_constraints], dim=-1)
            # Create dataframe
            df_train_runtime = make_dataframe(
                design_variables=self.design_variables,
                outcome_variables=self.outcome_variables,
                train_x=normalize(train_x, self.bounds).detach(),
                train_y=train_y.detach()
            )
        # convert observational data to tensors 
        train_x_obs = torch.tensor(
            observational_data[self.design_variables].values).to(**self.tkwargs)
        # normalize train_x to bounds
        train_x_obs = normalize(train_x_obs, self.bounds).detach().cpu().numpy()
        # add the train_x to original observational_data
        observational_data[self.design_variables] = train_x_obs

        if df_train_runtime is not None:
            # concatenate the two dataframes
            observational_data = pd.concat(
                [observational_data, df_train_runtime], 
                ignore_index=True
            )

        num_datapoints = len(observational_data) 
        # Create causal performance model
        CPM = CausalPerformanceModel(
            data=observational_data,
            design_variables=self.design_variables,
            objective_variables=self.objective_variables,
            is_design_var_independent=is_design_var_independent,
            kpi_and_constraints_variables=(
                self.kpi_variables + self.constraint_variables
            ),
            is_multifidelity=is_multifidelity,
            fidelity_param_name=fidelity_param_name,
            use_default_bk=True,
        )
        if alpha is None:
            alpha = self.CausaDis_alpha_selection_policy(
                        num_datapoints=num_datapoints,
                    )        
        causal_graph = CPM.causal_model(
            alpha=alpha,
            causal_discovery=causal_discovery,
            show_progress=self.verbose,
            save_graph=save_graph,
        )
        model = CausalMeanVarSurrogateNN(
                    input_dim=self.input_dim, 
                    output_dim = self.outcome_dim
                ).to(**self.tkwargs)
                
        causal_net, scm, causal_net_loss = fit_causal_model(
            model=model,
            observational_data=observational_data,
            train_x=train_x,
            train_y=train_y,
            causal_graph=causal_graph, 
            design_variables=self.design_variables, 
            outcome_variables=self.outcome_variables, 
            num_intervention=num_interventions,
            epochs=epochs,
            x_intervention_val=normalize(x_intervention_val, self.bounds),
            verbose=self.verbose,
            backend=causal_inf_backend,
            batch_processing=causal_inf_batch_size,
            n_jobs=causal_inf_n_jobs,                    
            **self.tkwargs
        ) 
        return causal_graph, causal_net, causal_net_loss, scm

    def update_causal_GP(
        self,
        train_x: Tensor,
        train_objectives: Tensor,
        causal_net: torch.nn.Module,
        is_multifidelity: bool,
        train_constraints: None | Tensor = None,
        data_covar_module: None | Kernel = None,
        fidelity_covar_module: None | Kernel = None,
        state_dict: None | dict = None
    ) -> tuple[ExactMarginalLogLikelihood, 
               CausalMultitaskGP | CausalMultitaskMultifidelityGP]:
        r""" 
        Args:
            train_x (Tensor): Training input features.
                NOTE: inputs are normalized to the bounds.
            train_objectives (Tensor): Training output features.
            causal_net (CausalMeanVarSurrogateNN): Trained Causal neural network.
            is_multifidelity (bool): Whether the model is multi-fidelity.
            train_constraints (None | Tensor): Training constraints (if any).
            data_covar_module (None | Kernel): Base covariance module for the CGP.
            state_dict (None | dict): State dictionary to load into the model.
            is_multifidelity (bool): Whether the model is multi-fidelity.

        Returns:
            tuple (ExactMarginalLogLikelihood, Model): Marginal log 
                likelihood and the causal GP model.
        """
        if self.custom_model is not None:
            return self.custom_model(
                train_x=train_x,
                train_objectives=train_objectives,
                causal_net=causal_net,
                train_constraints=train_constraints,
                state_dict=state_dict,
            )
        else:
            train_y = train_objectives.clone()
            if self.has_constraints:
                train_y = torch.cat([train_y, train_constraints], dim=-1)
            # Initialize likelihood
            likelihood = MultitaskGaussianLikelihood(
                num_tasks=train_y.shape[-1]
            )
            if is_multifidelity:
                if data_covar_module is None:
                    data_covar_module = RBFKernel(
                        # Exclude fidelity dimension
                        ard_num_dims=train_x.shape[-1] - 1,
                    )
                if fidelity_covar_module is None:
                    fidelity_covar_module = RBFKernel(
                        ard_num_dims=1,
                    )
                model = CausalMultitaskMultifidelityGP(
                    train_X=train_x,
                    train_Y=train_y,
                    causal_net=causal_net,
                    likelihood=likelihood,
                    data_covar_module=data_covar_module,
                    fidelity_covar_module=fidelity_covar_module,
                )

            else:    
                if data_covar_module is None:
                    data_covar_module = RBFKernel(
                        ard_num_dims=train_x.shape[-1],
                    )
                model = CausalMultitaskGP(
                    train_X=train_x,
                    train_Y=train_y,
                    causal_net=causal_net,
                    likelihood=likelihood,
                    data_covar_module=data_covar_module,
                    outcome_transform=Standardize(m=train_y.shape[-1])
                )
            # Set up marginal log likelihood
            mll = ExactMarginalLogLikelihood(likelihood, model)
            if state_dict is not None:
                model.load_state_dict(state_dict)          
            # Freeze causal net parameters
            # Future me, this is important
            # otherwise, the causal net will be trained along with the GP
            # which we don't want.
            for param in model.mean_module.causal_net.parameters():
                param.requires_grad = False
            for param in model.covar_module.causal_net.parameters():
                param.requires_grad = False
            return mll, model

    def load_causal_model(self):
        causal_net = CausalMeanVarSurrogateNN(
                    input_dim=self.input_dim, 
                    output_dim = self.outcome_dim
                ).to(**self.tkwargs)     
        causal_net.load_state_dict(self.state['causal_net'])   
        causal_graph = self.state['causal_graph']
        causal_net_loss = self.state['causal_net_loss']
        scm = self.state['scm']
        return causal_graph, causal_net, causal_net_loss, scm
    
    
    # Future me: put it on
    # algorithms.appropriate_name
    def CausaDis_alpha_selection_policy(
        self,
        num_datapoints: int,
        decay_factor: float = 0.6,
        max_alpha: float = 0.85,
        min_alpha: float = 0.05,
        data_scale: float = 1000.0
    ) -> float:
        r"""
        Adaptive alpha selection policy for causal discovery in PC algorithm.
        
        Starts with a higher alpha (less strict) and gradually decreases it 
        (more strict) as more data is collected.
        
        Mathematical formulation:
        α = max_α - (max_α - min_α) × (1 - exp(-λ × log(n)/log(data_scale)))
        
        Where:
        - λ = decay_factor (controls aggressiveness: higher λ = faster transition)
        - n = num_datapoints
        - data_scale = reference point for full progress (default: 100)
        
        Args:
            num_datapoints: Number of data points collected so far
            decay_factor: Controls transition aggressiveness (default: 0.5)
                         Higher values = more aggressive alpha reduction
                         Lower values = more gradual alpha reduction
            max_alpha: Maximum alpha value (default: 0.85)
            min_alpha: Minimum alpha value (default: 0.05)
            data_scale: Data points at which progress reaches 1.0 (default: 1000.0)
                       Higher values = keep alpha high longer
                       Lower values = reduce alpha faster
            
        Returns:
            float: Alpha value for PC algorithm (significance level)
        """
        # Calculate data-based progress using logarithmic scaling
        # With few data points, progress is low (alpha stays high)
        # With many data points, progress approaches 1 (alpha decreases)
        import math
        data_progress = min(
            math.log(max(num_datapoints, 1)) / math.log(data_scale), 
            1.0
        )
        
        # Apply exponential decay for smoother transition
        # Lower decay_factor = more gradual alpha reduction (stays high longer)
        # Higher decay_factor = more aggressive alpha reduction
        progress_scaled = 1.0 - math.exp(-decay_factor * data_progress)
        
        # Linear interpolation between max_alpha and min_alpha
        alpha = max_alpha - (max_alpha - min_alpha) * progress_scaled
        
        # Ensure alpha stays within bounds
        alpha = max(min(alpha, max_alpha), min_alpha)
        # print(f"Alpha: {round(alpha, 3)} (datapoints: {num_datapoints})")
        return round(alpha, 3)


    @status(show_func_name=True)
    def obtain_new_y(self, new_x: Tensor) -> tuple[Tensor, Tensor | None]:
        r"""
        Obtain new objective and constraint values for the given input.

        Args:
            new_x (Tensor): Input tensor of shape `n x d` where `d` is 
                the dimensionality of the design space (including fidelity).

        Returns:
            tuple[Tensor, Tensor | None]: A tuple containing:
                - new_obj: Objective values for the new input (shape `n x m`).
                - new_constraints: Constraint values for the new input (shape `n x c`)
                  or None if no constraints are present.
        """

        new_obj = self.problem(new_x)
        # Handle constraints if present
        new_constraints = None
        if self.has_constraints:
            new_constraints = -self.problem.evaluate_slack(new_x)
        return new_obj, new_constraints


    # Future me: redundant function
    # similar exist on metrics.evaluate
    @status(show_func_name=True)
    def get_metrics(
        self,
        compute_metrics: bool,
        model: Model,
        input_dim_without_fid: int,
        train_objectives: Tensor,
        train_constraints: Tensor,
        optimization_pref: MultiObjectiveOptimizationPref,
        objective_indices: list[int],
        constraints_indices: list[int],
        is_mf_model: bool,
        seed: int | None,
        project: None | Callable[[Tensor], Tensor] = None,
    ) -> tuple[float, float | None, float, float | None]:
        r"""
        Compute metrics for the current iteration.

        Args:
            compute_metrics (bool): Whether to compute metrics.
            model (Model): The model to use for predictions.
            input_dim_without_fid (int): The input dimension without fidelity.
            train_objectives (Tensor): The training objectives.
            train_constraints (Tensor): The training constraints.
            optimization_pref (MultiObjectiveOptimizationPref): The optimization preferences.
            objective_indices (list[int]): The indices of the objectives.
            constraints_indices (list[int]): The indices of the constraints.
            is_mf_model (bool): Whether the model is a multi-fidelity model.
            project (None | Callable[[Tensor], Tensor], optional): A function to 
                project the input to the target fidelity.

        Returns:
            tuple[float, float | None, float, float | None]: A tuple containing the 
                computed metrics.

                - nsga2_hv: Hypervolume computed on the posterior using NSGA-II.
                - nsga2_regret: Regret computed on the posterior using NSGA-II.
                - observed_hv: Observed hypervolume on the training set.
                - observed_regret: Observed regret on the training set.
        """
        
        if compute_metrics:
            # Compute HV and regret using NSGA-II on the posterior
            nsga2_hv, nsga2_regret, nsga2_violation = (
                optimization_pref.compute_nsga2_posterior_hv_regret(
                    input_dim_without_fid=input_dim_without_fid,
                    num_objectives=self.num_objectives,
                    objective_indices=objective_indices,
                    constraints_indices=constraints_indices if self.has_constraints else None,
                    bounds=self.bounds,
                    model=model,
                    problem=self.problem,
                    is_mf_model=is_mf_model,
                    project_to_target_fidelity=project if is_mf_model else None,
                    seed=seed if seed is not None else None,
                )
            )
            # Compute observed HV and regret
            observed_hv, observed_regret, observed_violation = (
                optimization_pref.computed_observed_hv_regret(
                    train_objectives=train_objectives,
                    train_constraints=train_constraints if self.has_constraints else None,
                )
            )
            return (
                nsga2_hv, nsga2_regret, nsga2_violation,
                observed_hv, observed_regret, observed_violation
            )
        return (None, None, None, None, None, None)
    
    # Future me: put it on
    # algorithms.utils
    def log_artifacts_to_wandb(self) -> None:
        """ Log artifacts to Weights & Biases. """
        if self.wandb_run is not None:
            iteration = self.state['iteration']
            save_dir = self.wandb_run.dir
            exp_name = self.wandb_run.name
            filename = f"{exp_name}_ckpt_{iteration}.pt"
            path = os.path.join(save_dir, filename)
            torch.save(self.state, path)
            self.wandb_run.log_artifact(
                artifact_or_path=path,
                name=filename,
                type="state"
            )

    # Future me: put it on
    # algorithms.utils
    def log_exp_stats_to_wandb(
        self,
        iter: int,
        exp_stats: dict,
        compute_metrics: bool
    ) -> None:
        r""" 
        Log experiment statistics to Weights & Biases. 

        Args:
            iter (int): The current iteration.
            exp_stats (dict): The experiment statistics to log.
            compute_metrics (bool): Whether to compute metrics.
        """
        if self.wandb_run is not None:
            if compute_metrics:
                self.wandb_run.log(exp_stats, step=iter)    


    # Future me: redundant function
    # similar exist on metrics.evaluate
    def exp_stats(
        self,
        budget: int | float,
        compute_metrics: bool,
        has_maxhv: bool,
        iteration: int,
        current_cost: int | float,
        initial_cost: int | float,
        acq_value: Tensor,
        causal_net_loss: float,
        seed: int,
        best_nsga2_regret: None | float = None,
        curr_nsga2_regret: None | float = None,
        observed_regret: None | float = None,
        curr_nsga2_violation: None | float = None,
        best_nsga2_hv: None | float = None,
        curr_nsga2_hv: None | float = None,
        observed_hv: None | float = None,
        observed_violation: None | float = None,
        new_fidelity: None | float = None
    ) -> dict[str, float]:
        r""" 
        Get the experiment statistics dict.

        Args:
            budget (int | float): The budget for the experiment.
            compute_metrics (bool): Whether to compute metrics.
            has_maxhv (bool): Whether the experiment has a maximum hypervolume.
            iteration (int): The current iteration.
            current_cost (int | float): The current cost.
            initial_cost (int | float): The initial cost.
            acq_value (Tensor): The acquisition value.
            causal_net_loss (float): The causal network loss.
            seed (int): The random seed.
            best_nsga2_regret (None | float): The best NSGA-II regret.
            curr_nsga2_regret (None | float): The current NSGA-II regret.
            observed_regret (None | float): The observed regret.
            best_nsga2_hv (None | float): The best NSGA-II hypervolume.
            curr_nsga2_hv (None | float): The current NSGA-II hypervolume.
            observed_hv (None | float): The observed hypervolume.
            new_fidelity (None | float): The new fidelity.

        Returns:
            dict[str, float]: The experiment statistics.
        """
        
        stats = {
                "budget": budget,
                "cost": current_cost,
                "causal_net_loss": causal_net_loss,
                "acquisition_value": acq_value.item(),
                "new_fidelity": (
                        new_fidelity if new_fidelity is not None else 1.0
                    ), # for single fidelity, this will be 1.0
                "initial_cost": initial_cost,
                "iteration": iteration,
                "seed": seed,
            }      
            
        if compute_metrics:
            if has_maxhv:
                if best_nsga2_regret is None \
                    or curr_nsga2_regret is None \
                    or observed_regret is None:
                        raise ValueError(
                            "All regret metrics must be provided if `compute_metrics` is True."
                        )
            if best_nsga2_hv is None \
                or curr_nsga2_hv is None \
                or observed_hv is None:
                raise ValueError(
                    "All HV metrics must be provided if `compute_metrics` is True."
                )
    
            if has_maxhv:
                stats["best_nsga2_regret"] = best_nsga2_regret
                stats["curr_nsga2_regret"] = curr_nsga2_regret
                stats["best_nsga2_hv"] = best_nsga2_hv                     
                stats["curr_nsga2_hv"] = curr_nsga2_hv                     
                stats["observed_hv"] = observed_hv
                stats["observed_regret"] = observed_regret                   
            else:
                stats["best_nsga2_hv"] = best_nsga2_hv                     
                stats["curr_nsga2_hv"] = curr_nsga2_hv                     
                stats["observed_hv"] = observed_hv   
            if self.has_constraints:
                stats["curr_nsga2_violation"] = curr_nsga2_violation
                stats["observed_violation"] = observed_violation
        return stats
    
    def save_stats_to_json(
        self, 
        stats: dict[str, Any], 
    ) -> None:
        r"""Save iteration stats to a JSON file.
        
        Args:
            stats (dict[str, Any]): The statistics dictionary to save.
        """
        
        method = self.__class__.__name__
        results_filepath = (
            f"rescue_log/{method}_{self.problem_name}_{self.timestamp}_{self.unique_id}.json"
        )
        # Use absolute path from current working directory
        file_path = Path(os.getcwd()) / results_filepath
        file_path.parent.mkdir(parents=True, exist_ok=True)
        
        # Load existing results if file exists
        if file_path.exists():
            with open(file_path, 'r') as f:
                results = json.load(f)
        else:
            results = []
        
        # Append new stats (assuming they're already JSON-serializable)
        stats['method'] = method
        stats['problem'] = self.problem_name
        results.append(stats)
        
        # Save to file
        with open(file_path, 'w') as f:
            json.dump(results, f, indent=2)
    
    def update_state(    
        self,
        is_multifidelity: bool,
        get_seed: int,
        budget: int | float,
        ref_point: Tensor,
        max_hv: float | None,
        initial_cost: int | float,
        current_cost: int | float,
        iteration: int,
        met_best_nsga2_regret: float | None,
        met_best_nsga2_hv: float | None,
        acquisition_value: float | None,
        new_fidelity: float | None,
        train_x: Tensor,
        train_obj: Tensor,
        train_constraints: Tensor | None,
        objective_indices: list[int],
        constraints_indices: list[int],
        x_intervention: Tensor,
        causal_graph: nx.DiGraph,
        scm: StructuralCausalModel,
        causal_net: torch.nn.Module,
        model: CausalMultitaskGP | CausalMultitaskMultifidelityGP,
        causal_net_loss: float,
        fixed_features_list: list[dict[int, float]] | None = None,
        target_train_x: Tensor | None = None,
        target_fid_obj: Tensor | None = None,
        target_fid_constraints: Tensor | None = None,
    ) -> None:
        
        self.state['get_seed'] = get_seed
        self.state['budget'] = budget
        self.state['initial_cost'] = initial_cost 
        self.state['current_cost'] = current_cost
        self.state['ref_point'] = ref_point.detach().cpu()
        self.state['max_hv'] = max_hv
        self.state['iteration'] = iteration
        self.state['met_best_nsga2_regret'] = met_best_nsga2_regret
        self.state['met_best_nsga2_hv'] = met_best_nsga2_hv    
        self.state['train_x'] = train_x.detach().cpu()
        self.state['train_obj'] = train_obj.detach().cpu()
        self.state['train_constraints'] = train_constraints.detach().cpu() if \
            train_constraints is not None else train_constraints
        self.state['objective_indices'] = objective_indices
        self.state['constraints_indices'] = constraints_indices
        self.state['x_intervention'] = x_intervention.detach().cpu()
        self.state['causal_graph'] = causal_graph
        self.state['scm'] = scm
        self.state['causal_net'] = causal_net.cpu().state_dict()
        self.state['model'] = model.cpu().state_dict()
        self.state['causal_net_loss'] = causal_net_loss

        # only used for metrics computation
        # for later used
        # future me: this is not used to load algorithm state
        self.state['acquisition_value'] = acquisition_value
        self.state['new_fidelity'] = new_fidelity

        if is_multifidelity:
            self.state['fixed_features_list'] = fixed_features_list
            self.state['target_train_x'] = target_train_x.cpu().detach() if target_train_x \
                is not None else target_train_x
            self.state['target_fid_obj'] = target_fid_obj.cpu().detach() if target_fid_obj \
                is not None else target_fid_obj
            self.state['target_fid_constraints'] = target_fid_constraints.cpu().detach() if \
                target_fid_constraints is not None else target_fid_constraints
        # We need to move the model and causal net back to the original device
        # because we saved them to cpu state dict
        model.to(**self.tkwargs)
        causal_net.to(**self.tkwargs)
            
    def load_state(self) -> None:
        return (
            self.state['iteration'],
            self.state['get_seed'],
            self.state['initial_cost'],
            self.state['current_cost'],
            self.state['ref_point'].to(**self.tkwargs),
            self.state['max_hv'],
            self.state['met_best_nsga2_regret'],
            self.state['met_best_nsga2_hv'],
            self.state['x_intervention'].to(**self.tkwargs),
            self.state['train_x'].to(**self.tkwargs),
            self.state['train_obj'].to(**self.tkwargs),
            self.state['train_constraints'].to(**self.tkwargs) if self.has_constraints else None,
            self.state['objective_indices'],
            self.state['constraints_indices'],
        )
    
    def _validate_inputs(self) -> None:
        if self.has_constraints:
            if not hasattr(self.problem, 'evaluate_slack'):
                raise ValueError("Constraints are defined but the problem does not "
                                 "have an `evaluate_slack` method.")

        # check bound dim and design dim
        if self.bounds.shape[-1] != self.input_dim:
            raise ValueError(
                f"Bound dimension {self.bounds.shape[-1]} does not match "
                f"input dimension {self.input_dim}."
            )
        
    def _validate_traindata(
        self, 
        train_x: Tensor,
        train_objectives: Tensor,
        train_constraints: Tensor
    ) -> None:
        if self.input_dim != train_x.shape[-1]:
            raise ValueError(
                f"train_x dim {train_x.shape[-1]} does not match with design "
                f"variables dim {self.input_dim}."
            )
        if self.num_objectives != train_objectives.shape[-1]:
            raise ValueError(
                f"train_objectives dim {train_objectives.shape[-1]} does not match with objectives "
                f"dim {self.num_objectives}."
            )
        if self.has_constraints and train_constraints is None:
            raise ValueError(
                f"Found constraints {self.constraint_variables} but training "
                f"data does not have constraints column."
            )
        if self.has_constraints and train_constraints is not None:
            if self.num_constraints != train_constraints.shape[-1]:
                raise ValueError(
                    f"train_constraints dim {train_constraints.shape[-1]} does not "
                    f"match with constraints dim {self.num_constraints}."
                )

    def validate_new_constraints(self, new_constraints: Tensor) -> None:
        if self.has_constraints and new_constraints is None:
            raise ValueError("`new_constraints` is `None`" 
                            "make sure `optimize_acquisition_function` returns new_constraints"
                            )
        if not self.has_constraints and new_constraints is not None:
            raise ValueError("`new_constraints` is not `None`" 
                            "make sure `optimize_acquisition_function` does not return new_constraints"
                            )
        
        
    def validate_run_inputs(
        self,
        compute_metrics: bool,
        objective_indices: list[int] | None,
        constraints_indices: list[int] | None
    ) -> None:
        if compute_metrics:
            if objective_indices is None:
                raise ValueError(
                    "`objective_indices` must be provided if `compute_metrics` is True."
                )
            if self.has_constraints and constraints_indices is None:
                raise ValueError(
                    "`constraints_indices` must be provided if `compute_metrics` is True "
                    "and the problem has constraints."
                )
        if self.has_constraints:
            if objective_indices is None:
                raise ValueError(
                    "`objective_indices` must be provided if the problem has constraints."
                )
            
    def validated_budget(
        self,
        budget: float | int,
        current_cost: float | int,
    ) -> None:
        if budget <= current_cost:
            raise ValueError(
                f"budget={budget} should be greater than {current_cost}"
            )


class BaseRescueAlgorithmMultifidelity(BaseRescueAlgorithm):
    def __init__(
        self,
        problem: MultiObjectiveTestProblem | ConstrainedBaseTestProblem,
        design_variables: list[str],
        objective_variables: list[str],
        bounds: Tensor,
        is_discrete_fidelities: bool,
        fidelity_param_name: str,
        target_fidelities: dict[int, float],
        cost_fn: Callable[[Tensor], Tensor],
        gen_initial_data: GenInitDataFn | None = None,
        custom_model: None | CausalGPFn = None, 
        verbose: bool = False,
        status_spinner: bool = True,
        fidelity_levels: None | Tensor = None,      
        kpi_variables: None | list[str] = None,
        constraint_variables: None | list[str] = None,
        device: None | torch.device = None,
        dtype: None | torch.dtype = None,
        wandb_run: None | Run = None,
        rescue_state: None | RescueStateMultifidelity = None,
    ) -> None:
        r"""
        Base class for the multi-fidelity rescue algorithm.

        Args:
            problem (MultiObjectiveTestProblem | ConstrainedBaseTestProblem): The 
                optimization problem.
            design_variables (list[str]): The design variables.
            objective_variables (list[str]): The objective variables.
            bounds (Tensor): The bounds for the design variables.
            gen_initial_data (Callable[[float | int], tuple[Tensor, Tensor, Tensor | None]]): 
                Function to generate initial data.
            custom_model (None | CausalGPFn, optional): Custom model function.
            is_discrete_fidelities (bool): Whether the fidelities are discrete.
            fidelity_param_name (str): The name of the fidelity parameter.
            cost_fn (Callable[[Tensor], Tensor]): The cost function.
            verbose (bool, optional): Whether to print verbose output. Defaults to False.
            status_spinner (bool, optional): Whether to show a status spinner. Defaults to True.
            fidelity_levels (None | Tensor, optional): The fidelity levels. Defaults to None.
            kpi_variables (None | list[str], optional): The KPI variables. Defaults to None.
            constraint_variables (None | list[str], optional): The constraint variables. 
                Defaults to None.
            device (None | torch.device, optional): The device to use. Defaults to None.
            dtype (None | torch.dtype, optional): The data type to use. Defaults to None.
            wandb_run (None | Run, optional): The Weights & Biases run. Defaults to None.
            rescue_state (None | RescueStateMultifidelity, optional): The rescue state. 
                Defaults to None.
        """
        
        super().__init__(
            problem=problem,
            custom_model=custom_model,
            design_variables=design_variables,
            objective_variables=objective_variables,
            bounds=bounds,
            gen_initial_data=gen_initial_data,
            cost_fn=cost_fn,
            constraint_variables=constraint_variables,
            kpi_variables=kpi_variables,
            device=device,
            dtype=dtype,
            verbose=verbose,
            status_spinner=status_spinner,
            wandb_run=wandb_run,
            rescue_state=rescue_state,
        )
        # Future me: Naive check for rescue state
        # can be easily bypassed        
        if self.rescue_state is not None:
            if not self.state['is_multifidelity']:
                raise ValueError("Expected `algorithms.state.RescueStateMultifidelity`, "
                                 "but found `algorithms.state.RescueState`!"
                                )
            self.is_discrete_fidelities = self.state['is_discrete_fidelities']
            self.fidelity_levels = self.state['fidelity_levels']
            self.fidelity_param_name = self.state['fidelity_param_name']             
        else:
            self.is_discrete_fidelities = is_discrete_fidelities
            self.fidelity_levels = fidelity_levels
            self.fidelity_param_name = fidelity_param_name 

            self.state['is_multifidelity'] = True
            self.state['is_discrete_fidelities'] = self.is_discrete_fidelities
            self.state['fidelity_levels'] = self.fidelity_levels
            self.state['fidelity_param_name'] = self.fidelity_param_name

        self._validate_fidelity_info()

        self.target_fidelities = target_fidelities
        self.normalized_target_fidelities = self._normalize_target_fidelities(
                                                self.target_fidelities
                                            ) 
    
    def generate_initial_data_mf(
        self, 
        n_full_fidelity_equiv: int | float | None = None,
    ) -> tuple[Tensor, Tensor, Tensor | None]:
        r"""
        Generate initial training data.
        
        Args:
            n_full_fidelity_equiv: Number of initial samples 
                (or budget for multi-fidelity).
            
        Returns:
            tuple[Tensor, Tensor, Tensor | None]: 
                - train_x: Input samples with dim `n x d` where last 
                    column of `d` corresponds to fidelity levels.
                - train_obj: Objective values with dim `n x m`.
                - train_constraints: Constraint values with dim `n x c`
                    (None if no constraints).
        """
        if self.gen_initial_data is None:
            if n_full_fidelity_equiv is None:
                raise ValueError("`initial_budget` must be specified.")
            sampling = GenerateInitialSample(
                problem=self.problem,
                bounds=self.bounds,
                has_constraints=self.has_constraints
            )
            train_x, train_obj, train_cons = (
                sampling.generate_initial_data_multifidelity(
                    n_full_fidelity_equiv=n_full_fidelity_equiv,
                    cost_fn=self.cost_fn,
                    is_discrete=self.is_discrete_fidelities,
                    fidelity_levels=self.fidelity_levels,
                )
            )
        else:
            train_x, train_obj, train_cons = self.gen_initial_data(n_full_fidelity_equiv)
        self._validate_traindata(train_x, train_obj, train_cons)
        return train_x, train_obj, train_cons
    
    def _normalize_target_fidelities(
        self,
        target_fidelities: dict[int, float]
    ) -> dict[int, float]:
        r"""Normalize target fidelities based on bounds."""
        normalized = {}
        for idx, fidelity in target_fidelities.items():
            lb = self.bounds[0, idx].item()
            ub = self.bounds[1, idx].item()
            normalized[idx] = (fidelity - lb) / (ub - lb)
        return normalized    

    def fidelity_to_fixed_features_list(self) -> list[dict[int, float]]:
        r""" 
        If discrete fidelities are used, return a list of fixed 
        features for each fidelity level.

        Args:
            train_x (Tensor): Input tensor with design variables.

        Return:
            list[dict[int, float]]: A list of fixed features
                for each fidelity level or None.
        """
        return [{self.input_dim - 1: float(v)} for v in self.fidelity_levels]
    

    def _project_to_target_fidelity(self, X: Tensor) -> Tensor:
        r"""Project input to target fidelity."""
        return project_to_target_fidelity(
            X=X,
            d=self.input_dim,
            target_fidelities=self.normalized_target_fidelities,
        )    

    def get_target_fid_observations(
        self,
        train_x: Tensor,
    ) -> tuple[Tensor, Tensor, Tensor | None]:
        r""" 
        Args:
            train_x (Tensor): Training inputs.

        Returns:
            tuple[Tensor, Tensor, Tensor | None]: 
                - Projected inputs at target fidelity.
                - Objectives at target fidelity.
                - Constraints at target fidelity (if applicable).
        """
        # This will be costly if evaluation is done at runtime
        # Do it after the optimization is done
        target_fid_x = self._project_to_target_fidelity(train_x)
        target_fid_obj = self.problem(target_fid_x)
        if self.has_constraints:
            target_fid_constraints = -self.problem.evaluate_slack(target_fid_x) 
        return (
            target_fid_x,
            target_fid_obj,
            target_fid_constraints if self.has_constraints else None
        )   
            
    def load_state_mf(self) -> None:
        return (
            self.state['iteration'],
            self.state['get_seed'],
            self.state['initial_cost'],
            self.state['current_cost'],
            self.state['ref_point'].to(**self.tkwargs),
            self.state['max_hv'],
            self.state['met_best_nsga2_regret'],
            self.state['met_best_nsga2_hv'],
            self.state['x_intervention'].to(**self.tkwargs),
            self.state['train_x'].to(**self.tkwargs),
            self.state['train_obj'].to(**self.tkwargs),
            self.state['train_constraints'].to(**self.tkwargs) if self.has_constraints else None,
            self.state['objective_indices'],
            self.state['constraints_indices'],
            self.state['fixed_features_list'],
            self.state['target_train_x'],
            self.state['target_fid_obj'],
            self.state['target_fid_constraints']
        )
    
    def _validate_fidelity_info(self):
        if self.is_discrete_fidelities and self.fidelity_levels is None:
            raise ValueError("`fidelity_levels` must be provided "
            "if `is_discrete_fidelities` is True."
            )
        if self.fidelity_param_name not in self.design_variables:
            raise ValueError(f"fidelity_param_name {self.fidelity_param_name} "
            f"not found in design variables {self.design_variables}."
            )        
