import argparse
import os
import shutil
import time

import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim
import torch.utils.data
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.autograd import Variable
import resnet
#from adabound import AdaBound
#from attacker import LinfPGDAttack

import attack_generator as attack

import numpy as np
import random

model_names = sorted(name for name in resnet.__dict__
    if name.islower() and not name.startswith("__")
                     and name.startswith("resnet")
                     and callable(resnet.__dict__[name]))

print(model_names)

parser = argparse.ArgumentParser(description='Propert ResNets for CIFAR10 in pytorch')
parser.add_argument('--arch', '-a', metavar='ARCH', default='resnet32',
                    choices=model_names,
                    help='model architecture: ' + ' | '.join(model_names) +
                    ' (default: resnet32)')
parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
                    help='number of data loading workers (default: 4)')
parser.add_argument('-b', '--batch-size', default=128, type=int,
                    metavar='N', help='mini-batch size (default: 128)')
parser.add_argument('--val-batch-size', type=int, default=200, metavar='N',
                        help='input batch size for validating (default: 200)')
parser.add_argument('--test-batch-size', type=int, default=128, metavar='N',
                        help='input batch size for testing (default: 128)')
parser.add_argument('--print-freq', '-p', default=50, type=int,
                    metavar='N', help='print frequency (default: 1)')
parser.add_argument('--resume', default='', type=str, metavar='PATH',
                    help='path to latest checkpoint (default: none)')
parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
                    help='evaluate model on validation set')
parser.add_argument('--pretrained', dest='pretrained', action='store_true',
                    help='use pre-trained model')
parser.add_argument('--save-dir', dest='save_dir',
                    help='The directory used to save the trained models',
                    default='save_temp', type=str)
parser.add_argument('--seed', type=int, default=1, metavar='S',help='random seed (default: 1)')
parser.add_argument('--method',type=str,default='dat')    # fgsm, ifgsm, cwl2
#parser.add_argument('--epsilon',type=float,default=0.031)
# parser.add_argument('--partition', action='store_true', default=False,
#                         help='partition training dataset into traindata and validatedata')

best_prec1 = 0

# results = {'trainloss_val':[],'trainloss_avg':[],'trainprec_val':[],'trainprec_avg':[],
#             'valloss_val':[],'valloss_avg':[],'valprec_val':[],'valprec_avg':[],
#             'adv_valloss_val':[],'adv_valloss_avg':[],'adv_valprec_val':[],'adv_valprec_avg':[],
#            'testprec':[],'testloss':[],'adv_testprec':[],'adv_testloss':[]}

def setup_seed(seed):
     torch.manual_seed(seed)
     torch.cuda.manual_seed(seed)
     torch.cuda.manual_seed_all(seed)
     np.random.seed(seed)
     random.seed(seed)
     torch.backends.cudnn.deterministic = True

def main():
    global args, best_prec1
    args = parser.parse_args()
    setup_seed(args.seed)
    #torch.manual_seed(1)

    # Check the save_dir exists or not
    # if not os.path.exists(args.save_dir):
    #     os.makedirs(args.save_dir)
    assert args.resume != ''

    main_start_time = time.time()

    net = torch.nn.DataParallel(resnet.__dict__[args.arch](),device_ids=[0])
    net.cuda()

    #model = LinfPGDAttack(net)

    assert os.path.isfile(args.resume)
    checkpoint = torch.load(args.resume)
    #best_prec1 = checkpoint['best_prec1']
    net.load_state_dict(checkpoint['state_dict'])
    print("=> loaded checkpoint. ")

    #cudnn.benchmark = True

    #normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
    #                                 std=[0.229, 0.224, 0.225])

    batchsize_test = args.val_batch_size

    attack_type = args.method

    val_loader = torch.utils.data.DataLoader(
        datasets.CIFAR10(root='./data', train=False, transform=transforms.Compose([
            transforms.ToTensor(),
            #normalize,
        ])),
        batch_size=batchsize_test, shuffle=False,
        num_workers=args.workers, pin_memory=True)

    test_loader = val_loader

    #print('train_loader.dataset:',len(train_loader.dataset))
    print('val_loader.dataset:',len(val_loader.dataset))

    # define loss function (criterion) and optimizer
    #criterion = nn.CrossEntropyLoss(reduction='mean').cuda()

    model = net

    print('==> Evaluating Performance under White-box Adversarial Attack')

    loss, test_nat_acc = attack.eval_clean(model, test_loader)
    print('Natural Test Accuracy: {:.2f}%'.format(100. * test_nat_acc))
    if args.method == "dat":
        # Evalutions the same as DAT.
        loss, fgsm_acc = attack.eval_robust(model, test_loader, perturb_steps=1, epsilon=0.031, step_size=0.031,
                                            loss_fn="cent", category="Madry", rand_init=True)
        print('FGSM Test Accuracy: {:.2f}%'.format(100. * fgsm_acc))
        loss, pgd20_acc = attack.eval_robust(model, test_loader, perturb_steps=20, epsilon=0.031, step_size=0.031 / 4,
                                             loss_fn="cent", category="Madry", rand_init=True)
        print('PGD20 Test Accuracy: {:.2f}%'.format(100. * pgd20_acc))
        loss, cw_acc = attack.eval_robust(model, test_loader, perturb_steps=30, epsilon=0.031, step_size=0.031 / 4,
                                          loss_fn="cw", category="Madry", rand_init=True)
        print('CW Test Accuracy: {:.2f}%'.format(100. * cw_acc))
    if args.method == 'trades':
        # Evalutions the same as TRADES.
        # wri : with random init, wori : without random init
        loss, fgsm_wori_acc = attack.eval_robust(model, test_loader, perturb_steps=1, epsilon=0.031, step_size=0.031,
                                                 loss_fn="cent", category="Madry", rand_init=False)
        print('FGSM without Random Start Test Accuracy: {:.2f}%'.format(100. * fgsm_wori_acc))
        loss, pgd20_wori_acc = attack.eval_robust(model, test_loader, perturb_steps=20, epsilon=0.031, step_size=0.003,
                                                  loss_fn="cent", category="Madry", rand_init=False)
        print('PGD20 without Random Start Test Accuracy: {:.2f}%'.format(100. * pgd20_wori_acc))
        loss, cw_wori_acc = attack.eval_robust(model, test_loader, perturb_steps=30, epsilon=0.031, step_size=0.003,
                                               loss_fn="cw", category="Madry", rand_init=False)
        print('CW without Random Start Test Accuracy: {:.2f}%'.format(100. * cw_wori_acc))
        loss, fgsm_wri_acc = attack.eval_robust(model, test_loader, perturb_steps=1, epsilon=0.031, step_size=0.031,
                                                loss_fn="cent", category="Madry", rand_init=True)
        print('FGSM with Random Start Test Accuracy: {:.2f}%'.format(100. * fgsm_wri_acc))
        loss, pgd20_wri_acc = attack.eval_robust(model, test_loader, perturb_steps=20, epsilon=0.031, step_size=0.003,
                                                 loss_fn="cent", category="Madry", rand_init=True)
        print('PGD20 with Random Start Test Accuracy: {:.2f}%'.format(100. * pgd20_wri_acc))
        loss, cw_wri_acc = attack.eval_robust(model, test_loader, perturb_steps=30, epsilon=0.031, step_size=0.003,
                                              loss_fn="cw", category="Madry", rand_init=True)
        print('CW with Random Start Test Accuracy: {:.2f}%'.format(100. * cw_wri_acc))


if __name__ == '__main__':
    main()
