#!/usr/bin/env python3

from __future__ import annotations

import os
from functools import partial
import torch
from gpytorch.utils.warnings import NumericalWarning
from rescue.algorithms.multi_fidelity import RescueAlgorithmMultifidelity
from rescue.utils.setup_wandb import wandb_init
from experiments.config import ExpConfig

import warnings
warnings.filterwarnings("ignore", category=NumericalWarning)

tkwargs={
    "device": torch.device(os.getenv("EXP_DEVICE", "cpu")),
    "dtype": getattr(torch, os.getenv("EXP_DTYPE", "double")),
}

if torch.cuda.is_available():
    torch.cuda.empty_cache()


METHOD_NAME = "rescue"
WANDB_TAGS = ["multi_fidelity", "rescue", "cgp"]

def rescue(
    exp_config: ExpConfig,
    problem_name: str,
    seed: int,
    wandb_mode: str = "disabled",
    wandb_project: str = "rescue"
) -> None:

    c = exp_config
    # For logging the stats
    WANDB_RUN = wandb_init(
                    mode=wandb_mode,
                    project=wandb_project,
                    exp_name=f"{problem_name}_{METHOD_NAME}",
                    wandb_id=None,
                    run_id=seed,
                    tags=WANDB_TAGS,
                )
    rescue_algorithm = RescueAlgorithmMultifidelity(
        problem=c.PROBLEM_MF,
        custom_model=c.RESCUE_model,
        gen_initial_data=partial(c.init_sample_MF, seed=seed),
        design_variables=c.DESIGN_VAR_MF,
        objective_variables=c.OBJECTIVES_VAR,
        constraint_variables=c.CONSTRAINTS_VAR,
        kpi_variables=c.KPI_VAR,
        target_fidelities=c.TARGET_FIDELITIES,
        fidelity_param_name=c.FIDELITY_PARAM_NAME,
        bounds=c.BOUNDS_MF,
        cost_fn=c.CostFn_MF,
        is_discrete_fidelities=c.IS_DISCRETE_FIDELITIES,
        fidelity_levels=c.FIDELITY_LEVELS,
        wandb_run=WANDB_RUN,
        **tkwargs
    )
    res = rescue_algorithm.run(
            budget=c.OPTIMIZATION_BUDGET_MF,
            include_initcost_to_budget=c.INCLUDE_INITCOST_TO_BUDGET,
            ref_point=c.REF_POINT,
            objective_indices=c.OBJECTIVES_INDICES,
            constraints_indices=c.CONSTRAINTS_INDICES,
            max_hv=c.MAX_HV,
            causal_save_graph=False,
            use_rescue_fit=c.RESCUE_use_rescue_fit,
            causal_observational_data=c.RESCUE_obs_data_MF(seed=seed),
            causal_discovery=c.RESCUE_causal_discovery_method,
            causal_intervention_samples=c.RESCUE_causal_intervention_samples,
            causal_net_epochs=c.RESCUE_causal_net_epochs,
            causal_discovery_alpha=c.RESCUE_causal_discovery_alpha,
            causal_is_design_var_independent=c.RESCUE_is_design_var_independent,
            fit_causal_model_after_each=c.RESCUE_fit_causal_model_after_each,
            acqfn_use_causalmomf=c.RESCUE_acqfn_use_causalmomf,
            acqfn_causalmomf_mc_samples=c.ACQFN_MC_SAMPLES,
            acqfn_raw_samples=c.ACQFN_RAW_SAMPLES,
            acqfn_num_restarts=c.ACQFN_NUM_RESTARTS,
            acqfn_curr_val_mc_samples=c.HVKG_acqfn_curr_val_mc_samples,
            acqfn_inner_mc_samples=c.HVKG_inner_mc_samples,
            acqfn_curr_val_raw_samples=2 * c.ACQFN_RAW_SAMPLES,
            acqfn_curr_val_num_restarts=c.HVKG_acqfn_curr_val_num_restarts,
            acqfn_num_fantasies=c.HVKG_num_fantasies,
            acqfn_num_pareto=c.HVKG_num_pareto,
            optim_acqfn_options=c.OPTIM_ACQFN_OPTIONS,
            optim_inner_acqfn_options=c.OPTIM_INNER_ACQFN_OPTIONS,
            compute_metrics=c.COMPUTE_METRICS,
            show_status=True,
            seed=seed,
    )