import os
import pdb
import sys 
import torch
import pickle
import argparse
import torch.optim
import torch.nn as nn
import torch.utils.data
import matplotlib.pyplot as plt
import torchvision.models as models
from advertorch.utils import NormalizeByChannelMeanStd

from models.resnetv2 import ResNet18
from datasets import *
import utils

parser = argparse.ArgumentParser(description='PyTorch Cifar10 Training')

########################## base setting ##########################
parser.add_argument('--data', type=str, default='../data', help='location of the data corpus')
parser.add_argument('--print_freq', default=50, type=int, help='print frequency')
parser.add_argument('--gpu', type=int, default=0, help='gpu device id')
parser.add_argument('--pretrained', help='The directory of pretrained models', default=None, type=str)
parser.add_argument('--batch_size', type=int, default=128, help='batch size')

########################## attack setting ##########################
parser.add_argument('--norm', default='linf', type=str, help='linf or l2')
parser.add_argument('--test_eps', default=8, type=float, help='test_eps')
parser.add_argument('--test_step', default=20, type=int, help='test_step')
parser.add_argument('--test_gamma', default=2, type=float, help='test_gamma')
parser.add_argument('--test_randinit', action='store_false', help='randinit usage flag (default: on)')


def main():
    global args
    args = parser.parse_args()

    args.test_eps = args.test_eps / 255
    args.test_gamma = args.test_gamma / 255
    print(args)

    torch.cuda.set_device(int(args.gpu))

    ########################## prepare dataset ##########################
    model = ResNet18(num_classes = 10)
    model.normal = NormalizeByChannelMeanStd(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010])
    _, _, test_loader = cifar10_dataloaders(train_batch_size= args.batch_size, test_batch_size=args.batch_size, data_dir =args.data)
    model.cuda()
    criterion = nn.CrossEntropyLoss()

    print('loading pretrained model from{}'.format(args.pretrained))
    pre_weight = torch.load(args.pretrained, map_location = torch.device('cuda:'+str(args.gpu)))
    model.load_state_dict(pre_weight)

    # testing 
    print('Start testing on standard test dataset')
    _, test_sa = utils.validate(test_loader, model, criterion)
    print('*'*50)
    print('standard accuracy (SA) = {}'.format(test_sa))
    print('*'*50)

    print('Start testing under PGD Attack')
    print('norm = {}'.format(args.norm))
    print('epsilon = {}'.format(args.test_eps))
    print('step_size = {}'.format(args.test_gamma))
    print('number of steps = {}'.format(args.test_step))
    print('randinit = {}'.format(args.test_randinit))

    _, test_ra = utils.validate_adv(test_loader, model, criterion, args)
    print('*'*50)
    print('Robust accuracy (RA) = {}'.format(test_ra))
    print('*'*50)
        
if __name__ == '__main__':
    main()


