import numpy as np
import matplotlib.pyplot as plt
plt.rcParams['font.family'] = 'Times New Roman' 

def cal_sum(path):
    data = np.load(path)
    error_rate_list = []
    for i in range(len(data)):
        data_ = data[i][40000:]
        l = len(data_)
        error_rate = data_.sum()/l
        error_rate_list.append(error_rate)
    return error_rate_list

def make_path_list(theo_flg):
    if theo_flg:
        constant_config = 'theo'
    else:
        constant_config = 'onlyT'
    folder_path = ''
    path_list = []
    path_list.append(folder_path+f'results_structured_prediciton_banditMNIST{constant_config}diameter10rep20.npy')
    path_list.append(folder_path+f'resultsGappletron-adaL2logisticMNIST{constant_config}banditdiameter10rep20.npy')
    path_list.append(folder_path+f'resultsGaptron-adaL2hingeMNIST{constant_config}banditdiameter10rep20.npy')
    path_list.append(folder_path+f'resultsGaptron-adaL2logisticMNIST{constant_config}banditdiameter10rep20.npy')
    return path_list

def make_box(theo_flg,save=False):
    path_list = make_path_list(theo_flg)
    name_list = ['Our\nAlgorithm', 'Gappletron\nLogistic','Gaptron\nHinge','Gaptron\nLogistic']
    if theo_flg:
        name_theo = 'theo'
    else:
        name_theo = 'nontheo'
    fig, ax = plt.subplots(figsize=(14,10))
    error_rate_list_list = []
    for i in range(len(path_list)):
        error_rate_list_list.append(cal_sum(path_list[i]))
    ax.boxplot(error_rate_list_list, labels=name_list)
    ax.set_ylim(bottom=0)
    ax.set_ylabel('Error Rate')
    ax.grid(True)
    plt.tight_layout()
    if save:
        save_path = f'_experiment_box_MNIST{name_theo}_B10_rep20_for_neurips.eps'
        plt.savefig(save_path)
        

plt.rcParams["font.size"] = 45
make_box(True, True)