# Xingchen Wan <xwan@robots.ox.ac.uk> | 2020

from localglobal.test_funcs import *
from localglobal.mixed_test_func import *
import numpy as np
import random
from localglobal.bo.optimizer import TurboOptimizer, TurboMOptimizer
from localglobal.bo.optimizer_mixed import MixedTurboOptimizer
import logging
import argparse
import os
import pickle
import pandas as pd
import time, datetime
from localglobal.test_funcs.random_seed_config import *
from localglobal.mixed_test_func.offline_rl import uncert_types
from localglobal.parse_string_map import parse_string_map

# Set up the objective function
parser = argparse.ArgumentParser('Run Experiments')
parser.add_argument('-p', '--problem', type=str, default='offline_rl')
parser.add_argument('-n', '--n_trust_regions', type=int, default=1)
parser.add_argument('--max_iters', type=int, default=200, help='Maximum number of BO iterations.')  # PER TRIAL!
parser.add_argument('--lamda', type=float, default=1e-6, help='the noise to inject for some problems')
parser.add_argument('--batch_size', type=int, default=4, help='batch size for BO.')
parser.add_argument('--n_trials', type=int, default=50)
parser.add_argument('--n_init', type=int, default=20)
parser.add_argument('--save_path', type=str, default='output/')
parser.add_argument('--ard', action='store_true', default=False)
parser.add_argument('-a', '--acq', type=str, default='thompson', help='choice of the acquisition function.')
parser.add_argument('--random_seed_objective', type=int, default=20,
                    help='The default value of 20 is provided also in COMBO')
parser.add_argument('-d', '--debug', action='store_true', help='Whether to turn on debugging mode (a lot of output will'
                                                               'be generated).')
parser.add_argument('--no_save', action='store_true',
                    help='If activated, do not save the current run into a log folder.')
parser.add_argument('--seed', type=int, default=None, help='**initial** seed setting')
parser.add_argument('--global_bo', action='store_true',
                    help='whether to use global BO modelling only (disabling the local BO modelling)')
parser.add_argument('--random_restart', action='store_true',
                    help='whether to use random restarting strategy (vs UCB strategy)')
parser.add_argument('-k', '--kernel_type', type=str, default=None, help='specifies the kernel type')
parser.add_argument('--data_dir', default='./data/')
parser.add_argument('--infer_noise_var', action='store_true')
#
parser.add_argument('--offline_rl_epochs', type=int, default=500)
parser.add_argument('--offline_rl_yaml', type=str, default=r"args_yml/bo_test_rig_hcmed.yml")

args = parser.parse_args()
options = vars(args)
print(options)

# Time string will be used as the directory name
time_string = datetime.datetime.now()
time_string = time_string.strftime('%Y%m%d_%H%M%S')

if args.debug:
    logging.basicConfig(level=logging.INFO)

# Sanity checks
assert args.n_trust_regions >= 1
assert args.acq in ['ucb', 'ei', 'thompson'], 'Unknown acquisition function choice ' + str(args.acq)

# Create the relevant folders, and save the arguments to reproduce the experiment, etc.
if not args.no_save:
    save_path = os.path.join(args.save_path, args.problem, time_string)
    if not os.path.exists(save_path):
        os.makedirs(save_path)
    option_file = open(save_path + "/command.txt", "w+")
    option_file.write(str(options))
    option_file.close()
else:
    save_path = None

