#!/usr/bin/env python3

r"""
https://botorch.org/docs/v0.15.0/tutorials/multi_objective_bo/

.. [Daulton2020]
    Daulton, S., Balandat, M., & Bakshy, E. (2020). 
    Differentiable expected hypervolume improvement 
    for parallel multi-objective Bayesian optimization. 
    Advances in neural information processing systems, 
    33, 9851-9864.

"""

from __future__ import annotations

from typing import Callable

import torch
from torch import Tensor
from wandb.sdk.wandb_run import Run

from botorch.models.gp_regression import SingleTaskGP
from botorch.utils.transforms import normalize
from botorch.test_functions.base import (
    ConstrainedBaseTestProblem,
    MultiObjectiveTestProblem,
)
from botorch.acquisition.multi_objective import (
    qLogExpectedHypervolumeImprovement,
)
from botorch.utils.multi_objective.box_decompositions.non_dominated import (
    FastNondominatedPartitioning,
)
from botorch.sampling.normal import SobolQMCNormalSampler
from botorch.optim.optimize import optimize_acqf
from botorch.acquisition.multi_objective.objective import IdentityMCMultiOutputObjective
from botorch.utils.transforms import unnormalize
from botorch import fit_gpytorch_mll

from rescue.metrics.optimization_pref import (
    MultiObjectiveOptimizationPref,
    nsga2_posterior_pareto
)
from rescue.utils.utils import status

from baselines.base import BaseBaseline, GPFn, GenInitDataFn


