#!/usr/bin/env python3

r"""
Multi-objective multi-fidelity Bayesian optimization using qNEHVI.

.. [Daulton2021]
    S. Daulton, M. Balandat, and E. Bakshy. Parallel Bayesian Optimization of
    Multiple Noisy Objectives with Expected Hypervolume Improvement. Advances in
    Neural Information Processing Systems 34, 2021.
"""

from __future__ import annotations

from typing import Callable
import warnings

import torch
from torch import Tensor
from wandb.sdk.wandb_run import Run

from botorch.exceptions import NumericsWarning
from botorch.models.gp_regression import SingleTaskGP
from botorch.utils.transforms import normalize
from botorch.test_functions.base import (
    ConstrainedBaseTestProblem,
    MultiObjectiveTestProblem,
)
from botorch.acquisition.multi_objective import (
    qNoisyExpectedHypervolumeImprovement
)
from botorch.optim.optimize import optimize_acqf, optimize_acqf_mixed
from botorch.sampling.normal import SobolQMCNormalSampler
from botorch.acquisition.multi_objective.objective import IdentityMCMultiOutputObjective
from botorch.utils.transforms import unnormalize
from botorch import fit_gpytorch_mll

from rescue.metrics.optimization_pref import (
    MultiObjectiveOptimizationPref,
    nsga2_posterior_pareto
)
from rescue.utils.utils import status

from baselines.base import BaseBaselineMultifidelity, GPFn, GenInitDataFn

warnings.filterwarnings("ignore", category=NumericsWarning)

