import numpy as np
import matplotlib.pyplot as plt

data1 = np.load(r"Bilevel_exp\news_exp\AID_bs_5657_vbs_5657_olrmu_100.0_0.0_ilrmu_95.0_0.0_eta_0.5_T_10_hessianq_10\results.npy")
data3 = np.load(r"Bilevel_exp\news_exp\BA-CG_bs_5657_vbs_5657_olrmu_100.0_0.0_ilrmu_95.0_0.0_eta_0.5_T_10_hessianq_10\results.npy")
data4 = np.load(r"Bilevel_exp\news_exp\ITD_bs_5657_vbs_5657_olrmu_100.0_0.0_ilrmu_95.0_0.0_eta_0.5_T_10_hessianq_10\results.npy")
data6 = np.load(r"Bilevel_exp\news_exp\RAHGD_bs_5657_vbs_5657_olrmu_100.0_0.0_ilrmu_95.0_0.0_eta_0.5_T_10_hessianq_10\results.npy")
data18 = np.load(r"Bilevel_exp\news_exp\RAGD-GS_bs_5657_vbs_5657_olrmu_100.0_0.0_ilrmu_95.0_0.0_eta_0.5_T_10_hessianq_10\results.npy")
data19 = np.load(r"Bilevel_exp\news_exp\F2BA_bs_5657_vbs_5657_olrmu_100.0_0.0_ilrmu_95.0_0.0_eta_0.5_T_10_hessianq_10\results.npy")



plt.rcParams['figure.figsize'] = (8.0, 6.0)
plt.plot(data18[:26, 2], data18[:26, 1], 'k-.', label='RAGD-GS')
plt.plot(data1[:43, 2], data1[:43, 1], 'g-x', label='AID-BiO')
plt.plot(data3[:40, 2], data3[:40, 1], 'y-*', label='BA-CG')
plt.plot(data4[:45, 2], data4[:45, 1], 'c-d', label='ITD-BiO')
plt.plot(data6[:42, 2], data6[:42, 1], 'm--p', label='RAHGD')
plt.plot(data19[:26, 2], data19[:26, 1], 'b-o', label='F2BA')
plt.xlabel('running time  (s)', fontsize=20)
plt.ylabel('Test Accuracy', fontsize=20)
plt.xlim(0, 70)
plt.legend(fontsize=18)
plt.tick_params(labelsize=15)
plt.grid(True)
plt.savefig('accu_time.eps', format='eps')
plt.show(block=True)


plt.plot(data18[:26, 3], data18[:26, 0], 'k-.', label='RAGD-GS')
plt.plot(data1[:23, 3], data1[:23, 0], 'g-x', label='AID-BiO')
plt.plot(data3[:32, 3], data3[:32, 0], 'y-*', label='BA-CG')
plt.plot(data4[:41, 3], data4[:41, 0], 'c-d', label='ITD-BiO')
plt.plot(data6[:42, 3], data6[:42, 0], 'm--p', label='RAHGD')
plt.plot(data19[:26, 3], data19[:26, 0], 'b-o', label='F2BA')

plt.xlabel('# oracle calls', fontsize=20)
plt.ylabel('Test Loss', fontsize=20)
plt.xlim(0, 2.5e6)
plt.legend(fontsize=18)
plt.tick_params(labelsize=15)
plt.grid(True)
plt.savefig('loss_oracle.eps', format='eps')
plt.show(block=True)



plt.plot(data18[:26, 3], data18[:26, 1], 'k-.', label='RAGD-GS')
plt.plot(data1[:23, 3], data1[:23, 1],  'g-x', label='AID-BiO')
plt.plot(data3[:32, 3], data3[:32, 1], 'y-*', label='BA-CG')
plt.plot(data4[:41, 3], data4[:41, 1], 'c-d', label='ITD-BiO')
plt.plot(data6[:42, 3], data6[:42, 1], 'm--p', label='RAHGD')
plt.plot(data19[:26, 3], data19[:26, 1], 'b-o', label='F2BA')

plt.xlabel('# oracle calls', fontsize=20)
plt.ylabel('Test Accuracy', fontsize=20)
plt.xlim(0, 2.5e6)
plt.legend(fontsize=18)
plt.tick_params(labelsize=15)
plt.grid(True)
plt.savefig('accu_oracle.eps', format='eps')
plt.show(block=True)


plt.plot(data18[:26, 2], data18[:26, 0], 'k-.', label='RAGD-GS')
plt.plot(data1[:42, 2], data1[:42, 0],  'g-x', label='AID-BiO')
plt.plot(data3[:40, 2], data3[:40, 0], 'y-*', label='BA-CG')
plt.plot(data4[:45, 2], data4[:45, 0], 'c-d', label='ITD-BiO')
plt.plot(data6[:43, 2], data6[:43, 0], 'm--p', label='RAHGD')
plt.plot(data19[:26, 2], data19[:26, 0], 'b-o', label='F2BA')

plt.xlabel('running time  (s)', fontsize=20)
plt.ylabel('Test Loss', fontsize=20)
plt.xlim(0, 70)
plt.legend(fontsize=18)
plt.tick_params(labelsize=15)
plt.grid(True)
plt.savefig('loss_time.eps', format='eps')
plt.show(block=True)