"""
比较不同的初始化方法，对贝叶斯优化方法在bandit问题上的影响。
总共比较四种初始化方法
1. lhs
2. lhd 
3. opd
4. ops
"""


import argparse
import pickle
import matplotlib.pyplot as plt
import numpy as np
import os
from scipy.optimize import minimize
import IPython
import copy
import logging

from utils.initialization import lhs, lhd, opd, opdnc, init_lhs, init_lhd, init_opdnc, init_opd


import sys
# sys.path.append('../bayesopt')
# sys.path.append('../ml_utils')
import argparse
import os
import testFunctions.syntheticFunctions
from methods.CoCaBO import CoCaBO
from methods.BatchCoCaBO import BatchCoCaBO


logging.basicConfig(level=logging.INFO, format="%(asctime)s %(lineno)d: %(message)s")


def CoCaBO_Exps(obj_func, budget, initN=24, trials=40, kernel_mix=0.5, batch=None, initialization="lhd", save_path=None, kernel_path=None, nb_init=10):

    # define saving path for saving the results
    # saving_path = f'data/syntheticFns/{obj_func}/'
    saving_path = save_path
    if not os.path.exists(saving_path):
        os.makedirs(saving_path)

    # define the objective function
    if obj_func == 'func2C':
        f = testFunctions.syntheticFunctions.func2C
        categories = [3, 5]

        bounds = [{'name': 'h1', 'type': 'categorical', 'domain': (0, 1, 2)},
                  {'name': 'h2', 'type': 'categorical', 'domain': (0, 1, 2, 3, 4)},
                  {'name': 'x1', 'type': 'continuous', 'domain': (-1, 1)},
                  {'name': 'x2', 'type': 'continuous', 'domain': (-1, 1)}]

    elif obj_func == 'budget_func2C':
        f = testFunctions.syntheticFunctions.budget_func2C
        categories = [3, 5, 10]

        bounds = [{'name': 'h1', 'type': 'categorical', 'domain': (0, 1, 2)},
                  {'name': 'h2', 'type': 'categorical', 'domain': (0, 1, 2, 3, 4)},
                  {'name': 'h3', 'type': 'categorical', 'domain': (0, 1, 2, 3, 4, 5, 6, 7, 8, 9)},
                  {'name': 'x1', 'type': 'continuous', 'domain': (-1, 1)},
                  {'name': 'x2', 'type': 'continuous', 'domain': (-1, 1)}]

    elif obj_func == 'func3C':
        f = testFunctions.syntheticFunctions.func3C
        categories = [3, 5, 4]

        bounds = [{'name': 'h1', 'type': 'categorical', 'domain': (0, 1, 2)},
                  {'name': 'h2', 'type': 'categorical', 'domain': (0, 1, 2, 3, 4)},
                  {'name': 'h3', 'type': 'categorical', 'domain': (0, 1, 2, 3)},
                  {'name': 'x1', 'type': 'continuous', 'domain': (-1, 1)},
                  {'name': 'x2', 'type': 'continuous', 'domain': (-1, 1)}]

    elif obj_func == "budget_func3C":
        f = testFunctions.syntheticFunctions.budget_func3C
        categories = [3, 5, 4, 10]

        bounds = [{'name': 'h1', 'type': 'categorical', 'domain': (0, 1, 2)},
                  {'name': 'h2', 'type': 'categorical', 'domain': (0, 1, 2, 3, 4)},
                  {'name': 'h3', 'type': 'categorical', 'domain': (0, 1, 2, 3)},
                  {'name': 'h4', 'type': 'categorical', 'domain': (0, 1, 2, 3, 4, 5, 6, 7, 8, 9)},
                  {'name': 'x1', 'type': 'continuous', 'domain': (-1, 1)},
                  {'name': 'x2', 'type': 'continuous', 'domain': (-1, 1)}]

    elif obj_func == "rosenbrock":
        f = testFunctions.syntheticFunctions.rosenbrock
        categories = [10]

        bounds = [{'name': 'h1', 'type': 'categorical', 'domain': (0, 1, 2, 3, 4, 5, 6, 7, 8, 9)},
                  {'name': 'x1', 'type': 'continuous', 'domain': (-1, 1)},
                  {'name': 'x2', 'type': 'continuous', 'domain': (-1, 1)}]

    elif obj_func == "sixhumpcamp":
        f = testFunctions.syntheticFunctions.sixhumpcamp
        categories = [10]

        bounds = [{'name': 'h1', 'type': 'categorical', 'domain': (0, 1, 2, 3, 4, 5, 6, 7, 8, 9)},
                  {'name': 'x1', 'type': 'continuous', 'domain': (-1, 1)},
                  {'name': 'x2', 'type': 'continuous', 'domain': (-1, 1)}]

    elif obj_func == "beale":
        f = testFunctions.syntheticFunctions.sixhumpcamp
        categories = [10]

        bounds = [{'name': 'h1', 'type': 'categorical', 'domain': (0, 1, 2, 3, 4, 5, 6, 7, 8, 9)},
                  {'name': 'x1', 'type': 'continuous', 'domain': (-1, 1)},
                  {'name': 'x2', 'type': 'continuous', 'domain': (-1, 1)}]
    else:
        raise NotImplementedError

    # 根据样本特征类型进行合适的初始化操作。

    # Run CoCaBO Algorithm
    if batch == 1:
        # sequential CoCaBO
        mabbo = CoCaBO(objfn=f, initN=initN, bounds=bounds,
                       acq_type='LCB', C=categories,
                       kernel_mix=kernel_mix)

    else:
        # batch CoCaBO
        mabbo = BatchCoCaBO(objfn=f, initN=initN, bounds=bounds,
                            acq_type='LCB', C=categories,
                            kernel_mix=kernel_mix,
                            batch_size=batch)

    if initialization == "lhs":
        init_data = init_lhs(bounds, nb=10)
    elif initialization == "lhd":
        init_data = init_lhd(bounds, nb=10)
    elif initialization == "opdnc":
        gp = pickle.load(open(kernel_path, "rb"))
        init_data, _ = init_opdnc(gp.compute_Ka, bounds, nb=10)
    elif initialization == "opd":
        gp = pickle.load(open(kernel_path, "rb"))
        init_data, _ = init_opd(gp.compute_Ka, bounds, nb=10)
        print(init_data)
    else:
        raise RuntimeError

    init_result = f(list(range(len(categories))), init_data[:, len(categories):])

    print(f"init_data:{init_data.shape} {init_data}")
    print(f"init_result:{init_result.shape} {init_result}")

    mabbo.runTrials(trials, budget, saving_path, initData=[init_data], initResult=[init_result])


