#! /usr/bin/env python3


from __future__ import annotations

from typing import Callable

import time

import numpy as np
import random
import torch
from torch import Tensor
from pymoo.core.problem import Problem
from pymoo.algorithms.moo.nsga2 import NSGA2
from pymoo.optimize import minimize
from pymoo.termination.max_gen import MaximumGenerationTermination   
from botorch.test_functions.multi_objective import (
    MultiObjectiveTestProblem,
    ConstrainedBaseTestProblem
)
from botorch.acquisition.utils import project_to_target_fidelity
from botorch.utils.transforms import unnormalize

from rescue.metrics.optimization_pref import MultiObjectiveOptimizationPref

from rich import print

def find_true_pareto_nsga2(
    problem: MultiObjectiveTestProblem | ConstrainedBaseTestProblem,
    input_dim_without_fid: int,
    num_objectives: int,
    ref_point: torch.Tensor,
    is_mf_problem: bool,
    max_hv: float | None = None,
    population_size: int = 250,
    max_gen: int = 100,
    num_constraints: None | int = None,
    seed: None | int = None,
    verbose: bool = False,
    project_to_target: None | Callable[[Tensor], Tensor] = None,
    device: torch.device | None = None,
    dtype: torch.dtype | None = None
) -> tuple[float, float | None]:
    r""" 
    Find the true Pareto front using NSGA-II optimization.

    Args:
        problem (MultiObjectiveTestProblem | 
            ConstrainedBaseTestProblem): The optimization problem.
        input_dim_without_fid (int): Input dimensionality without 
            fidelity.
        num_objectives (int): Number of objectives.
        ref_point (torch.Tensor): Reference point for hypervolume.
        is_mf_problem (bool): Whether problem is multi-fidelity.
        max_hv (float | None): Maximum hypervolume for regret 
            computation.
        population_size (int): NSGA-II population size.
        max_gen (int): Maximum number of generations.
        num_constraints (None | int): Number of constraints.
        seed (None | int): Random seed.
        verbose (bool): Whether to show verbose output.
        project_to_target (None | Callable[[Tensor], Tensor]): 
            Function to project to target fidelity.
        device (torch.device | None): Compute device.
        dtype (torch.dtype | None): Data type.

    Returns:
        tuple[float, float | None]: Hypervolume and regret.

    Note:
        Assumes problem(negate=True) and -problem.evaluate_slack(X).
        When is_mf_problem is True and project_to_target is None,
        assumes last dim of the input is the fidelity dimension and
        target fidelity is 1.0.
    """
    if is_mf_problem:
        if project_to_target is None:
            project_to_target = project_to_target_fidelity(
                    X=X,
                    d=problem.dim,
                    target_fidelities={problem.dim - 1: 1.0},
                )
        bounds = problem.bounds[:, :-1]
    else:
        bounds = problem.bounds
    has_constraints = num_constraints is not None
    tkwargs = {
        "dtype": dtype if dtype is not None else torch.double,
        "device": device if device is not None else torch.device("cpu"),
    }

    class PymooProblem(Problem):
        def __init__(self):
            super().__init__(
                n_var=input_dim_without_fid,
                n_obj=num_objectives,
                n_ieq_constr=num_constraints if has_constraints else 0,
                type_var=np.double,
            )
            self.xl = np.zeros(input_dim_without_fid)
            self.xu = np.ones(input_dim_without_fid)

        def _evaluate(self, x, out, *args, **kwargs):
            # This is important to avoid bound validation errors
            X = unnormalize(
                torch.from_numpy(x).to(**tkwargs), bounds
            )
            if is_mf_problem:
                X = project_to_target(X)
            out["F"] = -problem(X).detach().cpu().numpy()
            if has_constraints:
                out["G"] = -problem.evaluate_slack(X).cpu().numpy()


    pymoo_problem = PymooProblem()
    algorithm = NSGA2(
        pop_size=population_size,
        eliminate_duplicates=True,
    )
    res = minimize(
        pymoo_problem,
        algorithm,
        termination=MaximumGenerationTermination(max_gen),
        seed=seed if seed is not None else None,
        verbose=verbose,
    )
    X = torch.tensor(
        res.X,
        **tkwargs,
    )
    X = unnormalize(X, bounds)
    if is_mf_problem:
        X = project_to_target(X)
    true_obj = problem(X)
    true_constraints = None
    if has_constraints:
        true_constraints = -problem.evaluate_slack(X)
    pref = MultiObjectiveOptimizationPref(
        ref_point=ref_point,
        max_hv=max_hv,
        **tkwargs
    )
    volume, regret, _ = pref.computed_observed_hv_regret(
        train_objectives=true_obj,
        train_constraints=true_constraints
    )
    return volume, regret


