from __future__ import division, print_function
import os, sys, time, logging, json
import argparse
from argparse import Namespace
import torch, numpy as np
from torchvision import transforms, datasets
import math
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from torch.utils.data import TensorDataset
from shutil import copyfile

from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import StratifiedKFold
from models import *

from functions import *
from my_datasets import TripletDataset

bool_func = lambda x: x in ['True', 'true']

parser = argparse.ArgumentParser(description='Train Bayesian Neural Net on MNIST with Variational Inference')
parser.add_argument("--config_file", type=str)
app_args = parser.parse_args()

parser_params = argparse.ArgumentParser()
# MODEL
parser_params.add_argument('--model', type=str, nargs='?', action='store', default='Gaussian_prior', help='Model to run. Options are \'Gaussian_prior\', \'Laplace_prior\', \'GMM_prior\'.')
parser_params.add_argument('--prior_sig', type=float, nargs='?', action='store', default=0.5, help='Standard deviation of prior. Default: 0.1.')
parser_params.add_argument('--n_samples', type=float, nargs='?', action='store', default=3, help='How many MC samples to take when approximating the ELBO. Default: 3.')
parser_params.add_argument('--embed_size', type=int, default=100, help='Embedding dimension. Default: \'64\'.')
parser_params.add_argument('--hidden_sizes', type=str, default="[512,512]", help='Hidden dimension of neural network. Default: \'[1200]\'.')
parser_params.add_argument('--reweight_constant', type=float, default=0.1)
parser_params.add_argument('--use_pred_loss', type=bool_func, default=False)
parser_params.add_argument('--KNN_test', type=int, default=3, help='Number of nearest neighbours for determining lable for a datapoint. Default: \'3\'.')
parser_params.add_argument('--loss_func', type=str, nargs='?', action='store', default='LMNN', help='Choose likelihood function p(D|w). Options are \'Gaussian_approx\', \'LMNN\'. Default: \'BBP_results\'.')
parser_params.add_argument('--ensemble', type=int, default=0, help='Use ensemble at test phase or not')
parser_params.add_argument('--margin', type=float, default=1.0, help='Margin for LMNN likelihood')

# DATA
parser_params.add_argument('--datafile', type=str, default='cifar10_featues_SimCLR-0.81.dat')
parser_params.add_argument('--samples_per_class', type=int, default=500, help='Default: \'50\'.')
parser_params.add_argument('--noise_rates', type=str, default='[0.3]', help='Default: [0.3]')
parser_params.add_argument('--noise_type', type=str, default='symmetric', help='Default: symmetric')
parser_params.add_argument('--use_LMNN_triplets', type=bool_func, default=False, help='Default: \'False\'.')
parser_params.add_argument('--use_PCA', type=bool_func, default=True, help='Use PCA to reduce the feature size. Default: \'False\'.')
parser_params.add_argument('--PCA_dim', type=int, default=100, help='Default: \'100\'.')
parser_params.add_argument('--KNN_triplet', type=int, default=11, help='Number of nearest neighbour to form triplets, bigger helps. Default: \'1200\'.')
parser_params.add_argument('--balance_noise', type=bool_func, default=True)
parser_params.add_argument('--standardize', type=bool_func, default=True)
parser_params.add_argument('--preprocessed_data', type=str, default=None)
# TRAINING
parser_params.add_argument('--randomize', type=bool_func, default=False, help='Default: \'False\'.')
parser_params.add_argument('--lr_milestones', type=str, default='[10, 20, 40, 50, 80]')
parser_params.add_argument('--lr_gamma', type=float, default=0.1)
parser_params.add_argument('--log_lag', type=int, default=5)
parser_params.add_argument('--use_pretrain', type=str, default=None, help='Path to pretrained model. Default: \'None\'.')
parser_params.add_argument('--batch_size', type=int, default=128, help='Default: \'32\'.')
parser_params.add_argument('--epochs', type=int, nargs='?', action='store', default=50, help='How many epochs to train. Default: 200.')
parser_params.add_argument('--lr', type=float, nargs='?', action='store', default=1e-4, help='learning rate. Default: 1e-3.')
parser_params.add_argument('--cross_validation', type=bool_func, nargs='?', default=True, help='learning rate. Default: 1e-3.')

