"""
get image times and accuracies
"""


import numpy as np
import matplotlib.pyplot as plt

import os
import os.path as osp


experiment = 'cifar10'
size = 500



# now get the ablation values
def get_values(dataset, size, method):
    folder = 'results/'+dataset+'/16/'+str(size)+'/'+method
    total_times = []
    total_mse = []
    for ex in range(1, 3+1):
        times = np.load(osp.join(folder, str(ex), 'epoch_times.npy'))
        val_metric = np.load(osp.join(folder, str(ex), 'epoch_val_metric_history.npy'))
        times =  np.cumsum(times)
        total_times.append(times[-1])
        total_mse.append(val_metric[-1])
    total_times = np.array(total_times)
    print(total_times)
    total_mse = np.array(total_mse)
    mean_time = np.mean(total_times)
    std_time = np.std(total_times)
    mean_mse = np.mean(total_mse)
    std_mse = np.std(total_mse)
    print(method)
    print('Time: {} +- {}'.format(mean_time, std_time))
    print('Accuracy: {} +- {}'.format(mean_mse, std_mse))
    print('\n')




get_values(experiment, size, 'adjoint_gq')
get_values(experiment, size, 'adjoint_ode')
get_values(experiment, size, 'adjoint_seminorm')