import argparse
import os
import shutil
import time

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim
import torch.utils.data
from torch.utils.data.distributed import DistributedSampler

import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.autograd import Variable
import numpy as np
import math
import utils
import sys
import calibration as cal

from models import WideResNet, WideResNetEnsemble
from models import TempCNN, TempConst, TempResNet, TempLinearOnReps, TempNNOnReps
from models import CalibrationMatrixScaling, CalibrationMonotone, CalibrationNNOnReps, CalibrationNNOnRepsUnshared
from models_resnet import wide_resnet50_2
from data import get_loaders

# used for logging to TensorBoard
from tensorboard_logger import configure, log_value

parser = argparse.ArgumentParser(description='PyTorch WideResNet Training')
parser.add_argument('--dataset', default='cifar10', type=str,
                    help='dataset (cifar10, cifar100, imagenet, cifar100-c)')
parser.add_argument('--shift_intensity', default=1, type=int,
                    help='Shift intensity for cifar{10, 100}-c dataset')
parser.add_argument('--split_size', default=-1, type=int,
                    help='Split the dataset into length _ and n_train-_')
parser.add_argument('--split_size_2', default=-1, type=int,
                    help='Further split the second split of the training dataset')
parser.add_argument('--use_split', default='train', type=str,
                    help='Use which split: {train|train_val}')
parser.add_argument('--epochs', default=200, type=int,
                    help='number of total epochs to run')
parser.add_argument('--start-epoch', default=0, type=int,
                    help='manual epoch number (useful on restarts)')
parser.add_argument('-b', '--batch-size', default=128, type=int,
                    help='mini-batch size (default: 128)')
parser.add_argument('--val_batch_size', default=128, type=int,
                    help='mini-batch size for validation')
parser.add_argument('--loss', default='cross_entropy', type=str,
                    help='loss function (default: cross_entropy)')
parser.add_argument('--lr', '--learning-rate', default=0.1, type=float,
                    help='initial learning rate')
parser.add_argument('--scheduler', default='cosine', type=str,
                    help='learning rate scheduler')
parser.add_argument('--ddp', action='store_true',
                    help='Use DDP training')
parser.add_argument('--local_rank', default=0, type=int,
                    help='local rank for DDP training')
parser.add_argument('--parallel', action='store_true',
                    help='perform multi-gpu training through dataparallel')
parser.add_argument('--model_seed', default=None, type=int,
                    help='Random seed for initializing model')
parser.add_argument('--data_seed', default=None, type=int,
                    help='Random seed for sgd noises (minibatch noise, data augmentation noise)')
parser.add_argument('--reset_seed_epoch', default=-1, type=int,
                    help='Reset data seed at this epoch')
parser.add_argument('--data_seed_reset', default=None, type=int,
                    help='New data seed for resetting')
parser.add_argument('--optimizer', default='momentum', type=str,
                    help='Optimizer')
parser.add_argument('--momentum', default=0.9, type=float, help='momentum')
parser.add_argument('--nesterov', default=True, type=bool, help='nesterov momentum')
parser.add_argument('--weight-decay', '--wd', default=5e-4, type=float,
                    help='weight decay (default: 5e-4)')
parser.add_argument('--temp_l2', default=0.0, type=float,
                    help='L2 regularization for temp layer (weights only, not on biases)')
parser.add_argument('--print-freq', '-p', default=10, type=int,
                    help='print frequency (default: 10)')
parser.add_argument('--train_num_ensemble', default=-1, type=int,
                    help='Train an ensembled model directly with _ submodels')
parser.add_argument('--model', default='wideresnet', type=str,
                    help='model')
parser.add_argument('--layers', default=28, type=int,
                    help='total number of layers (default: 28)')
parser.add_argument('--widen-factor', default=10, type=int,
                    help='widen factor (default: 10)')
parser.add_argument('--droprate', default=0, type=float,
                    help='dropout probability (default: 0.0)')
parser.add_argument('--no-augment', dest='augment', action='store_false',
                    help='whether to use standard augmentation (default: True)')
