from space_gen import run_mcts_transfer, run_lamcts
import argparse
import os
from generate_dataset import data_gen, data_gen_Sphere2D
import functions
from meta_learning_with_rgpe import run_rgpe
from FSBO.fsbo_metatrain import run_fsbo2
from FSBO.fsbo_test import fsbo_test
from transformer_pretrain import train
from transformer_test import transformer
from origin_bo import simpleBO
os.environ["WANDB_MODE"] = "offline"

parser = argparse.ArgumentParser()
parser.add_argument('--mode', type=str, default='hpob', choices=['hpob', 'bbob', 'Sphere2D', 'real'])
parser.add_argument('--search-space-id', type=str, default='4796')
parser.add_argument('--dataset-id', type=str, default='3549')
parser.add_argument('--methods', type=str, default='mcts-transfer', choices=["mcts-transfer", "lamcts", "mcts-transfer-with-weights", "mcts-transfer-exponential", "RGPE", "FSBO", "FSBO-test", "pretrain","transformer", "transfer-transformer", "mcts-initialization-transformer", "GP-EI"])
parser.add_argument('--dims', type=int, default=10)
parser.add_argument('--iteration', type=int, default=100)
parser.add_argument('--rep', type=int, default=3)
parser.add_argument('--similar', type=str, choices=["similar", "unsimilar", "combine", "mix-similar", 'mix-unsimilar', 'mix-both'])
parser.add_argument('--weight-update', type=str, default='linear-half', choices=["all-one", "linear-half", "exponential"])
parser.add_argument('--Cp', type=float, default=1.0)
parser.add_argument('--weight-decay', action='store_true')
parser.add_argument('--gamma', type=float, default=0.99)
parser.add_argument('--data-gen', action='store_true')
parser.add_argument('--threshold', type=int, default=10)
parser.add_argument('--kernel-type', type=str, default='rbf')
parser.add_argument('--local', action='store_true')
parser.add_argument('--similarity', type=str, default='optimal', choices=["optimal", "topN", "Npercent","distribution", "KL"])
parser.add_argument('--N', type=float, default=1.0)
parser.add_argument('--alpha', type=float, default=0.5)

args = parser.parse_args()
this_dir = os.path.abspath(os.path.dirname(__file__))
model_path = f"{this_dir}/lamcts/model_{args.mode}/{args.kernel_type}/"

if args.data_gen:
    data_gen(args)
    # data_gen_Sphere2D(args)
else:
    run_dict = {
        "lamcts": run_lamcts,
        "mcts-transfer": run_mcts_transfer,
        "mcts-transfer-with-weights": run_mcts_transfer,
        "mcts-transfer-exponential": run_mcts_transfer,
        "transfer-transformer": run_mcts_transfer,
        "mcts-initialization-transformer": run_mcts_transfer,
        "RGPE": run_rgpe,
        "FSBO": run_fsbo2,
        "FSBO-test": fsbo_test,
        "pretrain": train,
        "transformer": transformer,
        "GP-EI": simpleBO,
    }

    if args.methods == "mcts-transfer-with-weights":
        args.weight_update = "linear-half"
    elif args.methods == "mcts-transfer":
        args.weight_update = 'all-one'
    elif args.methods == "mcts-transfer-exponential":
        args.weight_update = "exponential"
    elif args.methods == "transfer-transformer":
        args.weight_update = "linear-half"
    elif args.methods == "mcts-initialization-transformer":
        args.weight_update = "linear-half"
        
    if args.mode == 'bbob':
        args.dims = 10
    elif args.mode == 'real':
        dims_functions = {
            "LunarLander": functions.lunar_lander.get_dim(),
            "RobotPush": functions.push_function.get_dim(),
            "Rover": functions.rover_function.get_dim()
        }
        args.dims = dims_functions[args.search_space_id]
    
    if args.similarity == "optimal":
        args.N=1.0
        
    run_dict[args.methods](args)
    
    
    