class BaselineMOMFqNEHVI(BaseBaselineMultifidelity):
    def __init__(
        self,
        problem: MultiObjectiveTestProblem | ConstrainedBaseTestProblem,
        bounds: Tensor,
        has_constraints: bool,
        is_discrete_fidelities: bool,
        target_fidelities: dict[int, float],
        cost_fn: Callable[[Tensor], Tensor],
        gen_initial_data: GenInitDataFn | None = None,
        custom_model: None | GPFn = None,
        fidelity_levels: None | Tensor = None,
        status_spinner: bool = True,
        device: None | torch.device = None,
        dtype: None | torch.dtype = None,  
        wandb_run: None | Run = None,
        algorithm_state: None | dict = None
    ):
        r""" 
        Initialize the baseline multi-fidelity qNEHVI algorithm.

        Args:
            problem (MultiObjectiveTestProblem | 
                ConstrainedBaseTestProblem): The optimization problem to 
                solve.
            bounds (Tensor): The bounds for the optimization problem.
            has_constraints (bool): Whether the problem has constraints.
            is_discrete_fidelities (bool): Whether fidelities are 
                discrete.
            target_fidelities (dict[int, float]): Target fidelity values.
            cost_fn (Callable[[Tensor], Tensor]): The cost function for 
                evaluations.
            gen_initial_data (GenInitDataFn | None): A function to 
                generate initial data.
            custom_model (None | GPFn): A custom model initialization 
                function.
            fidelity_levels (None | Tensor): Discrete fidelity levels.
            status_spinner (bool): Whether to show a status spinner 
                during optimization.
            device (None | torch.device): The device to run the 
                optimization on.
            dtype (None | torch.dtype): The data type for the tensors.
            wandb_run (None | Run): The Weights & Biases run object for 
                logging.
            algorithm_state (None | dict): The state of the algorithm (if 
                resuming).
        """
        super().__init__(
            problem=problem,
            cost_fn=cost_fn,
            custom_model=custom_model,
            gen_initial_data=gen_initial_data,
            bounds=bounds,
            has_constraints=has_constraints,
            status_spinner=status_spinner,
            device=device,
            dtype=dtype,
            wandb_run=wandb_run,
            algorithm_state=algorithm_state,
            is_discrete_fidelities=is_discrete_fidelities,
            fidelity_levels=fidelity_levels,
            target_fidelities=target_fidelities
        )

    @status(show_func_name=True)
    def _generate_initial_data(
        self,
        n_full_fidelity_equiv: float | int | None = None,
    ) -> tuple[Tensor, Tensor, Tensor | None]:
        r"""
        Generate initial training data.
        
        Args:
            n_full_fidelity_equiv (float | int | None): Number of 
                initial samples (or budget for multi-fidelity).
            
        Returns:
            tuple[Tensor, Tensor, Tensor | None]: 
                - train_x: Input samples with dim `n x d` where last 
                    column of `d` corresponds to fidelity levels.
                - train_obj: Objective values with dim `n x m`.
                - train_constraints: Constraint values with dim `n x c` 
                    (None if no constraints).
        """
        return  self.generate_initial_data_mf(
                     n_full_fidelity_equiv=n_full_fidelity_equiv
                ) 
        

    @status(show_func_name=True) 
    def fit_gp(self, mll) -> None:
        r"""Fit the GP model.

        Args:
            mll: The marginal likelihood to optimize.
        """
        fit_gpytorch_mll(mll)

    @status(show_func_name=True)
    def optimize_qNEHVI(
        self,
        model: SingleTaskGP,
        train_x: torch.Tensor,
        sampler: SobolQMCNormalSampler,
        ref_point: torch.Tensor,
        q: int,
        num_restarts: int,
        raw_samples: int,
        budget: float,
        iteration: int,
        objective_indices: None | list[int] = None,
        constraints_indices: None | list = None,
        fixed_features_list: None | list[dict[int, float]] = None,
        options: dict[str, bool | float | int | str] | None = None
    ) -> tuple[Tensor, Tensor]:
        r"""
        Optimize qNEHVI acquisition function.

        Args:
            model (SingleTaskGP): The GP model.
            train_x (torch.Tensor): Training inputs.
            sampler (SobolQMCNormalSampler): MC sampler.
            ref_point (torch.Tensor): Reference point.
            q (int): Batch size.
            num_restarts (int): Number of optimization restarts.
            raw_samples (int): Number of raw samples.
            budget (float): Total budget.
            iteration (int): Current iteration.
            objective_indices (None | list[int]): Objective indices.
            constraints_indices (None | list): Constraint indices.
            fixed_features_list (None | list[dict[int, float]]): Fixed 
                features for discrete fidelities.
            options (dict[str, bool | float | int | str] | None): 
                Optimization options.

        Returns:
            tuple[Tensor, Tensor]: New candidates and acquisition 
                values.

        Note:
            Last `c` dim corresponds to constraints.
        """
        if self.has_constraints:
            # Constraints are in the model after all objectives
            # Model structure: [obj0, obj1, ..., objN, cons0, cons1, ...]
            num_objectives = len(objective_indices)
            constraints = [
                lambda Z, idx=num_objectives+j: Z[..., idx] 
                for j in range(len(constraints_indices))
            ]
        acq_func = qNoisyExpectedHypervolumeImprovement(
            model=model,
            ref_point=ref_point,  # use known reference point
            X_baseline=normalize(train_x, self.bounds),
            sampler=sampler,
            objective=IdentityMCMultiOutputObjective(
                outcomes=objective_indices
            ),
            constraints=constraints if self.has_constraints else None,
            cache_pending=False,
            cache_root=False,
            prune_baseline=True
        )
        # Optimization
        if self.is_discrete_fidelities:
            candidates, vals = optimize_acqf_mixed(
                acq_function=acq_func,
                bounds=self.standard_bounds,
                q=q,
                fixed_features_list=fixed_features_list,
                num_restarts=num_restarts,
                raw_samples=raw_samples,  # used for intialization heuristic
                options=options
            )
        else:
            candidates, vals = optimize_acqf(
                acq_function=acq_func,
                bounds=self.standard_bounds,
                q=q,
                num_restarts=num_restarts,
                raw_samples=raw_samples,  # used for intialization heuristic
                sequential=True,
                options=options
            )
        # if the AF val is 0, set the fidelity parameter to zero
        # if the AF val is 0, set the fidelity parameter to zero
        if vals.item() == 0.0:
            if self.is_discrete_fidelities:
                min_fidelity = self.fidelity_levels.min().item()
                candidates[:, -1] = min_fidelity
            else:
                candidates[:, -1] = 0.0
        # Ensure we get at least one sample within budget
        if iteration == 0 and self.cost_fn(candidates) > budget:
            if self.is_discrete_fidelities:
                min_fidelity = self.fidelity_levels.min().item()
                candidates[:, -1] = min_fidelity
            else:
                candidates[:, -1] = 0.0
        if self.is_discrete_fidelities:
            new_x = candidates.detach()
            new_x[:, :-1] = unnormalize(candidates[:, :-1].detach(), bounds=self.bounds[:, :-1])
        else:
            new_x = unnormalize(candidates.detach(), bounds=self.bounds)
        return new_x, vals

    def run(
        self,
        budget: int | float,
        ref_point: Tensor,
        init_budget: float | None = None,
        include_initcost_to_budget: bool = True,
        q: int = 1,
        acqfn_mc_samples: int = 128,
        acqfn_raw_samples: int = 512,
        acqfn_num_restarts: int = 10,
        optim_acqfn_options: dict[str, bool | float | int | str] | None = None,
        objective_indices: None | list[int] = None,
        constraints_indices: None | list = None, 
        compute_metrics: bool = False,
        max_hv: float | Tensor | None = None, 
        seed: int | None = None,
        show_status: bool = True
    ) -> None:
        r"""
        Run the optimization algorithm.

        Args:
            budget (int | float): Total optimization budget.
            ref_point (Tensor): Reference point for hypervolume.
            init_budget (float | None): Initial sampling budget.
            include_initcost_to_budget (bool): Whether to include 
                initial cost in budget.
            q (int): Batch size.
            acqfn_mc_samples (int): MC samples for acquisition.
            acqfn_raw_samples (int): Raw samples for acquisition.
            acqfn_num_restarts (int): Optimization restarts.
            optim_acqfn_options (dict[str, bool | float | int | str] | 
                None): Acquisition optimization options.
            objective_indices (None | list[int]): Objective indices.
            constraints_indices (None | list): Constraint indices.
            compute_metrics (bool): Whether to compute metrics.
            max_hv (float | Tensor | None): Maximum hypervolume.
            seed (int | None): Random seed.
            show_status (bool): Whether to show status messages.
        """
        
        self.validate_run_inputs(compute_metrics, objective_indices, constraints_indices)
        has_maxhv = max_hv is not None      
        # === Helper Functions for run ===
        def _get_total_cost(train_x: Tensor) -> float:
            return self.cost_fn(train_x).sum().item() 
        
        def _get_new_cost(new_x: Tensor) -> float:
            return self.cost_fn(new_x).item()
        # === End Helper Functions for run ===

        if self.algorithm_state is not None:
            (
                iteration,
                get_seed,
                initial_cost,
                current_cost,             
                ref_point,
                max_hv,
                met_best_nsga2_regret,
                met_best_nsga2_hv,  
                train_x,
                train_obj,
                train_constraints,
                fixed_features_list,
                objective_indices,
                constraints_indices,
                target_train_x,
                target_fid_obj,
                target_fid_constraints,
            ) = self.load_state_mf()

            self.validated_budget(budget, current_cost)

            mll, model = self.initialize_model(
                train_x=normalize(train_x, self.bounds), 
                train_obj=train_obj,
                train_constraints=train_constraints,
                state_dict=self.state['model'],
            )
            if compute_metrics:
                optimization_pref = MultiObjectiveOptimizationPref(
                    ref_point=ref_point,
                    max_hv=max_hv,
                    **self.tkwargs
                )   

        else:
            target_train_x = None
            target_fid_obj = None
            target_fid_constraints = None        
            fixed_features_list = None
            met_best_nsga2_regret = None
            met_best_nsga2_hv = None  
            met_curr_nsga2_regret = None
            met_curr_nsga2_hv = None
            met_curr_nsga2_violation = None
            met_observed_violation = None
            met_observed_regret = None
            met_observed_hv = None
            iteration = 0
            current_cost = 0

            if seed is not None:
                torch.manual_seed(seed)
            get_seed = torch.initial_seed()
            
            # Intializing train_x to zero
            train_x, train_obj, train_constraints = self._generate_initial_data(init_budget)
            # Generate fixed features for each fidelity level
            # when is_discrete_fidelities is True
            if self.is_discrete_fidelities:
                fixed_features_list = self.fidelity_to_fixed_features_list(train_x)
            
            # run N_BATCH rounds of BayesOpt after the initial random batch
            initial_cost = _get_total_cost(train_x)
            print("Initial cumulative cost:", initial_cost) if show_status else None
            if include_initcost_to_budget:
                current_cost = initial_cost
            self.validated_budget(budget, current_cost)

            mll, model = self.initialize_model(
                train_x=normalize(train_x, self.bounds), 
                train_obj=train_obj,
                train_constraints=train_constraints,
            )

            self.fit_gp(mll=mll)  # Fit the model            

            if compute_metrics:
                met_best_nsga2_regret = float("inf")
                met_best_nsga2_hv = float("-inf")                             
                target_train_x, target_fid_obj, target_fid_constraints = (
                    self.get_target_fid_observations(train_x=train_x)
                )  
                optimization_pref = MultiObjectiveOptimizationPref(
                    ref_point=ref_point,
                    max_hv=max_hv,
                    **self.tkwargs
                )   

            self.update_state(
                is_multifidelity=True,
                get_seed=get_seed,
                budget=budget,
                ref_point=ref_point,
                max_hv=max_hv,
                initial_cost=initial_cost,
                current_cost=current_cost,
                iteration=iteration,
                met_best_nsga2_regret=met_best_nsga2_regret,
                met_best_nsga2_hv=met_best_nsga2_hv,
                acquisition_value= None,
                new_fidelity=None,
                train_x=train_x,
                train_obj=train_obj,
                train_constraints=train_constraints,
                objective_indices=objective_indices,
                constraints_indices=constraints_indices,
                model=model,
                fixed_features_list=fixed_features_list,
                target_train_x=target_train_x,
                target_fid_obj=target_fid_obj,
                target_fid_constraints=target_fid_constraints
            )
            self.log_artifacts_to_wandb()
        # +=================================+
        # |            BO loop              |
        # +=================================+  
        # Generate Sampler
        sampler = SobolQMCNormalSampler(
            sample_shape=torch.Size([acqfn_mc_samples])
        )        
        while current_cost < budget:
            # optimize acquisition functions and get new observations
            new_x, acq_value = self.optimize_qNEHVI(
                model=model,
                train_x=train_x,
                sampler=sampler,
                ref_point=ref_point,
                q=q,
                raw_samples=acqfn_raw_samples,
                num_restarts=acqfn_num_restarts,
                fixed_features_list=fixed_features_list,
                constraints_indices=constraints_indices,
                objective_indices=objective_indices,
                options=optim_acqfn_options,
                budget=budget,
                iteration=iteration,
            )

            new_fidelity = new_x.detach()[..., -1].item() 

            # Real intervention
            new_obj, new_constraints = self.obtain_new_y(new_x)
            self.validate_new_constraints(new_constraints)
            
            # Updating train_x and train_obj
            train_x = torch.cat([train_x, new_x], dim=0)
            train_obj = torch.cat([train_obj, new_obj], dim=0)
            if new_constraints is not None:
                train_constraints = torch.cat([train_constraints, new_constraints])

            # reinitialize model
            mll, model = self.initialize_model(
                train_x=normalize(train_x, self.bounds), 
                train_obj=train_obj,
                train_constraints=train_constraints,
            )

            self.fit_gp(mll=mll)  # Fit the model 

            iteration += 1
            current_cost += _get_new_cost(new_x)   
            # +=================================+
            # |          End BO loop            |
            # +=================================+

            # +===========================================+
            # |           Do not contribute               |
            # |     to the optimization performance       |
            # +===========================================+
            # This can be time-consuming if target fidelity is costly
            # Should be only used for research purposes.
            # Logging and computing metrics can be
            # done offine or parallel to the optimization
            # in that case, compute_metrics should be False.
            if compute_metrics:
                target_new_x = self.project_to_target(new_x)
                target_train_x = torch.cat([target_train_x, target_new_x], dim=0)
                # We need obtain the new observations at the target fidelity
                new_target_fid_obj = self.problem(target_new_x)
                if self.has_constraints:
                    new_target_fid_constraints = -self.problem.evaluate_slack(target_new_x)
                    target_fid_constraints = torch.cat(
                        [target_fid_constraints, new_target_fid_constraints], 
                        dim=0
                    )
                target_fid_obj = torch.cat([target_fid_obj, new_target_fid_obj], dim=0)
                ## Evaluate optimization performance
                # Compute HV and regret using NSGA-II on the posterior 
                (met_curr_nsga2_hv, met_curr_nsga2_regret, met_curr_nsga2_violation,
                    met_observed_hv, met_observed_regret, met_observed_violation) = (
                        self.get_metrics(
                            is_mf_model=True,
                            # future me: we need to drop the fidelity dim for
                            # evaluating the problem
                            input_dim_without_fid=self.bounds.size(-1) - 1,
                            project=self.project_to_target,
                            compute_metrics=compute_metrics,
                            model=model,
                            train_objectives=target_fid_obj,
                            train_constraints=target_fid_constraints,
                            optimization_pref=optimization_pref,
                            objective_indices=objective_indices,
                            constraints_indices=constraints_indices,
                            seed=get_seed
                        )
                    )
                if has_maxhv:
                    if met_curr_nsga2_regret is not None:
                        if met_curr_nsga2_regret < met_best_nsga2_regret:
                            met_best_nsga2_regret = met_curr_nsga2_regret
                if met_curr_nsga2_hv is not None:
                    if met_curr_nsga2_hv > met_best_nsga2_hv:
                        met_best_nsga2_hv = met_curr_nsga2_hv 

            stats = self.exp_stats(
                budget=budget,
                initial_cost=initial_cost,
                seed=get_seed,
                compute_metrics=compute_metrics,
                has_maxhv=has_maxhv,
                iteration=iteration,
                current_cost=current_cost,
                acq_value=acq_value,
                best_nsga2_regret=met_best_nsga2_regret,
                curr_nsga2_regret=met_curr_nsga2_regret,
                curr_nsga2_violation=met_curr_nsga2_violation,
                observed_regret=met_observed_regret,
                best_nsga2_hv=met_best_nsga2_hv,
                curr_nsga2_hv=met_curr_nsga2_hv,
                observed_hv=met_observed_hv,
                observed_violation=met_observed_violation,
                new_fidelity=new_fidelity
            )           
            self.log_exp_stats_to_wandb(
                iter=iteration, 
                exp_stats=stats,
                compute_metrics=compute_metrics,
            )
            self.save_stats_to_json(stats)
            # +===========================================+
            # |          End Do not contribute            |
            # |     to the optimization performance       |
            # +===========================================+ 

            # Update state
            self.update_state(
                is_multifidelity=True,
                get_seed=get_seed,
                budget=budget,
                ref_point=ref_point,
                max_hv=max_hv,
                initial_cost=initial_cost,
                current_cost=current_cost,
                iteration=iteration,
                met_best_nsga2_regret=met_best_nsga2_regret,
                met_best_nsga2_hv=met_best_nsga2_hv,
                acquisition_value=acq_value,
                new_fidelity=new_fidelity,
                train_x=train_x,
                train_obj=train_obj,
                train_constraints=train_constraints,
                objective_indices=objective_indices,
                constraints_indices=constraints_indices,
                model=model,
                fixed_features_list=fixed_features_list,
                target_train_x=target_train_x,
                target_fid_obj=target_fid_obj,
                target_fid_constraints=target_fid_constraints
            )
            self.log_artifacts_to_wandb()

            # === Terminal log ===
            self.term_print(
                show_stats=show_status, 
                exp_stats=stats,
                budget=budget
                )
            torch.cuda.empty_cache()
        if self.wandb_run is not None:
            self.wandb_run.finish()

        res_pymoo = nsga2_posterior_pareto(
            model=model,
            input_dim_without_fid=self.bounds.size(-1) - 1,
            num_objectives=len(objective_indices),
            objective_indices=objective_indices,
            is_mf_model=True,
            device=self.tkwargs['device'],
            dtype=self.tkwargs['dtype'],
            project_to_target_fidelity=self.project_to_target,
            constraints_indices=constraints_indices,
        )
        res = {
            "pareto_X": res_pymoo.X,
            "pareto_Y": res_pymoo.F,
            "constraints": res_pymoo.G,
            "pymoo_res": res_pymoo,
            "gp_model": model,
        }
        return res