import matplotlib
import matplotlib.pyplot as plt
from matplotlib.pyplot import MultipleLocator
plt.rc('font',family='Times New Roman')
plt.switch_backend('agg')
import numpy as np
import numpy
import seaborn as sns
sns.set_theme(style = 'darkgrid')

from matplotlib.font_manager import FontProperties

matplotlib.rcParams['text.usetex'] = True
# matplotlib.rcParams['text.latex.unicode'] = True
plt.rc('text', usetex=True) #Use latex

'''
Task Averaged Regret with respect to the Sample Size per Task over Omniglot Dataset
'''

font = {'size': 14}
matplotlib.rc('font', **font)

num_task = 10
tasks = np.array(list(range(1,num_task+1)))

T_1_regret    = np.array([0.02668, 0.02002, 0.0163, 0.01499, 0.01318, 0.01215, 0.01196, 0.01126, 0.00982, 0.00936])
T_2_regret     = np.array([0.02535, 0.01989, 0.0158, 0.01461, 0.01303, 0.01191, 0.01114, 0.01029, 0.00968, 0.00912])
T_3_regret    = np.array([0.0244, 0.01895, 0.01561, 0.01411, 0.0127, 0.01109, 0.01083, 0.00992, 0.00927, 0.00899])


T_1_regret_plot, = plt.plot(tasks, T_1_regret,  color='red', linewidth=2)
T_1_regret_plot_circle, = plt.plot(tasks, T_1_regret, 'o', color='red') 

T_2_regret_plot, = plt.plot(tasks, T_2_regret,  color='green', linewidth=2, label='1-shot, multi-task')
T_2_regret_plot_circle, = plt.plot(tasks, T_2_regret, 'o',color='green')

T_3_regret_plot, = plt.plot(tasks, T_3_regret, color='blue', linewidth=2)
T_3_regret_plot_circle, = plt.plot(tasks, T_3_regret, 'o', color='blue')

plt.ylabel('Task-Averaged Regret ' r'$\frac{\bar{R}_{T,m}}{m}$', fontsize=14, labelpad=2)

plt.xlabel('Number of Samples per Task ' r'$m$',fontsize=14,  labelpad=2)

plt.grid(True)
plt.xlim(0.5, num_task + 0.5)

x = range(1, num_task+1)
plt.xticks(x, [i * 5 for i in list(x)])

plt.ylim(0.008,0.028)

fontP = FontProperties()
fontP.set_size('medium')
plt.tight_layout()

# x_major_locator=MultipleLocator(1)
y_major_locator = MultipleLocator(0.002)

ax=plt.gca()
# ax.xaxis.set_major_locator(x_major_locator)
ax.yaxis.set_major_locator(y_major_locator)

plt.legend([(T_1_regret_plot, T_1_regret_plot_circle ), (T_2_regret_plot, T_2_regret_plot_circle),
 (T_3_regret_plot,T_3_regret_plot_circle)], [r"$T=1$", r"$T=5$", r"$T=10$"], loc='upper right', bbox_to_anchor=(19/20, 20/20), fontsize=14, prop=fontP)

plt.savefig('./omniglot_regret.pdf',dpi=200)
plt.show()