"""
get image times and accuracies
"""


import numpy as np
import matplotlib.pyplot as plt

import os
import os.path as osp




# now get the ablation values
def get_values(i):
    folder = 'results/nested_spheres/gtol'+str(i)+'/1500/adjoint_gq'
    total_times = []
    total_mse = []
    for ex in range(1, 5+1):
        times = np.load(osp.join(folder, str(ex), 'epoch_times.npy'))
        val_metric = np.load(osp.join(folder, str(ex), 'epoch_val_metric_history.npy'))
        times =  np.cumsum(times)
        total_times.append(times[-1])
        total_mse.append(val_metric[-1])
    total_times = np.array(total_times)
    total_mse = np.array(total_mse)
    mean_time = np.mean(total_times)
    std_time = np.std(total_times)
    mean_mse = np.mean(total_mse)
    std_mse = np.std(total_mse)
    print(i)
    print('Time: {} +- {}'.format(mean_time, std_time))
    print('MSE: {} +- {}'.format(mean_mse, std_mse))
    print('\n')
    
gtols = ['0.1', '0.003', '1e-06', '1e-09']
for i in range(4):
    get_values(gtols[i])