# MVRSM demo
# By Laurens Bliek, 16-03-2020
# Supported functions: 'func2C', 'func3C', 'dim10Rosenbrock',
# 'linearmivabo', 'dim53Rosenbrock', 'dim53Ackley', 'dim238Rosenbrock'
# Example: python demo.py -f dim10Rosenbrock  -n 10 -tl 4
# Here, -f is the function to be optimised, -n is the number of iterations, and -tl is the total number of runs.
# Afterward, use plot_result.py for visualisation.

import sys
# sys.path.append('../bayesopt')
# sys.path.append('../ml_utils')
import argparse
import os
import numpy as np
import pickle
import time
import testFunctions.syntheticFunctions
from methods.CoCaBO import CoCaBO
from methods.BatchCoCaBO import BatchCoCaBO
import MVRSM
from hyperopt import fmin, tpe, rand, hp, STATUS_OK, Trials
from functools import partial

from scipy.optimize import rosen
from linear_MIVABOfunction import Linear

from localglobal.mixed_test_func import *


if __name__ == '__main__':
    # Read arguments
    parser = argparse.ArgumentParser(description="Run BayesOpt Experiments")
    parser.add_argument('-f', '--func', help='Objective function',
                        default='adv_attack',
                        type=str)  # Supported functions: 'func2C', 'func3C', 'dim10Rosenbrock',
    # 'linearmivabo', 'dim53Rosenbrock', 'dim53Ackley', 'dim238Rosenbrock'
    parser.add_argument('-mix', '--kernel_mix',
                        help='Mixture weight for production and summation kernel. Default = 0.0', default=0.5,
                        type=float)
    parser.add_argument('-n', '--max_itr', help='Max Optimisation iterations. Default = 100',
                        default=500, type=int)
    parser.add_argument('-tl', '--trials', help='Number of random trials. Default = 20',
                        default=1, type=int)
    parser.add_argument('-b', '--batch',
                        help='Batch size (>1 for batch CoCaBO and =1 for sequential CoCaBO). Default = 1',
                        default=1, type=int)

    args = parser.parse_args()
    print(f"Got arguments: \n{args}")
    obj_func = args.func
    kernel_mix = args.kernel_mix
    n_itrs = args.max_itr
    n_trials = args.trials
    batch = args.batch

    folder = os.path.join(os.path.curdir, 'data', 'syntheticFns', obj_func)
    if not os.path.isdir(folder):
        os.mkdir(folder)

    if obj_func == 'dim10Rosenbrock':
        ff = testFunctions.syntheticFunctions.dim10Rosenbrock
        d = 10  # Total number of variables
        lb = -2 * np.ones(d).astype(int)  # Lower bound
        ub = 2 * np.ones(d).astype(int)  # Upper bound
        num_int = 3  # number of integer variables
        lb[0:num_int] = 0
        ub[0:num_int] = num_int + 1
    elif obj_func == 'func3C':
        ff = testFunctions.syntheticFunctions.func3C
        d = 5  # Total number of variables
        lb = -1 * np.ones(d).astype(int)  # Lower bound for continuous variables
        ub = 1 * np.ones(d).astype(int)  # Upper bound for continuous variables
        num_int = 3  # number of integer variables
        lb[0:num_int] = 0
        ub[0] = 2
        ub[1] = 4
        ub[2] = 3
    elif obj_func == 'func2C':
        ff = testFunctions.syntheticFunctions.func2C
        d = 4  # Total number of variables
        lb = -1 * np.ones(d).astype(int)  # Lower bound for continuous variables
        ub = 1 * np.ones(d).astype(int)  # Upper bound for continuous variables
        num_int = 2  # number of integer variables
        lb[0:num_int] = 0
        ub[0] = 2
        ub[1] = 4
    elif obj_func == 'linearmivabo':
        LM = Linear(laplace=False)
        ff = LM.objective_function
        d = 16  # Total number of variables
        lb = 0 * np.ones(d).astype(int)  # Lower bound for continuous variables
        ub = 3 * np.ones(d).astype(int)  # Upper bound for continuous variables
        num_int = 8  # number of integer variables
        lb[0:num_int] = 0
        ub[0:num_int] = 3
    elif obj_func == 'dim53Rosenbrock':
        ff = testFunctions.syntheticFunctions.dim53Rosenbrock
        d = 53  # Total number of variables
        lb = -2 * np.ones(d).astype(int)  # Lower bound
        ub = 2 * np.ones(d).astype(int)  # Upper bound
        num_int = 50  # number of integer variables
        lb[0:num_int] = 0
        ub[0:num_int] = 1
    elif obj_func == 'dim53Ackley':
        ff = testFunctions.syntheticFunctions.dim53Ackley
        d = 53  # Total number of variables
        lb = -1 * np.ones(d).astype(float)  # Lower bound
        ub = 1 * np.ones(d).astype(float)  # Upper bound
        num_int = 50  # number of integer variables
        lb[0:num_int] = 0
        ub[0:num_int] = 1
    elif obj_func == 'dim238Rosenbrock':
        ff = testFunctions.syntheticFunctions.dim238Rosenbrock
        d = 238  # Total number of variables
        lb = -2 * np.ones(d).astype(int)  # Lower bound
        ub = 2 * np.ones(d).astype(int)  # Upper bound
        num_int = 119  # number of integer variables
        lb[0:num_int] = 0
        ub[0:num_int] = 4
    elif obj_func == 'xgboost':
        ff = testFunctions.syntheticFunctions.xgboost
        d = 8  # Total number of variables
        lb = -1 * np.ones(d).astype(int)  # Lower bound
        ub = 1 * np.ones(d).astype(int)  # Upper bound
        num_int = 3
        lb[0:num_int] = 0
        ub[0:num_int] = 1
    elif obj_func == 'adv_attack':
        # Define ff later for this problem, we need to iterate through images etc.
        # ff = testFunctions.syntheticFunctions.adv_attack
        d = 85  # Total number of variables
        lb = -1 * np.ones(d).astype(int)  # Lower bound
        ub = 1 * np.ones(d).astype(int)  # Upper bound
        num_int = 43
        lb[0:num_int] = 0
        ub[0:num_int] = 13
        ub[num_int-1] = 2
    else:
        raise NotImplementedError

    x0 = np.zeros(d)  # Initial guess
    x0[0:num_int] = np.round(
        np.random.rand(num_int) * (ub[0:num_int] - lb[0:num_int]) + lb[0:num_int])  # Random initial guess (integer)
    x0[num_int:d] = np.random.rand(d - num_int) * (ub[num_int:d] - lb[num_int:d]) + lb[
                                                                                    num_int:d]  # Random initial guess (continuous)

    rand_evals = 0  # Number of random iterations, same as initN above (24)
    max_evals = n_itrs + rand_evals  # Maximum number of MVRSM iterations, the first <rand_evals> are random

    runs = []
    if obj_func != 'adv_attack':
        for i in range(n_trials):
            def obj_MVRSM(x):
                # print(x[0:num_int])
                h = np.copy(x[0:num_int]).astype(int)
                if obj_func == 'func3C' or obj_func == 'func2C':
                    result = ff(h, x[num_int:])[0][0]
                elif obj_func == 'linearmivabo':
                    result = ff(x)
                else:
                    result = ff(h, x[num_int:])
                return result
            def run_MVRSM():
                solX, solY, model, logfile, run_data = MVRSM.MVRSM_minimize(obj_MVRSM, x0, lb, ub, num_int, max_evals,
                                                                            rand_evals)
                os.rename(logfile, os.path.join(folder, logfile))
                print("Solution found: ")
                print(f"X = {solX}")
                print(f"Y = {solY}")
                return run_data
            print("Start MVRSM trials")
            run_data = run_MVRSM()
            runs.append(run_data)
    else:
        for t in range(50):
            for i in range(9):
                f = AdversarialAttack(f'./tf_models/',
                                      save_dir='output/',
                                      target_label=i,
                                      img_offset=t,
                                      )
                ff = lambda hs, xs: testFunctions.syntheticFunctions.adv_attack(hs, xs, f)


                def obj_MVRSM(x):
                    # print(x[0:num_int])
                    h = np.copy(x[0:num_int]).astype(int)
                    if obj_func == 'func3C' or obj_func == 'func2C':
                        result = ff(h, x[num_int:])[0][0]
                    elif obj_func == 'linearmivabo':
                        result = ff(x)
                    else:
                        result = ff(h, x[num_int:])
                    return result


                def run_MVRSM():
                    solX, solY, model, logfile, run_data = MVRSM.MVRSM_minimize(obj_MVRSM, x0, lb, ub, num_int,
                                                                                max_evals,
                                                                                rand_evals)
                    os.rename(logfile, os.path.join(folder, logfile))
                    print("Solution found: ")
                    print(f"X = {solX}")
                    print(f"Y = {solY}")
                    return run_data


                print(f"Start MVRSM trials t={t}, i={i}")
                run_data = run_MVRSM()
                runs.append(run_data)

    runs = np.array(runs)
    print(runs)

    mvrsm_file = open(os.path.join(folder, obj_func + '_baseline_result_mvrsm_fullruns.pkl'), 'wb')
    pickle.dump({'runs': runs}, mvrsm_file)
    mvrsm_file.close()
