from __future__ import print_function
import argparse
from tqdm import tqdm
from PIL import Image

import torch
import torch.nn as nn
import torch.optim as optim
import sys
sys.path.append('./')

from utils.misc import *
from utils.test_helpers import *
from utils.prepare_dataset import *
from utils.rotation import *
from utils.prepare_attack_dataset import *
from defense.DANN import *
import argparse

# Create a parser for loading the source dataset 
parser = argparse.ArgumentParser()
parser.add_argument('--dataset', default='cifar10')
parser.add_argument('--dataroot', default='/nobackup/yguo/datasets')
parser.add_argument('--depth', default=26, type=int)
parser.add_argument('--width', default=1, type=int)
parser.add_argument('--batch_size', default=128, type=int)
parser.add_argument('--group_norm', default=0, type=int)
parser.add_argument('--shared', default='layer2')

args = parser.parse_args()

net, _, _, _ = build_model(args)
model = DANNWrapper(net)
_, test_source_loader = prepare_test_data(args)
_, train_source_loader = prepare_train_data(args)


#### Data Preparation
target_train_data = ADVDataset('attack_data/cifar10c_fog_adv_l2_none_gn/train.npy') 
target_test_data = ADVDataset('attack_data/cifar10c_fog_adv_l2_none_gn/test.npy') 
train_target_loader = torch.utils.data.DataLoader(
dataset=target_train_data,
batch_size=args.batch_size,
shuffle=False,
num_workers=8)
test_target_loader = torch.utils.data.DataLoader(
dataset=target_test_data,
batch_size=args.batch_size,
shuffle=False,
num_workers=8) 

# Model Preparation
init_random_seed(0)

n_epoch = 100

# setup optimizer
lr = 3e-4
optimizer = optim.Adam(model.parameters(), lr=lr)

# optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
# scheduler = torch.optim.lr_scheduler.MultiStepLR(
# optimizer, [75, 125], gamma=0.1, last_epoch=-1) 
model = model.cuda()

source_dataset_name = 'cifar10'
target_dataset_name = 'cifar10c-fog-pgd80-l2'
# DANN training
for epoch in range(n_epoch):
    train_one_epoch(model, train_source_loader, test_target_loader, optimizer, epoch, n_epoch)
    # scheduler.step()
    test_one_epoch(model, test_target_loader, target_dataset_name, epoch)