#!/usr/bin/env python3

from __future__ import annotations


from typing import Callable, Protocol, Any

import os
import json
import uuid
from pathlib import Path
from datetime import datetime
from abc import ABC
import warnings
import torch
from torch import Tensor
from wandb.sdk.wandb_run import Run

from gpytorch.likelihoods import Likelihood
from gpytorch.kernels import RBFKernel
from gpytorch.mlls import ExactMarginalLogLikelihood
from gpytorch.likelihoods import MultitaskGaussianLikelihood

from botorch.models.model import Model
from botorch.models.model_list_gp_regression import ModelListGP
from botorch.test_functions.base import (
    ConstrainedBaseTestProblem,
    MultiObjectiveTestProblem,
)
from botorch.acquisition.utils import project_to_target_fidelity

from rescue.algorithms.initial_sampling import GenerateInitialSample
from rescue.metrics.optimization_pref import MultiObjectiveOptimizationPref
from rescue.utils.utils import status

from baselines.multitask_gp import MultitaskGP

from rich import print
from tabulate import tabulate
import re

class GPFn(Protocol):
    def __call__(
        self,
        *,
        train_x: Tensor,
        train_objectives: Tensor,
        train_constraints: Tensor,
        state_dict: dict,
    ) -> tuple[Model, Likelihood]:
        ...

class GenInitDataFn(Protocol):
    def __call__(
        self,
        *,
        n: float | int
    ) -> tuple[Tensor, Tensor, Tensor | None]:
        ...

