from matplotlib import pyplot as plt
import numpy as np
import argparse

n, m = 100, 300

try:
    parser = argparse.ArgumentParser(description='Experiment arguments')
    parser.add_argument('--datainputmodel', '-dim', type=str, default='markov')
    args = parser.parse_args()
    data_input_model = args.datainputmodel
except:
    print('not parsing command line inputs. use given parameters.')
    data_input_model = 'periodic'

u_hseq_all = []
u_ave_baseline_all = []
inf_norm_to_beta_hseq_all = []
ave_one_norm_to_beta_hseq_all = []
inf_norm_to_u_hseq_all = []
ave_one_norm_to_u_hseq_all = []
max_rel_buyer_regret_all = []
ave_rel_buyer_regret_all = []
inf_norm_u_ave_baseline_all = []

experiment_name = 'movielens_{}_{}_{}'.format(n, m, data_input_model.lower().replace(' ', '_'))

for sample_path_idx in range(10):

    (
        u_hseq, B, 
        inf_norm_u_ave_baseline,
        inf_norm_to_beta_hseq, #ave_one_norm_to_beta_hseq, 
        inf_norm_to_u_hseq, #ave_one_norm_to_u_hseq, 
        inf_norm_to_B, #ave_one_norm_to_B, 
        max_rel_buyer_regret,# ave_rel_buyer_regret,
    ) = np.load(
        'results/{}_all_logs_{}.npz'.format(experiment_name, sample_path_idx)
    ).values()

    u_hseq_all.append(u_hseq)
    inf_norm_u_ave_baseline_all.append(inf_norm_u_ave_baseline)
    inf_norm_to_beta_hseq_all.append(inf_norm_to_beta_hseq)
    # ave_one_norm_to_beta_hseq_all.append(ave_one_norm_to_beta_hseq)
    inf_norm_to_u_hseq_all.append(inf_norm_to_u_hseq)
    # ave_one_norm_to_u_hseq_all.append(ave_one_norm_to_u_hseq)
    max_rel_buyer_regret_all.append(max_rel_buyer_regret)
    # ave_rel_buyer_regret_all.append(ave_rel_buyer_regret)

u_hseq = np.array(u_hseq)
inf_norm_u_ave_baseline_all = np.array(inf_norm_u_ave_baseline_all)
inf_norm_to_beta_hseq_all = np.array(inf_norm_to_beta_hseq_all)
# ave_one_norm_to_beta_hseq_all = np.array(ave_one_norm_to_beta_hseq_all)
inf_norm_to_u_hseq_all = np.array(inf_norm_to_u_hseq_all)
# ave_one_norm_to_u_hseq_all = np.array(ave_one_norm_to_u_hseq_all)
max_rel_buyer_regret_all = np.array(max_rel_buyer_regret_all)
# ave_rel_buyer_regret_all = np.array(ave_rel_buyer_regret_all)

T = len(inf_norm_to_beta_hseq)
fig = plt.figure(figsize=(6, 4))
t0 = int(T//50)
skip_size = max(int(T//2000), 5)
num_dp = (T - t0) // skip_size

import seaborn as sns
plt.clf()
sns.set_theme()

all_data_arrays = (
    inf_norm_to_beta_hseq_all, 
    inf_norm_to_u_hseq_all, 
    inf_norm_u_ave_baseline_all
    
    # ave_one_norm_to_beta_hseq_all, 
    # ave_one_norm_to_u_hseq_all
    # max_rel_buyer_regret_all
)

all_labels = (
    r'max$_i$ $|\beta_i^t - \beta^{\rm HS}_i|/\beta^{\rm HS}_i$', 
    r'max$_i$ $|\bar{u}_i^t - u^{\rm HS}_i|/u^{\rm HS}_i$', 
    r'max$_i$ $|\bar{u}^{\rm baseline}_i - u^{\rm HS}_i|/u^{\rm HS}_i$', 
)

def get_title_text(data_input_model):
    if data_input_model == 'markov': return 'Markov'
    if data_input_model == 'iid': return 'i.i.d.'
    if data_input_model == 'periodic': return 'Periodic'
    if data_input_model == 'mild': return 'Mild Corruption'

for data_array, label in zip(all_data_arrays, all_labels):
    plt.errorbar(
        np.arange(t0+1, T+1, skip_size), 
        np.mean(data_array[:, np.arange(t0, T, skip_size)], axis=0), 
        (1/np.sqrt(10)) * np.std(data_array[:, np.arange(t0, T, skip_size)], axis=0), 
        errorevery=num_dp//10,
        linestyle='solid', 
        label=label, 
    )
# horizontal bars
[plt.axvline(pt, linewidth=1.0, linestyle = 'dotted') for pt in range(t0, T+1) if pt % (n*20) == 0]
plt.xticks(range(0, T+1, T//5))
plt.xlabel('t')
plt.legend(prop={'size': 12}, loc='center right')
plt.title(get_title_text(data_input_model))
plt.savefig(f'../plots/{experiment_name}.pdf', bbox_inches='tight')
plt.show()