import os
import matplotlib.pyplot as plt
import torch
import numpy as np

def get_data(names):
    folder_path = '/path/to/classification_cifar10/curve'
    paths = [os.path.join(folder_path, name) for name in names]
    return {name: torch.load(fp) for name, fp in zip(names, paths)}


# def plot(names, curve_type='train', labels=None, ylim=(80, 101), loc='upper left'):
#     plt.figure()
#     plt.ylim(ylim)  # if curve_type == 'train' else 96)
#     curve_data = get_data(names)
#     for i, label in zip(curve_data.keys(), labels):
#         acc = np.array(curve_data[i]['{}_acc'.format(curve_type.lower())])
#         print(acc.shape)
#         if label == 'AdaBelief':
#             plt.plot(acc, '-', label=label)
#         else:
#             plt.plot(acc, '--', label=label)
#
#     plt.grid()
#     plt.legend(fontsize=14, loc=loc)
#     plt.title('{} accuracy ~ Training epoch'.format(curve_type))
#     plt.xlabel('Training Epoch')
#     plt.ylabel('Accuracy')
#     plt.show()


paths = os.listdir('/path/to/classification_cifar10/curve')
# paths.remove('.DS_Store')
paths.remove('extra_curves')
print(len(paths))
curve_data = get_data(paths)
for ctr, i in enumerate(curve_data.keys()):
    print(paths[ctr])
    acc_train = np.array(curve_data[i]['train_acc'])

    acc_test = np.array(curve_data[i]['test_acc'])

    print('Train shape: ', acc_train.shape, ' Test shape: ', acc_test.shape, ' last: ', acc_test[-1])