# this file plots the tabular portion of the synthetic data results (top row of figure 1)
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
from os.path import exists

# load results
n, L, trials = 100000, 100, 5  # horizon, number of items, and number of trials
all_tests = np.arange(4, 18, 2)  # all values of K tested
num_tests = all_tests.shape[0]  # number of tests
regret_avg = np.zeros((3, n, num_tests))  # average of the regret across trials
regret_std = np.zeros((3, n, num_tests))  # standard deviation of the regret across trials
runtime_avg = np.zeros((3, num_tests))  # average of the runtime across trials
runtime_std = np.zeros((3, num_tests))  # standard deviation of the runtime across trials
for test_idx in range(num_tests):
    # try to load results for current test, printing error message and exiting if it doesn't exist
    K = all_tests.astype(int)[test_idx]
    result_file = 'results/syn_tab_' + str(K) + '.npz'
    if not exists(result_file):
        print('error -- file ' + result_file + ' does not exist -- see README')
        exit()
    data = np.load(result_file)
    regret_avg[:, :, test_idx] = np.mean(data['arr_0'], axis=2)
    regret_std[:, :, test_idx] = np.std(data['arr_0'], axis=2)
    runtime_avg[:, test_idx] = np.mean(data['arr_1'], axis=1)
    runtime_std[:, test_idx] = np.std(data['arr_1'], axis=1)

# if we've reached this line, we've loaded all the data and just need to make the plots
# as a disclaimer, this plotting code is hacked together from incongruous matplotlib examples so may be perplexing

# initialize plot
matplotlib.rcParams.update({'font.size': 18})
fig, ax = plt.subplots(1, 3)
fig.set_size_inches(11, 3)
fig.subplots_adjust(wspace=0.05)
plt.tight_layout()
plt.subplots_adjust(left=0.075, right=0.98, bottom=0.21, top=0.7)

# first plot
avg = regret_avg[:, n - 1, :]
std = regret_std[:, n - 1, :]
ax[0].plot(all_tests, avg[0, :], 'r--')
ax[0].fill_between(all_tests, avg[0, :] - std[0, :], avg[0, :] + std[0, :], color='r', alpha=0.2)
ax[0].plot(all_tests, avg[1, :], 'b-.')
ax[0].fill_between(all_tests, avg[1, :] - std[1, :], avg[1, :] + std[1, :], color='b', alpha=0.2)
ax[0].plot(all_tests, avg[2, :], 'g:')
ax[0].fill_between(all_tests, avg[2, :] - std[2, :], avg[2, :] + std[2, :], color='g', alpha=0.2)
ax[0].set(xlabel=r'$K$')
ax[0].set(xlim=[min(all_tests), max(all_tests)])
ax[0].set_xticks([4, 6, 8, 10, 12, 14, 16])
ax[0].set(ylabel='Regret')
ax[0].set(ylim=[0, 12e3])
ax[0].set_yticks([0, 3e3, 6e3, 9e3, 12e3])
ax[0].ticklabel_format(style='sci', axis='y', scilimits=(0, 0))
ax[0].set_title(r'$n=10^5$', fontdict={'fontsize': 18}, loc='right')
ax[0].tick_params(labelbottom=True, labeltop=False, labelleft=True, labelright=False,
                  bottom=True, top=True, left=True, right=True)

# second plot
avg = runtime_avg
std = runtime_std
ax[1].semilogy(all_tests, avg[0, :], 'r--')
ax[1].fill_between(all_tests, avg[0, :] - std[0, :], avg[0, :] + std[0, :], color='r', alpha=0.2)
ax[1].semilogy(all_tests, avg[1, :], 'b-.')
ax[1].fill_between(all_tests, avg[1, :] - std[1, :], avg[1, :] + std[1, :], color='b', alpha=0.2)
ax[1].semilogy(all_tests, avg[2, :], 'g:')
ax[1].fill_between(all_tests, avg[2, :] - std[2, :], avg[2, :] + std[2, :], color='g', alpha=0.2)
ax[1].set(xlabel=r'$K$')
ax[1].set(xlim=[min(all_tests), max(all_tests)])
ax[1].set_xticks([4, 6, 8, 10, 12, 14, 16])
ax[1].set(ylabel='Runtime')
ax[1].set(ylim=[1e1, 1e4])
ax[1].set_yticks([1e1, 1e2, 1e3, 1e4])
ax[1].set_title(r'$n=10^5$', fontdict={'fontsize': 18}, loc='right')
ax[1].tick_params(labelbottom=True, labeltop=False, labelleft=True, labelright=False,
                  bottom=True, top=True, left=True, right=True)

# third plot
avg = regret_avg[:, :, 3]
std = regret_std[:, :, 3]
ax[2].plot(range(n), avg[0, :], 'r--', label='CascadeUCB1')
ax[2].fill_between(range(n), avg[0, :] - std[0, :], avg[0, :] + std[0, :], color='r', alpha=0.2)
ax[2].plot(range(n), avg[1, :], 'b-.', label='CascadeUCB-V')
ax[2].fill_between(range(n), avg[1, :] - std[1, :], avg[1, :] + std[1, :], color='b', alpha=0.2)
ax[2].plot(range(n), avg[2, :], 'g:', label='CascadeKL-UCB')
ax[2].fill_between(range(n), avg[2, :] - std[2, :], avg[2, :] + std[2, :], color='g', alpha=0.2)
ax[2].set(xlabel=r'$n$')
ax[2].set(xlim=[0, n])
ax[2].set_xticks([0, 2e4, 4e4, 6e4, 8e4, 10e4])
ax[2].ticklabel_format(style='sci', axis='x', scilimits=(0, 0))
ax[2].set(ylabel='Regret')
ax[2].set(ylim=[0, 12e3])
ax[2].set_yticks([0, 3e3, 6e3, 9e3, 12e3])
ax[2].ticklabel_format(style='sci', axis='y', scilimits=(0, 0))
ax[2].set_title(r'$K=10$', fontdict={'fontsize': 18}, loc='right')
ax[2].tick_params(labelbottom=True, labeltop=False, labelleft=True, labelright=False,
                  bottom=True, top=True, left=True, right=True)

# add legend and save figure
handles, labels = ax[2].get_legend_handles_labels()
fig.legend(handles, labels, loc='upper center', ncol=4)
fig.savefig('plots/syn_tab.png', dpi=300)
