# Runs the standard bo and bo in our problems
import sys
sys.path.append('..')
sys.path.append('../..')
from localglobal.baselines.TuRBO.turbo import Turbo1, TurboM
from localglobal.baselines.vanilla_bo import BO
from localglobal.test_funcs import *
from localglobal.mixed_test_func import *
from localglobal.baselines.modifiedobjfunc import ModifiedObjectiveFunc
import logging
import argparse
import os
import pickle
import pandas as pd
import time, datetime
from localglobal.test_funcs.random_seed_config import *


# Set up the objective function
parser = argparse.ArgumentParser('Run Experiments for TurBO baseline')
parser.add_argument('-p', '--problem', type=str, default='pest')
parser.add_argument('-o', '--optimizer', type=str, default='bo')
parser.add_argument('-n', '--n_trust_regions', type=int, default=1)
parser.add_argument('--max_iters', type=int, default=150, help='Maximum number of BO iterations.')
parser.add_argument('--lamda', type=float, default=1e-6, help='the noise to inject for some problems')
parser.add_argument('--batch_size', type=int, default=1, help='batch size for BO.')
parser.add_argument('--n_trials', type=int, default=20)
parser.add_argument('--n_init', type=int, default=20)
parser.add_argument('--save_path', type=str, default='output/baselines/bo/')
parser.add_argument('--ard', action='store_true')
parser.add_argument('--random_seed_objective', type=int, default=20, help='The default value of 20 is provided also in'
                                                                          'COMBO')
parser.add_argument('-d', '--debug', action='store_true', help='Whether to turn on debugging mode (a lot of output will'
                                                               'be generated).')
parser.add_argument('--no_save', action='store_true', help='If activated, do not save the current run into a log folder.')
parser.add_argument('--seed', type=int, default=None, help='**initial** seed setting')
parser.add_argument('--data_dir', default='./data/')
parser.add_argument('--normalize', action='store_true')
parser.add_argument('--infer_noise_var', action='store_true')
parser.add_argument('--random_restart', default=False, action='store_true')


args = parser.parse_args()
options = vars(args)
print(options)

if not args.random_restart and args.n_trust_regions > 1:
    logging.warning('Guided restart is yet implemented for TuRBO-M. The random_restart flag is ignored!')


# Time string will be used as the directory name
time_string = datetime.datetime.now()
time_string = time_string.strftime('%Y%m%d_%H%M%S')

if args.debug:
    logging.basicConfig(level=logging.INFO)

# Create the relevant folders, and save the arguments to reproduce the experiment, etc.
if not args.no_save:
    save_path = os.path.join(args.save_path, args.problem, time_string)
    if not os.path.exists(save_path):
        os.makedirs(save_path)
    option_file = open(save_path + "/command.txt", "w+")
    option_file.write(str(options))
    option_file.close()
else:
    save_path = None


for t in range(args.n_trials):
    kwargs = {}
    if args.random_seed_objective is not None:
        assert 1 <= int(args.random_seed_objective) <= 25
        args.random_seed_objective -= 1

    if args.problem == 'contamination':
        random_seed_pair_ = generate_random_seed_pair_contamination()
        case_seed_ = sorted(random_seed_pair_.keys())[int(args.random_seed_objective / 5)]
        init_seed_ = sorted(random_seed_pair_[case_seed_])[int(args.random_seed_objective % 5)]
        f = Contamination(lamda=args.lamda,
                          random_seed_pair=(case_seed_, init_seed_), normalize=False)

    elif args.problem == 'pest':
        random_seed_ = sorted(generate_random_seed_pestcontrol())[args.random_seed_objective]
        f = PestControl(random_seed=random_seed_, normalize=args.normalize)
    elif args.problem == 'pest_plus':
        random_seed_ = sorted(generate_random_seed_pestcontrol())[args.random_seed_objective]
        f = PestControlDifficult(random_seed=random_seed_, normalize=args.normalize)
    elif args.problem == 'branin':
        f = Branin(normalize=args.normalize)
    elif args.problem == 'ising':
        random_seed_pair_ = generate_random_seed_pair_ising()
        case_seed_ = sorted(random_seed_pair_.keys())[int(args.random_seed_objective / 5)]
        init_seed_ = sorted(random_seed_pair_[case_seed_])[int(args.random_seed_objective % 5)]
        f = Ising(lamda=args.lamda, random_seed_pair=(case_seed_, init_seed_), normalize=args.normalize)
    elif args.problem == 'func2C':
        f = Func2C(lamda=args.lamda,normalize=args.normalize)
    elif args.problem == 'func3C':
        f = Func3C(lamda=args.lamda, normalize=args.normalize)
    elif args.problem == 'ackley53':
        f = Ackley53(lamda=args.lamda, normalize=args.normalize)
    elif args.problem == 'ackley200':
        f = ContAckley(lamda=args.lamda, d=200)
    elif args.problem == 'ackley106':
        f = Ackley106(lamda=args.lamda)
    elif args.problem == 'rosen200':
        f = Rosen200(lamda=args.lamda)
    elif args.problem == 'xgboost-mnist':
        f = XGBoostOptTask(lamda=args.lamda, task='mnist', seed=args.seed, normalize=args.normalize)
    elif args.problem == 'svm-boston':
        f = SVMOptTask(lamda=0., task='boston', seed=args.seed, normalize=args.normalize)
    elif args.problem == 'nasbench101':
        try:
            f = pickle.load(open(args.data_dir + 'nasbench101.pickle', 'rb'))
        except:
            f = NASBench101(data_dir=args.data_dir, )
    elif args.problem == 'MaxSAT60':
        f = MaxSAT60()
    else:
        raise ValueError('Unrecognised problem type %s' % args.problem)

    f = ModifiedObjectiveFunc(f)

    print('----- Starting trial %d / %d -----' % ((t + 1), args.n_trials))
    res = pd.DataFrame(np.nan, index=np.arange(2000), columns=['Index', 'LastValue', 'BestValue', 'Time'])
    if args.infer_noise_var:
        noise_variance = None
    else:
        noise_variance = f.lamda if hasattr(f, 'lamda') else None

    if hasattr(f, 'int_constrained_dims'):
        int_dims = f.int_constrained_dims
    else:
        int_dims = None

    if args.optimizer == 'turbo':
        if args.n_trust_regions == 1:
            bo = Turbo1(
                f=f,
                lb=f.lb,
                ub=f.ub,
                n_init=args.n_init,
                max_evals=args.max_iters,
                use_ard=args.ard,
                verbose=True,
                guided_restart=not args.random_restart,
                **kwargs)
        else:
            bo = TurboM(
                f=f,
                lb=f.lb,
                ub=f.ub,
                n_init=args.n_init,
                max_evals=args.max_iters,
                verbose=True,
                use_ard=args.ard, **kwargs
            )
    elif args.optimizer == 'bo':
        bo = BO(
            f=f,
            lb=f.lb,
            ub=f.ub,
            n_init=args.n_init,
            max_evals=args.max_iters,
            use_ard=args.ard,
            verbose=True,
            **kwargs
        )
    else:
        raise NotImplementedError(args.optimizer + ' is not implemented as a valid optimizer choice!')
    bo.optimize()
    if f.normalize:
        fX = (bo.fX * f.std) + f.mean
    else:
        fX = bo.fX
    res.iloc[:fX.shape[0], 1] = fX
    for i in range(fX.shape[0]):
        res.iloc[i, 2] = np.min(fX[:i+1])
    if save_path is not None:
        pickle.dump(res, open(os.path.join(save_path, 'trial-%d.pickle' % t), 'wb'))
    if args.seed is not None:
        args.seed += 1



