import numpy as np
import matplotlib.pyplot as plt
import scipy.io

from utils.preprocessing_utils import get_environment_path
from utils.postprocessing_utils import set_size

plt.rcParams.update({
            "font.family": "serif",  # use serif/main font for text elements
            "text.usetex": True,  # use inline math for ticks
            "pgf.rcfonts": False,  # don't setup fonts from rc parameters
            'text.latex.preamble': [r'\usepackage{amsmath}',
                                    r'\usepackage{amssymb}']
        })

c = {
    'blau100': [0, 0.32941176470588235, 0.6235294117647059],
    'schwarz100': [0, 0, 0],
    'schwarz50': [0.611764705882353, 0.6196078431372549, 0.6235294117647059],
    'gruen100': [0.3411764705882353, 0.6705882352941176, 0.15294117647058825],
    'orange100': [0.9647058823529412, 0.6588235294117647, 0],
    'rot100': [0.8, 0.027450980392156862, 0.11764705882352941],
}
LABELS = ['Random',
          'IGP-UCB',
          'BR-GP-UCB',
          'SW-GP-UCB',
          'W-GP-UCB',
          'ET-GP-UCB (ours)']

REGRET_NAMES = ['random_regrets_track',
                'gp_regrets_track',
                'r_gp_regrets_track',
                'sw_gp_regrets_track',
                'w_gp_regrets_track',
                'et_gp_regrets_track', ]

COLORS = [c['schwarz50'],
          c['schwarz100'],
          c['blau100'],
          c['orange100'],
          c['gruen100'],
          c['rot100']]

SUBSAMPLE = 1

linestyle = {"linewidth": 1, "markeredgewidth": 0.5, "elinewidth": 0.5, "capsize": 3, "barsabove": False}

env_path = get_environment_path()

x, y = set_size(490, subplots=(1, 3), fraction=1.)
fig, axes = plt.subplots(nrows=1, ncols=3,
                         figsize=(x, y * 1.6), )

# FIGURE A
file_path1 = env_path + 'matlab_examples_from_Deng_et_al/Fig_a_abruptly_changing/regret_abruptly_changing_reasonable_noise_2.mat'
results = scipy.io.loadmat(file_path1)
T = results['T']
x = np.arange(1, T + 1)[::SUBSAMPLE]

for i, name in enumerate(REGRET_NAMES):
    regret = results[name]
    mean_regret = np.mean(regret, axis=1)[::SUBSAMPLE]
    stdv_regret = np.std(regret, axis=1)[::SUBSAMPLE]
    axes[0].fill_between(x.flatten(), mean_regret - stdv_regret, mean_regret + stdv_regret,
                         color=COLORS[i], alpha=0.1, lw=0.3)
    axes[0].plot(x.flatten(), mean_regret, label=LABELS[i], color=COLORS[i], linewidth=1)
    print(f'{LABELS[i]}: Mean is {mean_regret[-1]}')

axes[0].set_ylim([0, 400])
axes[0].set_xlim([x[0], x[-1]])
axes[0].set_xticks([0, 100, 200, 300, 400, 500])
axes[0].set_ylabel('Regret $R_t$')
axes[0].set_xlabel('Time $t$\n(a) Abruptly-change, SE kernel', labelpad=2, linespacing=2)

# FIGURE B
file_path1 = env_path + 'matlab_examples_from_Deng_et_al/Fig_b_slowly_changing/regret_slowly_changing_reasonable_noise_2.mat'
results = scipy.io.loadmat(file_path1)
T = results['T']
x = np.arange(1, T + 1)[::SUBSAMPLE]

for i, name in enumerate(REGRET_NAMES):
    regret = results[name]
    mean_regret = np.mean(regret, axis=1)[::SUBSAMPLE]
    stdv_regret = np.std(regret, axis=1)[::SUBSAMPLE]
    axes[1].fill_between(x.flatten(), mean_regret - stdv_regret, mean_regret + stdv_regret,
                         color=COLORS[i], alpha=0.1, lw=0.3)
    axes[1].plot(x.flatten(), mean_regret, label=LABELS[i], color=COLORS[i], linewidth=1)
    print(f'{LABELS[i]}: Mean is {mean_regret[-1]}')

axes[1].set_ylim([0, 250])
axes[1].set_xlim([x[0], x[-1]])
axes[1].set_xticks([0, 100, 200, 300, 400, 500])
# axes[1].set_ylabel('$R_t$')
axes[1].set_xlabel('Time $t$\n(b) Slowly-change, SE kernel', labelpad=2, linespacing=2)

# FIGURE C
file_path1 = env_path + 'matlab_examples_from_Deng_et_al/Fig_c_stock_market_refactored/regret_refactored.mat'
results = scipy.io.loadmat(file_path1)
results['random_regrets_track'] = results['random_regret_track']
T = results['T']
x = np.arange(1, T + 1)[::SUBSAMPLE]

for i, name in enumerate(REGRET_NAMES):
    regret = results[name]
    mean_regret = np.mean(regret, axis=1)[::SUBSAMPLE]
    stdv_regret = np.std(regret, axis=1)[::SUBSAMPLE]
    axes[2].fill_between(x.flatten(), mean_regret - stdv_regret, mean_regret + stdv_regret,
                         color=COLORS[i], alpha=0.1, lw=0.3)
    axes[2].plot(x.flatten(), mean_regret, label=LABELS[i], color=COLORS[i], linewidth=1)
    print(f'{LABELS[i]}: Mean is {mean_regret[-1]}')

axes[2].set_ylim(bottom=0, top=0.6 * 10 ** 5)
axes[2].ticklabel_format(axis='y', scilimits=[4, 4])
axes[2].set_xlim([x[0], x[-1]])
axes[2].set_xticks([0, 200, 400, 600, 800])
# axes[1].set_ylabel('$R_t$')
axes[2].set_xlabel('Time $t$\n(c) Stock market data', labelpad=2, linespacing=2)

# create legend
axes[2].legend(bbox_to_anchor=(-2.5, 1.13,), loc="lower left", handletextpad=0.2, borderpad=0.2,
               borderaxespad=0.1, ncol=len(LABELS),  # edgecolor='inherit', fancybox=False,
               columnspacing=0.7, prop={'size': 9}, handlelength=1.5)

fig.subplots_adjust(bottom=0.26, top=0.82, left=0.09, right=0.98, wspace=0.2, hspace=0.45)
plt.show()
plt.close()
print('Done!')