# LOG
parser_params.add_argument('--log_name', type=str, default='BBP_LMNN', help='Default: \'BBP_LMNN\'.')
parser_params.add_argument('--runtime', type=str, default='0', help='Default: \'0\'.')
parser_params.add_argument('--verbose', type=bool_func, default=False, help='Print values each minibatch. Default: \'False\'.')
parser_params.add_argument('--results_dir', type=str, nargs='?', action='store', default='BBP_results', help='Where to save learnt training plots. Default: \'BBP_results\'.')
parser_params.add_argument('--plot_embed', type=bool_func, default=False, help='Default: \'False\'.')


### Parse argument
with open(app_args.config_file, 'r') as f:
    param_text = ""
    for line in f:
        line = line.strip()
        if len(line) <= 1: continue
        if line[0] == '[' and line[-1] == ']': continue
        param_text += '--' + "\t".join(line.strip().split("=")) + '\t'
    args = parser_params.parse_args(param_text[:-1].split("\t"))

results_dir = args.results_dir # Where to save plots and error, accuracy vectors
nb_epochs = args.epochs
mkdir(results_dir)
savefolder = os.path.join(results_dir, args.log_name)
mkdir(str(savefolder))
imgs_save_folder = os.path.join(savefolder, "{}_images".format(args.runtime))
mkdir(imgs_save_folder)

config_file = os.path.join(savefolder, args.runtime + '.ini')
if not os.path.exists(config_file):
    copyfile(app_args.config_file, config_file)

if not args.randomize:
    manualSeed = 11
    np.random.seed(manualSeed)
    torch.manual_seed(manualSeed)
    # if you are suing GPU
    if torch.cuda.is_available():
        torch.cuda.manual_seed(manualSeed)
        torch.cuda.manual_seed_all(manualSeed)
        torch.backends.cudnn.enabled = False 
        torch.backends.cudnn.benchmark = False
        torch.backends.cudnn.deterministic = True

flogger = get_logger(args.runtime, savefolder)
flogger.info("Param: {}".format(json.dumps(args.__dict__, indent=2, sort_keys=True)))

use_cuda = torch.cuda.is_available()
if use_cuda:
    cprint('g', '\nUse CUDA')

###===============###
### Configuration ###
###===============###

batch_size = args.batch_size
noiserates = json.loads(args.noise_rates)
embed_size = args.embed_size
hidden_sizes = json.loads(args.hidden_sizes)
KNN_triplet = args.KNN_triplet # #must-links of each points (default: 21, bigger helps)
dim = args.PCA_dim
runtime = args.runtime
nsamples = int(args.n_samples) 
preprocessed_data = args.preprocessed_data

if args.loss_func == 'LMNN':
    loss_func = log_large_margin_loss(args.margin)
elif args.loss_func == 'Gaussian_approx':
    loss_func = log_gaussian_approx_large_margin_loss(args.margin)
else:
    print('Invalid loss function')
    exit(1)

###==================================================###
### Stratified sampling and Injecting Noise to label ###
###==================================================###

def prepare_data(args):
    datafile = args.datafile
    data = torch.load(datafile)
    trX, trY, teX, teY = data['trX'], data['trY'], data['teX'], data['teY']
    if 'trY_gt' in data:
        trY_gt = data['trY_gt']
    else:
        trY_gt = trY.copy()
    del data

    num_classes = len(set(trY))
    #### Random sample small subset of data for training
    perN = np.ones((num_classes), dtype='int') * args.samples_per_class # #per class
    selected_indices = randomStratifiedSampleData(trY, perN)
    trX = trX[selected_indices, :].copy()
    trY = trY[selected_indices].copy()
    trY_gt = trY_gt[selected_indices].copy()

    return (trX, trY, trY_gt), (teX, teY), num_classes

def split_train_val(trX, trY, val_ratio=0.1):
    ###### subsample validation set
    all_indices = set([i for i in range(len(trX))])
    val_indices = randomStratifiedSampleData(trY)
    train_indices = all_indices - set(val_indices)
    valX = trX[val_indices, :].copy()
    valY = trY[val_indices].copy()
    trX = trX[train_indices, :].copy()
    trY = trY[train_indices].copy()
    trY_gt = trY.copy()
    return (trX, trY, trY_gt), (valX, valY)
    