parser.add_argument('--output_all_layers', action='store_true',
                    help='ask the model to output all hidden layers')
# parser.add_argument('--adaptive_temp', action='store_true',
#                     help='adaptive temperature models')
parser.add_argument('--use_temp', action='store_true',
                    help='Use non-trivial temp in forward-passing through individualized temperature models')
parser.add_argument('--use_calib', action='store_true',
                    help='Use a generalized calibration model')
parser.add_argument('--calib_output_mode', default='logprobs', type=str,
                    help='Output mode for calib model: {logprobs|logits}')
parser.add_argument('--num_temps', default=5, type=int,
                    help='Number of temperatures in generalized calibration model')
parser.add_argument('--temp_model', default=None, type=str,
                    help='Separate temperature model: {cnn|}')
parser.add_argument('--calib_model', default=None, type=str,
                    help='Separate calibration model: {cnn|}')
parser.add_argument('--temp_model_depth', default=4, type=int,
                    help='depth for temperature model')
parser.add_argument('--temp_model_width', default=512, type=int,
                    help='width for temperature model')
parser.add_argument('--temp_train_fc_only', action='store_true',
                    help='train fc layer only for temp model')
parser.add_argument('--temp_model_pretrained', action='store_true',
                    help='use pretrained model to initialized temp model')
parser.add_argument('--temp_zero_init_residual', action='store_true',
                    help='use zero init (on residual branches) for temp model')
parser.add_argument('--min_temp', default=1.0, type=float,
                    help='minimal temperature for temperature model')
parser.add_argument('--neg_slope', default=0.5, type=float,
                    help='negative slope for leaky relu activation in calib model')
parser.add_argument('--temp_init_increment', default=0.5, type=float,
                    help='increment for initializing bias in generalized calib model')
parser.add_argument('--activation', default='leaky_relu', type=str,
                    help='activation for general monotone calibrator')
parser.add_argument('--save_path_optimal_temps', default=None, type=str,
                    help='Path for saving optimal temperatures')
parser.add_argument('--distill_temps', default=None, type=str,
                    help='Path for distilling temperatures (from)')
# parser.add_argument('--grad_mode', default='all', type=str,
#                     help='What parameters to learn: {all|model|temp}')
parser.add_argument('--nll_weight', default=1.0, type=float,
                    help='weight for nll loss')
parser.add_argument('--ece_weight', default=0, type=float,
                    help='weight for ece loss')
parser.add_argument('--kl_ece_weight', default=0, type=float,
                    help='weight for kl-ece loss')
parser.add_argument('--ece_num_partitions_train', default=5, type=int,
                    help='Number of partitions for computing ece')
parser.add_argument('--ece_num_partitions_val', default=15, type=int,
                    help='Number of partitions for computing ece')
parser.add_argument('--resume', default='', type=str,
                    help='path to latest checkpoint (default: none)')
parser.add_argument('--resume_from_ddp', action='store_true',
                    help='resume model from ddp checkpoint')
parser.add_argument('--name', default='WideResNet-28-10', type=str,
                    help='name of experiment')
parser.add_argument('--tensorboard',
                    help='Log progress to TensorBoard', action='store_true')
parser.add_argument('--no_cudnn_deterministic', dest='cudnn_deterministic',
                    action='store_false', help='disable cudnn deterministic')
parser.add_argument('--save_epoch', default=-1, type=int,
                    help='save model checkpoint per _ epoch')
parser.add_argument('--save_model_per_x_batch', default=-1, type=int,
                    help='save model checkpoint per _ batch')
parser.add_argument('--evaluate_on_train', action='store_true',
                    help='also evaluate on the training dataset')
parser.add_argument('--compute_prr', action='store_true',
                    help='compute prr for a given model')
parser.add_argument('--debug_mode', action='store_true',
                    help='enter debug mode')
parser.add_argument('--debug_dataset', default='val',
                    help='dataset for debugging')
