# -*- coding: utf-8 -*-
import os
import sys
import re
import pickle
import concurrent.futures


os.environ["OMP_NUM_THREADS"] = "1"
os.environ["OPENBLAS_NUM_THREADS"] = "1"
os.environ["MKL_NUM_THREADS"] = "1"
os.environ["VECLIB_NUM_THREADS"] = "1"
os.environ["NUMEXPR_NUM_THREADS"] = "1"

import numpy as np
import numpy.matlib
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
from matplotlib import ticker
import GPy
from scipy.stats import sem

import script.test_functions as test_functions


plt.rcParams['pdf.fonttype'] = 42
plt.rcParams['ps.fonttype'] = 42
plt.rcParams['font.family'] = 'sans-serif'

plt.rcParams["font.size"] = 20
plt.rcParams['xtick.labelsize'] = 20
plt.rcParams['ytick.labelsize'] = 20
plt.rcParams['legend.fontsize'] = 18
plt.rcParams['figure.figsize'] = (6.5,5.5)

plt.rcParams['errorbar.capsize'] = 4.0

plt.rcParams['lines.linewidth'] = 2.5
plt.rcParams['lines.markeredgewidth'] = 1.5
plt.rcParams['lines.markersize'] = 10.

plt.rcParams['legend.borderaxespad'] = 0.1
plt.rcParams['legend.labelspacing'] = 0.2
plt.rcParams['legend.borderpad'] = 0.1
plt.rcParams['legend.columnspacing'] = 0.25
plt.rcParams["legend.handletextpad"] = 0.25
plt.rcParams['legend.handlelength'] = 2.5
plt.rcParams['legend.handleheight'] = 0.5
# plt.rcParams['figure.constrained_layout.use'] = True