#### Generate random noise
def inject_noise(trY_gt, noiserates):
    trY = trY_gt.copy()
    if len(noiserates) == 1:
        message = "injecting symmetric noise ..."
        print(message), flogger.info(message)
        trY = generateRandomLabelNoise2(trY_gt, noiserates[0])
    elif len(noiserates) > 1:
        message = "Asymmetric noises: {}".format(noiserates)
        print(message), flogger.info(message)
        trY = generateImBalancedRandomLabelNoise2(trY_gt, noiserates)
    return trY

def feature_engineering(trX, trY, valX, valY, teX, teY, args):
    #### Dimensionality reduction with PCA
    [_, acc0, _, _] = KNNtest(args.KNN_test, trX, trY, teX, teY)
    mess = 'Initial Test performance (acc%%): %.2f\n' % (acc0 * 100)
    print(mess), flogger.info(mess)
    pca_acc0 = 0.
    if trX.shape[1] > dim:
        print('Dimensionality reduction with PCA...\n')
        trX_pca, valX_pca, teX_pca = Wrapper_PCA(trX, valX, teX, dim)
        [_, pca_acc0, _, _] = KNNtest(args.KNN_test, trX_pca, trY, teX_pca, teY)
        mess = 'PCA Test performance (acc%%): %.2f\n' % (pca_acc0 * 100)
        print(mess), flogger.info(mess)

    if args.use_PCA:
        s0trX, s0valX, s0teX = trX_pca, valX_pca, teX_pca
    else:
        s0trX, s0valX, s0teX = trX, valX, teX

    if args.standardize:
        print("rescale data ...")
        # Standardize input values to N(0,1)
        mean = s0trX.mean(axis=0, keepdims=True)
        std = s0trX.std(axis=0, keepdims=True)
        std[std == 0] = 1

        s0trX = (s0trX - mean)/std
        s0teX = (s0teX - mean)/std
        if s0valX is not None:
            s0valX = (s0valX - mean)/std
        
        [_, acc0_rescale, _, _] = KNNtest(args.KNN_test, s0trX, trY, s0teX, teY)
        mess = 'Initial Test performance rescale (acc%%): %.2f\n' % (acc0_rescale * 100)
        print(mess), flogger.info(mess)

    return s0trX, s0valX, s0teX, (acc0, pca_acc0, acc0_rescale)

###====================###
### Train Model        ###
###====================###
class EearlyStopping:
    def __init__(self, net, patience=10, criterion='max', path=None):
        self.patience = patience
        self.cnt_patience = -1
        self.criterion = criterion
        if self.criterion == 'max':
            self.best_perf = -np.inf
        elif self.criterion == 'min':
            self.best_perf = np.inf
        self.num_epoch = 0
        self.path = path
        self.net = net
    
    def step(self, curr_perf):
        self.num_epoch += 1
        if self.criterion == 'max':
            if curr_perf > self.best_perf:
                self.cnt_patience = 0
                self.best_epoch = self.num_epoch
                self.best_perf = curr_perf
                if self.path is not None: self.net.save(self.path)
            else:
                self.cnt_patience += 1
        elif self.criterion == 'min':
            if curr_perf < self.best_perf:
                self.cnt_patience = 0
                self.best_epoch = self.num_epoch
                self.best_perf = curr_perf
                if self.path is not None: self.net.save(self.path)
            else:
                self.cnt_patience += 1
        if self.cnt_patience == self.patience:
            return True
        return False