parser.set_defaults(augment=True)
parser.set_defaults(cudnn_deterministic=True)

best_prec1 = 0
NUM_CLASSES = {
    'cifar10': 10, 'cifar100': 100,
    'cifar100c': 100,
    'imagenet': 1000,
}


def main():
    global args, best_prec1
    args = parser.parse_args()
    if args.debug_mode and not args.compute_prr:
        args.name = 'debug'
    if args.local_rank == 0 and args.tensorboard:
        if args.name != 'debug':
            configure("runs/%s"%(args.name))
        else:
            configure("/tmp/debug")

    # configure distributed backend
    if args.ddp:
        torch.distributed.init_process_group(backend="nccl")
        device = torch.device("cuda:{}".format(args.local_rank))
    else:
        device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

    # Data loading
    train_loader, train_val_loader, val_loader = get_loaders(
        args.dataset, batch_size=args.batch_size, val_batch_size=args.val_batch_size,
        augment=args.augment, split_size=args.split_size, split_size_2=args.split_size_2,
        distill_temps=args.distill_temps,
        ddp=args.ddp, shift=args.shift_intensity
    )

    # import pdb; pdb.set_trace()


    loader = train_loader if args.use_split == 'train' else train_val_loader
    args.n_train, args.n_test = len(loader.dataset), len(val_loader.dataset)
    args.num_classes = NUM_CLASSES[args.dataset]

    # import pdb; pdb.set_trace()

    if args.model_seed:
        torch.random.manual_seed(args.model_seed)

    # create model and optionally cast as ddp model
    if args.model == 'wideresnet':
        if args.train_num_ensemble > 0:
            model = WideResNetEnsemble(args.layers, args.num_classes,
                                       args.widen_factor, dropRate=args.droprate,
                                       num_ensemble=args.train_num_ensemble)
        else:
            model = WideResNet(args.layers, args.num_classes,
                               args.widen_factor, dropRate=args.droprate)
    elif args.model == 'wideresnet_50_2':
        model = wide_resnet50_2(num_classes=args.num_classes)
    model = model.to(device)
    if args.ddp:
        model = torch.nn.parallel.DistributedDataParallel(
            model, device_ids=[args.local_rank], output_device=args.local_rank)

    # create temp model and optimizer
    temp_model = None
    if args.temp_model == 'cnn':
        temp_model = TempCNN(args.temp_model_width, args.temp_model_depth,
                             train_fc_only=args.temp_train_fc_only,
                             num_classes=1, min_temp=args.min_temp).to(device)
    elif args.temp_model == 'resnet':
        temp_model = TempResNet(pretrained=args.temp_model_pretrained,
                                min_temp=args.min_temp,
                                zero_init_residual=args.temp_zero_init_residual).to(device)
    elif args.temp_model == 'const':
        temp_model = TempConst(min_temp=args.min_temp).to(device)
    elif args.temp_model == 'linear_on_reps':
        temp_model = TempLinearOnReps([5], [640], min_temp=args.min_temp).to(device)
        # temp_model = TempLinearOnReps([2, 3, 5], [320*16*16, 640*8*8, 640], min_temp=args.min_temp).to(device)
    elif args.temp_model == 'nn_on_reps':
        temp_model = TempNNOnReps([5], [640 if args.model == 'wideresnet' else 2048],
                                  depth=args.temp_model_depth, width=args.temp_model_width,
                                  min_temp=args.min_temp).to(device)
    if args.calib_model == 'monotone':
        temp_model = CalibrationMonotone(
            min_temp=args.min_temp, temp_init_increment=args.temp_init_increment,
            num_temps=args.num_temps, neg_slope=args.neg_slope,
            output_mode=args.calib_output_mode).to(device)
    if args.calib_model == 'nn_on_reps':
        temp_model = CalibrationNNOnReps(
            [5], [640 if args.model == 'wideresnet' else 2048],
            depth=args.temp_model_depth, width=args.temp_model_width,
            min_temp=args.min_temp, temp_init_increment=args.temp_init_increment,
            num_temps=args.num_temps, neg_slope=args.neg_slope,
            output_mode=args.calib_output_mode, activation=args.activation).to(device)
    if args.calib_model == 'nn_on_reps_unshared':
        temp_model = CalibrationNNOnRepsUnshared(
            [5], [640 if args.model == 'wideresnet' else 2048],
            depth=args.temp_model_depth, width=args.temp_model_width,
            min_temp=args.min_temp, temp_init_increment=args.temp_init_increment,
            num_temps=args.num_temps, neg_slope=args.neg_slope,
            output_mode=args.calib_output_mode).to(device)
    if args.calib_model == 'matrix_scale':
        temp_model = CalibrationMatrixScaling(args.num_classes).to(device)
    if temp_model is not None:
        if args.optimizer == 'sgd':
            temp_model_optimizer = torch.optim.SGD(temp_model.parameters(), args.lr,
                                                   weight_decay=args.weight_decay)
        elif args.optimizer == 'momentum':
            temp_model_optimizer = torch.optim.SGD(temp_model.parameters(), args.lr,
                                                   momentum=args.momentum, nesterov=args.nesterov,
                                                   weight_decay=args.weight_decay)
        elif args.optimizer == 'adam':
            temp_model_optimizer = torch.optim.Adam(temp_model.parameters(), args.lr,
                                                    weight_decay=args.weight_decay)
    else:
        temp_model_optimizer = None

    # import pdb; pdb.set_trace()
    # get the number of model parameters
    print('Number of model parameters: {}'.format(
        sum([p.data.nelement() for p in model.parameters()])))

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            # args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            state_dict = checkpoint['state_dict']
            # import pdb; pdb.set_trace()
            if args.resume_from_ddp:
                state_dict = utils.process_ddp_statedict(state_dict)
                print("=> processed ddp state_dict")
            model.load_state_dict(state_dict)
            if temp_model is not None and 'temp_model_state_dict' in checkpoint:
                temp_model.load_state_dict(checkpoint['temp_model_state_dict'])
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    # for training on multiple GPUs.
    # Use CUDA_VISIBLE_DEVICES=0,1 to specify which GPUs to use
    if args.parallel:
        model = nn.DataParallel(model).cuda()

    if args.cudnn_deterministic:
        cudnn.deterministic = True
        cudnn.benchmark = True
    else:
        cudnn.benchmark = True

    # define loss function (criterion) and optimizer
    if args.loss == 'cross_entropy':
        criterion = nn.NLLLoss().to(device)
    elif args.loss == 'mse':
        criterion = nn.MSELoss(size_average=True).to(device)

    # cosine learning rate
    if args.start_epoch == 0:
        if args.optimizer == 'momentum':
            optimizer = torch.optim.SGD(model.parameters(), args.lr,
                                        momentum=args.momentum, nesterov=args.nesterov,
                                        weight_decay=args.weight_decay)
        elif args.optimizer == 'sgd':
            optimizer = torch.optim.SGD(model.parameters(), args.lr,
                                        weight_decay=args.weight_decay)
        elif args.optimizer == 'adam':
            optimizer = torch.optim.Adam(model.parameters(), args.lr,
                                         weight_decay=args.weight_decay)
        if args.use_split == 'train_val' and (args.use_temp or args.use_calib):
            the_optimizer = temp_model_optimizer
        else:
            the_optimizer = optimizer
        if args.scheduler == 'cosine':
            scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(the_optimizer, len(loader)*args.epochs)
        elif args.scheduler == 'oneepoch':
            scheduler = torch.optim.lr_scheduler.MultiStepLR(the_optimizer, milestones=[len(loader)], gamma=0.1)
        elif args.scheduler == 'cifar':
            scheduler = torch.optim.lr_scheduler.MultiStepLR(
                the_optimizer, milestones=[c * len(loader) for c in [60, 120, 160]], gamma=0.2)
        elif args.scheduler == 'cifar_long':
            scheduler = torch.optim.lr_scheduler.MultiStepLR(
                the_optimizer, milestones=[c * len(loader) for c in [60, 120, 160]], gamma=0.2)
        elif args.scheduler == 'imagenet':
            scheduler = torch.optim.lr_scheduler.MultiStepLR(
                the_optimizer, milestones=[c * len(loader) for c in [30, 60, 90]], gamma=0.1)
        elif args.scheduler == 'onecycle':
            scheduler = torch.optim.lr_scheduler.OneCycleLR(the_optimizer, max_lr=args.lr,
                                                            total_steps=len(loader) * args.epochs)
        elif args.scheduler == 'const':
            scheduler = torch.optim.lr_scheduler.MultiStepLR(the_optimizer, milestones=[], gamma=1.0)
    else:
        new_base_lr = .5 * args.lr * (1 + math.cos(math.pi * args.start_epoch / args.epochs))
        optimizer = torch.optim.SGD(model.parameters(), new_base_lr,
                                    momentum=args.momentum, nesterov=args.nesterov,
                                    weight_decay=args.weight_decay)
        for group in optimizer.param_groups:
            group.setdefault('initial_lr', group['lr'])
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, len(train_loader)*args.epochs,
                                                               last_epoch=len(train_loader)*args.start_epoch)

    if args.save_path_optimal_temps is not None:
        unshuffled_train_val_loader = torch.utils.data.DataLoader(
            dataset_split[1], batch_size=args.batch_size, shuffle=False, **kwargs)
        _ = validate(unshuffled_train_val_loader, model, 0, args,
                     device=device, save_path_optimal_temps=args.save_path_optimal_temps)
        sys.exit()
    elif args.debug_mode:
        debug_loader = train_val_loader if args.debug_dataset == 'train_val' else val_loader
        if temp_model:
            temp_model.min_temp = 0.2
        _ = validate(debug_loader, model, 0, args,
                     device=device, debug_mode=True,
                     compute_prr=args.compute_prr,
                     temp_model=temp_model)
        sys.exit()

    for epoch in range(args.start_epoch, args.epochs):
        # use data seed if it exits; otherwise use model seed if it exists
        if args.data_seed:
            torch.random.manual_seed(args.data_seed + epoch)
        # # optionally reset random seed
        # if (args.reset_seed_epoch > 0) and (epoch == args.reset_seed_epoch):
        #     torch.random.manual_seed(args.data_seed_reset)
        # optionally save the model
        if args.local_rank == 0 and args.save_model_per_x_batch > 0:
            save_checkpoint({
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'best_prec1': -1.,
            }, False, filename=f'checkpoint_batch_0.pth.tar')

        # train for one epoch
        train(loader, model, optimizer, scheduler, epoch, args,
              device=device, temp_model=temp_model, temp_model_optimizer=temp_model_optimizer)

        # optionally evaluate on train set
        if args.local_rank == 0 and args.evaluate_on_train:
            _ = validate(loader, model, epoch, args,
                         device=device, temp_model=temp_model, tag='eval_train')

        # evaluate on validation set
        if args.local_rank == 0:
            prec1 = validate(val_loader, model, epoch, args,
                             device=device, temp_model=temp_model)

            # remember best prec@1 and save checkpoint
            is_best = prec1 > best_prec1
            best_prec1 = max(prec1, best_prec1)
            checkpoint_dict = {
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'best_prec1': best_prec1,
            }
            if temp_model:
                checkpoint_dict['temp_model_state_dict'] = temp_model.state_dict()
            save_checkpoint(checkpoint_dict, is_best)

            if args.save_epoch > 0 and (epoch+1) % args.save_epoch == 0:
                save_checkpoint({
                    'epoch': epoch + 1,
                    'state_dict': model.state_dict(),
                    'best_prec1': best_prec1,
                }, False, filename=f'checkpoint_epoch_{epoch+1}.pth.tar')
            print('Best accuracy: ', best_prec1)