if __name__ == '__main__':

    parser = argparse.ArgumentParser(description="Run BayesOpt Experiments")
    parser.add_argument('-f', '--func', help='Objective function',
                        default='budget_func2C', type=str)
    parser.add_argument('-mix', '--kernel_mix',
                        help='Mixture weight for production and summation kernel. Default = 0.0', default=0.0,
                        type=float)
    parser.add_argument('-n', '--max_itr', help='Max Optimisation iterations. Default = 100',
                        default=100, type=int)
    parser.add_argument('-tl', '--trials', help='Number of random trials. Default = 20',
                        default=20, type=int)
    parser.add_argument('-b', '--batch', help='Batch size (>1 for batch CoCaBO and =1 for sequential CoCaBO). Default = 1',
                        default=1, type=int)
    parser.add_argument("-init", "--initialization", help="initialization method", default="lhs", choices=["lhs", "lhd", "opd", "opdnc"])
    parser.add_argument("--save_path", default="log/1_")
    parser.add_argument("--kernel_path", default="log/1_lhd_budget_func2C/gp.pkl")
    parser.add_argument("--nb_init", default=10)

    args = parser.parse_args()
    print(f"Got arguments: \n{args}")
    obj_func = args.func
    kernel_mix = args.kernel_mix
    n_itrs = args.max_itr
    n_trials = args.trials
    batch = args.batch
    initialization = args.initialization
    save_path = f"{args.save_path}{obj_func}/"

    CoCaBO_Exps(obj_func=obj_func, budget=n_itrs,
                trials=n_trials, kernel_mix=kernel_mix, batch=batch,
                initialization=initialization, save_path=save_path, kernel_path=f"{args.kernel_path}/gp.pkl",
                nb_init=args.nb_init)