def train(net, s0trX=None, s0valX=None, s0teX=None, 
        triplet_train_loader=None, triplet_val_loader=None, triplet_test_loader=None,
        train_loader=None, val_loader=None, test_loader=None, checkpointer=None, args=None):
    epoch = (net.epoch + 1) if net.epoch > 0 else 0
    cprint('c', '\nTrain:')
    flogger.info("Train: ")
    print('  init cost variables:')
    nb_epochs, KNN_te = args.epochs, args.KNN_test

    kl_train = np.zeros(nb_epochs)
    nll_triplet_train = np.zeros(nb_epochs)
    nll_pred_train = np.zeros(nb_epochs)

    err_val = np.zeros(nb_epochs)
    best_acc = -np.inf

    imgs_train, imgs_eval = [], []
    
    train_accs, eval_accs = [], []
    
    Ntriplets = len(triplet_train_loader.dataset)
    for i in range(nb_epochs):
        # We draw more samples on the first epoch in order to ensure convergence
        if i+epoch == 0:
            ELBO_samples = 10
        else:
            ELBO_samples = nsamples

        net.set_mode_train(True)
        tic = time.time()
        nb_samples = 0
        bad_triplets = 0

        for cnt_batch, (xijl_batch, targets) in enumerate(triplet_train_loader):
            info = {"n_samples": ELBO_samples, "cnt_batch": cnt_batch, "epoch": i, 'dataset_size': Ntriplets}
            if args.use_pred_loss:
                kldiv_batch, nll_batch, pred_loss, bad_triplets_batch = net.train_batch(xijl_batch, targets, loss_func=loss_func, **info)
            else:
                kldiv_batch, nll_batch, bad_triplets_batch = net.train_batch(xijl_batch, targets, loss_func=loss_func, **info)

            bad_triplets += bad_triplets_batch

            kl_train[i] += kldiv_batch
            nll_triplet_train[i] += nll_batch

            nb_samples += len(xijl_batch[0])
            if args.verbose:
                print('\rbatch {}/{}: dlk: {}, pred: {} '.format(cnt_batch, nb_batches, kldiv_batch, nll_batch), end="\n")
            else:
                print('\rbatch {}/{} '.format(cnt_batch, nb_batches), end="")

        kl_train[i] /= nb_samples  # Normalise by number of samples in order to get comparable number to the -log like
        nll_triplet_train[i] /= nb_samples

        badtriplets_str = "badtriplets: Train {:.4f}".format(bad_triplets/nb_samples)
        toc = time.time()
        net.epoch = i + epoch

        # ---- validation
        eval_loader = val_loader
        eval_triplet_loader = triplet_val_loader
        evalX, evalY = s0valX, valY
        
        if eval_loader is not None and eval_triplet_loader is not None:
            eval_bad_triplets = 0
            nb_samples = 0
            net.set_mode_train(False)
            if args.use_pred_loss:
                for x_batch, y_batch in eval_loader:
                    pred_loss, err_pred = net.eval_pred(x_batch, y_batch)
                    err_val[i] += err_pred
                    nb_samples += len(x_batch)
            else:
                for xp1, xp2, xn in eval_triplet_loader:
                    cost_pred, eval_bad_triplets_batch = net.eval_dist(xp1, xp2, xn, loss_func=loss_func)
                    eval_bad_triplets += eval_bad_triplets_batch
                    err_val[i] += cost_pred
                    nb_samples += len(xp1)
            err_val[i] /= nb_samples
            badtriplets_str += " Eval {:.4f}".format(eval_bad_triplets/nb_samples)
        net.scheduler.step()
        message = "it %d/%d, KL: %f, nll_tr: %f, nll_te: %f, lr: %f, %s " % (
                net.epoch, nb_epochs+epoch, kl_train[i], nll_triplet_train[i], err_val[i], 
                net.optimizer.param_groups[0]['lr'], badtriplets_str)
        # ---- print
        message += ' time: %f sec' % (toc - tic)
        print(message), flogger.info(message)

        s0trX_embed = get_embedding(net, s0trX, ensemble=args.ensemble)
        eval_embed = get_embedding(net, evalX, ensemble=args.ensemble)
        [_, acc_train, _, _]  = KNNtest(KNN_te, s0trX_embed, trY, s0trX_embed, trY)
        [_, acc_eval, _, _]  = KNNtest(KNN_te, s0trX_embed, trY, eval_embed, evalY)
        message = '  [%d] %d-KNN acc: Train: %.2f, %s: %.2f\n'%(i, KNN_te, acc_train*100, mode, acc_eval * 100)
        train_accs.append(acc_train), eval_accs.append(acc_eval)
        print(message), flogger.info(message)
        
        if args.plot_embed:
            img_embed_train = plot_highest_var_dim(s0trX_embed, trY, epoch=i, path=os.path.join(imgs_save_folder, "train_{}.png".format(i)))
            img_embed_eval = plot_highest_var_dim(eval_embed, evalY, epoch=i, path=os.path.join(imgs_save_folder, "eval_{}.png".format(i)))
            # imgs_train.append(img_embed_train)
            # imgs_eval.append(img_embed_eval)
        early_stop = checkpointer.step(acc_eval)
        if early_stop:
            mess = "Early Stopped!"
            print(mess), flogger.info(mess)
            break
    return train_accs, eval_accs, imgs_train, imgs_eval