def train(train_loader, model, optimizer, scheduler, epoch,
          args, device=torch.device("cuda"),
          temp_model=None, temp_model_optimizer=None):
    """Train for one epoch on the training set"""
    batch_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()
    eces = AverageMeter()
    # kl_eces = AverageMeter()

    # switch to train mode if training the model, else use eval mode
    if args.use_temp or args.use_calib:
        model.eval()
        model.requires_grad_(False)
    else:
        model.train()
    end = time.time()

    if args.distill_temps is not None:
        criterion = get_criterion('mse', device)
    else:
        criterion = get_criterion(args.loss, device)


    for i, (input, target) in enumerate(train_loader):
        target = target.to(device)
        input = input.to(device)

        output, all_layers = model(input, output_all_layers=True)

        # Scale temperature by temp model
        if args.use_temp:
            if args.temp_model in ['linear_on_reps', 'nn_on_reps']:
                temp = temp_model(all_layers)
            elif args.temp_model in ['const', 'cnn', 'resnet']:
                temp = temp_model(input)
            output = F.log_softmax(output / temp)
        elif args.use_calib:
            output, temp = temp_model(output, all_layers)

        loss = 0.
        if args.nll_weight > 0:
            if args.distill_temps is not None:
                loss += args.nll_weight * criterion(temp, target)
                # import pdb; pdb.set_trace()
            elif args.loss == 'mse':
                one_hot_target = F.one_hot(target, num_classes=args.num_classes).float()
                loss += args.nll_weight * criterion(output, one_hot_target)
            else:
                loss += args.nll_weight * criterion(output, target)
            # measure accuracy
            if args.distill_temps is None:
                precs = accuracy(output.data, target, topk=(1, 5))
                prec1, prec5 = precs[0], precs[1]
                top1.update(prec1.item(), input.size(0))
                top5.update(prec5.item(), input.size(0))


        # (optionally) add ece loss
        if args.ece_weight > 0 or args.kl_ece_weight > 0:
            probs = torch.exp(output)
            if args.ece_weight > 0:
                ece = utils.expected_calibration_error(probs, target,
                                                       num_partitions=args.ece_num_partitions_train)
                eces.update(ece.data.item(), input.size(0))
                loss += args.ece_weight * ece
            if args.kl_ece_weight > 0:
                kl_ece = utils.kl_ece(probs, target,
                                      num_partitions=args.ece_num_partitions_train)
                # kl_eces.update(kl_ece.data.item(), input.size(0))
                loss += args.kl_ece_weight * kl_ece

        # (optionally) add l2 reg for temp layer
        # if args.temp_l2 > 0:
        #     loss += args.temp_l2 * model.temp_layer_l2()

        losses.update(loss.data.item(), input.size(0))

        # compute gradient and do SGD step
        # import pdb; pdb.set_trace()
        optimizer.zero_grad()
        if temp_model:
            temp_model_optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if temp_model:
            temp_model_optimizer.step()
        scheduler.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        # optionally save models within an epoch
        if args.local_rank == 0 and args.save_model_per_x_batch > 0 and (i+1) % args.save_model_per_x_batch == 0:
            save_checkpoint({
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'best_prec1': -1.,
            }, False, filename=f'checkpoint_batch_{i+1}.pth.tar')

        if i % args.print_freq == 0:
            print('Epoch: [{0}][{1}/{2}]\t'
                  'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                  'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
                  'Prec@5 {top5.val:.3f} ({top5.avg:.3f})\t'
                  'LR {lr:.4f}'.format(
                      epoch, i, len(train_loader), batch_time=batch_time,
                      loss=losses, top1=top1, top5=top5, lr=scheduler.get_last_lr()[0]))
    # log to TensorBoard
    if args.local_rank == 0 and args.tensorboard:
        log_value('train_loss', losses.avg, epoch)
        log_value('train_acc', top1.avg, epoch)
        log_value('train_acc_top5', top5.avg, epoch)
        if args.ece_weight > 0:
            log_value('train_ece', eces.avg, epoch)


