import matplotlib.pyplot as plt
import torch
import numpy as np
import copy

datasets_esj = [torch.load(f'ESJ_shots5_miniimagenet_T30{i+1}.pt', map_location=torch.device('cpu')) for i in range(5)]
datasets_foa = [torch.load(f'FOAminiimagenet{i+1}.pt', map_location=torch.device('cpu')) for i in range(5)]

def calc_mean_std(datasets, key):
    
    data = np.array([ds[key][:85] for ds in datasets])
    mean = np.mean(data, axis=0)
    std = np.std(data, axis=0)
    return mean, std

def calc_mean_std1(datasets, key):
    data = np.array([ds[key][:120] for ds in datasets])
    mean = np.mean(data, axis=0)
    std = np.std(data, axis=0)
    return mean, std

mean_time_esj, std_time_esj = calc_mean_std(datasets_esj, 'timea')
mean_accuracy_esj, std_accuracy_esj = calc_mean_std(datasets_esj, 'testaccuracy')


mean_time_foa, std_time_foa = calc_mean_std1(datasets_foa, 'timea')
mean_accuracy_foa, std_accuracy_foa = calc_mean_std1(datasets_foa, 'testaccuracy')

plt.figure(figsize=(10, 9))


plt.plot(mean_time_foa, mean_accuracy_foa, label='qNBO(BFGS)', color='red', linestyle='-', linewidth=2)
plt.fill_between(mean_time_foa, mean_accuracy_foa - std_accuracy_foa, mean_accuracy_foa + std_accuracy_foa, color='red', alpha=0.2)


plt.plot(mean_time_esj, mean_accuracy_esj, label='PZOBO', color='blue', linestyle='-', linewidth=2)
plt.fill_between(mean_time_esj, mean_accuracy_esj - std_accuracy_esj, mean_accuracy_esj + std_accuracy_esj, color='blue', alpha=0.2)

plt.xlabel('Running time (s)', fontsize=30)
plt.ylabel('Test accuracy (%)', fontsize=30)

plt.xticks(fontsize=20)
plt.yticks(fontsize=20)

plt.grid(visible=True, which='both', linestyle='--', alpha=0.5)


plt.legend(fontsize=30, loc='lower right')

plt.savefig('mini.pdf', dpi=300, bbox_inches='tight')
