
#!/usr/bin/env python3

from __future__ import annotations

import os
import pandas as pd
import torch
from torch import Tensor
from botorch.test_functions.multi_objective_multi_fidelity import MOMFBraninCurrin
from experiments.config import ExpConfig
from experiments.problems.branincurrin.models import (
    causal_multitask_multifidelity_gp_model,
    causal_multitask_gp_model,
    gp_model
)

tkwargs={
    "device": torch.device(os.getenv("EXP_DEVICE", "cpu")),
    "dtype": getattr(torch, os.getenv("EXP_DTYPE", "double")),
}

PROBLEM = MOMFBraninCurrin(negate=True).to(**tkwargs)
DESIGN_VAR_MF = ['X1', 'X2', 'S']
DESIGN_VAR_SF = ['X1', 'X2']
OBJECTIVES_VAR = ['Y1', 'Y2']
FIDELITY_PARAM_NAME = 'S'
CONSTRAINTS_VAR = None
KPI_VAR = None

OBJECTIVES_INDICES = [0, 1]
CONSTRAINTS_INDICES = None
TARGET_FIDELITIES = {len(DESIGN_VAR_MF) - 1: 1.0}

HAS_CONSTRAINTS = False
IS_DISCRETE_FIDELITIES = False
FIDELITY_LEVELS = None

BOUNDS_MF = PROBLEM.bounds
BOUNDS_SF = PROBLEM.bounds[:, :-1]  # exclude fidelity parameter

REF_POINT = torch.zeros(PROBLEM.num_objectives, **tkwargs)
REF_POINT_MOMF = torch.zeros(PROBLEM.num_objectives + 1, **tkwargs)

initdata_dir_mf = "data/init_trails/BraninCurrin_init_trails_mf.pt"
initdata_dir_sf = "data/init_trails/BraninCurrin_init_trails_sf.pt"

obsdata_dir_mf = "data/observational_data/BraninCurrin_obs_data_mf.pt"
obsdata_dir_sf = "data/observational_data/BraninCurrin_obs_data_sf.pt"

def CostFn_MF(X: Tensor) -> Tensor:
    """A simple exponential cost function."""
    exp_arg = torch.tensor(4.8, dtype=X.dtype, device=X.device)
    val = torch.exp(exp_arg * X[..., -1:])
    return val

# For single-fidelity
def PROBLEM_SF(X: Tensor) -> Tensor:
    fidelity_level = 1.0
    s = torch.full((X.shape[0], 1), fidelity_level, 
                   dtype=X.dtype, device=X.device)
    x_with_fidelity = torch.cat([X, s], dim=-1).to(dtype=X.dtype, device=X.device)
    return PROBLEM(x_with_fidelity)

def CostFn_SF(X: Tensor) -> Tensor:
    """Fixed cost with s = 1.0 in exp(4.8 * s)."""
    exp_arg = torch.tensor(4.8, dtype=X.dtype, device=X.device)
    val = torch.exp(exp_arg * torch.ones(X.shape[:-1] + (1,), 
                                         dtype=X.dtype, device=X.device))
    return val

def _load_tensors(f: str) -> dict[int, tuple[Tensor, Tensor, Tensor | None]]:
    return torch.load(
        f,
        weights_only=False,
        map_location=tkwargs["device"],
    )

def observational_data_mf(*, seed: int) -> pd.DataFrame:
    obs_data = _load_tensors(obsdata_dir_mf)[seed] # tuples of tensors
    obs_data = torch.cat(
        [t.cpu() for t in obs_data if t is not None], dim=-1).numpy()
    return pd.DataFrame(
        obs_data,
        columns=(
            DESIGN_VAR_MF + 
            OBJECTIVES_VAR + 
            (CONSTRAINTS_VAR if CONSTRAINTS_VAR else [])
        )
    )

