import torch.backends.cudnn as cudnn
import torch
import torch.utils.data
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import numpy as np
from advertorch.context import ctx_noparamgrad_and_eval
from advertorch.attacks import LinfPGDAttack
import argparse
import os
import sys
sys.path.append('./')
from utils.prepare_dataset import *
from utils.misc import init_random_seed
from utils.test_helpers import build_model
from utils.prepare_attack_dataset import *
from utils.prepare_corruption_dataset import *
from shutil import copyfile


parser = argparse.ArgumentParser()
parser.add_argument('--dataset', default='cifar10')
parser.add_argument('--level', default=0, type=int)
parser.add_argument('--corruption', default='original')
parser.add_argument('--dataroot', default='/nobackup/yguo/datasets/')
parser.add_argument('--shared', default=None)
########################################################################
parser.add_argument('--depth', default=26, type=int)
parser.add_argument('--width', default=1, type=int)
parser.add_argument('--batch_size', default=128, type=int)
parser.add_argument('--group_norm', default=0, type=int)
parser.add_argument('--fix_bn', default=False, type=bool)
parser.add_argument('--fix_ssh', default=False, type=bool)

args = parser.parse_args()
model, _, _, _ = build_model(args)

attack_type = 'linf'
if attack_type == 'linf':
    name = "cifar10_adv_pgd20"
    ckpt = torch.load('./results/pretrain/cifar10_adv_pgd7/ckpt.pth')
    model.load_state_dict(ckpt['net'])
    # Prepare Test dataset
    _, test_loader = prepare_test_data(args)
    _, train_loader = prepare_train_data(args) 
    prepare_pgd_attack_data(args, train_loader, model, name, train=True, nb_iter=20)
    prepare_pgd_attack_data(args, test_loader, model, name, train=False, nb_iter=20)
elif attack_type == 'l2': 
    name = "cifar10_adv_l2_pgd20"
    ckpt = torch.load('./results/pretrain/cifar10_adv_l2_pgd7/ckpt.pth')
    model.load_state_dict(ckpt['net'])
    # Prepare Test dataset
    _, test_loader = prepare_test_data(args)
    _, train_loader = prepare_train_data(args) 

    prepare_pgd_attack_data_l2(args, train_loader, model, name, train=True, nb_iter=20)
    prepare_pgd_attack_data_l2(args, test_loader, model, name, train=False, nb_iter=20)