from __future__ import print_function
import argparse
from tqdm import tqdm
from PIL import Image

import torch
import torch.nn as nn
import torch.optim as optim
import sys
sys.path.append('./')

from utils.misc import *
from utils.test_helpers import *
from utils.prepare_dataset import *
from utils.rotation import *
from utils.prepare_attack_dataset import *
from utils.DANN_model import DANNWrapper
from defense.DANN import *

# Check if DANN can adapt to the adversarial samples of the pretrained TTT network

# Create a parser for loading the source dataset 
parser = argparse.ArgumentParser()
parser.add_argument('--dataset', default='cifar10')
parser.add_argument('--dataroot', default='/nobackup/yguo/datasets')
parser.add_argument('--depth', default=26, type=int)
parser.add_argument('--width', default=1, type=int)
parser.add_argument('--batch_size', default=128, type=int)
parser.add_argument('--group_norm', default=0, type=int)
parser.add_argument('--shared', default='layer2')
args = parser.parse_args()

net, _, _, _ = build_model(args)
model = DANNWrapper(net)
_, test_source_loader = prepare_test_data(args)
_, train_source_loader = prepare_train_data(args)

#### Data Preparation
target_train_data = ADVDataset('attack_data/prTTT_pgd8/train.npy') 
target_test_data = ADVDataset('attack_data/prTTT_pgd8/test.npy') 
train_target_loader = torch.utils.data.DataLoader(
dataset=target_train_data,
batch_size=32,
shuffle=False,
num_workers=8)
test_target_loader = torch.utils.data.DataLoader(
dataset=target_test_data,
batch_size=32,
shuffle=False,
num_workers=8) 

# Model Preparation
init_random_seed(0)

lr = 3e-4
batch_size = 128
n_epoch = 100

# setup optimizer
# lr = 3e-4
# optimizer = optim.Adam(model.parameters(), lr=lr)

optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
scheduler = torch.optim.lr_scheduler.MultiStepLR(
optimizer, [75, 125], gamma=0.1, last_epoch=-1) 

loss_class = torch.nn.CrossEntropyLoss()
loss_domain = torch.nn.CrossEntropyLoss()

model = model.cuda()
loss_class = loss_class.cuda()
loss_domain = loss_domain.cuda()

source_dataset_name = 'cifar10'
target_dataset_name = 'cifar10-pgd8'
# DANN training
for epoch in range(n_epoch):
    train_one_epoch(model, train_source_loader, train_target_loader, epoch)

    test(model, test_source_loader, source_dataset_name, epoch)
    test(model, test_target_loader, target_dataset_name, epoch)
        
