from __future__ import print_function
import argparse
from tqdm import tqdm
from PIL import Image

import torch
import torch.nn as nn
import torch.optim as optim
import sys
sys.path.append('./')

from utils.misc import *
from utils.test_helpers import *
from utils.prepare_dataset import *
from utils.rotation import *
from utils.prepare_attack_dataset import *
from adv_test_calls.advtest_TTT import * 
from shutil import copyfile
from utils.prepare_corruption_dataset import *
from defense.DANN import *
import os

# To run the test, you need to disable the attack data removal in DANN_FPA.py
def test_FPA(args, FPA_iter = 7): 
    # Load the model 
    net, _, _, _ = build_model(args)
    ckpt = torch.load('./results/pretrain/cifar10c_fog_none_gn/DANN_FPA_RI/ckpt{}.pth'.format(FPA_iter))
    model = DANNWrapper(net) 
    model.load_state_dict(ckpt['model']) 
    for prev_iter in range(1,FPA_iter+1): 
        name = "DANN_FPA_RI{}_fog_pgd8".format(prev_iter)
        # DANN target data preparation
        target_train_data = ADVDataset('attack_data/{}/train.npy'.format(name)) 
        target_test_data = ADVDataset('attack_data/{}/test.npy'.format(name)) 
        train_target_loader = torch.utils.data.DataLoader(
        dataset=target_train_data,
        batch_size=32,
        shuffle=False,
        num_workers=8)
        test_target_loader = torch.utils.data.DataLoader(
        dataset=target_test_data,
        batch_size=32,
        shuffle=False,
        num_workers=8) 
        model = model.cuda()
        source_dataset_name = 'cifar10'
        target_dataset_name = 'cifar10c-FPA_RI{}-fog-pgd8'.format(prev_iter)
        test(model, train_target_loader, target_dataset_name, 1)
            
if __name__ == '__main__': 
    parser = argparse.ArgumentParser()
    parser.add_argument('--dataset', default='cifar10')
    parser.add_argument('--corruption', default='fog')
    parser.add_argument('--level', default=5, type=int)
    parser.add_argument('--dataroot', default='/nobackup/yguo/datasets/')
    parser.add_argument('--shared', default='layer2')
    ########################################################################
    parser.add_argument('--depth', default=26, type=int)
    parser.add_argument('--width', default=1, type=int)
    parser.add_argument('--batch_size', default=32, type=int)
    parser.add_argument('--group_norm', default=0, type=int)
    parser.add_argument('--fix_bn', action='store_false')
    parser.add_argument('--fix_ssh', action='store_false')
    ########################################################################
    parser.add_argument('--lr', default=0.001, type=float)
    parser.add_argument('--niter', default=1, type=int)
    parser.add_argument('--online', action='store_true')
    parser.add_argument('--threshold', default=1, type=float)
    parser.add_argument('--dset_size', default=0, type=int)
    ########################################################################
    parser.add_argument('--resume', default='./results/pretrain/cifar10c_fog_none_gn/DANN_FPA_RI')
    args = parser.parse_args()
    test_FPA(args, FPA_iter = 6)

