# this file plots the real data results from the appendix (figure 3)
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
from os.path import exists
import itertools

# load results for the per genre test
movielens_file = 'data/movies.npz'
if not exists(movielens_file):
    # try to load movielens file, printing error message and exiting if it doesn't exist
    print('error -- file ' + movielens_file + ' does not exist -- see README')
    exit()
genre_list = np.load(movielens_file)['arr_1']
num_genre = len(genre_list)
rel_reg_avg = np.zeros(num_genre)  # average of the regret of our alg relative to existing
rel_reg_std = np.zeros(num_genre)  # standard deviation of the regret of our alg relative to existing
stats_avg = np.zeros((2, num_genre))  # average of the other statistics needed in the plot
stats_std = np.zeros((2, num_genre))  # standard deviation of the other statistics needed in the plot
for genre_idx in range(num_genre):
    # try to load current genre file, printing error message and exiting if it doesn't exist
    result_file = 'results/real_lin_' + genre_list[genre_idx] + '.npz'
    if not exists(result_file):
        print('error -- file ' + result_file + ' does not exist -- see README')
        exit()
    regret = np.load('results/real_lin_' + genre_list[genre_idx] + '.npz')['arr_0']
    statistics = np.load('results/real_lin_' + genre_list[genre_idx] + '.npz')['arr_1']
    rel_reg_avg[genre_idx] = np.mean(np.divide(regret[1, -1, :], regret[0, -1, :]))
    rel_reg_std[genre_idx] = np.std(np.divide(regret[1, -1, :], regret[0, -1, :]))
    stats_avg[0, genre_idx] = np.mean(statistics[2, :])
    stats_std[0, genre_idx] = np.std(statistics[2, :])
    stats_avg[1, genre_idx] = np.mean(statistics[0, :])
    stats_std[1, genre_idx] = np.std(statistics[0, :])

# if we've reached this line, we've loaded all the data and just need to make the plots
# as a disclaimer, this plotting code is hacked together from incongruous matplotlib examples so may be perplexing

# initialize figure
matplotlib.rcParams.update({'font.size': 18})
fig, ax = plt.subplots(1, 2)
fig.set_size_inches(11, 6)
plt.subplots_adjust(left=0.09, right=0.98, bottom=0.12, top=0.52)
marker_shape = ['o', 's', 'D', 'v', '^', '<', '>', 'P', 'X']
marker_fill = ['full', 'none']
all_markers = list(itertools.product(*[marker_fill, marker_shape]))

# plot data
for genre_idx in range(num_genre):
    ax[0].errorbar(stats_avg[0, genre_idx], rel_reg_avg[genre_idx], xerr=stats_std[0, genre_idx],
                   yerr=rel_reg_std[genre_idx], markersize=8, color='k', marker=all_markers[genre_idx][1],
                   fillstyle=all_markers[genre_idx][0], linestyle='none', label=genre_list[genre_idx])
    ax[1].errorbar(stats_avg[0, genre_idx], stats_avg[1, genre_idx], xerr=stats_std[0, genre_idx],
                   yerr=stats_std[1, genre_idx], markersize=8, color='k', marker=all_markers[genre_idx][1],
                   fillstyle=all_markers[genre_idx][0], linestyle='none', label=genre_list[genre_idx])

# format the first plot
ax[0].set(xlabel='Click probability for greedy')
ax[0].set(xlim=[0.35, 0.85])
ax[0].set_xticks([0.4, 0.5, 0.6, 0.7, 0.8])
ax[0].set(ylabel='Relative regret')
ax[0].set(ylim=[0.5, 0.75])
ax[0].set_yticks([0.5, 0.55, 0.6, 0.65, 0.7, 0.75])
ax[0].set_title(r'$(n,d,K)=(10^6,20,4)$', fontdict={'fontsize': 18}, loc='right')
ax[0].tick_params(labelbottom=True, labeltop=False, labelleft=True, labelright=False,
                  bottom=True, top=True, left=True, right=True)

# format the second plot
ax[1].set(xlabel='Click probability for greedy')
ax[1].set(xlim=[0.35, 0.85])
ax[1].set_xticks([0.4, 0.5, 0.6, 0.7, 0.8])
ax[1].set(ylabel='Number movies')
ax[1].set(ylim=[0, 18e2])
ax[1].set_yticks([0, 3e2, 6e2, 9e2, 12e2, 15e2, 18e2])
ax[1].ticklabel_format(style='sci', axis='y', scilimits=(0, 0))
ax[1].set_title(r'$(n,d,K)=(10^6,20,4)$', fontdict={'fontsize': 18}, loc='right')
ax[1].tick_params(labelbottom=True, labeltop=False, labelleft=True, labelright=False,
                  bottom=True, top=True, left=True, right=True)

# add legend and save figure
handles, labels = ax[0].get_legend_handles_labels()
fig.legend(handles, labels, loc='upper center', ncol=3)
fig.savefig('plots/real_lin_app.png', dpi=300)
