
import torch.backends.cudnn as cudnn
import torch
import torch.utils.data
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import numpy as np
from advertorch.context import ctx_noparamgrad_and_eval
from advertorch.attacks import LinfPGDAttack
import argparse
import os
import sys
sys.path.append('./')
from utils.prepare_dataset import *
from utils.misc import init_random_seed
from utils.test_helpers import build_model
from utils.prepare_attack_dataset import *
from utils.prepare_corruption_dataset import *
from shutil import copyfile


parser = argparse.ArgumentParser()
parser.add_argument('--dataset', default='cifar10')
parser.add_argument('--dataroot', default='/nobackup/yguo/datasets/')
parser.add_argument('--shared', default='layer2')
########################################################################
parser.add_argument('--depth', default=26, type=int)
parser.add_argument('--width', default=1, type=int)
parser.add_argument('--batch_size', default=128, type=int)
parser.add_argument('--group_norm', default=8, type=int)
parser.add_argument('--fix_bn', default=False, type=bool)
parser.add_argument('--fix_ssh', default=False, type=bool)

args = parser.parse_args()
model, _, _, _ = build_model(args)

net, ext, head, ssh = build_model(args) 

# Attack cifar10 with the adversarially pretrained TTT model
name = "TTT_cifar10_pgd20"
copyfile('./results/pretrain/cifar10_adv_layer2_gn_expand/ckpt.pth', './results/pretrain/cifar10_adv_layer2_gn_expand/ckpt_temp.pth')
ckpt = torch.load('./results/pretrain/cifar10_adv_layer2_gn_expand/ckpt_temp.pth')
model.load_state_dict(ckpt['net'])
# Prepare Test dataset
_, test_loader = prepare_test_data(args)
_, train_loader = prepare_train_data(args)
prepare_pgd_attack_data(args, train_loader, model, name, nb_iter=20, train=True)
prepare_pgd_attack_data(args, test_loader, model, name, nb_iter=20, train=False)

# Attack cifar10c-fog with the adversarially pretrained TTT model
model, _, _, _ = build_model(args)

net, ext, head, ssh = build_model(args)

name = "TTT_cifar10_fog_pgd20"
copyfile('./results/pretrain/cifar10_adv_layer2_gn_expand/ckpt.pth',
         './results/pretrain/cifar10_adv_layer2_gn_expand/ckpt_temp.pth')
ckpt = torch.load('./results/pretrain/cifar10_adv_layer2_gn_expand/ckpt_temp.pth')
model.load_state_dict(ckpt['net'])
# Prepare Test dataset
(_, train_loader), (_, test_loader) = prepare_fog_data()
prepare_pgd_attack_data(args, train_loader, model,
                        name, nb_iter=20, train=True)
prepare_pgd_attack_data(args, test_loader, model,
                        name, nb_iter=20, train=False)
