import torch 
from model import NEURAL
from dataset import DataMain
from torchvision import transforms
import time
import argparse
import os
import numpy as np

from corruptions import gaussian_noise
from corruptions import shot_noise
from corruptions import impulse_noise
from corruptions import defocus_blur
from corruptions import glass_blur
from corruptions import motion_blur
from corruptions import zoom_blur
from corruptions import snow
from corruptions import frost
from corruptions import fog
from corruptions import brightness
from corruptions import contrast
from corruptions import elastic_transform
from corruptions import pixelate
from corruptions import jpeg_compression
from ROA import ROA

from dataset import return_data


severity_map = {
    'gaussian_noise' : 5, 
    'shot_noise' : 5, 
    'impulse_noise' : 5, 


    'glass_blur' : 5, 
    'defocus_blur' : 5, 
    'motion_blur': 5, 
    'zoom_blur' : 5, 

    'fog': 5,
    'frost': 5,
    'snow': 5,
    'contrast' : 6, 
    'brightness' : 8, 
    'elastic_transform' : 5,
    
    'jpeg_compression' : 5,
    'pixelate' : 7,
    'pgd_attack_random' : None,
    'ROA' : None
}


device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
batch_size = 1


def pgd_attack(model, images, labels, eps=0.3, alpha=2/255, iters=40) :
    # let's do untargeted attack
    images = images.clone().to(device)
    labels = labels.to(device)
    loss = torch.nn.CrossEntropyLoss()

    ori_images = images.data

    for i in range(iters) :
        images.requires_grad = True
        outputs = model(images)

        model.zero_grad()
        cost = loss(outputs, labels).to(device)
        cost.backward()

        adv_images = images + alpha*images.grad.sign()
        eta = torch.clamp(adv_images - ori_images, min=-eps, max=eps)
        images = torch.clamp(ori_images + eta, min=-0.5, max=0.5).detach_()

    return images


def pgd_attack_random(model, images, labels, eps=1, alpha=1, iters=40, randomize=True):
    """ Construct L_inf adversarial examples on the examples X """
    model.eval()
    labels = labels.to(device)
    if randomize:
        delta = torch.rand_like(images, requires_grad=True).to(device)
        delta.data = delta.data * 2 * eps - eps
        delta.data = (delta.data + images ).clamp(-0.5,0.5)-(images)
    else:
        delta = torch.zeros_like(images, requires_grad=True).to(device)
    
    for t in range(iters):
        loss = torch.nn.CrossEntropyLoss()(model(images + delta ), labels)
        loss.backward()
        
        delta.data = (delta + alpha*delta.grad.detach().sign()).clamp(-eps,eps)
        delta.data = (delta.data + images ).clamp(-0.5,0.5)-(images)
        delta.grad.zero_()
    
    return delta+images


def test(data,model,attack):
    
    correct_adv = 0
    correct = 0
    tot = 0
    corruptor = eval(attack)
    severity = severity_map[attack]
    # print(corruptor, severity)

    for batch_idx, (inputs, targets) in enumerate(data):
        X, GT = inputs, targets
        X = torch.tensor(X, dtype=torch.float32)
        X = X.to(device)
        
        if attack == 'pgd_attack_random':
            attack_eps = 8
            X_adv = pgd_attack_random(model,X,GT,eps=attack_eps/255.0,alpha=1/255,iters=100,randomize=True)
        elif attack == 'ROA':
            attacker = ROA(model,size=32)
            X_adv = attacker.exhaustive_search(X,GT,0.05,30,5,5,2,2,False)
        else:
            convert_img = transforms.Compose([transforms.ToPILImage()])
            # X_tmp = [convert_img(img_tensor) for img_tensor in X]
            img = convert_img(X[0])
            X_adv = corruptor(img, severity)
            if attack in ['jpeg_compression', 'pixelate']:
                X_adv = np.array(X_adv)
            X_adv = torch.from_numpy(X_adv / 255.).to(dtype=torch.float32).to(device)
            # X_adv = X_adv.permute(0, 3, 1, 2)
            X_adv = torch.unsqueeze(X_adv.permute(2, 0, 1), 0)

        Y = model(X)
        Y = torch.argmax(Y,dim=1)

        # X_adv = torch.tensor(X_adv, dtype=torch.float32).to(device)
        Y_adv = model(X_adv)
        Y_adv = torch.argmax(Y_adv,dim=1)

        this_batch_size = len(Y)
        
        for i in range(this_batch_size):
            tot+=1
            if GT[i] == Y[i]:
                correct+=1
            if GT[i] == Y_adv[i]:
                correct_adv+=1
    
    print('acc = %d/%d (%.2f%%), adv_acc = %d/%d (%.2f%%)' % (correct, tot, (100 * correct / tot),
                                                            correct_adv, tot, (100 * correct_adv / tot)))





print('[Data] Preparing .... ')
# data = DataMain(batch_size=batch_size)
# data.data_set_up(istrain=True)
# data.greeting()
dataset = 'real_data'
trainloader, testloader, len_trainset, len_testset = return_data(dataset, batch_size, image_size=32)
print('[Data] Done .... ')


print('[Model] Preparing .... ')
model = NEURAL(n_class=8,n_channel=3) 
model = model.to(device)
model.eval()
print('[Model] Done .... ')

## Specify checkpoint path
ckpt_path = './adv_train_8_ckpt/model_1_adv_acc=0.617500.ckpt'
ckpt = torch.load(ckpt_path, map_location = device)
model.load_state_dict(ckpt)

for attack in severity_map.keys():
    print(f'------------- Adv acc for {attack} ---------------')
    test(testloader,model,attack)