class BaseBaseline(ABC):
    def __init__(
        self,
        problem: MultiObjectiveTestProblem | ConstrainedBaseTestProblem,
        bounds: Tensor,
        has_constraints: bool,
        cost_fn: Callable[[Tensor], Tensor] | None = None,
        gen_initial_data: GenInitDataFn | None = None,
        custom_model: None | GPFn = None,
        status_spinner: bool = True,
        device: None | torch.device = None,
        dtype: None | torch.dtype = None,  
        wandb_run: None | Run = None,
        algorithm_state: None | dict = None,
    ) -> None:
        r""" 
        Initialize the baseline algorithm.

        Args:
            problem (MultiObjectiveTestProblem | 
                ConstrainedBaseTestProblem): The optimization problem to 
                solve.
            bounds (Tensor): The bounds for the optimization problem.
            has_constraints (bool): Whether the problem has constraints.
            cost_fn (Callable[[Tensor], Tensor] | None): The cost 
                function for evaluations.
            gen_initial_data (GenInitDataFn | None): A function to 
                generate initial data.
            custom_model (None | GPFn): A custom model initialization 
                function.
            status_spinner (bool): Whether to show a status spinner 
                during optimization.
            device (None | torch.device): The device to run the 
                optimization on.
            dtype (None | torch.dtype): The data type for the tensors.
            wandb_run (None | Run): The Weights & Biases run object for 
                logging.
            algorithm_state (None | dict): The state of the algorithm (if 
                resuming).
        """
        if device is None:
            device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
        dtype = torch.double if dtype is None else dtype        
        self.tkwargs = {
            "device": device,
            "dtype": dtype,
        }

        self.problem = problem
        self.cost_fn = cost_fn
        self.algorithm_state = algorithm_state
        if self.algorithm_state is not None:
            self.state = algorithm_state
            warnings.warn("`algorithm_state` is provided, "
                    "the algorithm will use `algorithm_state` to initialize itself.",
                    UserWarning
            )

            self.bounds = self.state['bounds'].to(**self.tkwargs)
            self.has_constraints = self.state['has_constraints']
        else:
            self.state = {}
            self.bounds = bounds.to(**self.tkwargs)
            self.has_constraints = has_constraints

            self.state['bounds'] = self.bounds
            self.state['has_constraints'] = self.has_constraints

        self.custom_model = custom_model
        self.gen_initial_data = gen_initial_data
        self.wandb_run = wandb_run
        self.algorithm_state = algorithm_state
        self.standard_bounds = torch.zeros(
            2, self.bounds.shape[1], **self.tkwargs)
        self.standard_bounds[1] = 1

        # Generate unique results filepath
        self.problem_name = self.problem.__class__.__name__
        self.timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        self.unique_id = str(uuid.uuid4())[:8]

        self.status_spinner = status_spinner
   
        
    def generate_initial_data(
        self, 
        n: int | float | None = None,
    ) -> tuple[Tensor, Tensor, Tensor | None]:
        r"""
        Generate initial training data.
        
        Args:
            n (int | float | None): Number of initial samples (or budget 
                for multi-fidelity).
            
        Returns:
            tuple[Tensor, Tensor, Tensor | None]: 
                - train_x: Input samples with dim `n x d` where last 
                    column of `d` corresponds to fidelity levels.
                - train_obj: Objective values with dim `n x m`.
                - train_constraints: Constraint values with dim `n x c` 
                    (None if no constraints).
        """
        if self.gen_initial_data is None:
            if n is None:
                raise ValueError("`initial_budget` must be specified.")
            sampling = GenerateInitialSample(
                problem=self.problem,
                bounds=self.bounds,
                has_constraints=self.has_constraints
            )
            train_x, train_obj, train_cons = sampling.generate_initial_data(n=n)
        else:
            train_x, train_obj, train_cons = self.gen_initial_data(n)
        self._validate_trainconstraints(train_cons)
        return train_x, train_obj, train_cons

    def initialize_model(
        self,
        train_x: Tensor, 
        train_obj: Tensor, 
        train_constraints: Tensor | None,
        state_dict: dict | None = None,
    ) -> tuple[ModelListGP, ExactMarginalLogLikelihood]:
        r"""Initializes a ModelList with Matern 5/2 Kernel and returns 
        the model and its MLL.

        Note: a batched model could also be used here.

        Args:
            train_x (Tensor): Input features for training.
            train_obj (Tensor): Objective values for training.
            train_constraints (Tensor | None): Constraint values for 
                training (if any).
            state_dict (dict | None): State dictionary for model 
                initialization (if any).

        Returns:
            tuple[ModelListGP, ExactMarginalLogLikelihood]: The model and 
                its marginal log likelihood.
        """
        if self.custom_model is not None:
            return self.custom_model(
                train_x=train_x,
                train_obj=train_obj,
                train_constraints=train_constraints,
                state_dict=state_dict,
            )
        else:
            has_constraints = train_constraints is not None
            train_y = train_obj.clone()
            if has_constraints:
                train_y = torch.cat([train_y, train_constraints], dim=-1)

            # Initialize likelihood
            likelihood = MultitaskGaussianLikelihood(
                num_tasks=train_y.shape[-1]
            )
            model = MultitaskGP( 
                    train_X=train_x,
                    train_Y=train_y,
                    likelihood=likelihood,
                    base_covar_module=RBFKernel(
                        ard_num_dims=train_x.shape[-1],
                    ),
                )
            # Set up marginal log likelihood
            mll = ExactMarginalLogLikelihood(likelihood, model) 
            if state_dict is not None:
                model.load_state_dict(state_dict)  
            return mll, model

    @status(show_func_name=True)
    def obtain_new_y(
        self,
        new_x: Tensor,
    ) -> tuple[Tensor, Tensor | None]:
        r"""
        Obtain new objective and constraint values for given inputs.

        Args:
            new_x (Tensor): New input points to evaluate.

        Returns:
            tuple[Tensor, Tensor | None]: New objectives and constraints 
                (None if no constraints).
        """
        new_obj = self.problem(new_x)
        # Handle constraints if present
        new_constraints = None
        if self.has_constraints:
            new_constraints = -self.problem.evaluate_slack(new_x)
        return new_obj, new_constraints
        
            
    @status(show_func_name=True)
    def get_metrics(
        self,
        compute_metrics: bool,
        model: Model,
        input_dim_without_fid: int,
        train_objectives: Tensor,
        train_constraints: Tensor,
        optimization_pref: MultiObjectiveOptimizationPref,
        objective_indices: list[int],
        constraints_indices: list[int],
        is_mf_model: bool,
        seed: int | None,
        project: None | Callable[[Tensor], Tensor] = None,
    ) -> tuple[float, float | None, float, float | None]:
        r"""
        Compute optimization metrics.

        Args:
            compute_metrics (bool): Whether to compute metrics.
            model (Model): The trained model.
            input_dim_without_fid (int): Input dimensionality without 
                fidelity.
            train_objectives (Tensor): Training objectives.
            train_constraints (Tensor): Training constraints.
            optimization_pref (MultiObjectiveOptimizationPref): 
                Optimization preference object.
            objective_indices (list[int]): Indices of objectives.
            constraints_indices (list[int]): Indices of constraints.
            is_mf_model (bool): Whether model is multi-fidelity.
            seed (int | None): Random seed.
            project (None | Callable[[Tensor], Tensor]): Projection 
                function for multi-fidelity.

        Returns:
            tuple[float, float | None, float, float | None]: Metrics 
                tuple.
        """
        
        if compute_metrics:
            # Compute HV and regret using NSGA-II on the posterior
            nsga2_hv, nsga2_regret, nsga2_violation = (
                optimization_pref.compute_nsga2_posterior_hv_regret(
                    input_dim_without_fid=input_dim_without_fid,
                    num_objectives=len(objective_indices),
                    objective_indices=objective_indices,
                    constraints_indices=constraints_indices if self.has_constraints else None,
                    bounds=self.bounds,
                    model=model,
                    problem=self.problem,
                    is_mf_model=is_mf_model,
                    project_to_target_fidelity=project if is_mf_model else None,
                    seed=seed if seed is not None else None,
                )
            )
            # Compute observed HV and regret
            observed_hv, observed_regret, observed_violation = (
                optimization_pref.computed_observed_hv_regret(
                    train_objectives=train_objectives,
                    train_constraints=train_constraints if self.has_constraints else None,
                )
            )
            return (
                nsga2_hv, nsga2_regret, nsga2_violation,
                observed_hv, observed_regret, observed_violation
            )
        return (None, None, None, None)
    
    def exp_stats(
        self,
        compute_metrics: bool,
        budget: int | float,
        has_maxhv: bool,
        iteration: int,
        current_cost: int | float,
        initial_cost: int | float,
        acq_value: Tensor,
        seed: int,
        best_nsga2_regret: None | float = None,
        curr_nsga2_regret: None | float = None,
        observed_regret: None | float = None,
        curr_nsga2_violation: None | float = None,
        best_nsga2_hv: None | float = None,
        curr_nsga2_hv: None | float = None,
        observed_hv: None | float = None,
        observed_violation: None | float = None,
        new_fidelity: None | float = None
    ) -> dict[str, float]:
        r"""
        Compile experiment statistics.

        Args:
            compute_metrics (bool): Whether metrics are computed.
            budget (int | float): Total budget.
            has_maxhv (bool): Whether maximum HV is known.
            iteration (int): Current iteration number.
            current_cost (int | float): Current cumulative cost.
            initial_cost (int | float): Initial cost.
            acq_value (Tensor): Acquisition function value.
            seed (int): Random seed.
            best_nsga2_regret (None | float): Best NSGA-II regret.
            curr_nsga2_regret (None | float): Current NSGA-II regret.
            observed_regret (None | float): Observed regret.
            curr_nsga2_violation (None | float): Current NSGA-II 
                violation.
            best_nsga2_hv (None | float): Best NSGA-II hypervolume.
            curr_nsga2_hv (None | float): Current NSGA-II hypervolume.
            observed_hv (None | float): Observed hypervolume.
            observed_violation (None | float): Observed violation.
            new_fidelity (None | float): New fidelity level.

        Returns:
            dict[str, float]: Dictionary of experiment statistics.
        """
        
        stats = {
                "budget": budget,
                "cost": current_cost,
                "acquisition_value": acq_value.item() if acq_value is not None else float('nan'),
                "new_fidelity": (
                        new_fidelity if new_fidelity is not None else 1.0
                    ), # for single fidelity, this will be 1.0
                "initial_cost": initial_cost,
                "iteration": iteration,
                "seed": seed,
            }      
            
        if compute_metrics:
            if has_maxhv:
                if best_nsga2_regret is None \
                    or curr_nsga2_regret is None \
                    or observed_regret is None:
                        raise ValueError(
                            "All regret metrics must be provided if `compute_metrics` is True."
                        )
            if best_nsga2_hv is None \
                or curr_nsga2_hv is None \
                or observed_hv is None:
                raise ValueError(
                    "All HV metrics must be provided if `compute_metrics` is True."
                )
    
            if has_maxhv:
                stats["best_nsga2_regret"] = best_nsga2_regret
                stats["curr_nsga2_regret"] = curr_nsga2_regret
                stats["best_nsga2_hv"] = best_nsga2_hv                     
                stats["curr_nsga2_hv"] = curr_nsga2_hv                     
                stats["observed_hv"] = observed_hv
                stats["observed_regret"] = observed_regret                   
            else:
                stats["best_nsga2_hv"] = best_nsga2_hv                     
                stats["curr_nsga2_hv"] = curr_nsga2_hv                     
                stats["observed_hv"] = observed_hv   
            if self.has_constraints:
                stats["curr_nsga2_violation"] = curr_nsga2_violation
                stats["observed_violation"] = observed_violation
        return stats
    
    def save_stats_to_json(
        self, 
        stats: dict[str, Any], 
    ) -> None:
        r"""Save iteration stats to a JSON file.
        
        Args:
            stats (dict[str, Any]): The statistics dictionary to save.
        """
        
        method = self.__class__.__name__
        results_filepath = (
            f"rescue_log/{method}_{self.problem_name}_{self.timestamp}_{self.unique_id}.json"
        )
        # Use absolute path from current working directory
        file_path = Path(os.getcwd()) / results_filepath
        file_path.parent.mkdir(parents=True, exist_ok=True)
        
        # Load existing results if file exists
        if file_path.exists():
            with open(file_path, 'r') as f:
                results = json.load(f)
        else:
            results = []
        
        # Append new stats (assuming they're already JSON-serializable)
        stats['method'] = method
        stats['problem'] = self.problem_name
        results.append(stats)
        
        # Save to file
        with open(file_path, 'w') as f:
            json.dump(results, f, indent=2)

    def log_exp_stats_to_wandb(
        self,
        iter: int,
        exp_stats: dict,
        compute_metrics: bool
    ) -> None:
        r"""
        Log experiment statistics to Weights & Biases.

        Args:
            iter (int): Current iteration.
            exp_stats (dict): Experiment statistics.
            compute_metrics (bool): Whether metrics are computed.
        """
        if self.wandb_run is not None:
            if compute_metrics:
                self.wandb_run.log(exp_stats, step=iter)   

    def log_artifacts_to_wandb(self) -> None:
        r"""
        Log optimization state artifacts to Weights & Biases.
        """
        if self.wandb_run is not None:
            iteration = self.state['iteration']
            save_dir = self.wandb_run.dir
            exp_name = self.wandb_run.name
            filename = f"{exp_name}_ckpt_{iteration}.pt"
            path = os.path.join(save_dir, filename)
            torch.save(self.state, path)
            self.wandb_run.log_artifact(
                artifact_or_path=path,
                name=filename,
            type="state"
            )

    def term_print(
        self, 
        show_stats: bool,
        exp_stats: dict[str, float],
        budget: int | float,
    ) -> None:
        r"""
        Print experiment statistics to terminal.

        Args:
            show_stats (bool): Whether to show statistics.
            exp_stats (dict[str, float]): Experiment statistics.
            budget (int | float): Total budget.
        """
        if not show_stats:
            return

        def keep(key, d):
            return {key: d[key]} if key in d else {}

        stats_to_show = {
            **keep("cost", exp_stats),
            **keep("best_nsga2_regret", exp_stats),
            **keep("curr_nsga2_regret", exp_stats),
            **keep("best_nsga2_hv", exp_stats),
            **keep("curr_nsga2_hv", exp_stats),
            **keep("curr_nsga2_violation", exp_stats),
            **keep("observed_hv", exp_stats),
            **keep("observed_regret", exp_stats),
            **keep("observed_violation", exp_stats),
            **keep("new_fidelity", exp_stats),
            **keep("acqu_value", exp_stats),
            **keep("iteration", exp_stats),
        }
        # Future me: This is over engineered, but it works.
        if "curr_nsga2_regret" in stats_to_show:
            stats_to_show = {
                k: v for k, v in stats_to_show.items()
                if "hv" not in k.lower()
            }
        if not self.has_constraints:
            stats_to_show = {
                k: v for k, v in stats_to_show.items()
                if "violation" not in k.lower()
            }
        row = []
        for k, v in stats_to_show.items():
            label = k
            if label.lower().startswith("observed_"):
                label = "obs_" + label.split("_", 1)[1]

            # replacements
            label = re.sub(r"(?i)acqu_value", "acqu_val", label)
            label = re.sub(r"(?i)new_fidelity", "new_fid", label)
            label = re.sub(r"(?i)regret", "rgrt", label)

            if k == "cost":
                row.append(f"Budget: {v:.1f}/{budget}")
            elif k == "iteration":
                row.append(f"Iter: {v}")
            elif isinstance(v, (float)):
                row.append(f"{label}: {v:.3f}")
            else:
                row.append(f"{label}: {v}")

        print(tabulate([row], tablefmt="plain"))

    def update_state(    
        self,
        is_multifidelity: bool,
        get_seed: int,
        budget: int | float,
        ref_point: Tensor,
        max_hv: float | None,
        initial_cost: int | float,
        current_cost: int | float,
        iteration: int,
        met_best_nsga2_regret: float | None,
        met_best_nsga2_hv: float | None,
        acquisition_value: float | None,
        new_fidelity: float | None,
        train_x: Tensor,
        train_obj: Tensor,
        train_constraints: Tensor | None,
        objective_indices: list[int],
        constraints_indices: list[int],
        model: Model,
        fixed_features_list: list[dict[int, float]] | None = None,
        target_train_x: Tensor | None = None,
        target_fid_obj: Tensor | None = None,
        target_fid_constraints: Tensor | None = None,
    ) -> None:
        r"""
        Update the algorithm state.

        Args:
            is_multifidelity (bool): Whether algorithm is multi-fidelity.
            get_seed (int): Random seed used.
            budget (int | float): Total budget.
            ref_point (Tensor): Reference point.
            max_hv (float | None): Maximum hypervolume.
            initial_cost (int | float): Initial cost.
            current_cost (int | float): Current cost.
            iteration (int): Current iteration.
            met_best_nsga2_regret (float | None): Best NSGA-II regret.
            met_best_nsga2_hv (float | None): Best NSGA-II hypervolume.
            acquisition_value (float | None): Acquisition value.
            new_fidelity (float | None): New fidelity level.
            train_x (Tensor): Training inputs.
            train_obj (Tensor): Training objectives.
            train_constraints (Tensor | None): Training constraints.
            objective_indices (list[int]): Objective indices.
            constraints_indices (list[int]): Constraint indices.
            model (Model): The trained model.
            fixed_features_list (list[dict[int, float]] | None): Fixed 
                features for discrete fidelities.
            target_train_x (Tensor | None): Target fidelity inputs.
            target_fid_obj (Tensor | None): Target fidelity objectives.
            target_fid_constraints (Tensor | None): Target fidelity 
                constraints.
        """
        
        self.state['get_seed'] = get_seed
        self.state['budget'] = budget
        self.state['initial_cost'] = initial_cost 
        self.state['current_cost'] = current_cost
        self.state['ref_point'] = ref_point.detach()
        self.state['max_hv'] = max_hv
        self.state['iteration'] = iteration
        self.state['met_best_nsga2_regret'] = met_best_nsga2_regret
        self.state['met_best_nsga2_hv'] = met_best_nsga2_hv    
        self.state['train_x'] = train_x.detach()
        self.state['train_obj'] = train_obj.detach()
        self.state['train_constraints'] = train_constraints.detach() if \
            train_constraints is not None else train_constraints
        self.state['objective_indices'] = objective_indices
        self.state['constraints_indices'] = constraints_indices
        self.state['model'] = model.state_dict()

        # only used for metrics computation
        # for later used
        # future me: this is not used to load algorithm state
        self.state['acquisition_value'] = acquisition_value
        self.state['new_fidelity'] = new_fidelity

        if is_multifidelity:
            self.state['fixed_features_list'] = fixed_features_list
            self.state['target_train_x'] = target_train_x.detach() if target_train_x \
                is not None else target_train_x
            self.state['target_fid_obj'] = target_fid_obj.detach() if target_fid_obj \
                is not None else target_fid_obj
            self.state['target_fid_constraints'] = target_fid_constraints.detach() if \
                target_fid_constraints is not None else target_fid_constraints
    
    def load_state(self) -> None:
        return (
            self.state['iteration'],
            self.state['get_seed'],
            self.state['initial_cost'],
            self.state['current_cost'],
            self.state['ref_point'],
            self.state['max_hv'],
            self.state['met_best_nsga2_regret'],
            self.state['met_best_nsga2_hv'],
            self.state['train_x'],
            self.state['train_obj'],
            self.state['train_constraints'],
            self.state['objective_indices'],
            self.state['constraints_indices'],
        )

    def _validate_trainconstraints(
        self, 
        train_constraints: Tensor
    ) -> None:
        r"""
        Validate training constraints.

        Args:
            train_constraints (Tensor): Training constraint values.

        Raises:
            ValueError: If constraints are expected but not provided.
        """
        if self.has_constraints and train_constraints is None:
            raise ValueError(
                f"`has_constraints={self.has_constraints} but training "
                f"data does not have constraints column."
            )

    def validate_run_inputs(
        self,
        compute_metrics: bool,
        objective_indices: list[int] | None,
        constraints_indices: list[int] | None
    ) -> None:
        r"""
        Validate inputs for the run method.

        Args:
            compute_metrics (bool): Whether metrics will be computed.
            objective_indices (list[int] | None): Objective indices.
            constraints_indices (list[int] | None): Constraint indices.

        Raises:
            ValueError: If required indices are missing.
        """
        if compute_metrics:
            if objective_indices is None:
                raise ValueError(
                    "`objective_indices` must be provided if `compute_metrics` is True."
                )
            if self.has_constraints and constraints_indices is None:
                raise ValueError(
                    "`constraints_indices` must be provided if `compute_metrics` is True "
                    "and the problem has constraints."
                )
        if self.has_constraints:
            if objective_indices is None:
                raise ValueError(
                    "`objective_indices` must be provided if the problem has constraints."
                )

    def validate_new_constraints(self, new_constraints: Tensor) -> None:
        r"""
        Validate new constraint values.

        Args:
            new_constraints (Tensor): New constraint values.

        Raises:
            ValueError: If constraint expectations are not met.
        """
        if self.has_constraints and new_constraints is None:
            raise ValueError("`new_constraints` is `None`" 
                            "make sure `optimize_acquisition_function` returns new_constraints"
                            )
        if not self.has_constraints and new_constraints is not None:
            raise ValueError("`new_constraints` is not `None`" 
                            "make sure `optimize_acquisition_function` does not return new_constraints"
                            )
    def validated_budget(
        self,
        budget: float | int,
        current_cost: float | int,
    ) -> None:
        r"""
        Validate that budget is sufficient.

        Args:
            budget (float | int): Total budget.
            current_cost (float | int): Current cumulative cost.

        Raises:
            ValueError: If budget is insufficient.
        """
        if budget <= current_cost:
            raise ValueError(
                f"budget={budget} should be greater than {current_cost}"
            )


