from __future__ import print_function
import argparse
from tqdm import tqdm
from PIL import Image

import torch
import torch.nn as nn
import torch.optim as optim
import sys
sys.path.append('./')

from utils.misc import *
from utils.test_helpers import *
from utils.prepare_dataset import *
from utils.rotation import *
from utils.prepare_attack_dataset import *
from adv_test_calls.advtest_TTT import * 
from shutil import copyfile
from utils.prepare_corruption_dataset import *
from defense.DANN import *
import os

# Create a parser for loading the source dataset
parser = argparse.ArgumentParser()
parser.add_argument('--dataset', default='cifar10')
parser.add_argument('--dataroot', default='/nobackup/yguo/datasets')
parser.add_argument('--depth', default=26, type=int)
parser.add_argument('--width', default=1, type=int)
parser.add_argument('--batch_size', default=128, type=int)
parser.add_argument('--group_norm', default=0, type=int)
parser.add_argument('--shared', default='layer2')
parser.add_argument('--mode', default = 0, type = int) 
parser.add_argument('--subsample_size', default = 10000, type = int) 
parser.add_argument('--alpha_scale', default = 1, type = float) 



args = parser.parse_args()
save_dir = 'results/pretrain/cifar10_adv_l2_pgd7/pgd20_DANN_sub{}'.format(args.subsample_size)
if not os.path.exists(save_dir):
    os.makedirs(save_dir)


net, _, _, _ = build_model(args)
model = DANNWrapper(net)
_, test_source_loader = prepare_test_data(args)
_, train_source_loader = prepare_train_data(args)

#### Data Preparation
target_train_data = ADVDataset('attack_data/cifar10_adv_l2_pgd20/train.npy')
target_test_data = ADVDataset('attack_data/cifar10_adv_l2_pgd20/test.npy')
train_target_loader = torch.utils.data.DataLoader(
    dataset=target_train_data,
    batch_size=args.batch_size,
    shuffle=False,
    num_workers=8)


if args.subsample_size == 10000:
    test_target_loader = torch.utils.data.DataLoader(
        dataset=target_test_data,
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=8)
else: 
    indices = torch.arange(10000)[:args.subsample_size] 
    target_test_data = torch.utils.data.Subset(target_test_data, indices)
    test_target_loader = torch.utils.data.DataLoader(
    dataset=target_test_data,
    batch_size=args.batch_size,
    shuffle=False,
    num_workers=8)
    

# Model Preparation
init_random_seed(0)

lr = 3e-4
n_epoch = 150*int(10000/args.subsample_size)

# setup optimizer
optimizer = optim.Adam(model.parameters(), lr=lr)

model = model.cuda()

source_dataset_name = 'cifar10'
target_dataset_name = 'cifar10-adv-l2-pgd20'
# DANN training
adv_test_result = []
for epoch in range(n_epoch):
    if args.mode == 0: # Standard TTT DANN setting
        train_one_epoch(model, train_source_loader,
                        test_target_loader, optimizer, epoch, n_epoch, alpha_scale = args.alpha_scale)
        adv_test_result.append(test_one_epoch(model, test_target_loader, target_dataset_name, epoch))
    if args.mode == 1: # TTT DANN generalization setting
        train_one_epoch(model, train_source_loader,
                        train_target_loader, optimizer, epoch, n_epoch)
        test_one_epoch(model, train_target_loader, target_dataset_name, epoch)
        print("Adaptation Accuracy ")
        adv_test_result.append(test_one_epoch(model, test_target_loader, target_dataset_name, epoch))
    np.save('results/DANN_homo/dann_cifar10_adv_l2_mode{}_sub{}.npy'.format(args.mode, args.subsample_size), 
            np.array(adv_test_result))   
    if epoch % (int(5*10000/args.subsample_size))  == 0: 
        torch.save({'model': model},"{}/epoch{}".format(save_dir, epoch))
 