#!/usr/bin/env python3

from __future__ import annotations

from typing import Callable
import warnings

import torch
from torch import Tensor
import numpy as np
from botorch.models.model import Model
from botorch.utils.multi_objective.pareto import (
    is_non_dominated,
) 
from botorch.utils.multi_objective.hypervolume import Hypervolume
from botorch.utils.transforms import unnormalize
from botorch.test_functions.base import (
    ConstrainedBaseTestProblem,
    MultiObjectiveTestProblem,
)
from botorch.utils.multi_objective.box_decompositions.non_dominated import (
    FastNondominatedPartitioning,
)
from pymoo.algorithms.moo.nsga2 import NSGA2
from pymoo.core.problem import Problem
from pymoo.optimize import minimize
from pymoo.termination.max_gen import MaximumGenerationTermination   
from pymoo.core.result import Result

def nsga2_posterior_pareto(
    model: Model,
    input_dim_without_fid: int,  
    num_objectives: int,
    objective_indices: list[int],
    is_mf_model: bool,
    device: torch.device,
    dtype: torch.dtype,
    population_size: int = 250,
    max_gen: int = 100,
    project_to_target_fidelity: None | Callable[[Tensor], Tensor] = None,
    constraints_indices: None | list[int] = None,
    seed: None | int = None,
    verbose: bool = False,
) -> Result:
    r"""
    Optimize the posterior mean using NSGA-II.

    Args:
        model (Model): The trained model.
        input_dim_without_fid (int): The dimensionality of the input 
            space without fidelity dimensions.
        num_objectives (int): The number of objectives to optimize.
        objective_indices (list[int]): The indices of the objectives 
            to optimize.
        is_mf_model (bool): Whether the model is a multi-fidelity model.
        device (torch.device): The device to use for computations.
        dtype (torch.dtype): The data type to use for computations.
        population_size (int): The size of the population for NSGA-II.
        max_gen (int): The maximum number of generations for NSGA-II.
        project_to_target_fidelity (None | Callable[[Tensor], Tensor]): 
            A function to project inputs to the target fidelity.
        constraints_indices (None | list[int]): The indices of the 
            constraints to consider.
        seed (None | int): Seed used in `pymoo.optimize.minimize`.
        verbose (bool): Whether to show `pymoo` verbose output.

    Returns:
        res: `class:~pymoo.core.result.Result` The optimization result 
            represented as an object.

    Raises:
        ValueError: If `constraints_indices` are specified without 
            `objective_indices`.
        ValueError: If `is_mf_model` is True, 
            `project_to_target_fidelity` must be specified.
    """ 
    if is_mf_model and project_to_target_fidelity is None:
        raise ValueError("If `is_mf_model` is True, " \
                         "`project_to_target_fidelity` must be specified."
                         )

    tkwargs = {
        "dtype": dtype,
        "device": device,
    }
    has_constraints = constraints_indices is not None
    num_constraints = len(constraints_indices) if has_constraints else 0
    class PosteriorMeanPymooProblem(Problem):
        def __init__(self):
            super().__init__(
                n_var=input_dim_without_fid,
                n_obj=num_objectives,
                n_ieq_constr=num_constraints,
                type_var=np.double,
            )
            self.xl = np.zeros(input_dim_without_fid)
            self.xu = np.ones(input_dim_without_fid)

        def _evaluate(self, x, out, *args, **kwargs):
            X = torch.from_numpy(x).to(**tkwargs)
            if is_mf_model:
                X = project_to_target_fidelity(X)
            with torch.no_grad():
                    # eval in batch mode
                    y = model.posterior(X.unsqueeze(-2)).mean.squeeze(-2)
            if num_objectives + num_constraints != y.shape[-1]:
                if objective_indices is None:
                    raise ValueError("Objective indices must be specified.")
            out["F"] = -y[:, objective_indices].cpu().numpy()
            if has_constraints:
                out["G"] = y[:, constraints_indices].cpu().numpy()

    pymoo_problem = PosteriorMeanPymooProblem()
    algorithm = NSGA2(
        pop_size=population_size,
        eliminate_duplicates=True,
    )
    res = minimize(
        pymoo_problem,
        algorithm,
        termination=MaximumGenerationTermination(max_gen),
        seed=seed if seed is not None else None,
        verbose=verbose,
    )
    return res


