
#!/usr/bin/env python3

from __future__ import annotations

from typing import Callable

from dataclasses import dataclass
import torch
import pandas as pd
from torch import Tensor

from gpytorch.mlls import ExactMarginalLogLikelihood

from botorch.test_functions.base import MultiObjectiveTestProblem

from baselines.multitask_gp import MultitaskGP

from rescue.models.causal_gp.multitask_multifidelity import (
    CausalMultitaskMultifidelityGP
)
from rescue.models.causal_gp.multitask import CausalMultitaskGP

PROBLEM_NAMES = [
    "BraninCurrin",
    "Park4D",
    "HPOXGBoost",
    "HPOXGBoostConstrained",
    "HPORanger",
    "HPORangerConstrained",
    "Health",
    "AGVNavigation",
]

@dataclass
class ExpConfig:
    PROBLEM_MF: MultiObjectiveTestProblem
    OPTIMIZATION_BUDGET_MF: float

    HAS_CONSTRAINTS: bool
    IS_DISCRETE_FIDELITIES: bool
    FIDELITY_LEVELS: Tensor | None

    DESIGN_VAR_MF: list[str]
    OBJECTIVES_VAR: list[str]
    CONSTRAINTS_VAR: list[str] | None
    KPI_VAR: list[str] | None
    FIDELITY_PARAM_NAME: str

    OBJECTIVES_INDICES: list[int] | None
    CONSTRAINTS_INDICES: list[int] | None
    TARGET_FIDELITIES: dict[int, float]

    BOUNDS_MF: Tensor

    REF_POINT: Tensor
    REF_POINT_MOMF: Tensor

    MAX_HV: float

    CostFn_MF: Callable[[Tensor], Tensor]
    init_sample_MF: Callable[[float | int], 
                             tuple[torch.Tensor, Tensor, Tensor | None]]


    # For single-fidelity
    DESIGN_VAR_SF: list[str]
    OPTIMIZATION_BUDGET_SF: float
    INCLUDE_INITCOST_TO_BUDGET: bool

    BOUNDS_SF: Tensor

    PROBLEM_SF: MultiObjectiveTestProblem
    CostFn_SF: Callable[[Tensor], Tensor]
    init_sample_SF: Callable[[float | int], 
                             tuple[Tensor, Tensor, Tensor | None]]
    
    ## Method configs
    # Models
    RESCUE_model: Callable[[Tensor, Tensor, Tensor | None, dict | None],
                           tuple[
                               ExactMarginalLogLikelihood, 
                               CausalMultitaskMultifidelityGP]]

    MULTITASK_CGP_model: Callable[[Tensor, Tensor, Tensor | None, dict | None],
                               tuple[
                                   ExactMarginalLogLikelihood,
                                   CausalMultitaskGP]]
    
    MULTITASK_GP_model: Callable[[Tensor, Tensor, Tensor | None, dict | None],
                               tuple[
                                   ExactMarginalLogLikelihood,
                                   MultitaskGP]]

    # Rescue
    RESCUE_obs_data_MF: Callable[[int], pd.DataFrame]
    RESCUE_obs_data_SF: Callable[[int], pd.DataFrame]
    RESCUE_causal_net_epochs: int
    RESCUE_causal_intervention_samples: int
    RESCUE_causal_discovery_method: str
    RESCUE_causal_discovery_alpha: float | None
    RESCUE_fit_causal_model_after_each: int | None
    RESCUE_is_design_var_independent: bool
    RESCUE_acqfn_use_causalmomf: bool
    RESCUE_use_rescue_fit: bool

    ## Common in KG-based `optimize_acqf`
    OPTIM_INNER_ACQFN_OPTIONS: dict[str, bool | float | int | str] | None
    HVKG_num_fantasies: int
    HVKG_num_pareto: int
    HVKG_acqfn_curr_val_mc_samples: int
    HVKG_acqfn_curr_val_num_restarts: int
    HVKG_num_restarts: int
    HVKG_inner_mc_samples: int
    HVKG_use_posterior_mean: bool

    ## Common in ALL
    COMPUTE_METRICS: bool

    ACQFN_RAW_SAMPLES: int
    ACQFN_MC_SAMPLES: int
    ACQFN_NUM_RESTARTS: int
    OPTIM_ACQFN_OPTIONS: dict[str, bool | float | int | str] | None