#! /usr/bin/env python3

from __future__ import annotations

from typing import Protocol

import torch
from torch import Tensor
from wandb.sdk.wandb_run import Run

from gpytorch.likelihoods import Likelihood
from gpytorch.models import ExactGP

from botorch.test_functions.base import (
    ConstrainedBaseTestProblem,
    MultiObjectiveTestProblem,
)
from botorch.acquisition.utils import project_to_target_fidelity
from botorch.models.model import Model


from rescue.algorithms.state import (
    RescueStateMultifidelity, 
    RescueState
)
from rescue.metrics.optimization_pref import (
    MultiObjectiveOptimizationPref
)
from rescue.metrics.utils import term_print, load_causal_model


class GPFn(Protocol):
    def __call__(
        self,
        *,
        train_x: Tensor,
        train_objectives: Tensor,
        train_constraints: Tensor,
        state_dict: dict,
    ) -> tuple[Model, Likelihood]:
        ...

class CausalGPFn(Protocol):
    def __call__(
        self,
        *,
        train_x: torch.Tensor,
        train_objectives: torch.Tensor,
        train_constraints: torch.Tensor,
        causal_net: torch.nn.Module,
        state_dict: dict,
        device: torch.device,
        dtype: torch.dtype,
        extra: object,
    ) -> tuple[ExactGP, Likelihood]:
        ...