def create_dataloaders(feature, target, train=False, transform=None, use_target=False, LMNN_triplets=False, KNN_triplet=21, batch_size=256, shuffle=False):
    if LMNN_triplets:
        triplets_indices = find_KNN_triplets(feature, target, KNN_triplet)
    else:
        triplets_indices = None

    feature, target = torch.tensor(feature), torch.tensor(target, dtype=torch.int64)
    dataset = TensorDataset(feature, target)
    if train:
        kwargs = {}
    else:
        kwargs = {'pin_memory': True} if use_cuda else {}

    triplet_dataset = TripletDataset(dataset, train=train, transform=transform, use_target=use_target, triplets_index=triplets_indices)
    triplet_loader = torch.utils.data.DataLoader(triplet_dataset, batch_size=batch_size, shuffle=True, drop_last=False, **kwargs)
    loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=False, drop_last=False, **kwargs)
    return dataset, triplet_loader, loader

def create_network(args, feature_size, embed_size, use_cuda, batch_size, nb_batches, hidden_sizes, num_classes):
    lr_hyperparam = {'lr': args.lr, 
                'lr_milestones': json.loads(args.lr_milestones), 
                'lr_gamma': args.lr_gamma }

    if args.model == 'Laplace_prior':
        net = Wrapper_FCBNN(lr_hyperparam=lr_hyperparam, input_dim=feature_size, output_dim=embed_size, cuda=use_cuda, 
                            classes=num_classes, batch_size=batch_size,
                            Nbatches=nb_batches, nhid=hidden_sizes,
                            prior_instance=laplace_prior(mu=0, b=args.prior_sig), 
                            bias_prior_instance=laplace_prior(mu=0, b=args.prior_sig), 
                            use_pred_loss=args.use_pred_loss, reweight_constant=args.reweight_constant)
    elif args.model == 'Gaussian_prior':
        net = Wrapper_FCBNN(lr_hyperparam=lr_hyperparam, input_dim=feature_size, output_dim=embed_size, cuda=use_cuda, 
                            classes=num_classes, batch_size=batch_size,
                            Nbatches=nb_batches, nhid=hidden_sizes,
                            prior_instance=isotropic_gauss_prior(mu=0, sigma=args.prior_sig), 
                            bias_prior_instance=isotropic_gauss_prior(mu=0, sigma=args.prior_sig*2), 
                            use_pred_loss=args.use_pred_loss, reweight_constant=args.reweight_constant)
    elif args.model == 'GMM_prior':
        net = Wrapper_FCBNN(lr_hyperparam=lr_hyperparam, input_dim=feature_size, output_dim=embed_size, cuda=use_cuda, 
                            classes=num_classes, batch_size=batch_size,
                            Nbatches=nb_batches, nhid=hidden_sizes,
                            prior_instance=spike_slab_2GMM(mu1=0, mu2=0, sigma1=args.prior_sig, sigma2=0.0005, pi=0.5), 
                            bias_prior_instance=spike_slab_2GMM(mu1=0, mu2=0, sigma1=args.prior_sig, sigma2=0.0005, pi=0.5), 
                            use_pred_loss=args.use_pred_loss, reweight_constant=args.reweight_constant)
    else:
        print('Invalid model type')
        exit(1)

    lr = lr_hyperparam['lr']
    lr_milestones = lr_hyperparam['lr_milestones']
    lr_gamma = lr_hyperparam['lr_gamma']

    optimizer = torch.optim.Adam(net.model.parameters(), lr=lr, betas=(0.9, 0.999), eps=1e-08, weight_decay=0)
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=lr_milestones, gamma=lr_gamma, last_epoch=-1)
    net.create_opt(optimizer, scheduler)

    if args.use_pretrain is not None: 
        net.load(args.use_pretrain)
        feature_size = net.input_dim
        embed_size = net.output_dim

    return net, feature_size, embed_size

def save_plots(fold, imgs_train, imgs_eval):
    filepath_img_train = '{}/train.{}.{}.gif'.format(savefolder, runtime, fold)
    filepath_img_eval = '{}/eval.{}.{}.gif'.format(savefolder, runtime, fold)
    imageio.mimsave(filepath_img_train, imgs_train, fps=1)
    imageio.mimsave(filepath_img_eval, imgs_eval, fps=1)