def observational_data_sf(*, seed: int) -> pd.DataFrame:
    obs_data = _load_tensors(obsdata_dir_sf)[seed] # tuples of tensors
    obs_data = torch.cat(
        [t.cpu() for t in obs_data if t is not None], dim=-1).numpy()
    return pd.DataFrame(
        obs_data,
        columns=(
            DESIGN_VAR_SF + 
            OBJECTIVES_VAR + 
            (CONSTRAINTS_VAR if CONSTRAINTS_VAR else [])
        )
    )

def init_sample_SF( 
    n: int | None = None, 
    *, 
    seed: int
) -> tuple[Tensor, Tensor, Tensor | None]:
    """Generate initial data for the problem."""
    return _load_tensors(initdata_dir_sf)[seed]

def init_sample_MF(
    n: int | None = None, 
    *, 
    seed: int
) -> tuple[Tensor, Tensor, Tensor | None]:
    """Generate initial data for the problem."""
    return _load_tensors(initdata_dir_mf)[seed]

BraninCurrinConfig = ExpConfig(
    # problem configs
    PROBLEM_MF=PROBLEM,
    PROBLEM_SF=PROBLEM_SF,
    HAS_CONSTRAINTS=HAS_CONSTRAINTS,
    IS_DISCRETE_FIDELITIES=IS_DISCRETE_FIDELITIES,
    FIDELITY_LEVELS=FIDELITY_LEVELS,
    DESIGN_VAR_MF=DESIGN_VAR_MF,
    DESIGN_VAR_SF=DESIGN_VAR_SF,
    OBJECTIVES_VAR=OBJECTIVES_VAR,
    CONSTRAINTS_VAR=CONSTRAINTS_VAR,
    KPI_VAR=KPI_VAR,
    FIDELITY_PARAM_NAME=FIDELITY_PARAM_NAME,
    OBJECTIVES_INDICES=OBJECTIVES_INDICES,
    CONSTRAINTS_INDICES=CONSTRAINTS_INDICES,
    TARGET_FIDELITIES=TARGET_FIDELITIES,
    BOUNDS_MF=BOUNDS_MF,
    BOUNDS_SF=BOUNDS_SF,
    REF_POINT=REF_POINT,
    REF_POINT_MOMF=REF_POINT_MOMF,
    MAX_HV=PROBLEM.max_hv,
    CostFn_MF=CostFn_MF,
    CostFn_SF=CostFn_SF,
    init_sample_MF=init_sample_MF,
    init_sample_SF=init_sample_SF,
    ## method configs
    # models
    RESCUE_model=causal_multitask_multifidelity_gp_model,
    MULTITASK_CGP_model=causal_multitask_gp_model,
    MULTITASK_GP_model=gp_model,
    # Rescue
    RESCUE_obs_data_MF=observational_data_mf,
    RESCUE_obs_data_SF=observational_data_sf,
    RESCUE_causal_net_epochs=200,
    RESCUE_causal_intervention_samples=500,
    RESCUE_causal_discovery_method="PC",
    RESCUE_use_rescue_fit=False,
    RESCUE_causal_discovery_alpha=None,
    RESCUE_fit_causal_model_after_each=None,
    RESCUE_is_design_var_independent=True,
    RESCUE_acqfn_use_causalmomf=False,
    # acquisition function
    HVKG_num_fantasies=8,
    HVKG_num_pareto=10,
    HVKG_use_posterior_mean=True,
    HVKG_acqfn_curr_val_mc_samples=None,
    HVKG_inner_mc_samples=None,
    HVKG_acqfn_curr_val_num_restarts=1,
    HVKG_num_restarts=1,
    ACQFN_MC_SAMPLES=128,
    ACQFN_RAW_SAMPLES=512,
    ACQFN_NUM_RESTARTS=10,
    OPTIM_INNER_ACQFN_OPTIONS={
        "nonnegative": True,
        "maxiter": 200
    },
    OPTIM_ACQFN_OPTIONS={
        "maxiter": 200
    },
    COMPUTE_METRICS=True,
    # Budgets
    OPTIMIZATION_BUDGET_MF=20.0,
    OPTIMIZATION_BUDGET_SF=3000.0,
    INCLUDE_INITCOST_TO_BUDGET=False,
)
