import numpy as np
import matplotlib.pyplot as plt
plt.rcParams['font.family'] = 'Times New Roman' 

def cal_sum(path):
    data = np.load(path)
    sum_list = []
    for i in range(len(data)):
        data_ = data[i]
        sum_list.append(data_.sum())
    return sum_list

def make_path_list(method,K_list,d,noise,theo_flg,B,project,rep):
    if theo_flg:
        constant_config = 'theo'
    else:
        constant_config = 'onlyT'
    folder_path = ''
    path_list = []
    for K in K_list:
        path_list.append(folder_path+f'results{method}SynSep_K{K}_keys{int(d*10)}_keylow{int(d)}.0_keyhigh{int(d*5)}.0_nbad{int(d*30)}_nbadon{int(d*5)}_labelnoise{round(noise,1)}{constant_config}')
    if method != '_structured_prediciton_bandit':
        for i in  range(len(path_list)):
            path_list[i] = path_list[i] + 'bandit'
    
    if B != 2:
        for i in  range(len(path_list)):
            path_list[i] = path_list[i] + f'diameter{B}'
    if project=='full':
        for i in  range(len(path_list)):
            path_list[i] = path_list[i] + 'project'
    elif project=='half':
        for i in  range(len(path_list)):
            path_list[i] = path_list[i] + 'project_half'
    for i in  range(len(path_list)):
        path_list[i] = path_list[i] + f'rep{rep}.npy'      
    return path_list

def cal_sum_ave_list(path_list):
    ans = []
    for path in path_list:
        sum_ = cal_sum(path)
        sum_ = sum(sum_)/len(sum_)
        ans.append(sum_)
    return ans

def plot_loss(method_list,K_list,d,noise,theo_flg,B,project,rep,save=False):
    ave_list_list = []
    for method in method_list:
        path_list = make_path_list(method,K_list,d,noise,theo_flg,B,project,rep)
        ave_list_list.append(cal_sum_ave_list(path_list))
        
    name_list = ['Our\nAlgorithm', 'Gappletron\nLogistic','Gaptron\nHinge','Gaptron\nLogistic']
    marker_list = ['o', 'v', '^','s']
    fig ,ax = plt.subplots(figsize=(11,7))
    for i in range(len(ave_list_list)):
        ax.plot(K_list, ave_list_list[i], marker=marker_list[i], linestyle='-', label=name_list[i])

    ax.set_xlabel("$K$")
    ax.set_ylabel("Loss")
    ax.grid(True, which="both", linestyle="--", linewidth=0.5)
    ax.legend(fontsize=25)
    ax.set_xticks(K_list, labels=[str(x) for x in K_list])

    plt.tight_layout()
    ax.set_ylim(bottom=0)
    if theo_flg:
        name_theo = 'theo'
    else:
        name_theo = 'nontheo'
    save_path = f'_comparing_regret_noise{noise}_{name_theo}_B{B}_d{d}_project{project}_for_neurips.eps'
    if save:
        plt.savefig(save_path)

plt.rcParams["font.size"] = 32
method = ['_structured_prediciton_bandit', 'Gappletron-adaL2logistic', 'Gaptron-adaL2hinge', 'Gaptron-adaL2logistic']
rep = 20
for label in [True]:
    for B in [10]:
        for noise in [0.0, 0.1]:
            for d in [2,4]:
                plot_loss(method,[3,6,12,24,48,96],d,noise ,label,B,project='no', rep=rep,save=True)