import torch.backends.cudnn as cudnn
import torch
import torch.utils.data
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import numpy as np
from advertorch.context import ctx_noparamgrad_and_eval
from advertorch.attacks import LinfPGDAttack, L2PGDAttack
import argparse
import os
import sys
sys.path.append('./')
from utils.prepare_dataset import *
from utils.misc import init_random_seed
from utils.test_helpers import build_model
from utils.DANN_model import DANNWrapper


NORM = ((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
te_transforms = transforms.Compose([transforms.ToTensor(),
                                    transforms.Normalize(*NORM)])
tr_transforms = transforms.Compose([transforms.RandomCrop(32, padding=4),
                                    transforms.RandomHorizontalFlip(),
                                    transforms.ToTensor(),
                                    transforms.Normalize(*NORM)])
mnist_transforms = transforms.Compose([transforms.Resize((32, 32)),
                                        transforms.ToTensor(),
                                        transforms.Normalize((0.1307,), (0.3081,))])

common_corruptions = ['gaussian_noise', 'shot_noise', 'impulse_noise', 'defocus_blur', 'glass_blur',
                    'motion_blur', 'zoom_blur', 'snow', 'frost', 'fog',
                    'brightness', 'contrast', 'elastic_transform', 'pixelate', 'jpeg_compression']

def test_single(model, inputs, label):
    model.eval()
    with torch.no_grad():
        outputs = model(inputs.cuda())
        _, predicted = outputs.max(1)
        print(predicted)

    return torch.sum(predicted ==label)


# Perturb the target data within in the l-inf ball to achieve the worst DANN loss (Not 
# just the prediction loss). Note: we don't need to compute the source domain 
# loss, since the perturbation is only on the target label. 
def dann_perturb(x, y, alpha, model, nb_iter = 7, eps_iter = 4/255, eps =  16/255):
    loss_class = torch.nn.CrossEntropyLoss()
    loss_domain = torch.nn.CrossEntropyLoss()

    loss_class = loss_class.cuda()
    loss_domain = loss_domain.cuda()
    # Note: x,y are both target data. 
    x = x.detach().clone().cuda()
    y = y.detach().clone().cuda()

    model.eval()
    # Define the perturbation
    delta = torch.zeros_like(x).cuda()
    delta = nn.Parameter(delta)
    delta.requires_grad_()

    # yd: domain of the data 
    yd = torch.ones_like(y).cuda()
    for i in range(nb_iter):
        adv_x = x + delta 
        torch.autograd.set_detect_anomaly(True)
        class_output, domain_output = model(input_data = adv_x, alpha = alpha) 
        
        # err_label = loss_class(class_output, y)
        err_domain = loss_domain(domain_output, yd)
        err = err_domain 
        err.backward()      

        grad_sign = delta.grad.data.sign()
        delta.data = delta.data + grad_sign * eps_iter 

        delta.data = torch.clamp(delta.data, min=-eps, max=eps)
        delta.data = torch.clamp(x.data + delta.data, min=-1, max=1) - x.data

        delta.grad.data.zero_()
    adv_x = x + delta.data 
    return adv_x 

# prepare the dann perturbed data for a dataset. 
def prepare_dann_pgd_attack_data(args, data_loader, model, name, alpha =1, train = False, 
                            eps = 16/255, nb_iter = 7, seed = 0):
    init_random_seed(seed)
    inputs_new = []
    labels_new = []
    n_correct = 0
    n_correct_adv = 0
    n_total = 0
    model.eval()
    for batch_idx, (inputs, labels) in enumerate(data_loader):
        inputs_cls, labels_cls = inputs.cuda(), labels.cuda()

        # Create adversarial data
        inputs_adv = dann_perturb(inputs_cls, labels_cls, alpha, model, nb_iter, eps_iter = 4/255, eps =  eps)

        inputs_new.extend(inputs_adv.cpu().numpy())
        labels_new.extend(labels_cls.cpu().numpy())  
        outputs_cls, _ = model(inputs_cls, alpha)
        outputs_adv, _ = model(inputs_adv, alpha)
        _, predicted_cls = outputs_cls.max(1)
        n_correct += torch.sum(predicted_cls == labels_cls).float()
        _, predicted_adv = outputs_adv.max(1)
        n_correct_adv += torch.sum(predicted_adv == labels_cls).float()
    n_total = len(labels_new)
    print('Process {}'.format(n_total))
    cls_acc = n_correct.data.cpu().numpy() * 1.0 / n_total
    print('Classification Accuracy:',
          cls_acc)
    adv_acc = n_correct_adv.data.cpu().numpy() * 1.0 / n_total
    print('Adversarial Accuracy:', 
          adv_acc)
    if not os.path.exists('./attack_data/{}'.format(name)):
        os.makedirs('./attack_data/{}'.format(name))
    if train:
        np.save('./attack_data/{}/train'.format(name), [inputs_new, labels_new])
    else: 
        np.save('./attack_data/{}/test'.format(name), [inputs_new, labels_new]) 
    return cls_acc, adv_acc

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--dataset', default='cifar10')
    parser.add_argument('--level', default=0, type=int)
    parser.add_argument('--corruption', default='original')
    parser.add_argument('--dataroot', default='/nobackup/yguo/datasets/')
    parser.add_argument('--shared', default='layer2')
    ########################################################################
    parser.add_argument('--depth', default=26, type=int)
    parser.add_argument('--width', default=1, type=int)
    parser.add_argument('--batch_size', default=128, type=int)
    parser.add_argument('--group_norm', default=8, type=int)
    parser.add_argument('--fix_bn', default=False, type=bool)
    parser.add_argument('--fix_ssh', default=False, type=bool)
 

    args = parser.parse_args()
    model, _, _, _ = build_model(args)

    net, ext, head, ssh = build_model(args) 


    name = "DANN_attack" 
    ckpt = torch.load('./results/cifar10_layer2_gn_expand/ckpt0.pth')
    model.load_state_dict(ckpt['net'])
    model = DANNWrapper(model) 
    model.cuda()
    # Prepare Test dataset
    _, data_loader = prepare_test_data(args)
    prepare_dann_pgd_attack_data(args, data_loader, model, name)