class EvaluateOptimPref:
    def __init__(
        self,
        problem: MultiObjectiveTestProblem | ConstrainedBaseTestProblem,
        ref_point: Tensor,
        is_mf_model: bool,
        device: torch.device,
        dtype: torch.dtype,
        gp_model: GPFn | None = None,
        causal_gp_model: CausalGPFn | None = None,
        wandb_run: None | Run = None,
        max_hv: None | float = None,
    ) -> None:
        r"""
        Initialize the optimization preference evaluator.

        Args:
            problem (MultiObjectiveTestProblem | 
                ConstrainedBaseTestProblem): The optimization problem to 
                evaluate.
            ref_point (Tensor): The reference point for hypervolume 
                computation.
            is_mf_model (bool): Whether the model is multi-fidelity.
            device (torch.device): The device to use for computations.
            dtype (torch.dtype): The data type to use for computations.
            gp_model (GPFn | None): The GP model function. 
                Defaults to None.
            causal_gp_model (CausalGPFn | None): The causal GP model 
                function. Defaults to None.
            wandb_run (None | Run): The wandb run for logging. 
                Defaults to None.
            max_hv (None | float): The maximum hypervolume for regret 
                computation. Defaults to None.
        """
        
        if gp_model is None and causal_gp_model is None:
            raise ValueError(
                "At least one of GP model or Causal GP model must be provided."
            )
        if gp_model is not None and causal_gp_model is not None:
            raise NotImplementedError(
                "Both GP model and Causal GP model are not supported."
            )

        self.gp_model = gp_model
        self.causal_gp_model = causal_gp_model
        self.problem = problem
        self.ref_point = ref_point
        self.is_mf_model = is_mf_model
        self.device = device
        self.dtype = dtype
        self.max_hv = max_hv
        self.has_maxhv = max_hv is not None
        self.wandb_run = wandb_run

    def get_metrics_from_state(
        self,
        state: RescueState | RescueStateMultifidelity,
    ) -> tuple[float, float | None, float, float | None]:
        r"""
        Compute metrics for the current iteration from the state.

        Args:
            state (RescueState | RescueStateMultifidelity): 
                The current state of the optimization.

        Returns:
            tuple[float, float | None, float, float | None]: 
                A tuple containing the computed metrics:
                - nsga2_hv: Hypervolume computed on the posterior 
                    using NSGA-II.
                - nsga2_regret: Regret computed on the posterior 
                    using NSGA-II.
                - observed_hv: Observed hypervolume on the training set.
                - observed_regret: Observed regret on the training set.
        """
        if state['is_multifidelity'] and not self.is_mf_model:
            raise ValueError("state is multi-fidelity but `is_mf_model=False`! "
                                "set `is_mf_model=True` to match the state."
                             )
        if not state['is_multifidelity'] and self.is_mf_model:
            raise ValueError("state is single-fidelity but `is_mf_model=True`! "
                             "set `is_mf_model=False` to match the state."
                             )

        input_dim = state['train_x'].size(-1)
        input_dim_without_fid = input_dim
        if self.is_mf_model:
            input_dim_without_fid = input_dim - 1

        objective_indices = state['objective_indices']
        constraints_indices = state['constraints_indices']
        num_objectives = len(objective_indices)

        bounds = state['bounds']  

        train_x = state['train_x']
        train_obj = state['train_obj']
        train_constraints = state['train_constraints']
        state_dict = state['model']

        if self.is_mf_model:
            def project(X: Tensor) -> Tensor:
                target_fidelities = {input_dim - 1: 1.0}
                return project_to_target_fidelity(
                    X=X, 
                    target_fidelities=target_fidelities, 
                    d=input_dim
                )        
        if self.gp_model is not None:
            _, model = self.gp_model(
                train_x=train_x,
                train_obj=train_obj,
                train_constraints=train_constraints,
                state_dict=state_dict,
            )
        if self.causal_gp_model is not None:
            causal_net = load_causal_model(
                input_dim=input_dim,
                output_dim=num_objectives,
                device=self.device,
                dtype=self.dtype,
                state_dict=state['causal_net']
            )
            _, model = self.causal_gp_model(
                train_x=train_x,
                train_objectives=train_obj,
                train_constraints=train_constraints,
                causal_net=causal_net,
                state_dict=state_dict,
                device=self.device,
                dtype=self.dtype
            )
        optimization_pref = MultiObjectiveOptimizationPref(
            ref_point=self.ref_point,
            max_hv=self.max_hv,
            device=self.device,
            dtype=self.dtype
        )     
        # Compute HV and regret using NSGA-II on the posterior
        nsga2_hv, nsga2_regret = (
            optimization_pref.compute_nsga2_posterior_hv_regret(
                input_dim_without_fid=input_dim_without_fid,
                num_objectives=num_objectives,
                objective_indices=objective_indices,
                constraints_indices=constraints_indices,
                bounds=bounds,
                model=model,
                problem=self.problem,
                is_mf_model=self.is_mf_model,
                project_to_target_fidelity=project if self.is_mf_model else None,
                seed=None,
            )
        )
        # Compute observed HV and regret
        observed_hv, observed_regret = (
            optimization_pref.computed_observed_hv_regret(
                train_objectives=train_obj,
                train_constraints=train_constraints,
            )
        )
        return (
            nsga2_hv, nsga2_regret, 
            observed_hv, observed_regret
        )


    def exp_stats_from_state(
        self,
        state_list: list[RescueStateMultifidelity | RescueState],
        show_stats: bool = True
    ) -> dict[int, dict[str, float]]:
        r"""
        Compute and log experiment statistics from a list of states.

        Args:
            state_list (list[RescueStateMultifidelity | RescueState]): 
                List of optimization states to evaluate.
            show_stats (bool): Whether to print statistics to terminal. 
                Defaults to True.

        Returns:
            dict[int, dict[str, float]]: Dictionary mapping iteration 
                number to statistics dictionary.
        """
        
        best_nsga2_regret = None
        best_nsga2_hv = float('-inf') 
        if self.has_maxhv:
            best_nsga2_regret = float('inf')

        history = {}
        for state in state_list:
            budget = state['budget']
            current_cost = state['current_cost']
            acq_value = state['acquisition_value']
            new_fidelity = state['new_fidelity']
            initial_cost = state['initial_cost']
            iteration = state['iteration']
            seed = state['get_seed']
            stats = {
                    "budget": budget,
                    "cost": current_cost,
                    "acquisition_value": acq_value.item(),
                    "new_fidelity": new_fidelity,
                    "initial_cost": initial_cost,
                    "iteration": iteration,
                    "seed": seed,
                }      
            (curr_nsga2_hv, curr_nsga2_regret, 
                observed_hv, observed_regret) = self.get_metrics_from_state(state)
            
            if self.has_maxhv:
                if curr_nsga2_regret is not None:
                    if curr_nsga2_regret < best_nsga2_regret:
                        best_nsga2_regret = curr_nsga2_regret
            if curr_nsga2_hv is not None:
                if curr_nsga2_hv > best_nsga2_hv:
                    best_nsga2_hv = curr_nsga2_hv

            if self.has_maxhv:
                stats["best_nsga2_regret"] = best_nsga2_regret
                stats["curr_nsga2_regret"] = curr_nsga2_regret
                stats["best_nsga2_hv"] = best_nsga2_hv                     
                stats["curr_nsga2_hv"] = curr_nsga2_hv                     
                stats["observed_hv"] = observed_hv
                stats["observed_regret"] = observed_regret                   
            else:
                stats["best_nsga2_hv"] = best_nsga2_hv                     
                stats["curr_nsga2_hv"] = curr_nsga2_hv                     
                stats["observed_hv"] = observed_hv  

            term_print(
                show_stats=show_stats,
                exp_stats=stats,
                budget=budget
            )

            self._log_exp_stats_to_wandb(
                iter=iteration,
                exp_stats=stats
            )
            history[iteration] = stats
        if self.wandb_run is not None: 
            self.wandb_run.finish()
        return history
    

    def _log_exp_stats_to_wandb(
        self,
        iter: int,
        exp_stats: dict,
    ) -> None:
        r"""
        Log experiment statistics to Weights & Biases.

        Args:
            iter (int): The current iteration number.
            exp_stats (dict): Dictionary containing experiment statistics 
                to log.
        """
        if self.wandb_run is not None:
            self.wandb_run.log(exp_stats, step=iter)   