class BaselineSFMO(BaseBaseline):
    def __init__(
        self,
        problem: MultiObjectiveTestProblem | ConstrainedBaseTestProblem,
        bounds: Tensor,
        has_constraints: bool,
        cost_fn: Callable[[Tensor], Tensor] | None = None,
        gen_initial_data: GenInitDataFn | None = None,
        status_spinner: bool = True,
        custom_model: None | GPFn = None,
        device: None | torch.device = None,
        dtype: None | torch.dtype = None,  
        wandb_run: None | Run = None,
        algorithm_state: None | dict = None
    ):
        r""" 
        Initialize the baseline algorithm.

        Args:
            problem (MultiObjectiveTestProblem | 
                ConstrainedBaseTestProblem): The optimization problem to 
                solve.
            bounds (Tensor): The bounds for the optimization problem.
            has_constraints (bool): Whether the problem has constraints.
            cost_fn (Callable[[Tensor], Tensor] | None): The cost 
                function for evaluations.
            gen_initial_data (GenInitDataFn | None): A function to 
                generate initial data.
            status_spinner (bool): Whether to show a status spinner 
                during optimization.
            custom_model (None | GPFn): A custom model initialization 
                function.
            device (None | torch.device): The device to run the 
                optimization on.
            dtype (None | torch.dtype): The data type for the tensors.
            wandb_run (None | Run): The Weights & Biases run object for 
                logging.
            algorithm_state (None | dict): The state of the algorithm (if 
                resuming).
        """
        super().__init__(
            problem=problem,
            cost_fn=cost_fn,
            custom_model=custom_model,
            gen_initial_data=gen_initial_data,
            bounds=bounds,
            has_constraints=has_constraints,
            status_spinner=status_spinner,
            device=device,
            dtype=dtype,
            wandb_run=wandb_run,
            algorithm_state=algorithm_state,
        )

    @status(show_func_name=True)
    def _generate_initial_data(
        self,
        n: float | int | None = None,
    ) -> tuple[Tensor, Tensor, Tensor | None]:
        r"""
        Generate initial training data.
        
        Args:
            n (float | int | None): Number of initial samples.
            
        Returns:
            tuple[Tensor, Tensor, Tensor | None]: 
                - train_x: Input samples with dim `n x d`.
                - train_obj: Objective values with dim `n x m`.
                - train_constraints: Constraint values with dim `n x c` 
                    (None if no constraints).
        """
        return self.generate_initial_data(n=n)

    @status(show_func_name=True) 
    def fit_gp(self, mll) -> None:
        r"""Fit the GP model.

        Args:
            mll: The marginal likelihood to optimize.
        """
        fit_gpytorch_mll(mll)

    @status(show_func_name=True)
    def optimize_acqufn(
        self,
        model: SingleTaskGP,
        train_obj: torch.Tensor,
        ref_point: torch.Tensor,
        q: int,
        num_restarts: int,
        raw_samples: int,
        sampler: SobolQMCNormalSampler,
        objective_indices: None | list[int] = None,
        constraints_indices: None | list[int] = None,
        options: None | dict[str, bool | float | int | str] = None
    ) -> tuple[Tensor, Tensor]:
        r"""
        Optimize acquisition function.

        Args:
            model (SingleTaskGP): The GP model.
            train_obj (torch.Tensor): Training objectives.
            ref_point (torch.Tensor): Reference point.
            q (int): Batch size.
            num_restarts (int): Number of optimization restarts.
            raw_samples (int): Number of raw samples.
            sampler (SobolQMCNormalSampler): MC sampler.
            objective_indices (None | list[int]): Objective indices.
            constraints_indices (None | list[int]): Constraint indices.
            options (None | dict[str, bool | float | int | str]): 
                Optimization options.

        Returns:
            tuple[Tensor, Tensor]: New candidates and acquisition 
                values.

        Note:
            Last `c` dim corresponds to constraints.
        """
        partitioning = FastNondominatedPartitioning(ref_point=ref_point, Y=train_obj)
        if self.has_constraints:
            # Constraints are in the model after all objectives
            # Model structure: [obj0, obj1, ..., objN, cons0, cons1, ...]
            num_objectives = len(objective_indices)
            constraints = [
                lambda Z, idx=num_objectives+j: Z[..., idx] 
                for j in range(len(constraints_indices))
            ]
        acq_func = qLogExpectedHypervolumeImprovement(
            model=model,
            ref_point=ref_point,  # use known reference point
            partitioning=partitioning,
            sampler=sampler,
            objective=IdentityMCMultiOutputObjective(
                outcomes=objective_indices
            ),
            constraints=constraints if self.has_constraints else None,
        )
        # Optimization
        candidates, vals = optimize_acqf(
            acq_function=acq_func,
            bounds=self.standard_bounds,
            q=q,
            num_restarts=num_restarts,
            raw_samples=raw_samples,  # used for intialization heuristic
            sequential=True,
            options=options
        )
        new_x = unnormalize(candidates.detach(), bounds=self.bounds)
        return new_x, vals


    def run(
        self,
        budget: int | float,
        ref_point: Tensor,
        init_budget: int | None = None,
        include_initcost_to_budget: bool = True,
        q: int = 1,
        acqfn_raw_samples: int = 512,
        acqfn_mc_samples: int = 128,
        acqfn_num_restarts: int = 10,
        optim_acqfn_options: None | dict[str, bool | float | int | str] = None,
        objective_indices: None | list[int] = None,
        constraints_indices: None | list = None, 
        compute_metrics: bool = True,
        max_hv: float | Tensor | None = None, 
        seed: int | None = None,
        show_status: bool = True
    ) -> None:
        r"""
        Run the optimization algorithm.

        Args:
            budget (int | float): Total optimization budget.
            ref_point (Tensor): Reference point for hypervolume.
            init_budget (int | None): Initial sampling budget.
            include_initcost_to_budget (bool): Whether to include 
                initial cost in budget.
            q (int): Batch size.
            acqfn_raw_samples (int): Raw samples for acquisition.
            acqfn_mc_samples (int): MC samples for acquisition.
            acqfn_num_restarts (int): Optimization restarts.
            optim_acqfn_options (None | dict[str, bool | float | int | 
                str]): Acquisition optimization options.
            objective_indices (None | list[int]): Objective indices.
            constraints_indices (None | list): Constraint indices.
            compute_metrics (bool): Whether to compute metrics.
            max_hv (float | Tensor | None): Maximum hypervolume.
            seed (int | None): Random seed.
            show_status (bool): Whether to show status messages.
        """
        
        self.validate_run_inputs(compute_metrics, objective_indices, constraints_indices)
        has_maxhv = max_hv is not None      
        # === Helper Functions for run ===
        def _get_total_cost(train_x: Tensor):
            if self.cost_fn is not None:
                return self.cost_fn(train_x).sum().item() 
            else:
                return train_x.shape[0]

        def _get_new_cost(new_x: Tensor):
            if self.cost_fn is not None:
                return self.cost_fn(new_x).item()
            else:
                return 1
        # === End Helper Functions for run ===

        if self.algorithm_state is not None:
            (
                iteration,
                get_seed,
                initial_cost,
                current_cost,             
                ref_point,
                max_hv,
                met_best_nsga2_regret,
                met_best_nsga2_hv,  
                train_x,
                train_obj,
                train_constraints,
                objective_indices,
                constraints_indices,
            ) = self.load_state()

            self.validated_budget(budget, current_cost) 

            mll, model = self.initialize_model(
                train_x=normalize(train_x, self.bounds), 
                train_obj=train_obj,
                train_constraints=train_constraints,
                state_dict=self.state['model'],
            )
            if compute_metrics:
                optimization_pref = MultiObjectiveOptimizationPref(
                    ref_point=ref_point,
                    max_hv=max_hv,
                    **self.tkwargs
                )   

        else:
            met_curr_nsga2_regret = None
            met_curr_nsga2_hv = None
            met_observed_regret = None
            met_observed_hv = None
            met_observed_violation = None
            met_curr_nsga2_violation = None
            met_best_nsga2_regret = None
            met_best_nsga2_hv = None  
            met_observed_violation = None
            met_curr_nsga2_violation = None
            iteration = 0
            current_cost = 0

            if seed is not None:
                torch.manual_seed(seed)
            get_seed = torch.initial_seed()
            
            # Intializing train_x to zero
            train_x, train_obj, train_constraints = self._generate_initial_data(init_budget)

            # run N_BATCH rounds of BayesOpt after the initial random batch
            initial_cost = _get_total_cost(train_x)
            print("Initial cumulative cost:", initial_cost) if show_status else None
            if include_initcost_to_budget:
                current_cost = initial_cost
            self.validated_budget(budget, current_cost) 

            mll, model = self.initialize_model(
                train_x=normalize(train_x, self.bounds), 
                train_obj=train_obj,
                train_constraints=train_constraints,
            )

            self.fit_gp(mll=mll)  # Fit the model            

            if compute_metrics:
                met_best_nsga2_regret = float("inf")
                met_best_nsga2_hv = float("-inf")                              
                optimization_pref = MultiObjectiveOptimizationPref(
                    ref_point=ref_point,
                    max_hv=max_hv,
                    **self.tkwargs
                )   

            self.update_state(
                is_multifidelity=False,
                get_seed=get_seed,
                budget=budget,
                ref_point=ref_point,
                max_hv=max_hv,
                initial_cost=initial_cost,
                current_cost=current_cost,
                iteration=iteration,
                met_best_nsga2_regret=met_best_nsga2_regret,
                met_best_nsga2_hv=met_best_nsga2_hv,
                train_x=train_x,
                train_obj=train_obj,
                train_constraints=train_constraints,
                objective_indices=objective_indices,
                constraints_indices=constraints_indices,
                model=model,
                acquisition_value=None,
                new_fidelity=None,
            )
            self.log_artifacts_to_wandb()
        # +=================================+
        # |            BO loop              |
        # +=================================+  
        # Generate Sampler
        sampler = SobolQMCNormalSampler(
            sample_shape=torch.Size([acqfn_mc_samples])
        )    
        while current_cost < budget:
            # optimize acquisition functions and get new observations
            new_x, acq_value = self.optimize_acqufn(
                model=model,
                train_obj=train_obj,
                ref_point=ref_point,
                q=q,
                raw_samples=acqfn_raw_samples,
                num_restarts=acqfn_num_restarts,
                objective_indices=objective_indices,
                constraints_indices=constraints_indices,
                sampler=sampler,
                options=optim_acqfn_options
            )

            # Real intervention
            new_obj, new_constraints = self.obtain_new_y(new_x)
            self.validate_new_constraints(new_constraints)
            
            # Updating train_x and train_obj
            train_x = torch.cat([train_x, new_x], dim=0)
            train_obj = torch.cat([train_obj, new_obj], dim=0)
            if new_constraints is not None:
                train_constraints = torch.cat([train_constraints, new_constraints])

            # reinitialize model
            mll, model = self.initialize_model(
                train_x=normalize(train_x, self.bounds), 
                train_obj=train_obj,
                train_constraints=train_constraints,
            )

            self.fit_gp(mll=mll)  # Fit the model  

            iteration += 1
            current_cost += _get_new_cost(new_x)    
            # +=================================+
            # |          End BO loop            |
            # +=================================+

            # Compute metrics
            if compute_metrics:
                ## Evaluate optimization performance
                # Compute HV and regret using NSGA-II on the posterior 
                (met_curr_nsga2_hv, met_curr_nsga2_regret, met_curr_nsga2_violation,
                    met_observed_hv, met_observed_regret, met_observed_violation) = (
                        self.get_metrics(
                            is_mf_model=False,
                            input_dim_without_fid=self.bounds.size(-1),
                            compute_metrics=compute_metrics,
                            model=model,
                            train_objectives=train_obj,
                            train_constraints=train_constraints,
                            optimization_pref=optimization_pref,
                            objective_indices=objective_indices,
                            constraints_indices=constraints_indices,
                            seed=get_seed
                        )
                    )
                if has_maxhv:
                    if met_curr_nsga2_regret is not None:
                        if met_curr_nsga2_regret < met_best_nsga2_regret:
                            met_best_nsga2_regret = met_curr_nsga2_regret
                if met_curr_nsga2_hv is not None:
                    if met_curr_nsga2_hv > met_best_nsga2_hv:
                        met_best_nsga2_hv = met_curr_nsga2_hv 

            stats = self.exp_stats(
                budget=budget,
                initial_cost=initial_cost,
                seed=get_seed,
                compute_metrics=compute_metrics,
                has_maxhv=has_maxhv,
                iteration=iteration,
                current_cost=current_cost,
                acq_value=acq_value,
                best_nsga2_regret=met_best_nsga2_regret,
                curr_nsga2_regret=met_curr_nsga2_regret,
                curr_nsga2_violation=met_curr_nsga2_violation,
                observed_regret=met_observed_regret,
                best_nsga2_hv=met_best_nsga2_hv,
                curr_nsga2_hv=met_curr_nsga2_hv,
                observed_violation=met_observed_violation,
                observed_hv=met_observed_hv,
            )           
            self.log_exp_stats_to_wandb(
                iter=iteration, 
                exp_stats=stats,
                compute_metrics=compute_metrics,
            )
            self.save_stats_to_json(stats)
            # End compute metrics

            # Update state
            self.update_state(
                is_multifidelity=False,
                get_seed=get_seed,
                budget=budget,
                ref_point=ref_point,
                max_hv=max_hv,
                initial_cost=initial_cost,
                current_cost=current_cost,
                iteration=iteration,
                met_best_nsga2_regret=met_best_nsga2_regret,
                met_best_nsga2_hv=met_best_nsga2_hv,
                train_x=train_x,
                train_obj=train_obj,
                train_constraints=train_constraints,
                objective_indices=objective_indices,
                constraints_indices=constraints_indices,
                model=model,
                acquisition_value=acq_value,
                new_fidelity=1.0,  # Single fidelity
            )
            self.log_artifacts_to_wandb()

            # === Terminal log ===
            self.term_print(
                show_stats=show_status, 
                exp_stats=stats,
                budget=budget
                )
            torch.cuda.empty_cache()
        if self.wandb_run is not None:
            self.wandb_run.finish()

        res_pymoo = nsga2_posterior_pareto(
            model=model,
            input_dim_without_fid=self.bounds.size(-1),
            num_objectives=len(objective_indices),
            objective_indices=objective_indices,
            is_mf_model=False,
            device=self.tkwargs['device'],
            dtype=self.tkwargs['dtype'],
            constraints_indices=constraints_indices,
        )
        res = {
            "pareto_X": res_pymoo.X,
            "pareto_Y": res_pymoo.F,
            "constraints": res_pymoo.G,
            "pymoo_res": res_pymoo,
            "gp_model": model,
        }
        return res