from __future__ import print_function
import argparse
from tqdm import tqdm
from PIL import Image

import torch
import torch.nn as nn
import torch.optim as optim
import sys
sys.path.append('./')

from utils.misc import *
from utils.test_helpers import *
from utils.prepare_dataset import *
from utils.rotation import *
from utils.prepare_attack_dataset import *
from utils.prepare_corruption_dataset import *
from defense.DANN import *
import argparse

# Create a parser for loading the source dataset 
parser = argparse.ArgumentParser()
parser.add_argument('--dataset', default='cifar10')
parser.add_argument('--dataroot', default='/nobackup/yguo/datasets')
parser.add_argument('--depth', default=26, type=int)
parser.add_argument('--width', default=1, type=int)
parser.add_argument('--batch_size', default=128, type=int)
parser.add_argument('--group_norm', default=0, type=int)
parser.add_argument('--shared', default='layer2')
parser.add_argument('--level', type=int)
parser.add_argument('--adv_type', default='advS')
parser.add_argument('--corruption', default='glass_blur')

args = parser.parse_args()

for level in [1,2,3,4,5]:
    net, _, _, _ = build_model(args)
    model = DANNWrapper(net)
    _, train_source_loader = prepare_train_data(args)

    try:
        assert(args.adv_type in ['advS', 'advT', 'none'])
    except AssertionError: 
        print("Choose either of the following adv_type for the target of DANN: advS, advT, none")
        
    if (args.adv_type == 'advS'): 
        #### Data Preparation
        target_train_data = ADVDataset("attack_data/cifar10c_{}_none_gn_lvl{}_advS/train.npy".format(args.corruption, level)) 
        train_target_loader = torch.utils.data.DataLoader(
        dataset=target_train_data,
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=8)
    elif (args.adv_type == 'advT'):
        #### Data Preparation
        target_train_data = ADVDataset("attack_data/cifar10c_{}_none_gn_lvl{}_advT/train.npy".format(args.corruption, level)) 
        train_target_loader = torch.utils.data.DataLoader(
        dataset=target_train_data,
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=8)
    else: 
        _, _, (_, all_loader)= prepare_corruption_data_lvl(level, corruption = args.corruption)
        train_target_loader = all_loader

    # Model Preparation
    init_random_seed(0)

    n_epoch = 100

    # setup optimizer
    lr = 3e-4
    optimizer = optim.Adam(model.parameters(), lr=lr)

    # optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
    # scheduler = torch.optim.lr_scheduler.MultiStepLR(
    # optimizer, [75, 125], gamma=0.1, last_epoch=-1) 
    model = model.cuda()

    accs = []
    source_dataset_name = 'cifar10'
    target_dataset_name = 'cifar10c-{}-pgd8-lvl{}'.format(args.corruption, level)
    # DANN training
    for epoch in range(n_epoch):
        train_one_epoch(model, train_source_loader, train_target_loader, optimizer, epoch, n_epoch)
        # scheduler.step()
        acc = test_one_epoch(model, train_target_loader, target_dataset_name, epoch)
        accs.append(acc)
        directory = './results/DANN/{}/{}'.format(args.adv_type,args.corruption)
        if not os.path.exists(directory):
            os.makedirs(directory)
        np.save(directory+'/accs_level{}'.format(level), accs)
        
