import presets
import torch
import torch.utils.data
import torchvision
from val_robust import utils
from torch import nn

# disable some unnecessary API to accelerate the training...
torch.autograd.profiler.emit_nvtx(False)
torch.autograd.profiler.profile(False)

def clamp(X, lower_limit, upper_limit):
    return torch.max(torch.min(X, upper_limit), lower_limit)

def PGDAttack(x, y, model, attack_epsilon, attack_alpha, lower_limit, loss_fn, upper_limit, max_iters, random_init):
    model.eval()

    delta = torch.zeros_like(x).cuda()
    if random_init:
        for iiiii in range(len(attack_epsilon)):
            delta[:, iiiii, :, :].uniform_(-attack_epsilon[iiiii][0][0].item(), attack_epsilon[iiiii][0][0].item())

    adv_imgs = clamp(x + delta, lower_limit, upper_limit)
    max_iters = int(max_iters)
    adv_imgs.requires_grad = True

    with torch.enable_grad():
        for _iter in range(max_iters):
            outputs = model(adv_imgs)

            loss = loss_fn(outputs, y)
            # loss.requires_grad = True
            grads = torch.autograd.grad(loss, adv_imgs, grad_outputs=None,
                                        only_inputs=True)[0]

            adv_imgs.data += attack_alpha * torch.sign(grads.data)

            adv_imgs = clamp(adv_imgs, x - attack_epsilon, x + attack_epsilon)

            adv_imgs = clamp(adv_imgs, lower_limit, upper_limit)

    return adv_imgs.detach()


def evaluate(model, criterion, data_loader):
    print('**********_evaluate_**********')
    metric_logger = utils.MetricLogger(delimiter="  ")
    header = f"Test:"

    with torch.inference_mode():
        for image, target in metric_logger.log_every(data_loader, 100, header):
            image = image.to('cuda', non_blocking=True)
            target = target.to('cuda', non_blocking=True)

            output = model(image)
            acc1, acc5 = utils.accuracy(output, target, topk=(1, 5))

            batch_size = image.shape[0]

            metric_logger.meters["acc1"].update(acc1.item(), n=batch_size)
            metric_logger.meters["acc5"].update(acc5.item(), n=batch_size)
    metric_logger.synchronize_between_processes()

    print(f"{header} Acc@1 {metric_logger.acc1.global_avg:.3f} Acc@5 {metric_logger.acc5.global_avg:.3f}")

    return metric_logger.acc1.global_avg

@torch.no_grad()
def evaluate_FGSM(model, criterion, data_loader, device = 'cuda', print_freq=100, log_suffix="", tta=False):
    print('**********_evaluate_FGSM**********')

    model.eval()
    metric_logger = utils.MetricLogger(delimiter="  ")
    header = f"Test: {log_suffix}"

    num_processed_samples = 0

    for image, target in metric_logger.log_every(data_loader, print_freq, header):
        image = image.to(device, non_blocking=True)
        target = target.to(device, non_blocking=True)

        std_imagenet = torch.tensor((0.229, 0.224, 0.225)).view(3, 1, 1).cuda()
        mu_imagenet = torch.tensor((0.485, 0.456, 0.406)).view(3, 1, 1).cuda()
        attack_epsilon = (1 / 255.) / std_imagenet
        attack_alpha = (1 / 255.) / std_imagenet
        upper_limit = ((1 - mu_imagenet) / std_imagenet)
        lower_limit = ((0 - mu_imagenet) / std_imagenet)
        adv_input = PGDAttack(image, target, model, attack_epsilon, attack_alpha, lower_limit, criterion,
                              upper_limit, max_iters=1, random_init=False)

        output_adv = model(adv_input)
        if tta:
            output_adv += model(torch.flip(adv_input, dims=[3]))

        acc1_adv, acc5_adv = utils.accuracy(output_adv, target, topk=(1, 5))

        batch_size = image.shape[0]

        metric_logger.meters["acc1_adv"].update(acc1_adv.item(), n=batch_size)
        metric_logger.meters["acc5_adv"].update(acc5_adv.item(), n=batch_size)

    metric_logger.synchronize_between_processes()

    print(
        f"{header} Acc@1_adv {metric_logger.acc1_adv.global_avg:.3f} Acc@5_adv {metric_logger.acc5_adv.global_avg:.3f}")

    return metric_logger.acc1_adv.global_avg


@torch.no_grad()
def evaluate_PGD(model, criterion, data_loader, device = 'cuda', print_freq=100, log_suffix="", tta=False):
    print('**********_evaluate_PGD**********')

    model.eval()
    metric_logger = utils.MetricLogger(delimiter="  ")
    header = f"Test: {log_suffix}"

    num_processed_samples = 0

    for image, target in metric_logger.log_every(data_loader, print_freq, header):
        image = image.to(device, non_blocking=True)
        target = target.to(device, non_blocking=True)

        std_imagenet = torch.tensor((0.229, 0.224, 0.225)).view(3, 1, 1).cuda()
        mu_imagenet = torch.tensor((0.485, 0.456, 0.406)).view(3, 1, 1).cuda()
        attack_epsilon = (1 / 255.) / std_imagenet
        attack_alpha = (0.5 / 255.) / std_imagenet
        upper_limit = ((1 - mu_imagenet) / std_imagenet)
        lower_limit = ((0 - mu_imagenet) / std_imagenet)
        adv_input = PGDAttack(image, target, model, attack_epsilon, attack_alpha, lower_limit, criterion,
                              upper_limit, max_iters=5, random_init=True)

        output_adv = model(adv_input)
        if tta:
            output_adv += model(torch.flip(adv_input, dims=[3]))

        acc1_adv, acc5_adv = utils.accuracy(output_adv, target, topk=(1, 5))

        batch_size = image.shape[0]

        metric_logger.meters["acc1_adv"].update(acc1_adv.item(), n=batch_size)
        metric_logger.meters["acc5_adv"].update(acc5_adv.item(), n=batch_size)

    metric_logger.synchronize_between_processes()

    print(
        f"{header} Acc@1_adv {metric_logger.acc1_adv.global_avg:.3f} Acc@5_adv {metric_logger.acc5_adv.global_avg:.3f}")

    return metric_logger.acc1_adv.global_avg

def load_data(valdir):

    preprocessing = presets.ClassificationPresetEval(
        crop_size=224, resize_size=256,
    )
    dataset_test = torchvision.datasets.ImageFolder(
        valdir,
        preprocessing,
    )

    test_sampler = torch.utils.data.SequentialSampler(dataset_test)

    return dataset_test, test_sampler

def eval_imgnet_adv(net,batch_size,num_workers,location,attack_type):
    assert attack_type in ['FGSM','PGD']
    net.cuda()
    net.eval()

    criterion = nn.CrossEntropyLoss(label_smoothing=0.0)

    dataset_test, test_sampler = load_data(location,)
    data_loader_test = torch.utils.data.DataLoader(
        dataset_test, batch_size=batch_size, sampler=test_sampler, num_workers=num_workers, pin_memory=True
    )
    print("*************ImageNet-1k Results*****************")
    acc = evaluate(net, criterion, data_loader_test)
    print("*************ImageNet-1k Results*****************")

    print("*************ImageNet-adv Results*****************")
    acc_adv_fgsm = evaluate_FGSM(net, criterion, data_loader_test,)
    print("*************ImageNet-adv Results*****************")

    print("*************ImageNet-adv Results*****************")
    acc_adv_pgd = evaluate_PGD(net, criterion, data_loader_test,)
    print("*************ImageNet-adv Results*****************")
    return acc, acc_adv_fgsm, acc_adv_pgd