import os
import torch
ch = torch
from tqdm import tqdm
from torchvision.utils import make_grid

from robustness.tools import helpers


def eval_model(args, model, loader, logger):
    """
    Evaluate a model for standard (and optionally adversarial) accuracy.

    Args:
        args (object) : A list of arguments---should be a python object 
            implementing ``getattr()`` and ``setattr()``.
        model (AttackerModel) : model to evaluate
        loader (iterable) : a dataloader serving `(input, label)` batches from
            the validation set
        store (cox.Store) : store for saving results in (via tensorboardX)
    """
    assert not hasattr(model, "module"), "model is already in DataParallel."
    model = ch.nn.DataParallel(model)

    prec1, nat_loss = _model_loop(args, 'val', loader, 
                                        model, None, 0, False, logger)

    adv_prec1, adv_loss = float('nan'), float('nan')

    log_info = {
        'epoch':0,
        'nat_prec1':prec1,
        'adv_prec1':adv_prec1,
        'nat_loss':nat_loss,
        'adv_loss':adv_loss,
        'train_prec1':float('nan'),
        'train_loss':float('nan'),
    }

    return log_info

def _model_loop(args, loop_type, loader, model, opt, epoch, adv, logger):
    if not loop_type in ['train', 'val']:
        err_msg = "loop_type ({0}) must be 'train' or 'val'".format(loop_type)
        raise ValueError(err_msg)
    is_train = (loop_type == 'train')

    losses = helpers.AverageMeter()
    top1 = helpers.AverageMeter()
    top5 = helpers.AverageMeter()

    prec = 'NatPrec' if not adv else 'AdvPrec'
    loop_msg = 'Train' if loop_type == 'train' else 'Val'

    # switch to train/eval mode depending
    model = model.train() if is_train else model.eval()

    train_criterion = ch.nn.CrossEntropyLoss()
    adv_criterion = None
    random_restarts = 0 if is_train else args.random_restarts

    attack_kwargs = {}
    iterator = tqdm(enumerate(loader), total=len(loader))
    accum_per_class = dict((class_id,0) for class_id in range(args.num_classes))
    correct_per_class = dict((class_id,0) for class_id in range(args.num_classes))
    all_labels, all_preds = [], []
    for i, (inp, target) in iterator:
        # measure data loading time
        target = target.cuda(non_blocking=True)
        
        output, final_inp = model(inp, target=target, make_adv=adv,
                                  **attack_kwargs)
        loss = train_criterion(output, target)

        if len(loss.shape) > 0: loss = loss.mean()

        model_logits = output[0] if (type(output) is tuple) else output
        preds = ch.argmax(ch.sigmoid(model_logits), dim=1)
        for item_target, item_pred in zip(target, preds):
            accum_per_class[int(item_target.cpu().detach())] += 1
            if int(item_target.cpu().detach()) == int(item_pred.cpu().detach()):
                correct_per_class[int(item_target.cpu().detach())] += 1
        all_labels.extend(target.cpu().detach())
        all_preds.extend(preds.cpu().detach())
        # measure accuracy and record loss
        top1_acc = float('nan')
        top5_acc = float('nan')
        try:
            maxk = min(5, model_logits.shape[-1])
            prec1, prec5 = helpers.accuracy(model_logits, target, topk=(1, maxk))
            prec1, prec5 = prec1[0], prec5[0]

            losses.update(loss.item(), inp.size(0))
            top1.update(prec1, inp.size(0))
            top5.update(prec5, inp.size(0))

            top1_acc = top1.avg
            top5_acc = top5.avg
        except:
            logger.info('Failed to calculate the accuracy.')

        reg_term = 0.0
        if helpers.has_attr(args, "regularizer"):
            reg_term =  args.regularizer(model, inp, target)
        loss = loss + reg_term

        # ITERATOR
        desc = ('{2} Epoch:{0} | Loss {loss.avg:.4f} | '
                '{1}1 {top1_acc:.3f} | {1}5 {top5_acc:.3f} | '
                'Reg term: {reg} ||'.format( epoch, prec, loop_msg, 
                loss=losses, top1_acc=top1_acc, top5_acc=top5_acc, reg=reg_term))

        # USER-DEFINED HOOK
        if helpers.has_attr(args, 'iteration_hook'):
            args.iteration_hook(model, i, loop_type, inp, target)

        iterator.set_description(desc)
        iterator.refresh()

    all_preds = torch.stack(all_preds)
    all_labels = torch.stack(all_labels)
    torch.save(all_labels, os.path.join(args.out_dir, args.resume.split('/')[-2], 'labels.pt'))
    torch.save(all_preds, os.path.join(args.out_dir, args.resume.split('/')[-2], 'preds.pt'))
    acc_per_class = dict((k, float(correct_per_class[k])/accum_per_class[k]) for k in correct_per_class)
    logger.info(f'Accum per class {accum_per_class}')
    logger.info(f'Correct per class {correct_per_class}')
    logger.info(f'Accuracy per class {acc_per_class}')
    if helpers is not None:
        prec_type = 'adv' if adv else 'nat'
        descs = ['loss', 'top1', 'top5']
        vals = [losses, top1, top5]
        for d, v in zip(descs, vals):
            logger.info(
                f"{'_'.join([prec_type, loop_type, d])}, {v.avg}, {epoch}")

    return top1.avg, losses.avg