def validate(val_loader, model, epoch,
             args, device=torch.device("cuda"), debug_mode=False,
             temp_model=None, save_path_optimal_temps=None,
             compute_prr=False,
             tag='val'):
    """Perform validation on the validation set"""
    batch_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

    # switch to evaluate mode
    model.eval()

    criterion = get_criterion(args.loss, device)

    probs_all = torch.zeros((len(val_loader.dataset), args.num_classes)).to(device)
    targets_all = torch.zeros(len(val_loader.dataset), dtype=torch.long).to(device)
    if args.use_temp:
        temps_all = torch.zeros((len(val_loader.dataset), 1)).to(device)
    elif args.use_calib:
        temps_all = torch.zeros((len(val_loader.dataset), 1, args.num_temps)).to(device)
    counter = 0
    end = time.time()
    # import pdb; pdb.set_trace()
    for i, (input, target) in enumerate(val_loader):
        target = target.to(device)
        input = input.to(device)

        # compute output
        with torch.no_grad():
            output, all_layers = model(input, output_all_layers=True)
            # Scale temperature
            if args.use_temp:
                if args.temp_model in ['linear_on_reps', 'nn_on_reps']:
                    temp = temp_model(all_layers)
                elif args.temp_model in ['const', 'cnn', 'resnet']:
                    temp = temp_model(input)
                output = F.log_softmax(output / temp)
            elif args.use_calib:
                output, temp = temp_model(output, all_layers)

            loss = criterion(output, target)

        # import pdb; pdb.set_trace()
        # measure accuracy and record loss
        precs = accuracy(output.data, target, topk=(1, 5))
        prec1, prec5 = precs[0], precs[1]
        losses.update(loss.data.item(), input.size(0))
        top1.update(prec1.item(), input.size(0))
        top5.update(prec5.item(), input.size(0))

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        # if model.output_mode == 'logprobs':
        probs_all[counter:(counter+input.size(0)), :] = torch.exp(output)
        # elif model.output_mode == 'logits':
        #     probs_all[counter:(counter+input.size(0)), :] = F.softmax(output)
        targets_all[counter:(counter+input.size(0))] = target
        if (args.use_temp or args.use_calib) and temp is not None:
            temps_all[counter:(counter+input.size(0))] = temp
        counter += input.size(0)

        if i % args.print_freq == 0:
            print('Test: [{0}/{1}]\t'
                  'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                  'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
                  'Prec@5 {top5.val:.3f} ({top5.avg:.3f})\t'.format(
                      i, len(val_loader), batch_time=batch_time, loss=losses,
                      top1=top1, top5=top5))

    # compute ece on entire val dataset
    ece = utils.expected_calibration_error(probs_all, targets_all,
                                           num_partitions=args.ece_num_partitions_val)
    # kl_ece = utils.kl_ece(probs_all, targets_all,
    #                       num_partitions=args.ece_num_partitions_val)
    pred_ent = utils.predictive_entropy(probs_all)
    avg_conf = probs_all.max(dim=1)[0].mean()
    ece_per_class = utils.ece_per_class(probs_all, targets_all,
                                        num_partitions=args.ece_num_partitions_val, num_classes=args.num_classes)
    debiased_ece_l1 = cal.get_calibration_error(probs_all.cpu(), targets_all.cpu(),
                                                p=1, debias=True, mode='top-label')
    debiased_ece_l2 = cal.get_calibration_error(probs_all.cpu(), targets_all.cpu(),
                                                p=2, debias=True, mode='top-label')
    print(' * Prec@1 {top1.avg:.3f}'.format(top1=top1))
    print(' * Prec@5 {top5.avg:.3f}'.format(top5=top5))
    print(' * Val loss {losses.avg:.3f}'.format(losses=losses))
    print(f' * Predictive entropy {pred_ent:.4f}')
    print(f' * Average confidence {avg_conf:.4f}')
    print(f' * ECE {ece:.4f}')
    # print(f' * KL-ECE {kl_ece:.4f}')
    print(f' * Debiased ECE p=1 {debiased_ece_l1:.5f}, p=2 {debiased_ece_l2:.5f}')
    print(f' * std(ECE per class) {np.std(ece_per_class)}')
    if args.use_temp or args.use_calib:
        print(f' * Average temp {temps_all.mean():.4f}, min temp {temps_all.min():.4f}, '
              f'max temp {temps_all.max():.4f}, avg temp {temps_all.mean():.4f}, '
              f'std temp {temps_all.std():.4f}')
    if debug_mode:
        if compute_prr:
            count_thres, accs, prr = utils.accuracy_against_counts(probs_all, targets_all)
            directory = "runs/%s/" % (args.name)
            np_path = os.path.join(directory, 'acc_against_conf.npy')
            prr_path = os.path.join(directory, 'prr.npy')
            np.save(np_path, np.vstack([count_thres, accs]))
            np.save(prr_path, np.array([prr]))
            return
        else:
            import pdb
            pdb.set_trace()
    if save_path_optimal_temps is not None:
        utils.optimal_individual_temp(probs_all, targets_all, min_temp=args.min_temp,
                                      max_iter=10000, verbose=1000, eta=0.01, overwrite_correct_preds=False,
                                      save_path=save_path_optimal_temps)
    # import pdb; pdb.set_trace()
    # log to TensorBoard
    if args.local_rank == 0 and args.tensorboard:
        log_value(f'{tag}_loss', losses.avg, epoch)
        log_value(f'{tag}_acc', top1.avg, epoch)
        log_value(f'{tag}_acc_top5', top5.avg, epoch)
        log_value(f'{tag}_ece', ece.data.item(), epoch)
        # log_value(f'{tag}_kl_ece', kl_ece.data.item(), epoch)
        log_value(f'{tag}_predictive_entropy', pred_ent.data.item(), epoch)
        log_value(f'{tag}_average_confidence', avg_conf.data.item(), epoch)
        log_value(f'{tag}_ece_debiased/L1', debiased_ece_l1, epoch)
        log_value(f'{tag}_ece_debiased/L2', debiased_ece_l2, epoch)
        if args.use_temp:
            log_value(f'{tag}_temp/mean', temps_all.mean().item(), epoch)
            log_value(f'{tag}_temp/std', temps_all.std().item(), epoch)
            log_value(f'{tag}_temp/min', temps_all.min().item(), epoch)
            log_value(f'{tag}_temp/max', temps_all.max().item(), epoch)
        elif args.use_calib:
            for i in range(args.num_temps):
                log_value(f'{tag}_temp/{i}/mean', temps_all[:, :, i].mean().item(), epoch)
                log_value(f'{tag}_temp/{i}/std', temps_all[:, :, i].std().item(), epoch)
                log_value(f'{tag}_temp/{i}/min', temps_all[:, :, i].min().item(), epoch)
                log_value(f'{tag}_temp/{i}/max', temps_all[:, :, i].max().item(), epoch)
    return top1.avg


def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
    """Saves checkpoint to disk"""
    directory = "runs/%s/"%(args.name)
    if not os.path.exists(directory):
        os.makedirs(directory)
    filename = directory + filename
    torch.save(state, filename)
    if is_best:
        shutil.copyfile(filename, 'runs/%s/'%(args.name) + 'model_best.pth.tar')

class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

def accuracy(output, target, topk=(1,)):
    """Computes the precision@k for the specified values of k"""
    maxk = max(topk)
    batch_size = target.size(0)

    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(target.view(1, -1).expand_as(pred))

    res = []
    for k in topk:
        correct_k = correct[:k].view(-1).float().sum(0)
        res.append(correct_k.mul_(100.0 / batch_size))
    return res


def get_criterion(loss, device):
    if loss == 'mse':
        return nn.MSELoss(size_average=True).to(device)
    elif loss == 'cross_entropy':
        return nn.NLLLoss().to(device)
    return None

if __name__ == '__main__':
    main()
