import torch.backends.cudnn as cudnn
import torch
import torch.utils.data
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import numpy as np
from advertorch.context import ctx_noparamgrad_and_eval
from advertorch.attacks import LinfPGDAttack, L2PGDAttack
import argparse
import os
import sys
sys.path.append('./')
from utils.prepare_dataset import *
from utils.misc import init_random_seed
from utils.test_helpers import build_model
 

NORM = ((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
te_transforms = transforms.Compose([transforms.ToTensor(),
                                    transforms.Normalize(*NORM)])
tr_transforms = transforms.Compose([transforms.RandomCrop(32, padding=4),
                                    transforms.RandomHorizontalFlip(),
                                    transforms.ToTensor(),
                                    transforms.Normalize(*NORM)])
mnist_transforms = transforms.Compose([transforms.Resize((32, 32)),
                                        transforms.ToTensor(),
                                        transforms.Normalize((0.1307,), (0.3081,))])

common_corruptions = ['gaussian_noise', 'shot_noise', 'impulse_noise', 'defocus_blur', 'glass_blur',
                    'motion_blur', 'zoom_blur', 'snow', 'frost', 'fog',
                    'brightness', 'contrast', 'elastic_transform', 'pixelate', 'jpeg_compression']

def test_single(model, inputs, label):
    model.eval()
    with torch.no_grad():
        outputs = model(inputs.cuda())
        _, predicted = outputs.max(1)
        print(predicted)

    return torch.sum(predicted ==label)

# Create l-inf attack for the CIFAR train and test data from a given model
def prepare_pgd_attack_data(args, data_loader, model, name, train = False, 
                            eps = 16/255, nb_iter = 7, seed = 0):
    init_random_seed(seed)
    adversary = LinfPGDAttack(
        model, loss_fn=nn.CrossEntropyLoss().cuda(), eps=eps,
        nb_iter=nb_iter, eps_iter=4/255, rand_init=True, clip_min=-1.0, clip_max=1.0,
        targeted=False)
    inputs_new = []
    labels_new = []
    n_correct = 0
    n_correct_adv = 0
    n_total = 0
    model.eval()
    for batch_idx, (inputs, labels) in enumerate(data_loader):
        inputs_cls, labels_cls = inputs.cuda(), labels.cuda()
        with ctx_noparamgrad_and_eval(model):
            inputs_adv = adversary.perturb(inputs_cls, labels_cls)
        inputs_new.extend(inputs_adv.cpu().numpy())
        labels_new.extend(labels_cls.cpu().numpy())  
        outputs_cls = model(inputs_cls)
        outputs_adv = model(inputs_adv)
        _, predicted_cls = outputs_cls.max(1)
        n_correct += torch.sum(predicted_cls == labels_cls).float()
        _, predicted_adv = outputs_adv.max(1)
        n_correct_adv += torch.sum(predicted_adv == labels_cls).float()
    n_total = len(labels_new)
    print('Process {}'.format(n_total))
    cls_acc = n_correct.data.cpu().numpy() * 1.0 / n_total
    print('Classification Accuracy:',
          cls_acc)
    adv_acc = n_correct_adv.data.cpu().numpy() * 1.0 / n_total
    print('Adversarial Accuracy:', 
          adv_acc)
    if not os.path.exists('./attack_data/{}'.format(name)):
        os.makedirs('./attack_data/{}'.format(name))
    if train:
        np.save('./attack_data/{}/train'.format(name), [inputs_new, labels_new])
    else: 
        np.save('./attack_data/{}/test'.format(name), [inputs_new, labels_new]) 
    return cls_acc, adv_acc

# Create l-2 attack for the CIFAR train and test data from a given model
def prepare_pgd_attack_data_l2(args, data_loader, model, name, train = False, 
                            eps = 160/255, nb_iter = 7, seed = 0):
    init_random_seed(seed)
    adversary = L2PGDAttack(
        model, loss_fn=nn.CrossEntropyLoss().cuda(), eps=eps,
        nb_iter=nb_iter, eps_iter=40/255, rand_init=True, clip_min=-1.0, clip_max=1.0,
        targeted=False)
    inputs_new = []
    labels_new = []
    n_correct = 0
    n_correct_adv = 0
    n_total = 0
    model.eval()
    for batch_idx, (inputs, labels) in enumerate(data_loader):
        inputs_cls, labels_cls = inputs.cuda(), labels.cuda()
        with ctx_noparamgrad_and_eval(model):
            inputs_adv = adversary.perturb(inputs_cls, labels_cls)
        inputs_new.extend(inputs_adv.cpu().numpy())
        labels_new.extend(labels_cls.cpu().numpy())  
        outputs_cls = model(inputs_cls)
        outputs_adv = model(inputs_adv)
        _, predicted_cls = outputs_cls.max(1)
        n_correct += torch.sum(predicted_cls == labels_cls).float()
        _, predicted_adv = outputs_adv.max(1)
        n_correct_adv += torch.sum(predicted_adv == labels_cls).float()
    n_total = len(labels_new)
    print('Process {}'.format(n_total))
    cls_acc = n_correct.data.cpu().numpy() * 1.0 / n_total
    print('Classification Accuracy:',
          cls_acc)
    adv_acc = n_correct_adv.data.cpu().numpy() * 1.0 / n_total
    print('Adversarial Accuracy:', 
          adv_acc)
    if not os.path.exists('./attack_data/{}'.format(name)):
        os.makedirs('./attack_data/{}'.format(name))
    if train:
        np.save('./attack_data/{}/train'.format(name), [inputs_new, labels_new])
    else: 
        np.save('./attack_data/{}/test'.format(name), [inputs_new, labels_new]) 
    return cls_acc, adv_acc

# Create a pytorch dataset from the saved samples
class ADVDataset(torch.utils.data.Dataset):
    def __init__(self, data_path, transform=None):
        self.data, self.labels = np.load(data_path, allow_pickle=True)
        self.n_data = len(self.data)
        self.transform = transform

    def __getitem__(self, item):
        label = int(self.labels[item])
        return self.data[item], label

    def __len__(self):
        return self.n_data


if __name__ == "__main__":
    # data = ADVDataset('attack_data/prTTT_pgd8/test.npy') 
    # loader = torch.utils.data.DataLoader(
    # dataset=data,
    # batch_size=32,
    # shuffle=False,
    # num_workers=8)
    parser = argparse.ArgumentParser()
    parser.add_argument('--dataset', default='cifar10')
    parser.add_argument('--level', default=0, type=int)
    parser.add_argument('--corruption', default='original')
    parser.add_argument('--dataroot', default='/nobackup/yguo/datasets/')
    parser.add_argument('--shared', default='layer2')
    ########################################################################
    parser.add_argument('--depth', default=26, type=int)
    parser.add_argument('--width', default=1, type=int)
    parser.add_argument('--batch_size', default=128, type=int)
    parser.add_argument('--group_norm', default=8, type=int)
    parser.add_argument('--fix_bn', default=False, type=bool)
    parser.add_argument('--fix_ssh', default=False, type=bool)
 

    args = parser.parse_args()
    model, _, _, _ = build_model(args)

    net, ext, head, ssh = build_model(args) 


    name = "prTTT_pgd8" 
    ckpt = torch.load('./results/cifar10_layer2_gn_expand/ckpt0.pth')
    model.load_state_dict(ckpt['net'])
    # Prepare Test dataset
    _, data_loader = prepare_test_data(args)
    prepare_pgd_attack_data(args, data_loader, model, name)