if nb_epochs == 0:
    print("Epoch == 0")
    exit(0)

skf = StratifiedKFold(n_splits=5)
eval_accs = []
eval_epochs = []

(trX_origin, trY_origin, trY_gt_origin), (teX, teY), num_classes = prepare_data(args)
if args.noise_type == 'asymmetric':
    rng = np.random.RandomState(0)
    noiserates = rng.randint(1, 7, num_classes) / 10

for fold_idx, (train_index, val_index) in enumerate(skf.split(trX_origin, trY_gt_origin)):
    print("###Fold {}: ".format(fold_idx))
    trX, trY_gt = trX_origin[train_index], trY_gt_origin[train_index]
    valX, valY = trX_origin[val_index], trY_gt_origin[val_index]


    trY = inject_noise(trY_gt, noiserates)

    true_noiserate = sum(trY != trY_gt) / len(trY)
    message = 'Shape:%s || #Training:%d || #Val: %d || #Test:%d || label noise (%%): %.2f\n' %(
        trX.shape, len(trY), len(valY), len(teY), true_noiserate * 100)
    print(message), flogger.info(message)

    s0trX, s0valX, s0teX, init_acc = feature_engineering(trX, trY, valX, valY, teX, teY, args)

    feature_size = s0trX.shape[1]
    message = 'feature dim: %d, embed_dim: %d' % (feature_size, embed_size)
    cprint('c', message), flogger.info(message)

    # Dataloaders
    train_dataset, triplet_train_loader, train_loader = create_dataloaders(s0trX, trY, train=True, 
            use_target=True, LMNN_triplets=True, KNN_triplet=KNN_triplet, batch_size=batch_size, shuffle=False)
    test_dataset, triplet_test_loader, test_loader = create_dataloaders(s0teX, teY, train=False, 
            use_target=False, LMNN_triplets=True, KNN_triplet=KNN_triplet, batch_size=batch_size, shuffle=False)
    val_dataset, triplet_val_loader, val_loader = create_dataloaders(s0valX, valY, train=False, 
            use_target=False, LMNN_triplets=False, KNN_triplet=KNN_triplet, batch_size=batch_size, shuffle=False)
    nb_batches = len(triplet_train_loader)
    Ntriplets = len(triplet_train_loader.dataset)
    cprint('c', 'Ntriplets: {} \n'.format(Ntriplets))
    # Create network
    net, feature_size, embed_size = create_network(args, feature_size, embed_size, use_cuda, batch_size, nb_batches, hidden_sizes, num_classes)
    if train_dataset.tensors[0].dtype == torch.float32:
        net.model.float()
    elif train_dataset.tensors[0].dtype == torch.float64:
        net.model.double()
    else:
        print("datatype not expected: {}".format(train_dataset.tensors[0].dtype))
        exit(1)

    if fold_idx == 0:
        cprint('c', '\nNetwork:')
        print(str(net))
        flogger.info(str(net))

    tic0 = time.time()
    checkpointer = EearlyStopping(net, patience=20, criterion='max')
    train_accs_rt, eval_accs_rt, imgs_train, imgs_eval = train(net, s0trX, s0valX, s0teX, 
            triplet_train_loader, triplet_val_loader, triplet_test_loader,
            train_loader, val_loader, test_loader, checkpointer, args)
    
    if args.plot_embed:
        save_plots(fold_idx, imgs_train, imgs_eval)

    toc0 = time.time()
    runtime_per_it = (toc0 - tic0) / float(nb_epochs)
    print('Best validation acc: {} at epoch {}'.format(checkpointer.best_perf, checkpointer.best_epoch))
    cprint('r', ' total runtime: %.2f,  average time: %.2f seconds\n' % (toc0-tic0, runtime_per_it))
    flogger.info(' total runtime: %.2f,  average time: %.2f seconds\n' % (toc0-tic0, runtime_per_it))
    eval_accs.append(checkpointer.best_perf)
    eval_epochs.append(checkpointer.best_epoch)

message = "Valdiation Accuracies: {}, average: {}".format(eval_accs, np.array(eval_accs).mean())
print(message), flogger.info(message)
message = "Evaluation at: {}".format(eval_epochs)
print(message), flogger.info(message)