from __future__ import print_function
import argparse
from tqdm import tqdm
from PIL import Image

import torch
import torch.nn as nn
import torch.optim as optim
import sys
sys.path.append('./')

from utils.misc import *
from utils.test_helpers import *
from utils.prepare_dataset import *
from utils.rotation import *
from utils.prepare_attack_dataset import *
from adv_test_calls.advtest_TTT import * 
from shutil import copyfile
from utils.prepare_corruption_dataset import * 
import os
from test_calls.test_adapt_adv import TTT_adapt_adv

def FPA(args, n_iter = 5): 
    # Generate the attack wrt the pretrained model on the corruption dataset. Here 
    # I use the corruption dataset with different levels combined. 
    model, _, _, _ = build_model(args)
    if not os.path.exists('./results/cifar10_layer2_gn_expand/FPA'):
        os.makedirs('./results/cifar10_layer2_gn_expand/FPA')
    copyfile('./results/cifar10_layer2_gn_expand/ckpt.pth',
        './results/cifar10_layer2_gn_expand/FPA/ckpt0.pth')
    copyfile('./results/cifar10_layer2_gn_expand/ckpt.pth',
        './results/cifar10_layer2_gn_expand/FPA/ckpt.pth')
    for FPA_iter in range(1, n_iter+1): 
        ####### Attack generation 
        ckpt = torch.load('./results/cifar10_layer2_gn_expand/FPA/ckpt.pth')
        model.load_state_dict(ckpt['net'])
        # Prepare Test dataset
        name = "prTTT_FPA{}_pgd8".format(FPA_iter)  
        (_,_), (_, data_loader) = prepare_corruption_data(args.corruption)
        prepare_pgd_attack_data(args, data_loader, model, name) 
        ####### TTT adaptation 
        TTT_adapt(args, 'attack_data/prTTT_FPA{}_pgd8'.format(FPA_iter))
        copyfile('./results/cifar10_layer2_gn_expand/FPA/ckpt.pth',
        './results/cifar10_layer2_gn_expand/FPA/ckpt{}.pth'.format(FPA_iter))

if __name__ == '__main__': 
    parser = argparse.ArgumentParser()
    parser.add_argument('--dataset', default='cifar10')
    parser.add_argument('--corruption', default='fog')
    parser.add_argument('--level', default=5, type=int)
    parser.add_argument('--dataroot', default='/nobackup/yguo/datasets/')
    parser.add_argument('--shared', default='layer2')
    ########################################################################
    parser.add_argument('--depth', default=26, type=int)
    parser.add_argument('--width', default=1, type=int)
    parser.add_argument('--batch_size', default=32, type=int)
    parser.add_argument('--group_norm', default=8, type=int)
    parser.add_argument('--fix_bn', action='store_false')
    parser.add_argument('--fix_ssh', action='store_false')
    ########################################################################
    parser.add_argument('--lr', default=0.001, type=float)
    parser.add_argument('--niter', default=1, type=int)
    parser.add_argument('--online', action='store_true')
    parser.add_argument('--threshold', default=1, type=float)
    parser.add_argument('--dset_size', default=0, type=int)
    ########################################################################
    parser.add_argument('--outf', default='results/prTTT_FPA_pgd8/')
    parser.add_argument('--resume', default='results/cifar10_layer2_gn_expand/FPA')
    args = parser.parse_args()
    FPA(args)