import numpy as np
import matplotlib.pyplot as plt
import math
plt.rcParams['font.family'] = 'Times New Roman' 

def cal_sum(path):
    data = np.load(path)
    sum_list = []
    for i in range(len(data)):
        data_ = data[i]
        sum_list.append(data_.sum())
    return sum_list

def make_path_list(method,rep,T,n,d_list,m,q,theo_flg,B,project,path):
    if theo_flg:
        constant_config = 'theo'
    else:
        constant_config = 'onlyT'
    folder_path = path
    path_list = []
    for d in d_list:
        path_list.append(folder_path+f'multilabel_fixed_{method}_results_structured_prediciton_banditmultilabel_data_n{n}_d{d}_m{m}_T{T}_q{q}_{constant_config}')
    
    if B != 2:
        for i in  range(len(path_list)):
            path_list[i] = path_list[i] + f'diameter{B}'
    for i in  range(len(path_list)):
        path_list[i] = path_list[i] + f'rep{rep}'   
        
    if project=='full':
        for i in  range(len(path_list)):
            path_list[i] = path_list[i] + 'project.npy'
    elif project=='half':
        for i in  range(len(path_list)):
            path_list[i] = path_list[i] + 'project_half.npy'   
    else :
        for i in  range(len(path_list)):
            path_list[i] = path_list[i] + '.npy'   
    return path_list

def cal_sum_ave_list(path_list):
    ans = []
    for path in path_list:
        sum_ = cal_sum(path)
        sum_ = sum(sum_)/len(sum_)
        ans.append(sum_)
    return ans

def calculater(d, m):
    s = 0
    for i in range(d):
        k = i+1
        s += (k*2) * math.comb(d-m, k) * math.comb(m, k) / (math.comb(d,m) * d)
    return s

def make_random(d_list, m, T):
    s = []
    for  d in d_list:
        a = calculater(d, m)
        s.append(a * T)
    return s

def plot_loss(rep,T,n,d_list,m,q,theo_flg,B,project,path,save=False):
    ave_list_list = []
    for method in['general', 'self']:
        path_list = make_path_list(method,rep,T,n,d_list,m,q,theo_flg,B,project,path)
        ave_list_list.append(cal_sum_ave_list(path_list))
    print(ave_list_list)
    name_list = ['inverse-weighted estimator', 'pseudo-inverse\nmatrix estimator', 'Random']
    marker_list = ['o', 'v', 's']
    fig ,ax = plt.subplots(figsize=(14,7))
    for i in range(len(ave_list_list)):
        ax.plot(d_list, ave_list_list[i], marker=marker_list[i], linestyle='-', label=name_list[i], markersize=15)

    ax.set_xlabel("$d$")
    ax.set_ylabel("Loss")
    ax.grid(True, which="both", linestyle="--", linewidth=0.5)
    ax.legend(fontsize=25)
    ax.set_xticks(d_list, labels=[str(x) for x in d_list])

    plt.tight_layout()
    ax.set_ylim(0, 5000)
    if theo_flg:
        name_theo = 'theo'
    else:
        name_theo = 'nontheo'
    save_path = f'_comparing_regret_multilabel_{name_theo}_B{B}_n{n}_m{m}_project{project}_for_neurips.eps'
    if save:
        plt.savefig(save_path)


plt.rcParams["font.size"] = 30
T = 10000
path = ''
for label in [False]:
    for B in [50]:
        for m in [5]:
            for q in [0.01]:
                for project in [ 'full']:
                    plot_loss(10,T,50,[10,12,16,20,24], m ,q, label, B, project=project , save=True, path = path)