import matplotlib
import matplotlib.pyplot as plt

matplotlib.rcParams['text.usetex'] = True
matplotlib.rcParams['text.latex.unicode'] = True
plt.rc('text', usetex=True) #Use latex

import torch
import numpy as np
import numpy

from matplotlib.pyplot import MultipleLocator

# plt.rcParams['figure.figsize'] = (7.0, 5.5)

'''
The generalization gap is estimated by the gap between the training error and the test error.

In particular, we save the model which achieves the smallest training error and test it over the test tasks and compute the error as the test error. Hence, the gap is estimated by abs(model['val_loss'][0]-min(model['train_loss'])).
'''
import seaborn as sns
sns.set_theme(style = 'darkgrid')

path='./'
path_loo='./loo_regression_bilevel/'
font = {'size': 20}
matplotlib.rc('font', **font)

task_num=1000
print_task_num=True




if print_task_num:
    l_val=[]
    l_train=[]
    l_gap=[]
    query=1
    shot=5
    name='gap'
    run_time=0
    # num_task=[10,33,100,333,1000,3333,10000]
    num_task=[10,33,100,333,1000,3333,10000]
    # num_task=[10,34,100,334,1000]
    log_num_task=[np.log10(i) for i in num_task]
    for i in num_task:
        model=torch.load(path+'bilevel_sq_task'+str(i)+'_query'+str(query)+'_shot'+str(shot)+'_'+str(run_time)+'/trlog')
        # l_val.append(model['val_loss'][0])
        # l_train.append(min(model['train_loss']))
        # l_gap.append(abs(model['val_loss'][0]-min(model['train_loss'])))
        l_val.append(min(model['val_loss']))
        l_train.append(min(model['train_loss']))
        l_gap.append(abs(min(model['val_loss'])-min(model['train_loss'])))
   
    plt.plot(log_num_task, l_val, linewidth=5,label=r'$er(\mathbf{A}(\mathbf{S}),\tau)$')
    plt.plot(log_num_task, l_val, 'v', ms=14)
    plt.plot(log_num_task, l_train, linewidth=5, label=r'$\hat{er}(\mathbf{A}(\mathbf{S}),\mathbf{S})$')
    plt.plot(log_num_task, l_train, 'v', ms=14)
    plt.plot(log_num_task,l_gap,linewidth=6, linestyle=':',color='black',label=r'$|er-\hat{er}|$')
   
    plt.plot(log_num_task,l_gap,'v',ms=16)
    plt.ylabel('Error', fontsize=24, labelpad = 0)

    y_major_locator=MultipleLocator(0.25)    
    ax=plt.gca()     
    ax.yaxis.set_major_locator(y_major_locator)       
    plt.ylim(-0.08,1.80)    
    plt.tick_params(axis = 'both', labelsize = 12, pad = -4.5)


    plt.xlabel('$\log$\# of tasks',fontsize=24, labelpad=0)
    plt.grid(True)

    if shot==5:
        # plt.tight_layout()
        # plt.legend()
        from matplotlib.font_manager import FontProperties
        fontP = FontProperties()
        fontP.set_size('small')
        # plt.tight_layout()
        if query==1:
            plt.legend(bbox_to_anchor=(0.56, 0.65),prop=fontP,loc="lower left")

        if query==15:
            plt.legend(bbox_to_anchor=(0.56, 0.65),prop=fontP, loc= "lower left")
    if shot==1:
        from matplotlib.font_manager import FontProperties
        fontP = FontProperties()
        fontP.set_size('small')
        plt.tight_layout()
        if query==1:
            plt.legend(bbox_to_anchor=(0.57, 0.6),prop=fontP)
        if query==15:
            plt.legend(bbox_to_anchor=(0.57, 0.57),prop=fontP)
    #plt.savefig('./ours_{}_query{}_shot{}.pdf'.format(name,query,shot),dpi=200)
    plt.show()