class MultiObjectiveOptimizationPref:
    def __init__(
        self,
        ref_point: Tensor,
        device: torch.device,
        dtype: torch.dtype,
        max_hv: None | float | Tensor = None,
    ) -> None:
        r"""
        Measures the performance of a multi-objective optimization method 
        in terms of hypervolume and log_regret.

        Args:
            ref_point (Tensor): The reference point for hypervolume 
                calculation.
            device (torch.device): The device to which tensors should be 
                moved.
            dtype (torch.dtype): The dtype of the tensors.
            max_hv (None | float | Tensor): The maximum hypervolume 
                observed. NOTE: This could be approximated using NSGA-II.
        """

        # TODO: Make it possible to compute regret when `max_hv`
        # is not known. 
        
        self.ref_point = ref_point
        self.max_hv = max_hv

        self.tkwargs = {
            "dtype": dtype,
            "device": device,
        }
        self.hv = Hypervolume(ref_point=self.ref_point)
        self.has_max_hv = self.max_hv is not None


    def compute_nsga2_posterior_hv_regret(
        self,
        input_dim_without_fid: int,
        num_objectives: int,
        model: Model,
        problem: MultiObjectiveTestProblem | ConstrainedBaseTestProblem, 
        bounds: Tensor,       
        is_mf_model: bool,
        population_size: int = 250,
        max_gen: int = 100,
        verbose: bool = False,
        seed: None | int = None,
        objective_indices: None | list[int] = None,
        constraints_indices: None | list[int] = None,  
        project_to_target_fidelity: None | Callable[[Tensor], Tensor] = None,      
    ) -> tuple[float, float | None, int | None]:
        r"""
        Compute the NSGA-II posterior hypervolume regret.

        Args:
            input_dim_without_fid (int): The dimensionality of the input 
                space.
            num_objectives (int): The number of objectives.
            model (Model): Trained model.
            problem (MultiObjectiveTestProblem | 
                ConstrainedBaseTestProblem): The optimization problem to 
                solve. NOTE: Must be in `botorch` `test_functions` format.
            bounds (Tensor): The bounds of input.
            is_mf_model (bool): Whether the model is a multi-fidelity 
                model.
            population_size (int): The size of the population for NSGA-II.
            max_gen (int): The maximum number of generations for NSGA-II.
            verbose (bool): Whether to show `pymoo` verbose output.
            seed (None | int): Seed used in `pymoo.optimize.minimize`.
            objective_indices (None | list): The indices of the objective 
                functions.
            constraints_indices (None | list): The indices of the constraint 
                functions.
            project_to_target_fidelity (None | Callable[[Tensor], Tensor]): 
                A function to Project X onto the target set of fidelities. 
                NOTE: similar to 
                `botorch.acquisition.utils.project_to_target_fidelity`

        Returns:
            tuple[float, float | None, int | None]: The (hypervolume, 
                log_regret, violation).
                - When `max_hv` is `None`, regret is `None`
                - When constraints are provided and no feasible points are 
                    found, hypervolume is `-inf` and regret is `inf`.

        Raises:
            ValueError: If `constraints_indices` are specified without 
                `objective_indices`.
            ValueError: If `is_mf_model` is True without 
                `project_to_target_fidelity`.
        """


        if constraints_indices is not None and objective_indices is None:
            raise ValueError("If `constraints_indices` are specified, "
            "`objective_indices` must be specified."
            )
        if is_mf_model and project_to_target_fidelity is None:
            raise ValueError("If `is_mf_model` is True, "
                            "`project_to_target_fidelity` must be specified."
                            )
        # bounds = bounds[:, :-1] if is_mf_model else bounds
        has_constraints = constraints_indices is not None
        violation = None
        res = nsga2_posterior_pareto(
            model=model,
            input_dim_without_fid=input_dim_without_fid,
            num_objectives=num_objectives,
            is_mf_model=is_mf_model,
            project_to_target_fidelity=project_to_target_fidelity,
            objective_indices=objective_indices,
            constraints_indices=constraints_indices,
            population_size=population_size,
            max_gen=max_gen,
            seed=seed,
            verbose=verbose,
            **self.tkwargs,
        )
        if res.X is None:
            # No solution found
            warnings.warn("NSGA-II did not find a solution.", RuntimeWarning)
            if has_constraints:
                violation = 1
            volume = 0.0
            if self.has_max_hv:
                regret = np.log10(self.max_hv - volume) 
            else:
                regret = None
            return volume, regret, violation
        
        X = torch.tensor(
            res.X,
            **self.tkwargs,
        )
        if is_mf_model:
            X = project_to_target_fidelity(X)
        # determine Pareto set of designs under model
        with torch.no_grad():
            preds = model.posterior(X.unsqueeze(-2)).mean.squeeze(-2)
        X = unnormalize(X, bounds)
        if has_constraints:
            violation = 0
            preds_obj = preds[:, objective_indices]
            preds_constr = preds[:, constraints_indices]
            # Filter feasible points according to model's predicted constraints
            is_feas_pred = (preds_constr <= 0).all(dim=-1)
            feas_pred_obj = preds_obj[is_feas_pred]
            feas_X = X[is_feas_pred]

            if feas_pred_obj.shape[0] > 0:
                # Find Pareto front among feasible predicted points
                pareto_mask = is_non_dominated(feas_pred_obj)
                pareto_X = feas_X[pareto_mask]

                # Evaluate the true objectives and constraints at the Pareto points
                true_obj = problem(pareto_X)
                true_constr = -problem.evaluate_slack(pareto_X)
                is_feas_true = (true_constr <= 0).all(dim=-1)
                feas_true_obj = true_obj[is_feas_true]

                if feas_true_obj.shape[0] > 0:
                    # Find final Pareto front among truly feasible points
                    pareto_mask = is_non_dominated(feas_true_obj)
                    pareto_y = feas_true_obj[pareto_mask]
                    volume = self.hv.compute(pareto_y)
                else:
                    # No true feasible points found
                    warnings.warn(
                        "No true feasible points found among predicted feasible points.",
                        RuntimeWarning,
                    )
                    violation = 1
                    volume = 0.0
            else:
                # No feasible points predicted by model
                warnings.warn(
                    "No feasible points predicted by the model.",
                    RuntimeWarning,
                )
                violation = 1
                volume = 0.0
            if self.has_max_hv:
                regret = np.log10(self.max_hv - volume)  
            else:
                regret = None
        else:
            pareto_mask = is_non_dominated(preds)
            X = X[pareto_mask]            
            Y = problem(X)
            partitioning = FastNondominatedPartitioning(ref_point=self.ref_point, Y=Y)
            volume = partitioning.compute_hypervolume().item()
            if self.has_max_hv:
                regret = np.log10(self.max_hv - volume)   
            else:
                regret = None
        return volume, regret, violation
    

    def computed_observed_hv_regret(
        self,
        train_objectives: Tensor,
        train_constraints: None | Tensor = None,
    ) -> tuple[float, float | None, int | None]:
        r"""
        Compute the observed hypervolume and log_regret based on the 
        training data. This function evaluates an acquisition function 
        performance.

        Args:
            train_objectives (Tensor): The training objective values.
            train_constraints (None | Tensor): The training constraint 
                values.

        Returns:
            tuple: A tuple containing the observed (hypervolume, 
                log_regret, violation).
        """
        
        has_constraints = train_constraints is not None
        violation = None
        if has_constraints:
            violation = 0
            hv = Hypervolume(ref_point=self.ref_point)
            is_feas = (train_constraints <= 0).all(dim=-1)
            feas_train_obj = train_objectives[is_feas]
            if feas_train_obj.shape[0] > 0:
                pareto_mask = is_non_dominated(feas_train_obj)
                pareto_y = feas_train_obj[pareto_mask]
                # compute hypervolume
                observed_volume = hv.compute(pareto_y)
            else:
                warnings.warn(
                    "No feasible points found in training data for observed hypervolume computation.",
                    RuntimeWarning,
                )
                violation = 1
                observed_volume = 0.0
            if self.has_max_hv:
                observed_regret = np.log10(self.max_hv - observed_volume)
            else:
                observed_regret = None
        else:
            partitioning = FastNondominatedPartitioning(
                    ref_point=self.ref_point,
                    Y=train_objectives
            )
            observed_volume = partitioning.compute_hypervolume().item()
            if self.has_max_hv:
                observed_regret = np.log10(self.max_hv - observed_volume)
            else:
                observed_regret = None
        return observed_volume, observed_regret, violation