from __future__ import print_function
import argparse
from tqdm import tqdm
from PIL import Image

import torch
import torch.nn as nn
import torch.optim as optim
import sys
sys.path.append('./')

from utils.misc import *
from utils.test_helpers import *
from utils.prepare_dataset import *
from utils.rotation import *
from utils.prepare_attack_dataset import *
from defense.DANN import *
import argparse

# Create a parser for loading the source dataset 
parser = argparse.ArgumentParser()
parser.add_argument('--dataset', default='cifar10')
parser.add_argument('--dataroot', default='/nobackup/yguo/datasets')
parser.add_argument('--depth', default=26, type=int)
parser.add_argument('--width', default=1, type=int)
parser.add_argument('--batch_size', default=128, type=int)
parser.add_argument('--group_norm', default=0, type=int)
parser.add_argument('--shared', default='layer2')
parser.add_argument('--corruption', default='glass_blur')
parser.add_argument('--adv_type', default='advS')

args = parser.parse_args()

net, _, _, _ = build_model(args)
model = DANNWrapper(net)
_, test_source_loader = prepare_test_data(args)
_, train_source_loader = prepare_train_data(args)


#### Data Preparation

if args.adv_type == 'advS': 
    target_train_data = ADVDataset('attack_data/cifar10c_{}_none_gn_advS/train.npy'.format(corruption)) 
    target_test_data = ADVDataset('attack_data/cifar10c_{}_none_gn_advS/test.npy'.format(corruption)) 
elif args.adv_type == 'advT': 
    target_train_data = ADVDataset(
        'attack_data/cifar10c_{}_none_gn_advT/train.npy'.format(corruption))
    target_test_data = ADVDataset(
        'attack_data/cifar10c_{}_none_gn_advT/test.npy'.format(corruption))
elif args.adv_type == 'none': 

train_target_loader = torch.utils.data.DataLoader(
dataset=target_train_data,
batch_size=args.batch_size,
shuffle=False,
num_workers=8)
test_target_loader = torch.utils.data.DataLoader(
dataset=target_test_data,
batch_size=args.batch_size,
shuffle=False,
num_workers=8) 

# Model Preparation
init_random_seed(0)

n_epoch = 100

# setup optimizer
lr = 3e-4
optimizer = optim.Adam(model.parameters(), lr=lr)

# optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
# scheduler = torch.optim.lr_scheduler.MultiStepLR(
# optimizer, [75, 125], gamma=0.1, last_epoch=-1) 
model = model.cuda()

source_dataset_name = 'cifar10'
target_dataset_name = 'cifar10c-fog-pgd8'
# DANN training
for epoch in range(n_epoch):
    train_one_epoch(model, train_source_loader, test_target_loader, optimizer, epoch, n_epoch)
    # scheduler.step()
    test_one_epoch(model, test_target_loader, target_dataset_name, epoch)
