# Runs REMBO
import sys
sys.path.append('..')
sys.path.append('../..')
from localglobal.baselines.REMBO.rembo import REMBO
from localglobal.test_funcs import *
from localglobal.mixed_test_func import *
from localglobal.baselines.modifiedobjfunc import ModifiedObjectiveFunc
import logging
import argparse
import os
import pickle
import pandas as pd
import time, datetime
from localglobal.test_funcs.random_seed_config import *


# Set up the objective function
parser = argparse.ArgumentParser('Run Experiments for TurBO baseline')
parser.add_argument('-p', '--problem', type=str, default='pest')
parser.add_argument('-o', '--optimizer', type=str, default='rembo')
parser.add_argument('--d_embed', type=int, default=5)
parser.add_argument('--max_iters', type=int, default=150, help='Maximum number of BO iterations.')
parser.add_argument('--lamda', type=float, default=1e-6, help='the noise to inject for some problems')
parser.add_argument('--batch_size', type=int, default=1, help='batch size for BO.')
parser.add_argument('--n_trials', type=int, default=20)
parser.add_argument('--n_init', type=int, default=20)
parser.add_argument('--save_path', type=str, default='output/baselines/rembo/')
parser.add_argument('--ard', action='store_true')
parser.add_argument('--random_seed_objective', type=int, default=20, help='The default value of 20 is provided also in'
                                                                          'COMBO')
parser.add_argument('-d', '--debug', action='store_true', help='Whether to turn on debugging mode (a lot of output will'
                                                               'be generated).')
parser.add_argument('--no_save', action='store_true', help='If activated, do not save the current run into a log folder.')
parser.add_argument('--seed', type=int, default=None, help='**initial** seed setting')
parser.add_argument('--data_dir', default='./data/')
parser.add_argument('--normalize', action='store_true')
parser.add_argument('--infer_noise_var', action='store_true')
parser.add_argument('--random_restart', default=False, action='store_true')


args = parser.parse_args()
options = vars(args)
print(options)

# Time string will be used as the directory name
time_string = datetime.datetime.now()
time_string = time_string.strftime('%Y%m%d_%H%M%S')

if args.debug:
    logging.basicConfig(level=logging.INFO)

# Create the relevant folders, and save the arguments to reproduce the experiment, etc.
if not args.no_save:
    save_path = os.path.join(args.save_path, args.problem, time_string)
    if not os.path.exists(save_path):
        os.makedirs(save_path)
    option_file = open(save_path + "/command.txt", "w+")
    option_file.write(str(options))
    option_file.close()
else:
    save_path = None


def ensure_not_1D(x):
    """
    Ensure x is not 1D (i.e. make size (D,) data into size (1,D))
    :param x: torch.Tensor
    :return:
    """
    import torch

    if x.ndim == 1:
        if isinstance(x, np.ndarray):
            x = np.expand_dims(x, axis=0)
        elif isinstance(x, torch.Tensor):
            x = x.unsqueeze(0)
    return x