def main(params):
    func_name, kernel_name, ell, noise_std = params

    # plt.style.use('tableau-colorblind10')
    seeds_num = 16
    seeds = np.arange(seeds_num)


    BO_methods = ["EIMS", "EI", "EI_wang", "EEEI", "IRGPUCB", "GPUCB", "PIMS", "TS", "MES", "JES"]
    linestyles = ["solid", (0, (6, 1)), (0, (10, 3)), "dashed", (0, (3, 1, 1, 1)), (0, (3, 1, 1, 1, 1, 1)), (0, (3, 1, 1, 1, 1, 1, 1, 1)), (0, (3, 1, 1, 1, 1, 1, 1, 1, 1, 1)), (0, (3, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1)), (0, (1, 1))]
    if func_name == "GP":
        save_name = 'GP_{}-lengthscale={}-noise_std={}'.format(kernel_name, ell, noise_std)
    else:
        save_name = '{}_{}-noise_std={}'.format(func_name, kernel_name, noise_std)
    result_path = '{}_results/'.format(save_name)

    max_iteration = 0
    fig = plt.figure()
    ax = fig.add_subplot(1, 1, 1)
    print(save_name + '--------------------------------------')

    for j, method in enumerate(BO_methods):
        if noise_std < 1.:
            errorevery = (5*j, 40)
        else:
            errorevery = (10*j, 80)

        plot=True
        InfReg_all = list()
        for seed in seeds:
            temp_path = result_path + method + '/seed=' + str(seed) + '/'

            if func_name != "GP" and noise_std == 0:
                if os.path.exists(temp_path + 'BestRegret.pickle') and os.path.getsize(temp_path + 'BestRegret.pickle')>0:
                    with open(temp_path + 'BestRegret.pickle', 'rb') as f:
                        InfReg = pickle.load(f)
                    InfReg_all.append(InfReg)
                else:
                    plot=False
            else:
                if os.path.exists(temp_path + 'InferenceRegret.pickle') and os.path.getsize(temp_path + 'InferenceRegret.pickle')>0:
                    with open(temp_path + 'InferenceRegret.pickle', 'rb') as f:
                        InfReg = pickle.load(f)
                    InfReg_all.append(InfReg)
                else:
                    plot=False

        if plot:
            min_len = np.min([np.size(reg) for reg in InfReg_all])
            # x_max = np.min([min_len, 50*input_dim])
            x_max = int(min_len)

            InfReg_all = [reg[:x_max] for reg in InfReg_all]
            InfReg_all = np.vstack(InfReg_all)
            InfReg_ave = np.mean(InfReg_all, axis=0)
            InfReg_se = sem(InfReg_all, axis=0, ddof=1)

            max_iteration = np.max([np.size(InfReg_ave), max_iteration])

            linestyle = linestyles[j]
            marker = None
            color = None
            label = method.replace("EI_wang", r"EI-$\mu^{\rm max}$").replace("EEEI", r"E$^3$I")
            if func_name == "GP" and method in ["EIMS", "EI_wang", "IRGPUCB", "GPUCB", "TS", "PIMS"]:
                label = label + r"$^\dagger$"

            if func_name != "GP" and method in ["EIMS", "TS", "PIMS"]:
                label = label + r"$^\dagger$"

            ax.errorbar(np.arange(np.size(InfReg_ave)), InfReg_ave, yerr=InfReg_se, errorevery=errorevery, capsize=4, elinewidth=2, label=label, marker=marker, markevery=5, linestyle=linestyle, color=color, markerfacecolor="None")

    if func_name == "GP":
        ax.set_title(r'GP ({}, $\ell$={}, $\sigma$={}) '.format(kernel_name, ell, noise_std), fontsize=20, loc="right")
    else:
        test_func = eval('test_functions.'+func_name)(rng=np.random.default_rng(1), noise_std=noise_std)
        ax.set_title(r'{} ({}, $d$={}) '.format(func_name, kernel_name, test_func.d), fontsize=20, loc="right")

    ax.set_xlim(0, max_iteration)
    ax.set_xlabel('Iteration')
    # ax.set_ylabel('Simple regret')
    ax.grid(which='major')
    ax.grid(which='minor')

    if func_name == "Ackley":
        plt.legend(ncol=2, loc="upper right")

    if not os.path.exists("figures/"):
        os.makedirs("figures/")

    if not "GP" in func_name:
        ax.set_yscale('log')
        # formatter = ticker.ScalarFormatter(useMathText=True)
        # formatter.set_scientific(True)
        # formatter.set_powerlimits((-3,1))
        # ax.yaxis.set_major_formatter(formatter)
        plt.tight_layout()
        plt.savefig('figures/log_SimpleRegret_{}.pdf'.format(save_name))
    else:
        ax.set_ylim(0, 1.5)
        if ell==0.1 and noise_std==1.0:
            ax.set_ylim(0, 2.2)
        formatter = ticker.ScalarFormatter(useMathText=True)
        formatter.set_scientific(True)
        formatter.set_powerlimits((-3,1))
        ax.yaxis.set_major_formatter(formatter)
        plt.tight_layout()
        plt.savefig('figures/SimpleRegret_{}.pdf'.format(save_name))

    if ell == 0.1 and noise_std == 0.01:
        fig_legend = plt.figure(figsize=(10, 2))
        fig_legend.legend(*ax.get_legend_handles_labels(), loc="center", ncol=5)
        fig_legend.gca().axis("off")
        plt.tight_layout()
        plt.savefig('figures/legend.pdf')
    plt.close()


