# Runs ALEBO and HESBO in the Ax platform
# Runs REMBO
import sys
sys.path.append('..')
sys.path.append('../..')

from localglobal.test_funcs import *
from localglobal.mixed_test_func import *
from localglobal.baselines.modifiedobjfunc import ModifiedObjectiveFunc
from ax.modelbridge.strategies.alebo import ALEBOStrategy
from ax.service.managed_loop import optimize

import logging
import argparse
import os
import pickle
import pandas as pd
import time, datetime
from localglobal.test_funcs.random_seed_config import *


# Set up the objective function
parser = argparse.ArgumentParser('Run Experiments for TurBO baseline')
parser.add_argument('-p', '--problem', type=str, default='pest')
parser.add_argument('-o', '--optimizer', type=str, default='alebo')
parser.add_argument('--d_embed', type=int, default=5)
parser.add_argument('--max_iters', type=int, default=150, help='Maximum number of BO iterations.')
parser.add_argument('--lamda', type=float, default=1e-6, help='the noise to inject for some problems')
parser.add_argument('--batch_size', type=int, default=1, help='batch size for BO.')
parser.add_argument('--n_trials', type=int, default=20)
parser.add_argument('--n_init', type=int, default=20)
parser.add_argument('--save_path', type=str, default='output/baselines/alebo/')
parser.add_argument('--ard', action='store_true')
parser.add_argument('--random_seed_objective', type=int, default=20, help='The default value of 20 is provided also in'
                                                                          'COMBO')
parser.add_argument('-d', '--debug', action='store_true', help='Whether to turn on debugging mode (a lot of output will'
                                                               'be generated).')
parser.add_argument('--no_save', action='store_true', help='If activated, do not save the current run into a log folder.')
parser.add_argument('--seed', type=int, default=None, help='**initial** seed setting')
parser.add_argument('--data_dir', default='./data/')
parser.add_argument('--normalize', action='store_true')
parser.add_argument('--infer_noise_var', action='store_true')
parser.add_argument('--random_restart', default=False, action='store_true')


args = parser.parse_args()
options = vars(args)
print(options)

# Time string will be used as the directory name
time_string = datetime.datetime.now()
time_string = time_string.strftime('%Y%m%d_%H%M%S')

if args.debug:
    logging.basicConfig(level=logging.INFO)

# Create the relevant folders, and save the arguments to reproduce the experiment, etc.
if not args.no_save:
    save_path = os.path.join(args.save_path, args.problem, time_string)
    if not os.path.exists(save_path):
        os.makedirs(save_path)
    option_file = open(save_path + "/command.txt", "w+")
    option_file.write(str(options))
    option_file.close()
else:
    save_path = None


for t in range(args.n_trials):
    kwargs = {}
    if args.random_seed_objective is not None:
        assert 1 <= int(args.random_seed_objective) <= 25
        args.random_seed_objective -= 1

    if args.problem == 'contamination':
        random_seed_pair_ = generate_random_seed_pair_contamination()
        case_seed_ = sorted(random_seed_pair_.keys())[int(args.random_seed_objective / 5)]
        init_seed_ = sorted(random_seed_pair_[case_seed_])[int(args.random_seed_objective % 5)]
        f = Contamination(lamda=args.lamda,
                          random_seed_pair=(case_seed_, init_seed_), normalize=False)

    elif args.problem == 'pest':
        random_seed_ = sorted(generate_random_seed_pestcontrol())[args.random_seed_objective]
        f = PestControl(random_seed=random_seed_, normalize=args.normalize)
    elif args.problem == 'pest_plus':
        random_seed_ = sorted(generate_random_seed_pestcontrol())[args.random_seed_objective]
        f = PestControlDifficult(random_seed=random_seed_, normalize=args.normalize)
    elif args.problem == 'branin':
        f = Branin(normalize=args.normalize)
    elif args.problem == 'ising':
        random_seed_pair_ = generate_random_seed_pair_ising()
        case_seed_ = sorted(random_seed_pair_.keys())[int(args.random_seed_objective / 5)]
        init_seed_ = sorted(random_seed_pair_[case_seed_])[int(args.random_seed_objective % 5)]
        f = Ising(lamda=args.lamda, random_seed_pair=(case_seed_, init_seed_), normalize=args.normalize)
    elif args.problem == 'func2C':
        f = Func2C(lamda=args.lamda,normalize=args.normalize)
    elif args.problem == 'func3C':
        f = Func3C(lamda=args.lamda, normalize=args.normalize)
    elif args.problem == 'ackley53':
        f = Ackley53(lamda=args.lamda, normalize=args.normalize)
    elif args.problem == 'ackley200':
        f = ContAckley(lamda=args.lamda, d=200)
    elif args.problem == 'ackley106':
        f = Ackley106(lamda=args.lamda)
    elif args.problem == 'rosen200':
        f = Rosen200(lamda=args.lamda)
    elif args.problem == 'xgboost-mnist':
        f = XGBoostOptTask(lamda=args.lamda, task='mnist', seed=args.seed, normalize=args.normalize)
    elif args.problem == 'svm-boston':
        f = SVMOptTask(lamda=0., task='boston', seed=args.seed, normalize=args.normalize)
    elif args.problem == 'nasbench101':
        try:
            f = pickle.load(open(args.data_dir + 'nasbench101.pickle', 'rb'))
        except:
            f = NASBench101(data_dir=args.data_dir, )
    elif args.problem == 'MaxSAT60':
        f = MaxSAT60()
    else:
        raise ValueError('Unrecognised problem type %s' % args.problem)

    f = ModifiedObjectiveFunc(f)
    params = f.parse_bound()


    def objective_func(parameterisation):
        """Convert the objective function class we implement into a format understood by ALEBO"""
        x = np.array([parameterisation[f'x_{i}'] for i in range(f.one_hot_dims)])
        return {'objective': (f(x), 0.0)}

    print('----- Starting trial %d / %d -----' % ((t + 1), args.n_trials))
    res = pd.DataFrame(np.nan, index=np.arange(2000), columns=['Index', 'LastValue', 'BestValue', 'Time'])
    if args.infer_noise_var:
        noise_variance = None
    else:
        noise_variance = f.lamda if hasattr(f, 'lamda') else None

    if args.optimizer == 'alebo':
        bo = ALEBOStrategy(D=f.one_hot_dims, d=args.d_embed, init_size=args.n_init)

    else:
        raise NotImplementedError(args.optimizer + ' is not implemented as a valid optimizer choice!')

    best_parameters, values, experiment, model = optimize(
        parameters=params,
        experiment_name=type(f).__name__,
        objective_name="objective",
        evaluation_function=objective_func,
        minimize=True,
        total_trials=args.max_iters,
        generation_strategy=bo,
    )

    fX = objectives = np.array([trial.objective_mean for trial in experiment.trials.values()])

    res.iloc[:fX.shape[0], 1] = fX
    for i in range(fX.shape[0]):
        res.iloc[i, 2] = np.min(fX[:i+1])
    if save_path is not None:
        pickle.dump(res, open(os.path.join(save_path, 'trial-%d.pickle' % t), 'wb'))
    if args.seed is not None:
        args.seed += 1



