import torch
import torch.nn as nn

from tqdm import tqdm
import sys
sys.path.append('..')

from utils import AverageMeter, accuracy_top1, accuracy
from attacks.step import LinfStep, L2Step

STEPS = {
    'Linf': LinfStep,
    'L2': L2Step,
}


def batch_tf_attack(args, model, x, target):
    orig_x = x.clone().detach()
    step = STEPS[args.constraint](orig_x, args.eps, args.step_size)

    @torch.enable_grad()
    def get_tf_examples(x):

        @torch.enable_grad()
        def get_adv_examples(x_inner):
            orig_x_inner = x_inner.clone().detach().requires_grad_(True)
            step_inner = STEPS[args.constraint](orig_x_inner, args.eps, args.step_size)
            for _ in range(args.num_steps):
                x_inner = x_inner.clone().detach().requires_grad_(True)
                logits_inner = model(x_inner)
                loss_inner = -1 * nn.CrossEntropyLoss()(logits_inner, target)
                grad_inner = torch.autograd.grad(loss_inner, [x_inner])[0]
                with torch.no_grad():
                    x_inner = step_inner.step(x_inner, grad_inner)
                    x_inner = step_inner.project(x_inner)
                    x_inner = torch.clamp(x_inner, 0, 1)
            return x_inner.clone().detach()

        for _ in range(args.num_steps):
            x = x.clone().detach().requires_grad_(False)

            x_inner = x.clone().detach().requires_grad_(True)
            x_adv = get_adv_examples(x_inner)
            x_adv = x_adv.clone().detach().requires_grad_(True)

            logits = model(x_adv)
            loss = +1 * nn.CrossEntropyLoss()(logits, target)
            grad = torch.autograd.grad(loss, [x_adv])[0]
            with torch.no_grad():
                x = step.step(x, grad)
                x = step.project(x)
                x = torch.clamp(x, 0, 1)
        return x.clone().detach()
    
    to_ret = None

    if args.random_restarts == 0:
        hyp = get_tf_examples(x)
        to_ret = hyp.detach()
    elif args.random_restarts == 1:
        x = x.detach() + 0.01 * torch.randn_like(x).detach()
        x = torch.clamp(x, 0, 1)
        hyp = get_tf_examples(x)
        to_ret = hyp.detach()
    else:
        for _ in range(args.random_restarts):
            x = x.detach() + 0.01 * torch.randn_like(x).detach()
            x = torch.clamp(x, 0, 1)

            hyp = get_tf_examples(x)
            if to_ret is None:
                to_ret = hyp.detach()
            
            logits = model(hyp)
            # TODO: use adversarial accuracy as metric
            corr, = accuracy(logits, target, topk=(1,), exact=True)
            corr = corr.bool()
            to_ret[corr] = hyp[corr]
    
    return to_ret.detach().requires_grad_(False)


@torch.no_grad()
def tf_attack(args, model, loader):
    model.eval()
    hyp_input = []
    clean_target = []
    loss_logger = AverageMeter()
    acc_logger = AverageMeter()
    ATTACK_NAME = 'TF-{}-{}'.format(args.num_steps, args.random_restarts)

    iterator = tqdm(enumerate(loader), total=len(loader), ncols=110)
    for i, (inp, target) in iterator:
        inp = inp.cuda(non_blocking=True)
        target = target.cuda(non_blocking=True)

        inp_hyp = batch_tf_attack(args, model, inp, target)
        hyp_input.append(inp_hyp.detach().cpu())
        clean_target.append(target.detach().cpu())

        logits = model(inp_hyp)
        # TODO: use adversarial accuracy as metric
        loss = nn.CrossEntropyLoss()(logits, target)
        acc = accuracy_top1(logits, target)

        loss_logger.update(loss.item(), inp.size(0))
        acc_logger.update(acc, inp.size(0))

        desc = ('[{}] | Loss {:.4f} | Accuracy {:.4f} ||'
                .format(ATTACK_NAME, loss_logger.avg, acc_logger.avg))
        iterator.set_description(desc)

    hyp_input = torch.cat(hyp_input, dim=0)
    clean_target = torch.cat(clean_target, dim=0)
    return hyp_input, clean_target, ATTACK_NAME