for t in range(args.n_trials):
    kwargs = {}
    if args.random_seed_objective is not None:
        assert 1 <= int(args.random_seed_objective) <= 25
        args.random_seed_objective -= 1

    if args.problem == 'contamination':
        random_seed_pair_ = generate_random_seed_pair_contamination()
        case_seed_ = sorted(random_seed_pair_.keys())[int(args.random_seed_objective / 5)]
        init_seed_ = sorted(random_seed_pair_[case_seed_])[int(args.random_seed_objective % 5)]
        f = Contamination(lamda=args.lamda,
                          random_seed_pair=(case_seed_, init_seed_), normalize=False)

    elif args.problem == 'pest':
        random_seed_ = sorted(generate_random_seed_pestcontrol())[args.random_seed_objective]
        f = PestControl(random_seed=random_seed_, normalize=args.normalize)
    elif args.problem == 'pest_plus':
        random_seed_ = sorted(generate_random_seed_pestcontrol())[args.random_seed_objective]
        f = PestControlDifficult(random_seed=random_seed_, normalize=args.normalize)
    elif args.problem == 'branin':
        f = Branin(normalize=args.normalize)
    elif args.problem == 'ising':
        random_seed_pair_ = generate_random_seed_pair_ising()
        case_seed_ = sorted(random_seed_pair_.keys())[int(args.random_seed_objective / 5)]
        init_seed_ = sorted(random_seed_pair_[case_seed_])[int(args.random_seed_objective % 5)]
        f = Ising(lamda=args.lamda, random_seed_pair=(case_seed_, init_seed_), normalize=args.normalize)
    elif args.problem == 'func2C':
        f = Func2C(lamda=args.lamda,normalize=args.normalize)
    elif args.problem == 'func3C':
        f = Func3C(lamda=args.lamda, normalize=args.normalize)
    elif args.problem == 'ackley53':
        f = Ackley53(lamda=args.lamda, normalize=args.normalize)
    elif args.problem == 'ackley200':
        f = ContAckley(lamda=args.lamda, d=200)
    elif args.problem == 'ackley106':
        f = Ackley106(lamda=args.lamda)
    elif args.problem == 'rosen200':
        f = Rosen200(lamda=args.lamda)
    elif args.problem == 'xgboost-mnist':
        f = XGBoostOptTask(lamda=args.lamda, task='mnist', seed=args.seed, normalize=args.normalize)
    elif args.problem == 'svm-boston':
        f = SVMOptTask(lamda=0., task='boston', seed=args.seed, normalize=args.normalize)
    elif args.problem == 'nasbench101':
        try:
            f = pickle.load(open(args.data_dir + 'nasbench101.pickle', 'rb'))
        except:
            f = NASBench101(data_dir=args.data_dir, )
    elif args.problem == 'MaxSAT60':
        f = MaxSAT60()
    else:
        raise ValueError('Unrecognised problem type %s' % args.problem)

    f = ModifiedObjectiveFunc(f)
    boundaries = np.array([f.lb, f.ub]).T.tolist()

    print('----- Starting trial %d / %d -----' % ((t + 1), args.n_trials))
    res = pd.DataFrame(np.nan, index=np.arange(2000), columns=['Index', 'LastValue', 'BestValue', 'Time'])
    if args.infer_noise_var:
        noise_variance = None
    else:
        noise_variance = f.lamda if hasattr(f, 'lamda') else None

    if hasattr(f, 'int_constrained_dims'):
        int_dims = f.int_constrained_dims
    else:
        int_dims = None

    if args.optimizer == 'rembo':
        bo = REMBO(d_embedding=args.d_embed, maxf=100, original_boundaries=boundaries)
    else:
        raise NotImplementedError(args.optimizer + ' is not implemented as a valid optimizer choice!')

    for i in range(args.max_iters):
        try:
            X_queries, X_queries_embedded = bo.select_query_point(batch_size=args.batch_size)
        except: # usually due to matrix not being psd
            print('Error occurred. Terminating current trial!')
            break

        # Ensure not 1D (i.e. size (D,))
        X_queries = ensure_not_1D(X_queries)

        # Evaluate the batch of query points 1-by-1
        for row_idx in range(len(X_queries)):
            X_query = X_queries[row_idx]
            X_query_embedded = X_queries_embedded[row_idx]

            # Ensure no 1D tensors (i.e. expand tensors of size (D,))
            X_query = ensure_not_1D(X_query)
            X_query_embedded = ensure_not_1D(X_query_embedded)

            # minimisation problem
            y_query = -f(X_query)
            bo.update(X_query, y_query[0], X_query_embedded)
        print(i, bo.y[-1])

        fX = bo.y.numpy().squeeze()
        if i > 1:
            res.iloc[:fX.shape[0], 1] = fX
            for j in range(fX.shape[0]):
                res.iloc[j, 2] = np.min(fX[:j+1])
            if save_path is not None:
                pickle.dump(res, open(os.path.join(save_path, 'trial-%d.pickle' % t), 'wb'))

    if args.seed is not None:
        args.seed += 1



