from __future__ import print_function
import argparse
from tqdm import tqdm
from PIL import Image

import torch
import torch.nn as nn
import torch.optim as optim
import sys
sys.path.append('./')

from utils.misc import *
from utils.test_helpers import *
from utils.prepare_dataset import *
from utils.rotation import *
from utils.prepare_attack_dataset import *
from shutil import copyfile    
from adv_test_calls.advtest_TTT import test_robustness

parser = argparse.ArgumentParser()
parser.add_argument('--dataset', default='cifar10')
parser.add_argument('--level', default=0, type=int)
parser.add_argument('--corruption', default='original')
parser.add_argument('--dataroot', default='/nobackup/yguo/datasets/')
parser.add_argument('--shared', default='layer2')
########################################################################
parser.add_argument('--depth', default=26, type=int)
parser.add_argument('--width', default=1, type=int)
parser.add_argument('--batch_size', default=32, type=int)
parser.add_argument('--group_norm', default=8, type=int)
parser.add_argument('--fix_bn', action='store_false')
parser.add_argument('--fix_ssh', action='store_false')
########################################################################
parser.add_argument('--lr', default=0.001, type=float)
parser.add_argument('--niter', default=1, type=int)
parser.add_argument('--online', action='store_true')
parser.add_argument('--threshold', default=1, type=float)
parser.add_argument('--dset_size', default=0, type=int)
########################################################################
parser.add_argument('--resume', default='results/pretrain/cifar10_adv_layer2_gn_expand')
parser.add_argument('--outf', default='results/pretrain/cifar10_adv_layer2_gn_expand')

args = parser.parse_args()
print(args)
test_robustness(args, adv_data_dir='attack_data/TTT_cifar10_pgd20')
