# -*- coding: utf-8 -*-
# ==========================================
# Title:  run_cocabo_exps.py
# Author: Binxin Ru and Ahsan Alvi
# Date:   20 August 2019
# Link:   https://arxiv.org/abs/1906.08878
# ==========================================

# =============================================================================
#  CoCaBO Algorithms 
# =============================================================================
import sys
# sys.path.append('../bayesopt')
# sys.path.append('../ml_utils')
import argparse
import os
import testFunctions.syntheticFunctions
from methods.CoCaBO import CoCaBO
from methods.BatchCoCaBO import BatchCoCaBO


def CoCaBO_Exps(obj_func, budget, initN=8, trials=40, kernel_mix=0.5, batch=None):
    # define saving path for saving the results
    saving_path = f'data/syntheticFns/{obj_func}/'
    if not os.path.exists(saving_path):
        os.makedirs(saving_path)

    # define the objective function
    if obj_func == 'func2C':
        f = testFunctions.syntheticFunctions.func2C
        categories = [3, 5]

        bounds = [{'name': 'h1', 'type': 'categorical', 'domain': (0, 1, 2)},
                  {'name': 'h2', 'type': 'categorical', 'domain': (0, 1, 2, 3, 4)},
                  {'name': 'x1', 'type': 'continuous', 'domain': (-1, 1)},
                  {'name': 'x2', 'type': 'continuous', 'domain': (-1, 1)}]

    elif obj_func == 'func3C':
        f = testFunctions.syntheticFunctions.func3C
        categories = [2, 2, 2]

        bounds = [
            # {'name': 'h1', 'type': 'categorical', 'domain': (0, 1, 2)},
            #       {'name': 'h2', 'type': 'categorical', 'domain': (0, 1, 2, 3, 4)},
            #       {'name': 'h3', 'type': 'categorical', 'domain': (0, 1, 2, 3)},
            {'name': 'h1', 'type': 'categorical', 'domain': (0, 1)},
            {'name': 'h2', 'type': 'categorical', 'domain': (0, 1)},
            {'name': 'h3', 'type': 'categorical', 'domain': (0, 1)},
          {'name': 'x1', 'type': 'continuous', 'domain': (-1, 1)},
          {'name': 'x2', 'type': 'continuous', 'domain': (-1, 1)}]

    elif obj_func == 'nasbench101':
        # Robin laptop
        # data_dir = '/Users/binxinru/Documents/Ph.D/Projects/MABBO/Multi-Arm-Bandit-BO/testFunctions/tabular_benchmarks/benchmark_data/'

        # Robin Rapid
        # data_dir = '/nfs/home/robinru/MABBO/Multi-Arm-Bandit-BO/testFunctions/tabular_benchmarks/benchmark_data/'

        # Xingchen desktop
        data_dir = '/mnt/08B82010B81FFAC0/PyCharm Projects/TurBODiscrete/data/'

        # Xingchen Rapid
        data_dir = '/nfs/home/xingchenw/nas/data/'
        nasbench101 = testFunctions.NAS_Bench(data_dir, deterministic=False)
        bounds, categories = nasbench101.make_gpyopt_space()
        f = nasbench101.evaluate
        # categories = [3, 3, 3, 3, 3]
        # bounds1 = [{'name': 'h'+str(i), 'type': 'categorical', 'domain': (0, 1, 2)} for i in range(5)]
        # bounds2 = [{'name': 'x'+str(i), 'type': 'continuous', 'domain': (0, 1)} for i in range(21)]
        # bounds3 = [{'name': 'x22', 'type': 'continuous', 'domain': (0, 9)}]
        # bounds = bounds1 + bounds2 + bounds3

    elif obj_func == 'ackley53':
        f = testFunctions.ackley53
        categories = [2] * 50
        bounds_cat = [
            {'name': 'h'+str(i), 'type': 'categorical', 'domain': (0, 1)} for i in range(50)
        ]
        bounds_cont = [
            {'name': 'x1', 'type': 'continuous', 'domain': (-1, 1)},
            {'name': 'x2', 'type': 'continuous', 'domain': (-1, 1)},
            {'name': 'x3', 'type': 'continuous', 'domain': (-1, 1)},
        ]
        bounds = bounds_cat + bounds_cont

    else:
        raise NotImplementedError

    # Run CoCaBO Algorithm
    if batch == 1:
        # sequential CoCaBO
        mabbo = CoCaBO(objfn=f, initN=initN, bounds=bounds,
                       acq_type='LCB', C=categories,
                       kernel_mix=kernel_mix)

    else:
        # batch CoCaBO
        mabbo = BatchCoCaBO(objfn=f, initN=initN, bounds=bounds,
                            acq_type='LCB', C=categories,
                            kernel_mix=kernel_mix,
                            batch_size=batch)
    mabbo.runTrials(trials, budget, saving_path)


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description="Run BayesOpt Experiments")
    parser.add_argument('-f', '--func', help='Objective function',
                        default='func2C', type=str)
    parser.add_argument('-mix', '--kernel_mix',
                        help='Mixture weight for production and summation kernel. Default = 0.0', default=0.5,
                        type=float)
    parser.add_argument('-n', '--max_itr', help='Max Optimisation iterations. Default = 100',
                        default=100, type=int)
    parser.add_argument('-tl', '--trials', help='Number of random trials. Default = 20',
                        default=1, type=int)
    parser.add_argument('-b', '--batch',
                        help='Batch size (>1 for batch CoCaBO and =1 for sequential CoCaBO). Default = 1',
                        default=2, type=int)

    args = parser.parse_args()
    print(f"Got arguments: \n{args}")
    obj_func = args.func
    kernel_mix = args.kernel_mix
    n_itrs = args.max_itr
    n_trials = args.trials
    batch = args.batch

    CoCaBO_Exps(obj_func=obj_func, budget=n_itrs,
                trials=n_trials, kernel_mix=kernel_mix, batch=batch)