def get_best_hypervolume(
    problem: MultiObjectiveTestProblem | ConstrainedBaseTestProblem,
    is_mf_problem: bool,
    population_size: int = 250,
    max_gen: int = 100,
    seeds: list[int] | None = None,
) -> list[tuple[int, float]]:
    r"""
    Finds the best hypervolume for the given problem.
    Uses NSGA-II to optimize the hypervolume for 1000 unique seeds.

    Args:
        problem (MultiObjectiveTestProblem | 
            ConstrainedBaseTestProblem): The multi-objective test 
            problem.
        is_mf_problem (bool): Whether the problem is multi-fidelity.
        population_size (int): NSGA-II population size.
        max_gen (int): Maximum number of generations.
        seeds (list[int] | None): List of seeds to use. If None, uses 
            100 random seeds + 42.

    Returns:
        list[tuple[int, float]]: List of (seed, hypervolume) tuples.

    Note:
        Assumes fidelity dimension is the last dimension and
        target fidelity is 1.0.
    """
    num_objectives = problem.num_objectives
    if isinstance(problem, ConstrainedBaseTestProblem):
        num_constraints = problem.num_constraints
    else:
        num_constraints = None
    if not isinstance(problem, MultiObjectiveTestProblem):
        raise ValueError("Problem must be a MultiObjectiveTestProblem.")
    print("="*40)
    print("Problem:", problem.__class__.__name__)
    print("Input dim", problem.dim)
    print("Num objectives:", num_objectives)
    print("Num constraints:", num_constraints if num_constraints is not None else 0)
    print("Is multi-fidelity problem:", is_mf_problem)
    print("Population size:", population_size)
    print("Max generations:", max_gen)
    print("="*40)
    tkwargs = {
        "dtype": torch.double,
        "device": torch.device("cuda") if torch.cuda.is_available() \
            else torch.device("cpu"),
    }
    input_dim_without_fid = problem.dim
    if is_mf_problem:
        input_dim_without_fid = problem.dim - 1
        def project(X: Tensor) -> Tensor:

            return project_to_target_fidelity(
                X=X,
                d=problem.dim,
                target_fidelities={problem.dim - 1: 1.0},
            )
    if seeds is None: 
        seeds = random.sample(range(1000), 100)

    max_hv = []
    best_hv = -float("inf")
    itr = 0
    start_time = time.time()
    iteration_times = []

    print(f"Finding best hypervolume over {len(seeds)} seeds...")
    for seed in seeds:
        iter_start = time.time()
        
        hv, _ = find_true_pareto_nsga2(
            problem=problem,
            is_mf_problem=is_mf_problem,
            project_to_target=project if is_mf_problem else None,
            input_dim_without_fid=input_dim_without_fid,
            num_objectives=problem.num_objectives,
            num_constraints=num_constraints,
            ref_point=problem.ref_point,
            population_size=population_size,
            max_gen=max_gen,
            seed=seed,
            **tkwargs
        )
        
        iter_time = time.time() - iter_start
        iteration_times.append(iter_time)
        avg_time = sum(iteration_times) / len(iteration_times)
        remaining_iters = len(seeds) - itr - 1
        eta = avg_time * remaining_iters
        
        max_hv.append(hv)
        if hv > best_hv:
            best_hv = hv
            best_seed = seed
        itr += 1
        
        print(
            f"[{itr}/{len(seeds)}] Seed: {seed}, HV: {hv},"
            f" Best HV so far: {best_hv} (seed: {best_seed})"
            f" | Iter time: {iter_time:.2f}s, Avg: {avg_time:.2f}s, ETA: {eta/60:.1f}m"
        )

    total_time = time.time() - start_time
    print(
        f"Max HV over {len(seeds)} seeds: {max(max_hv)} " 
        f"(seed: {seeds[max_hv.index(max(max_hv))]})"
        f" | Total time: {total_time/60:.1f}m"
    )
    return list(zip(seeds, max_hv))


if __name__ == "__main__":
    from botorch.test_functions.multi_objective_multi_fidelity import MOMFPark
    from rescue.problems.hpo.multi_objective import HPOXGBoost

    tkwargs = {
        "dtype": torch.double,
        "device": torch.device("cuda") if torch.cuda.is_available() \
            else torch.device("cpu"),
    }
    # PARK = MOMFPark(negate=True).to(**tkwargs)
    # get_best_hypervolume(problem=PARK, is_mf_problem=True)

    HPO = HPOXGBoost(negate=True).to(**tkwargs)
    get_best_hypervolume(problem=HPO, is_mf_problem=True)