from localglobal.test_funcs import *
from localglobal.mixed_test_func import *
import numpy as np
import random
from localglobal.bo.optimizer import TurboOptimizer, TurboMOptimizer
from localglobal.bo.optimizer_mixed import MixedTurboOptimizer
import logging
import argparse
import os
import pickle
import pandas as pd
import time, datetime
from localglobal.test_funcs.random_seed_config import *
import matplotlib.pyplot as plt
from copy import deepcopy

parser = argparse.ArgumentParser('Comp of ordinal vs cat kernels')
parser.add_argument('-k', '--kernel', type=str, default='ordinal')
parser.add_argument('--max_iters', type=int, default=100, help='Maximum number of BO iterations.')
parser.add_argument('--n_trials', type=int, default=20)
parser.add_argument('--save_path', type=str, default='output/')
parser.add_argument('-a', '--acq', type=str, default='ei', help='choice of the acquisition function.')
parser.add_argument('--random_seed_objective', type=int, default=20, help='The default value of 20 is provided also in COMBO')
parser.add_argument('-d', '--debug', action='store_true', help='Whether to turn on debugging mode (a lot of output will'
                                                               'be generated).')
parser.add_argument('--no_save', action='store_true', help='If activated, do not save the current run into a log folder.')
parser.add_argument('-p', '--plot', action='store_true', help='whether to activate plotting for the posterior and '
                                                              'acquisition function.')
parser.add_argument('--plot_interval', type=int, default=1, help='iteration interval between plotting/saving')

args = parser.parse_args()
options = vars(args)
print(options)

# Time string will be used as the directory name
time_string = datetime.datetime.now()
time_string = time_string.strftime('%Y%m%d_%H%M%S')

# Create the relevant folders, and save the arguments to reproduce the experiment, etc.
if not args.no_save:
    save_path = os.path.join(args.save_path, 'branin', time_string)
    if not os.path.exists(save_path):
        os.makedirs(save_path)
    option_file = open(save_path + "/command.txt", "w+")
    option_file.write(str(options))
    option_file.close()
else:
    save_path = None

if args.plot:
    # First create a copy of the function landscape of Branin
    f = Branin()
    plt.figure(figsize=[5, 5])
    res = np.zeros(f.config)
    for i in range(f.config[0]):
        for j in range(f.config[1]):
            res[i, j] = f.compute(np.array([[i, j]]), normalize=False)
    plt.title('Ground Truth')
    plt.imshow(res)
    if save_path is not None:
        plt.savefig(os.path.join(save_path, 'func_landscape.pdf'), dpi=200, bbox_inches='tight')
    else:
        plt.show()
    plt.close()

for t in range(args.n_trials):
    print('----- Starting trial %d / %d -----' % ((t + 1), args.n_trials))
    res = pd.DataFrame(np.nan, index=np.arange(int(args.max_iters)),
                       columns=['Index', 'LastValue', 'BestValue', 'Time'])

    f = Branin()
    n_categories = f.n_vertices

    turbo = TurboOptimizer(f.config, n_init=10, use_ard=False,
                           acq=args.acq,
                           global_bo=True,  # No need for local in 2D
                           kernel_type=args.kernel,)

    for i in range(args.max_iters):
        start = time.time()
        x_next = turbo.suggest(1)
        y_next = f.compute(x_next, )
        turbo.observe(x_next, y_next)
        end = time.time()
        y = np.array(turbo.turbo.fX)
        if y[:i].shape[0]:
            res.iloc[i, :] = [i, float(y[-1]), float(np.min(y[:i])), end - start]

            print('Iter %d, Last X %s; fX:  %.4f. fX_best: %.4f'
                  % (i, x_next.flatten(),
                     float(y[-1]),
                     float(np.min(y[:i]))))
        if save_path is not None:
            pickle.dump(res, open(os.path.join(save_path, 'trial-%d.pickle' % t), 'wb'))
        if args.plot and (i + 1) % args.plot_interval == 0 and turbo.turbo.gp is not None:
            # plot the GP posterior, etc.
            plt.figure(figsize=[5, 5])
            gp = deepcopy(turbo.turbo.gp)
            plotter = np.zeros(f.config)
            for ii in range(f.config[0]):
                for jj in range(f.config[1]):
                    loc = torch.tensor([[ii, jj]], dtype=torch.float32)
                    plotter[ii, jj] = gp(
                        loc
                    ).variance.detach().numpy()
            del gp
            plt.scatter(turbo.turbo.X[:-1, 1], turbo.turbo.X[:-1, 0], marker='x', color='black')
            plt.scatter(turbo.turbo.X[-1, 1], turbo.turbo.X[-1, 0], marker='x', color='r')
            plt.imshow(plotter)
            plt.title('Iter %d: GP Posterior Variance' % i)
            if save_path is not None:
                plt.savefig(os.path.join(save_path, 'gp_posterior_trial%d_iter%d.pdf' % (t, i)),
                            dpi=200, bbox_inches='tight')
            else:
                plt.show()
            plt.close()


