#!/usr/bin/env python3

r"""
Baseline References

.. [Daulton2020]
    Daulton, S., Balandat, M., & Bakshy, E. (2020). 
    Differentiable expected hypervolume improvement 
    for parallel multi-objective Bayesian optimization. 
    Advances in neural information processing systems, 
    33, 9851-9864.

.. [Daulton2023]
    Daulton, S., Balandat, M., & Bakshy, E. (2023, July). 
    Hypervolume knowledge gradient: a lookahead approach 
    for multi-objective Bayesian optimization with partial 
    information. In International Conference on Machine 
    Learning (pp. 7167-7204). PMLR.

.. [Irshad2024]
   Irshad, F., Karsch, S., & Döpp, A. (2024). 
   Leveraging trust for joint multi-objective 
   and multi-fidelity optimization. Machine 
   Learning: Science and Technology, 5(1), 015056.

"""

from __future__ import annotations

import os
import argparse

from experiments.config import PROBLEM_NAMES

def run_rescue(
    wandb_project: str,
    problem_name: str, 
    device: str,
    wandb_mode: str,
    seed: int
):
    # set env before importing exp_config
    os.environ["EXP_DEVICE"] = device
    os.environ["EXP_DTYPE"]  = "double"
    # future me, PROBLEM_CONFIGS requires to be imported
    # after EXP_DEVICE is set
    from experiments.problems.problems import PROBLEM_CONFIGS
    from experiments.methods.rescue import rescue
    # problem_name dispatch map
    if problem_name not in PROBLEM_CONFIGS:
        raise ValueError(f"Unknown problem_name: {problem_name}")
    exp_config = PROBLEM_CONFIGS[problem_name]
    rescue(
        wandb_project=wandb_project,
        exp_config=exp_config,
        problem_name=problem_name,
        wandb_mode=wandb_mode,
        seed=seed
    )

def run_single_fidelity_gp_qehvi(
    wandb_project: str,
    problem_name: str, 
    device: str,
    wandb_mode: str,
    seed: int
):
    # set env before importing exp_config
    os.environ["EXP_DEVICE"] = device
    os.environ["EXP_DTYPE"]  = "double"
    # future me, PROBLEM_CONFIGS requires to be imported
    # after EXP_DEVICE is set
    from experiments.problems.problems import PROBLEM_CONFIGS
    from experiments.methods.singlefidelity_gp_qehvi import single_fidelity_gp_qehvi
    # problem_name dispatch map
    if problem_name not in PROBLEM_CONFIGS:
        raise ValueError(f"Unknown problem_name: {problem_name}")
    exp_config = PROBLEM_CONFIGS[problem_name]
    single_fidelity_gp_qehvi(
        wandb_project=wandb_project,
        exp_config=exp_config,
        problem_name=problem_name,
        wandb_mode=wandb_mode,
        seed=seed
    )

def run_single_fidelity_cgp_hvkg(
    wandb_project: str,
    problem_name: str, 
    device: str,
    wandb_mode: str,
    seed: int
):
    # set env before importing exp_config
    os.environ["EXP_DEVICE"] = device
    os.environ["EXP_DTYPE"]  = "double"
    # future me, PROBLEM_CONFIGS requires to be imported
    # after EXP_DEVICE is set
    from experiments.problems.problems import PROBLEM_CONFIGS
    from experiments.methods.singlefidelity_cgp_qehvi import sf_cgp_qehvi
    # problem_name dispatch map
    if problem_name not in PROBLEM_CONFIGS:
        raise ValueError(f"Unknown problem_name: {problem_name}")
    exp_config = PROBLEM_CONFIGS[problem_name]
    sf_cgp_qehvi(
        wandb_project=wandb_project,
        exp_config=exp_config,
        problem_name=problem_name,
        wandb_mode=wandb_mode,
        seed=seed
    )

def run_multi_fidelity_momf(
    wandb_project: str,
    problem_name: str, 
    device: str,
    wandb_mode: str,
    seed: int
):
    # set env before importing exp_config
    os.environ["EXP_DEVICE"] = device
    os.environ["EXP_DTYPE"]  = "double"
    # future me, PROBLEM_CONFIGS requires to be imported
    # after EXP_DEVICE is set
    from experiments.problems.problems import PROBLEM_CONFIGS
    from experiments.methods.multifidelity_momf import mf_gp_momf
    # problem_name dispatch map
    if problem_name not in PROBLEM_CONFIGS:
        raise ValueError(f"Unknown problem_name: {problem_name}")
    exp_config = PROBLEM_CONFIGS[problem_name]
    mf_gp_momf(
        wandb_project=wandb_project,
        exp_config=exp_config,
        problem_name=problem_name,
        wandb_mode=wandb_mode,
        seed=seed
    )    

def run_multi_fidelity_hvkg(
    wandb_project: str,
    problem_name: str, 
    device: str,
    wandb_mode: str,
    seed: int
):
    # set env before importing exp_config
    os.environ["EXP_DEVICE"] = device
    os.environ["EXP_DTYPE"]  = "double"
    # future me, PROBLEM_CONFIGS requires to be imported
    # after EXP_DEVICE is set
    from experiments.problems.problems import PROBLEM_CONFIGS
    from experiments.methods.multifidelity_gp_hvkg import mf_gp_hvkg
    # problem_name dispatch map
    if problem_name not in PROBLEM_CONFIGS:
        raise ValueError(f"Unknown problem_name: {problem_name}")
    exp_config = PROBLEM_CONFIGS[problem_name]
    mf_gp_hvkg(
        wandb_project=wandb_project,
        exp_config=exp_config,
        problem_name=problem_name,
        wandb_mode=wandb_mode,
        seed=seed
    )