class BaseBaselineMultifidelity(BaseBaseline):
    def __init__(
        self,
        problem: MultiObjectiveTestProblem | ConstrainedBaseTestProblem,
        bounds: Tensor,
        has_constraints: bool,
        target_fidelities: dict[int, float],
        is_discrete_fidelities: bool,
        cost_fn: Callable[[Tensor], Tensor],
        gen_initial_data: GenInitDataFn | None = None,
        custom_model: None | GPFn = None,
        fidelity_levels: None | Tensor = None,
        status_spinner: bool = True,
        device: None | torch.device = None,
        dtype: None | torch.dtype = None,
        wandb_run: None | Run = None,
        algorithm_state: None | dict = None,
    ) -> None:
        super().__init__(
            problem=problem,
            custom_model=custom_model,
            gen_initial_data=gen_initial_data,
            bounds=bounds,
            has_constraints=has_constraints,
            status_spinner=status_spinner,
            device=device,
            dtype=dtype,
            wandb_run=wandb_run,
            algorithm_state=algorithm_state,
        )
        if self.algorithm_state is not None:
            self.is_discrete_fidelities = self.state['is_discrete_fidelities']
            self.fidelity_levels = self.state['fidelity_levels']
        else:
            self.is_discrete_fidelities = is_discrete_fidelities
            self.fidelity_levels = fidelity_levels

            self.state['is_discrete_fidelities'] = is_discrete_fidelities
            self.state['fidelity_levels'] = fidelity_levels

        self._validate_fidelity_info()
        self.cost_fn = cost_fn
        self.target_fidelities = target_fidelities
        self.normalized_target_fidelities = self._normalize_target_fidelities(
                                                self.target_fidelities
                                            ) 

    def generate_initial_data_mf(
        self, 
        n_full_fidelity_equiv: int | float | None = None,
    ) -> tuple[Tensor, Tensor, Tensor | None]:
        r"""
        Generate initial training data for multi-fidelity.
        
        Args:
            n_full_fidelity_equiv (int | float | None): Number of 
                initial samples (or budget for multi-fidelity).
            
        Returns:
            tuple[Tensor, Tensor, Tensor | None]: 
                - train_x: Input samples with dim `n x d` where last 
                    column of `d` corresponds to fidelity levels.
                - train_obj: Objective values with dim `n x m`.
                - train_constraints: Constraint values with dim `n x c` 
                    (None if no constraints).
        """
        if self.gen_initial_data is None:
            if n_full_fidelity_equiv is None:
                raise ValueError("`initial_budget` must be specified.")
            sampling = GenerateInitialSample(
                problem=self.problem,
                bounds=self.bounds,
                has_constraints=self.has_constraints
            )
            train_x, train_obj, train_cons = (
                sampling.generate_initial_data_multifidelity(
                    n_full_fidelity_equiv=n_full_fidelity_equiv,
                    cost_fn=self.cost_fn,
                    is_discrete=self.is_discrete_fidelities,
                    fidelity_levels=self.fidelity_levels,
                )
            )
        else:
            train_x, train_obj, train_cons = self.gen_initial_data(n_full_fidelity_equiv)
        self._validate_trainconstraints(train_cons)
        return train_x, train_obj, train_cons

    def _normalize_target_fidelities(
        self,
        target_fidelities: dict[int, float]
    ) -> dict[int, float]:
        r"""Normalize target fidelities based on bounds.

        Args:
            target_fidelities (dict[int, float]): Target fidelity values.

        Returns:
            dict[int, float]: Normalized target fidelities.
        """
        normalized = {}
        for idx, fidelity in target_fidelities.items():
            lb = self.bounds[0, idx].item()
            ub = self.bounds[1, idx].item()
            normalized[idx] = (fidelity - lb) / (ub - lb)
        return normalized
    
    def project_to_target(self, X: Tensor) -> Tensor:
        r""" 
        Project inputs to target fidelity.

        Args:
            X (Tensor): Input tensor to project.

        Returns:
            Tensor: The projected tensor.

        Note:
            Assumes the last dimension of X is the fidelity dimension.
        """
        return project_to_target_fidelity(
            X=X,
            d=self.bounds.size(-1),
            target_fidelities=self.normalized_target_fidelities,
        )
    
    def get_target_fid_observations(
        self,
        train_x: Tensor,
    ) -> tuple[Tensor, Tensor, Tensor | None]:
        r""" 
        Get observations at target fidelity.

        Args:
            train_x (Tensor): Training inputs.

        Returns:
            tuple[Tensor, Tensor, Tensor | None]: 
                - Projected inputs at target fidelity.
                - Objectives at target fidelity.
                - Constraints at target fidelity (if applicable).
        """
        # This will be costly if evaluation is done at runtime
        # Do it after the optimization is done
        target_fid_x = self.project_to_target(train_x)
        target_fid_obj = self.problem(target_fid_x)
        if self.has_constraints:
            target_fid_constraints = -self.problem.evaluate_slack(target_fid_x) 
        return (
            target_fid_x,
            target_fid_obj,
            target_fid_constraints if self.has_constraints else None
        ) 
    
    def fidelity_to_fixed_features_list(self, train_x: Tensor) -> list[dict[int, float]]:
        r""" 
        Get fixed features for discrete fidelities.

        Args:
            train_x (Tensor): Input tensor with design variables.

        Returns:
            list[dict[int, float]]: A list of fixed features for each 
                fidelity level.

        Note:
            Only applicable when discrete fidelities are used.
        """
        return [{train_x.shape[-1] - 1: float(v)} for v in self.fidelity_levels]
    
    def load_state_mf(self) -> None:
        return (
            self.state['iteration'],
            self.state['get_seed'],
            self.state['initial_cost'],
            self.state['current_cost'],
            self.state['ref_point'],
            self.state['max_hv'],
            self.state['met_best_nsga2_regret'],
            self.state['met_best_nsga2_hv'],
            self.state['train_x'],
            self.state['train_obj'],
            self.state['train_constraints'],
            self.state['fixed_features_list'],
            self.state['objective_indices'],
            self.state['constraints_indices'],
            self.state['target_train_x'],
            self.state['target_fid_obj'],
            self.state['target_fid_constraints']
        )

    def _validate_fidelity_info(self):
        if self.is_discrete_fidelities and self.fidelity_levels is None:
            raise ValueError("`fidelity_levels` must be provided "
            "if `is_discrete_fidelities` is True."
            )