import torch
import numpy as np

time = torch.load('output_time/interval_16.pth')

avg_batch_time = np.sum(time['train'])/3/391
print(avg_batch_time)

avg_sim_time = np.mean(time['submod_sim'])
print(avg_sim_time)

print(time['submod_order'])
avg_order_time = np.mean(time['submod_order'], axis=0)
print(avg_order_time)