def run_multi_fidelity_gp_qnehvi(
    wandb_project: str,
    problem_name: str, 
    device: str,
    wandb_mode: str,
    seed: int
):
    # set env before importing exp_config
    os.environ["EXP_DEVICE"] = device
    os.environ["EXP_DTYPE"]  = "double"
    # future me, PROBLEM_CONFIGS requires to be imported
    # after EXP_DEVICE is set
    from experiments.problems.problems import PROBLEM_CONFIGS
    from experiments.methods.multifielity_gp_qnehvi import multi_fidelity_gp_qnehvi
    # problem_name dispatch map
    if problem_name not in PROBLEM_CONFIGS:
        raise ValueError(f"Unknown problem_name: {problem_name}")
    exp_config = PROBLEM_CONFIGS[problem_name]
    multi_fidelity_gp_qnehvi(
        wandb_project=wandb_project,
        exp_config=exp_config,
        problem_name=problem_name,
        wandb_mode=wandb_mode,
        seed=seed
    )

def run_multi_fidelity_gp_qehvi(
    wandb_project: str,
    problem_name: str, 
    device: str,
    wandb_mode: str,
    seed: int
):
    # set env before importing exp_config
    os.environ["EXP_DEVICE"] = device
    os.environ["EXP_DTYPE"]  = "double"
    # future me, PROBLEM_CONFIGS requires to be imported
    # after EXP_DEVICE is set
    from experiments.problems.problems import PROBLEM_CONFIGS
    from experiments.methods.multifielity_gp_qehvi import multi_fidelity_gp_qehvi
    # problem_name dispatch map
    if problem_name not in PROBLEM_CONFIGS:
        raise ValueError(f"Unknown problem_name: {problem_name}")
    exp_config = PROBLEM_CONFIGS[problem_name]
    multi_fidelity_gp_qehvi(
        wandb_project=wandb_project,
        exp_config=exp_config,
        problem_name=problem_name,
        wandb_mode=wandb_mode,
        seed=seed
    )

def main():
    method_choices = [
        "rescue", 
        "mf_gp_hvkg",
        "mf_gp_momf",
        "mf_gp_qnehvi",
        "mf_gp_qehvi",
        "sf_gp_qehvi",
        "sf_cgp_qehvi"
    ]
    problem_choices = PROBLEM_NAMES
    wandb_mode_choices = ["disabled", "online", "offline"]

    parser = argparse.ArgumentParser("experiments")
    parser.add_argument("-m", "--method", required=True, choices=method_choices)
    parser.add_argument("-p", "--problem", required=True, choices=problem_choices)
    parser.add_argument("-s", "--seed", type=int, required=True)
    parser.add_argument("-d", "--device", default="cpu")
    parser.add_argument("-wp", "--wandb_project", type=str, required=True)
    parser.add_argument("-wm", "--wandb_mode", 
                        default="disabled", 
                        choices=wandb_mode_choices
                    )
    args = parser.parse_args()

    if args.method == "rescue":
        run_rescue(
            problem_name=args.problem, 
            device=args.device,  
            seed=args.seed,
            wandb_mode=args.wandb_mode,
            wandb_project=args.wandb_project
        )

    elif args.method == "sf_gp_qehvi":
        run_single_fidelity_gp_qehvi(
            problem_name=args.problem, 
            device=args.device, 
            seed=args.seed,
            wandb_mode=args.wandb_mode,
            wandb_project=args.wandb_project
        )

    elif args.method == "sf_cgp_hvkg":
        run_single_fidelity_cgp_hvkg(
            problem_name=args.problem, 
            device=args.device, 
            seed=args.seed,
            wandb_mode=args.wandb_mode,
            wandb_project=args.wandb_project
        )

    elif args.method == "mf_gp_momf":
        run_multi_fidelity_momf(
            wandb_project=args.wandb_project,
            problem_name=args.problem,
            device=args.device,
            wandb_mode=args.wandb_mode,
            seed=args.seed
        )

    elif args.method == "mf_gp_hvkg":
        run_multi_fidelity_hvkg(
            wandb_project=args.wandb_project,
            problem_name=args.problem,
            device=args.device,
            wandb_mode=args.wandb_mode,
            seed=args.seed
        )
    elif args.method == "mf_gp_qnehvi":
        run_multi_fidelity_gp_qnehvi(
            wandb_project=args.wandb_project,
            problem_name=args.problem,
            device=args.device,
            wandb_mode=args.wandb_mode,
            seed=args.seed
        )
    elif args.method == "mf_gp_qehvi":
        run_multi_fidelity_gp_qehvi(
            wandb_project=args.wandb_project,
            problem_name=args.problem,
            device=args.device,
            wandb_mode=args.wandb_mode,
            seed=args.seed
        )
    else:
        raise ValueError(f"Unknown method {args.method}")

if __name__ == "__main__":
    main()
