# this file plots the real data results from the main text (figure 2)
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
from os.path import exists

# try to load results for the all genre test, printing error message and exiting if it doesn't exist
result_file = 'results/real_lin_all.npz'
if not exists(result_file):
    print('error -- file ' + result_file + ' does not exist -- see README')
    exit()
regret_all_genre = np.load(result_file)['arr_0']
regret_avg_all_genre = np.mean(regret_all_genre, axis=2)  # average of the regret across trials
regret_std_all_genre = np.std(regret_all_genre, axis=2)  # standard deviation of the regret across trials

# load results for the per genre test
movielens_file = 'data/movies.npz'
if not exists(movielens_file):
    # try to load movielens file, printing error message and exiting if it doesn't exist
    print('error -- file ' + movielens_file + ' does not exist -- see README')
    exit()
genre_list = np.load(movielens_file)['arr_1']
n, num_genre = regret_avg_all_genre.shape[1], len(genre_list)
regret_avg_per_genre = np.zeros((2, n, num_genre))  # average of the regret across trials
regret_std_per_genre = np.zeros((2, n, num_genre))  # standard deviation of the regret across trials
for genre_idx in range(num_genre):
    # try to load current genre file, printing error message and exiting if it doesn't exist
    result_file = 'results/real_lin_' + genre_list[genre_idx] + '.npz'
    if not exists(result_file):
        print('error -- file ' + result_file + ' does not exist -- see README')
        exit()
    regret = np.load('results/real_lin_' + genre_list[genre_idx] + '.npz')['arr_0']
    regret_avg_per_genre[:, :, genre_idx] = np.mean(regret, axis=2)
    regret_std_per_genre[:, :, genre_idx] = np.std(regret, axis=2)

# if we've reached this line, we've loaded all the data and just need to make the plots
# as a disclaimer, this plotting code is hacked together from incongruous matplotlib examples so may be perplexing

# initialize plot
matplotlib.rcParams.update({'font.size': 18})
fig = plt.figure()
ax1 = plt.subplot2grid((1, 3), (0, 0), colspan=1)
ax2 = plt.subplot2grid((1, 3), (0, 1), colspan=2)
fig.set_size_inches(12.3, 3.5)
fig.subplots_adjust(wspace=0.05)
plt.tight_layout()
plt.subplots_adjust(left=0.07, right=0.99, bottom=0.32, top=0.74)

# first plot
avg, std = regret_avg_all_genre, regret_std_all_genre
ax1.plot(range(n), avg[0, :], 'r--', label='CascadeLinUCB')
ax1.fill_between(range(n), avg[0, :] - std[0, :], avg[0, :] + std[0, :], color='r', alpha=0.2)
ax1.plot(range(n), avg[1, :], 'g:', label='CascadeWOFUL')
ax1.fill_between(range(n), avg[1, :] - std[1, :], avg[1, :] + std[1, :], color='g', alpha=0.2)
ax1.set(xlabel=r'$n$')
ax1.set(xlim=[0, n])
ax1.set_xticks([0, 2e5, 4e5, 6e5, 8e5, 1e6])
ax1.ticklabel_format(style='sci', axis='x', scilimits=(0, 0))
ax1.set(ylabel='Regret')
ax1.set(ylim=[0, 2.4e4])
ax1.set_yticks([0, 0.6e4, 1.2e4, 1.8e4, 2.4e4])
ax1.ticklabel_format(style='sci', axis='y', scilimits=(0, 0))
handles, labels = ax1.get_legend_handles_labels()
fig.legend(handles, labels, loc='upper left', ncol=2)
ax1.set_title(r'$(d,K) = (20,4)$', fontdict={'fontsize': 18}, loc='right')
ax1.tick_params(labelbottom=True, labeltop=False, labelleft=True, labelright=False,
                bottom=True, top=True, left=True, right=True)

# second plot
avg, std = regret_avg_per_genre[:, n - 1, :], regret_std_per_genre[:, n - 1, :]
x1, x2 = range(0, 3 * num_genre, 3), range(1, 3 * num_genre + 1, 3)
ax2.bar(x1, avg[0, :], yerr=std[0, :], color='none', edgecolor='r', align='center', capsize=2, hatch='/' * 5,
        label='CascadeLinUCB')
ax2.bar(x2, avg[1, :], yerr=std[1, :], color='none', edgecolor='g', align='center', capsize=2, hatch='\\' * 5,
        label='CascadeWOFUL')
ax2.set(xlim=[-1, 3 * num_genre - 1])
ax2.set_xticks((np.array(x1) + np.array(x2)) / 2)
ax2.set_xticklabels(genre_list, rotation=30, ha='right')
ax2.set(ylabel='Regret')
ax2.set(ylim=[0, 3.6e4])
ax2.set_yticks([0, 0.9e4, 1.8e4, 2.7e4, 3.6e4])
ax2.ticklabel_format(style='sci', axis='y', scilimits=(0, 0))
handles, labels = ax2.get_legend_handles_labels()
fig.legend(handles, labels, loc='upper right', ncol=2)
ax2.set_title(r'per genre, $(n,d,K)=(10^6,20,4)$', fontdict={'fontsize': 18}, loc='right')
ax2.tick_params(labelbottom=True, labeltop=False, labelleft=True, labelright=False,
                bottom=True, top=True, left=True, right=True)

# save figure
fig.savefig('plots/real_lin_main.png', dpi=300)
