"""Evaluates the model"""

import argparse
import logging
import torch

from torch.autograd import Variable
import utils
from my_loss_function import loss_label_smoothing, loss_kd_regularization, loss_kd, loss_kd_self, loss_pseudo_kd, loss_CE, loss_pseudo_kd_self

# from train_kd import AverageMeter

parser = argparse.ArgumentParser()
parser.add_argument('--model_dir', default='experiments/base_model', help="Directory of params.json")
parser.add_argument('--restore_file', default='best', help="name of the file in --model_dir \
                     containing weights to load")
class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self, name, fmt=':f'):
        self.name = name
        self.fmt = fmt
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

    def __str__(self):
        fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
        return fmtstr.format(**self.__dict__)

def evaluate(model, loss_fn, dataloader, params, args, name):
    """Evaluate the model on `num_steps` batches.

    Args:
        model: (torch.nn.Module) the neural network
        loss_fn: a function that takes batch_output and batch_labels and computes the loss for the batch
        dataloader: (DataLoader) a torch.utils.data.DataLoader object that fetches data
        metrics: (dict) a dictionary of functions that compute a metric using the output and labels of each batch
        params: (Params) hyperparameters
        num_steps: (int) number of batches to train on, each of size params.batch_size
    """

    # set model to evaluation mode
    model.eval()
    losses = utils.AverageMeter()
    kl_uniforms = utils.AverageMeter()
    total = 0
    correct = 0
    top1 = AverageMeter('Acc@1', ':6.2f')
    top5 = AverageMeter('Acc@5', ':6.2f')
    # compute metrics over the dataset
    for data_batch, labels_batch in dataloader:

        data_batch, labels_batch = data_batch.cuda(async=True), labels_batch.cuda(async=True)

        data_batch, labels_batch = Variable(data_batch), Variable(labels_batch)
        # compute model output
        output_batch = model(data_batch)
        if args.regularization:
            loss = loss_fn(output_batch, labels_batch, params)
        elif args.pseudo_kd:
            loss = loss_fn(output_batch, labels_batch, params)
            # loss, kl_uniform = loss_fn(output_batch, labels_batch, params)
        elif args.pseudo_kd_beta:
            # loss = loss_fn(output_batch, labels_batch, output_teacher_batch, params)
            loss_fn = loss_CE
            loss = loss_fn(output_batch, labels_batch, params)
        else:
            loss_fn = loss_CE
            loss = loss_fn(output_batch, labels_batch, params)

        losses.update(loss.data, data_batch.size(0))
        # if args.pseudo_kd:
        #     kl_uniforms.update(kl_uniform.data, data_batch.size(0))
        _, predicted = output_batch.max(1)
        total += labels_batch.size(0)
        correct += predicted.eq(labels_batch).sum().item()

        acc1, acc5 = accuracy(output_batch, labels_batch, topk=(1, 5))
        top1.update(acc1[0], labels_batch.size(0))
        top5.update(acc5[0], labels_batch.size(0))

    loss_avg = losses.avg
    acc = 100.*correct/total
    logging.info(f"- Eval {name} metrics, acc:{acc:.4f}, loss: {loss_avg:.4f}, acc@1: {top1.avg:.4f}, acc@5: {top5.avg:.4f}")
    my_metric = {'accuracy': acc, 'loss': loss_avg, 'acc@5': top5.avg, 'acc@1': top1.avg}
    # if args.pseudo_kd:
    #     my_metric = {'accuracy': acc, 'loss': loss_avg, 'acc@5': top5.avg, 'acc@1': top1.avg, 'kl_uniform': kl_uniforms.avg}
    return my_metric


"""
This function duplicates "evaluate()" but ignores "loss_fn" simply for speedup purpose.
Validation loss during KD mode would display '0' all the time.
One can bring that info back by using the fetched teacher outputs during evaluation (refer to train.py)
"""
def evaluate_kd(model, dataloader, params):
    """Evaluate the model on `num_steps` batches.

    Args:
        model: (torch.nn.Module) the neural network
        loss_fn: a function that takes batch_output and batch_labels and computes the loss for the batch
        dataloader: (DataLoader) a torch.utils.data.DataLoader object that fetches data
        metrics: (dict) a dictionary of functions that compute a metric using the output and labels of each batch
        params: (Params) hyperparameters
        num_steps: (int) number of batches to train on, each of size params.batch_size
    """

    # set model to evaluation mode
    model.eval()
    total = 0
    correct = 0

    # compute metrics over the dataset
    for i, (data_batch, labels_batch) in enumerate(dataloader):

        # move to GPU if available
        data_batch, labels_batch = data_batch.cuda(async=True), labels_batch.cuda(async=True)
        # fetch the next evaluation batch
        data_batch, labels_batch = Variable(data_batch), Variable(labels_batch)
        
        # compute model output
        output_batch = model(data_batch)

        # loss = loss_fn_kd(output_batch, labels_batch, output_teacher_batch, params)
        loss = 0.0  #force validation loss to zero to reduce computation time
        _, predicted = output_batch.max(1)
        total += labels_batch.size(0)
        correct += predicted.eq(labels_batch).sum().item()

    acc = 100. * correct / total
    logging.info("- Eval metrics, acc:{acc:.4f}, loss: {loss:.4f}".format(acc=acc, loss=loss))
    my_metric = {'accuracy': acc, 'loss': loss}
    #my_metric['accuracy'] = acc
    return my_metric

def accuracy(output, target, topk=(1,)):
    """Computes the accuracy over the k top predictions for the specified values of k"""
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.contiguous().view(1, -1).expand_as(pred))
        # import ipdb
        # ipdb.set_trace()
        res = []
        for k in topk:
            correct_k = correct[:k].contiguous().view(-1).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res