for t in range(args.n_trials):

    kwargs = {}
    if args.random_seed_objective is not None:
        assert 1 <= int(args.random_seed_objective) <= 25
        args.random_seed_objective -= 1

    if args.problem == 'contamination':
        random_seed_pair_ = generate_random_seed_pair_contamination()
        case_seed_ = sorted(random_seed_pair_.keys())[int(args.random_seed_objective / 5)]
        init_seed_ = sorted(random_seed_pair_[case_seed_])[int(args.random_seed_objective % 5)]
        f = Contamination(lamda=args.lamda,
                          random_seed_pair=(case_seed_, init_seed_))
        kwargs = {
            'length_max_discrete': 25,
            'length_init_discrete': 5,
            'failtol': 40
        }

    elif args.problem == 'pest':
        random_seed_ = sorted(generate_random_seed_pestcontrol())[args.random_seed_objective]
        f = PestControl(random_seed=random_seed_)
        kwargs = {
            'length_max_discrete': 25,
            'length_init_discrete': 25,
            'failtol': 50,
        }
    elif args.problem == 'pest_plus':
        random_seed_ = sorted(generate_random_seed_pestcontrol())[args.random_seed_objective]
        f = PestControlDifficult(random_seed=random_seed_)
        kwargs = {
            'length_max_discrete': 40,
            'length_init_discrete': 10,
            'succtol': 3,
            'failtol': 30,
        }
    elif args.problem == 'pest_restart_ablation':
        random_seed_ = sorted(generate_random_seed_pestcontrol())[args.random_seed_objective]
        f = PestControl(random_seed=random_seed_)
        kwargs = {
            'length_max_discrete': 25,
            'length_init_discrete': 5,
            'failtol': 5,
            'succtol': 5,
            'tr_multiplier': 2.,
        }

    elif args.problem == 'branin':
        f = Branin(normalize=False)
    elif args.problem == 'ising':
        # todo: fully validate the Ising Sparsification problem to ensure that this produces the desired behaviours.
        random_seed_pair_ = generate_random_seed_pair_ising()
        case_seed_ = sorted(random_seed_pair_.keys())[int(args.random_seed_objective / 5)]
        init_seed_ = sorted(random_seed_pair_[case_seed_])[int(args.random_seed_objective % 5)]
        f = Ising(lamda=args.lamda, random_seed_pair=(case_seed_, init_seed_))
    elif args.problem == 'func2C':
        f = Func2C(lamda=args.lamda, )
    elif args.problem == 'func3C':
        f = Func3C(lamda=args.lamda)
    elif args.problem == 'ackley53':
        f = Ackley53(lamda=args.lamda)
        kwargs = {
            'length_max_discrete': 50,
            'length_init_discrete': 15,
            'failtol': 40,
        }
    elif args.problem == 'ackley106':
        f = Ackley106(lamda=args.lamda)
        kwargs = {
            'length_max_discrete': 100,
            'length_init_discrete': 15,
            'failtol': 40,
        }
    elif args.problem == 'rosen200':
        f = Rosen200(lamda=args.lamda)
        kwargs = {
            'length_max_discrete': 100,
            'length_init_discrete': 15,
            'failtol': 40,
        }
    elif args.problem == 'MaxSAT60':
        f = MaxSAT60()
        kwargs = {
            'length_max_discrete': 60,
            'failtol': 40
        }
    elif args.problem == 'xgboost-mnist':
        f = XGBoostOptTask(lamda=args.lamda, task='mnist', seed=args.seed)
    elif args.problem == 'svm-boston':
        f = SVMOptTask(lamda=0., task='boston', seed=args.seed)
        kwargs = {
            'length_max_discrete': 6,  # there
            'length_init_discrete': 6,
            'failtol': 40
        }
    elif args.problem == 'nasbench101':
        # try:
        #     f = pickle.load(open(args.data_dir + 'nasbench101.pickle', 'rb'))
        # except:
        f = NASBench101(data_dir=args.data_dir)
        # Fix the random seed, as per CoCaBO practice
        np.random.seed(t)
        random.seed(t)
        kwargs = {
            'length_max_discrete': 5,
            'length_init_discrete': 5,
            'failtol': 20
        }
    elif args.problem == 'offline_rl':
        f = OfflineRL(n_epochs=args.offline_rl_epochs, yaml_file=args.offline_rl_yaml)
        # kwargs = {
        #     'length_max_discrete': 100,
        #     'length_init_discrete': 15,
        #     'failtol': 40,
        # }
    else:
        raise ValueError('Unrecognised problem type %s' % args.problem)

    n_categories = f.n_vertices
    problem_type = f.problem_type

    print('----- Starting trial %d / %d -----' % ((t + 1), args.n_trials))
    res = pd.DataFrame(np.nan, index=np.arange(int(args.max_iters * args.batch_size)),
                       columns=['Index', 'LastValue', 'BestValue', 'Time'])
    if args.infer_noise_var:
        noise_variance = None
    else:
        noise_variance = f.lamda if hasattr(f, 'lamda') else None

    if hasattr(f, 'int_constrained_dims'):
        int_dims = f.int_constrained_dims
    else:
        int_dims = None

    if args.kernel_type is None:
        kernel_type = 'mixed' if problem_type == 'mixed' else 'type2'
    else:
        kernel_type = args.kernel_type

    if problem_type == 'mixed':
        if args.n_trust_regions == 1:
            turbo = MixedTurboOptimizer(f.config, f.lb, f.ub, f.continuous_dims, f.categorical_dims,
                                        int_constrained_dims=int_dims,
                                        n_init=args.n_init, use_ard=args.ard, acq=args.acq,
                                        kernel_type=kernel_type,
                                        noise_variance=noise_variance,
                                        guided_restart=not args.random_restart,
                                        global_bo=args.global_bo,
                                        **kwargs)
        else:
            raise NotImplementedError("TurBO-M optimiser for mixed search space is not yet implemented.")
    else:
        if args.n_trust_regions == 1:
            turbo = TurboOptimizer(f.config, n_init=args.n_init, use_ard=args.ard, acq=args.acq,
                                   global_bo=args.global_bo,
                                   kernel_type=kernel_type,
                                   guided_restart=not args.random_restart,
                                   noise_variance=noise_variance, **kwargs)
        else:
            if args.global_bo:
                raise ValueError("Multiple trust region setting is not compatible with global BO!")
            turbo = TurboMOptimizer(f.config, n_trust_regions=args.n_trust_regions, n_init=args.n_init,
                                    use_ard=args.ard,
                                    kernel_type=kernel_type,
                                    acq=args.acq, noise_variance=noise_variance, **kwargs)

    if args.problem == 'offline_rl':
        x_str, y_str = parse_string_map[args.offline_rl_yaml]
        xs = []
        for ps in x_str.splitlines():
            tokens = ps.split()
            xs.append([uncert_types.index(tokens[3]), int(tokens[5]), float(tokens[7]), int(tokens[9])])
        xs = np.array(xs)
        ys = np.array([float(v) for v in y_str.split()])

        print(xs)
        print(ys)
        print(f"{len(xs)} old samples added for pre-seeding.")

        turbo.suggest_init_(args.batch_size)
        turbo.observe(xs, ys, override_n_evals=True)

        # Trim down the initial points.
        turbo.trim_init_(n_initial=len(xs))

    for i in range(args.max_iters):
        start = time.time()
        x_next = turbo.suggest(args.batch_size)
        y_next = f.compute(x_next, normalize=f.normalize)
        turbo.observe(x_next, y_next)
        end = time.time()
        if f.normalize:
            Y = np.array(turbo.turbo.fX) * f.std + f.mean
        else:
            Y = np.array(turbo.turbo.fX)
        if Y[:i].shape[0]:
            # sequential
            if args.batch_size == 1:
                res.iloc[i, :] = [i, float(Y[-1]), float(np.min(Y[:i])), end - start]
            # batch
            else:
                for idx, j in enumerate(range(i * args.batch_size, (i + 1) * args.batch_size)):
                    res.iloc[j, :] = [j, float(Y[-idx]), float(np.min(Y[:i * args.batch_size])), end - start]
            # x_next = x_next.astype(int)
            argmin = np.argmin(Y[:i * args.batch_size])

            print('Iter %d, Last X %s; fX:  %.4f. X_best: %s, fX_best: %.4f'
                  % (i, x_next.flatten(),
                     float(Y[-1]),
                     ''.join([str(int(i)) for i in turbo.turbo.X[:i * args.batch_size][argmin].flatten()]),
                     Y[:i * args.batch_size][argmin]))
            # print(bo.bo.length_discrete)
        if save_path is not None:
            pickle.dump(res, open(os.path.join(save_path, 'trial-%d.pickle' % t), 'wb'))

    if args.seed is not None:
        args.seed += 1
