import torch.backends.cudnn as cudnn
import torch
import torch.utils.data
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import numpy as np
from advertorch.context import ctx_noparamgrad_and_eval
from advertorch.attacks import LinfPGDAttack
import argparse
import os
import sys
sys.path.append('./')
from utils.prepare_dataset import *
from utils.misc import init_random_seed
from utils.test_helpers import build_model
from utils.prepare_attack_dataset import *
from utils.prepare_corruption_dataset import *
from shutil import copyfile


parser = argparse.ArgumentParser()
parser.add_argument('--dataset', default='cifar10')
parser.add_argument('--level', default=0, type=int)
parser.add_argument('--corruption', default='fog')
parser.add_argument('--dataroot', default='/nobackup/yguo/datasets/')
parser.add_argument('--shared', default=None)
########################################################################
parser.add_argument('--depth', default=26, type=int)
parser.add_argument('--width', default=1, type=int)
parser.add_argument('--batch_size', default=128, type=int)
parser.add_argument('--group_norm', default=0, type=int)
parser.add_argument('--fix_bn', default=False, type=bool)
parser.add_argument('--fix_ssh', default=False, type=bool)

args = parser.parse_args()
model, _, _, _ = build_model(args)

common_corruptions = ['gaussian_noise', 'shot_noise', 'impulse_noise', 'defocus_blur', 'glass_blur',
                        'motion_blur', 'zoom_blur', 'snow', 'frost', 'fog',
                        'brightness', 'contrast', 'elastic_transform', 'pixelate', 'jpeg_compression']
adv_type = 'advT'

for corruption in common_corruptions:
    accs = {}
    cls_accs = []
    adv_accs = []
    net, ext, head, ssh = build_model(args) 
    if adv_type == 'advS':
        copyfile('./results/pretrain/cifar10_adv_pgd7/ckpt.pth',
                 './results/pretrain/cifar10_adv_pgd7/ckpt_temp.pth')
        ckpt = torch.load('./results/pretrain/cifar10_adv_pgd7/ckpt_temp.pth')
    elif adv_type == 'advT':
        ckpt = torch.load(
            'results/pretrain/cifar10c_{}_adv_none_gn/ckpt.pth'.format(corruption))

    model.load_state_dict(ckpt['net'])
    for level in range(1,6): 
        # Prepare Test dataset
        if adv_type == 'advS':
            name = "cifar10c_{}_none_gn_lvl{}_advS".format(corruption,level)
        elif adv_type == 'advT': 
            name = "cifar10c_{}_none_gn_lvl{}_advT".format(corruption,level)
        _, _, (_, all_loader)  = prepare_corruption_data_lvl(level = level, corruption = corruption)
        cls_acc, adv_acc = prepare_pgd_attack_data(args, all_loader, model, name, train=True, nb_iter = 20)
        cls_accs.append(cls_acc)
        adv_accs.append(adv_acc)
    accs["cls_accs"] = cls_accs
    accs["adv_accs"] = adv_accs
    if adv_type == 'advS':
        directory = './results/advS/{}'.format(corruption)
    elif adv_type == 'advT': 
        directory = './results/advT/{}'.format(corruption)
    if not os.path.exists(directory):
        os.makedirs(directory)
    torch.save(accs, directory+'/accs_lvls.pt')
