# Xingchen Wan <xwan@robots.ox.ac.uk> | 2021
# Study the hyperparameter sensitivity of the method on selected problems.

from localglobal.test_funcs import *
from localglobal.mixed_test_func import *
import numpy as np
import random
from localglobal.bo.optimizer import TurboOptimizer, TurboMOptimizer
from localglobal.bo.optimizer_mixed import MixedTurboOptimizer
import logging
import argparse
import os
import pickle
import pandas as pd
import time, datetime
from localglobal.test_funcs.random_seed_config import *

# Set up the objective function
parser = argparse.ArgumentParser('Run Sensitivity Studies')
parser.add_argument('-p', '--problem', type=str, default='pest')
parser.add_argument('--max_iters', type=int, default=150, help='Maximum number of BO iterations.')
parser.add_argument('--lamda', type=float, default=1e-6, help='the noise to inject for some problems')
parser.add_argument('--batch_size', type=int, default=1, help='batch size for BO.')
parser.add_argument('--n_trials', type=int, default=5)
parser.add_argument('--n_init', type=int, default=20)
parser.add_argument('--save_path', type=str, default='output/sensitivity/')
parser.add_argument('--ard', action='store_true')
parser.add_argument('-a', '--acq', type=str, default='ei', help='choice of the acquisition function.')
parser.add_argument('--random_seed_objective', type=int, default=20,
                    help='The default value of 20 is provided also in COMBO')
parser.add_argument('-d', '--debug', action='store_true', help='Whether to turn on debugging mode (a lot of output will'
                                                               'be generated).')
parser.add_argument('--no_save', action='store_true',
                    help='If activated, do not save the current run into a log folder.')
parser.add_argument('--seed', type=int, default=None, help='**initial** seed setting')
parser.add_argument('--data_dir', default='./data/')
parser.add_argument('--infer_noise_var', action='store_true')
parser.add_argument('--grid_search', action='store_true')

args = parser.parse_args()
options = vars(args)
print(options)

if args.debug:
    logging.basicConfig(level=logging.INFO)

# Sanity checks
assert args.acq in ['ucb', 'ei', 'thompson'], 'Unknown acquisition function choice ' + str(args.acq)


def run_exp(problem, kwargs, save_path):
    """Sensitivity study on the initial lengthscale"""
    for t in range(args.n_trials):
        if problem == 'pest':
            random_seed_ = sorted(generate_random_seed_pestcontrol())[args.random_seed_objective]
            f = PestControl(random_seed=random_seed_)
        else:
            f = Ackley53(lamda=args.lamda)
        print('----- Starting trial %d / %d -----' % ((t + 1), args.n_trials))
        res = pd.DataFrame(np.nan, index=np.arange(int(args.max_iters * args.batch_size)),
                           columns=['Index', 'LastValue', 'BestValue', 'Time', 'TRCat', 'TRCont'])
        if args.infer_noise_var:
            noise_variance = None
        else:
            noise_variance = f.lamda if hasattr(f, 'lamda') else None
        kernel_type = 'type2' if problem == 'pest' else 'mixed'
        if problem == 'pest':
            turbo = TurboOptimizer(f.config, n_init=args.n_init, use_ard=args.ard, acq=args.acq,
                                   kernel_type=kernel_type,
                                   noise_variance=noise_variance, **kwargs)
        else:
            turbo = MixedTurboOptimizer(f.config, f.lb, f.ub, f.continuous_dims, f.categorical_dims,
                                        n_init=args.n_init, use_ard=args.ard, acq=args.acq,
                                        kernel_type=kernel_type,
                                        noise_variance=noise_variance,
                                        **kwargs)
        for i in range(args.max_iters):
            start = time.time()
            x_next = turbo.suggest(args.batch_size)
            y_next = f.compute(x_next, normalize=f.normalize)
            turbo.observe(x_next, y_next)
            end = time.time()
            if f.normalize:
                Y = np.array(turbo.turbo.fX) * f.std + f.mean
            else:
                Y = np.array(turbo.turbo.fX)
            if Y[:i].shape[0]:
                # sequential
                if args.batch_size == 1:
                    res.iloc[i, :] = [i, float(Y[-1]), float(np.min(Y[:i])), end - start, turbo.turbo.length_discrete,
                                      turbo.turbo.length]
                # batch
                else:
                    for idx, j in enumerate(range(i * args.batch_size, (i + 1) * args.batch_size)):
                        res.iloc[j, :] = [j, float(Y[-idx]), float(np.min(Y[:i * args.batch_size])), end - start,
                                          turbo.turbo.length_discrete, turbo.turbo.length]
                # x_next = x_next.astype(int)
                argmin = np.argmin(Y[:i * args.batch_size])

                print('Iter %d, Last X %s; fX:  %.4f. X_best: %s, fX_best: %.4f'
                      % (i, x_next.flatten(),
                         float(Y[-1]),
                         ''.join([str(int(i)) for i in turbo.turbo.X[:i * args.batch_size][argmin].flatten()]),
                         Y[:i * args.batch_size][argmin]))
                # print(bo.bo.length_discrete)
            if save_path is not None:
                pickle.dump(res, open(os.path.join(save_path, 'trial-%d.pickle' % t), 'wb'))