def plot_cumulative_regret(params):
    func_name, kernel_name, ell, noise_std = params

    # plt.style.use('tableau-colorblind10')
    seeds_num = 16
    seeds = np.arange(seeds_num)

    BO_methods = ["EIMS", "EI", "EI_wang", "EEEI", "IRGPUCB", "GPUCB", "PIMS", "TS", "MES", "JES"]
    linestyles = ["solid", (0, (6, 1)), (0, (10, 3)), "dashed", (0, (3, 1, 1, 1)), (0, (3, 1, 1, 1, 1, 1)), (0, (3, 1, 1, 1, 1, 1, 1, 1)), (0, (3, 1, 1, 1, 1, 1, 1, 1, 1, 1)), (0, (3, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1)), (0, (1, 1))]

    if func_name == "GP":
        save_name = 'GP_{}-lengthscale={}-noise_std={}'.format(kernel_name, ell, noise_std)
    else:
        save_name = '{}_{}-noise_std={}'.format(func_name, kernel_name, noise_std)
    result_path = '{}_results/'.format(save_name)

    max_iteration = 0
    fig = plt.figure()
    ax = fig.add_subplot(1, 1, 1)
    print(save_name + '--------------------------------------')

    for j, method in enumerate(BO_methods):
        if noise_std < 1.:
            errorevery = (5*j, 40)
        else:
            errorevery = (10*j, 80)

        plot=True
        InfReg_all = list()
        for seed in seeds:
            temp_path = result_path + method + '/seed=' + str(seed) + '/'

            if os.path.exists(temp_path + 'CumulativeRegret.pickle') and os.path.getsize(temp_path + 'CumulativeRegret.pickle')>0:
                with open(temp_path + 'CumulativeRegret.pickle', 'rb') as f:
                    InfReg = pickle.load(f)
                InfReg_all.append(InfReg)
            else:
                plot=False

        if plot:
            min_len = np.min([np.size(reg) for reg in InfReg_all])
            # x_max = np.min([min_len, 50*input_dim])
            x_max = int(min_len)

            InfReg_all = [reg[:x_max] for reg in InfReg_all]
            InfReg_all = np.vstack(InfReg_all)
            InfReg_ave = np.mean(InfReg_all, axis=0)
            InfReg_se = sem(InfReg_all, axis=0, ddof=1)

            max_iteration = np.max([np.size(InfReg_ave), max_iteration])

            linestyle = linestyles[j]
            marker = None
            color = None
            label = method.replace("EI_wang", r"EI-$\mu^{\rm max}$").replace("EEEI", r"E$^3$I")
            if func_name == "GP" and method in ["EIMS", "EI_wang", "IRGPUCB", "GPUCB", "TS", "PIMS"]:
                label = label + r"$^\dagger$"

            if func_name != "GP" and method in ["EIMS", "TS", "PIMS"]:
                label = label + r"$^\dagger$"

            ax.errorbar(np.arange(np.size(InfReg_ave)), InfReg_ave, yerr=InfReg_se, errorevery=errorevery, capsize=4, elinewidth=2, label=label, marker=marker, markevery=5, linestyle=linestyle, color=color, markerfacecolor="None")

    if func_name == "GP":
        ax.set_title(r'GP ({}, $\ell$={}, $\sigma$={}) '.format(kernel_name, ell, noise_std), fontsize=20, loc="right")
    else:
        test_func = eval('test_functions.'+func_name)(rng=np.random.default_rng(1), noise_std=noise_std)
        ax.set_title(r'{} ({}, $d$={}) '.format(func_name, kernel_name, test_func.d), fontsize=20, loc="right")

    ax.set_xlim(0, max_iteration)
    ax.set_xlabel('Iteration')
    # ax.set_ylabel('Simple regret')
    ax.grid(which='major')
    ax.grid(which='minor')
    if func_name == "Ackley":
        plt.legend(ncol=2)

    if not os.path.exists("figures/"):
        os.makedirs("figures/")

    class CustomScalarFormatter(ticker.ScalarFormatter):
        def __init__(self, decimals=1, useMathText=True, *args, **kwargs):
            super().__init__(useMathText=useMathText, *args, **kwargs)
            self.decimals = decimals

        def _set_format(self):
            self.format = f"%.{self.decimals}f"

    formatter = CustomScalarFormatter(decimals=1, useMathText=True)
    formatter.set_scientific(True)
    formatter.set_powerlimits((-3, 1))
    ax.yaxis.set_major_formatter(formatter)

    plt.tight_layout()
    plt.savefig('figures/CumulativeRegret_{}.pdf'.format(save_name))
    plt.close()

if __name__ == '__main__':
    NUM_WORKER = 32

    kernel_names = ["SE", "Matern52"]
    lengthscales = [0.1, 0.2]
    noise_stds = [0.01, 0.1, 1.0]

    params = list()
    for kernel_name in kernel_names:
        for lengthscale in lengthscales:
            for noise_std in noise_stds:
                params.append(("GP", kernel_name, lengthscale, noise_std))

    func_names = ["Branin", "Shekel", "Styblinski_tang", "Hartmann3", "Hartmann4", "Hartmann6", "Ackley", "Bukin", "Cross_in_tray", "Eggholder", "Holder_table", "Langerman", "Levy", "Levy13", "Rastrigin", "Shubert", "Schwefel", "Rosenbrock", "Goldstein"]
    func_names = ["Branin", "Shekel", "Styblinski_tang", "Hartmann3", "Hartmann4", "Hartmann6", "Ackley", "Rosenbrock"]

    for func_name in func_names:
        params.append((func_name, "SE", 0.01, 0.0))

    with concurrent.futures.ProcessPoolExecutor(max_workers=NUM_WORKER) as executor:
        results = executor.map(main, params)

    with concurrent.futures.ProcessPoolExecutor(max_workers=NUM_WORKER) as executor:
        results = executor.map(plot_cumulative_regret, params)