import torchattacks
import torch
import torch.nn as nn
from utils import AverageMeter, accuracy_top1, accuracy
from tqdm import tqdm

def torchattack(args, model, data_loader, writer=None, epoch=0, loop_type='test'):
    model.eval()
    if args.eval_method == 'DIFGSM':
        attack = torchattacks.DIFGSM(model, eps=args.eps, random_start=args.random_restarts, steps=args.num_steps)
    elif args.eval_method == 'APGD':
        attack = torchattacks.APGD(model, norm=args.constraint, eps=args.eps, steps=args.num_steps, n_restarts=args.random_restarts, loss='ce')
    elif args.eval_method == 'FAB':
        attack = torchattacks.FAB(model, norm=args.constraint, eps=args.eps, steps=args.num_steps, n_restarts=args.random_restarts, targeted=True, n_classes=10)
    elif args.eval_method == 'Square':
        attack = torchattacks.Square(model, norm=args.constraint, eps=args.eps, n_queries=5000, n_restarts=args.random_restarts)

    loss_logger = AverageMeter()
    acc_logger = AverageMeter()
    ATTACK_NAME = '{}-{}-{}'.format(args.eval_method, args.constraint, args.eps)

    iterator = tqdm(enumerate(data_loader), total=len(data_loader), ncols=110)
    for i, (inp, target) in iterator:
        inp = inp.cuda(non_blocking=True)
        target = target.cuda(non_blocking=True)

        inp_adv = attack(inp, target)
        logits = model(inp_adv)

        print()

        loss = nn.CrossEntropyLoss()(logits, target)
        acc = accuracy_top1(logits, target)

        loss_logger.update(loss.item(), inp.size(0))
        acc_logger.update(acc, inp.size(0))

        desc = ('[{} {}] | Loss {:.4f} | Accuracy {:.4f} ||, {}'
                .format(ATTACK_NAME, loop_type, loss_logger.avg, acc_logger.avg, (inp-inp_adv).abs().max()))
        iterator.set_description(desc)

    if writer is not None:
        descs = ['loss', 'accuracy']
        vals = [loss_logger, acc_logger]
        for k, v in zip(descs, vals):
            writer.add_scalar('adv_{}_{}'.format(loop_type, k), v.avg, epoch)

    return loss_logger.avg, acc_logger.avg, ATTACK_NAME