def generate_save_path(func_args):
    # Generate the file save path.
    if not args.no_save:
        time_string = datetime.datetime.now()
        time_string = time_string.strftime('%Y%m%d_%H%M%S')
        save_path = os.path.join(args.save_path, args.problem, time_string)
        if not os.path.exists(save_path):
            os.makedirs(save_path)
        option_file = open(save_path + "/command.txt", "w+")
        curr_option = dict(options)
        curr_option.update(func_args)
        option_file.write(str(curr_option))
        option_file.close()
        return save_path
    else:
        return None


if __name__ == '__main__':
    # First conduct the sensitivity experiment on initial length
    def cross_search():
        # Search on the multiplier value
        args.n_trials = 1
        # --- 1. Run sensitivity on the initial TR length ---
        # Adjust the number of iterations.
        args.max_iters = 200
        keyword_args = [
            {'length_max_discrete': 25,  # this is fixed to the maximum dim
             'length_init_discrete': 25,
             'tr_multiplier' : i,
             'failtol': 40} for i in [1.2, 1.5, 2, 2.5]
        ]
        for idx, kwgs in enumerate(keyword_args):
            run_exp('pest', kwgs, generate_save_path(kwgs))

        # --- 3. Run sensitivity on the initial TR length on Ackley53 ---
        # Adjust the number of maximum iterations
        args.max_iters = 400
        keyword_args = [
            {'length_max_discrete': 50,  # this is fixed to the maximum dim
             'length_init_discrete': 30,
             'tr_multiplier': i,
             'failtol': 40} for i in [1.2, 1.5, 2, 2.5]
        ]
        for idx, kwgs in enumerate(keyword_args):
            run_exp('ackley53', kwgs, generate_save_path(kwgs))


    def grid_search():
        # --- 1. Run sensitivity on the initial TR length ---
        # Adjust the number of iterations.
        args.max_iters = 200
        args.n_trials = 1

        for length_init in [5, 10, 15, 20, 25]:
            for fail_tol in [10, 25, 40, 55, 70]:
                kwgs = {'length_max_discrete': 25,  # this is fixed to the maximum dim
                        'length_init_discrete': length_init,
                        'failtol': fail_tol}
                run_exp('pest', kwgs, generate_save_path(kwgs))

                # Adjust the number of maximum iterations
        args.max_iters = 400
        for length_init in [10, 20, 30, 40, 50]:
            for fail_tol in [20, 30, 40, 50, 60]:
                kwgs = {'length_max_discrete': 25,  # this is fixed to the maximum dim
                        'length_init_discrete': length_init,
                        'failtol': fail_tol}
                run_exp('ackley53', kwgs, generate_save_path(kwgs))

    args.problem = 'sensitivity'
    if args.grid_search:
        grid_search()
    else:
        cross_search()
