"""
https://github.com/ferjad/Universal_Adversarial_Perturbation_pytorch
Copyright (C) 2007 Free Software Foundation, Inc. <https://fsf.org/>

"""
from deeprobust.image.attack import deepfool
import collections
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import numpy as np
import torch
import torch.optim as optim
import torch.utils.data as data_utils
import math
from PIL import Image
import torchvision.models as models
import sys
import random
import time
from tqdm import tqdm

def zero_gradients(x):
    if isinstance(x, torch.Tensor):
        if x.grad is not None:
            x.grad.detach_()
            x.grad.zero_()
    elif isinstance(x, collections.abc.Iterable):
        for elem in x:
            zero_gradients(elem)

def get_model(model,device):
    if model == 'vgg16':
        net = models.vgg16(pretrained=True)
    elif model =='resnet18':
        net = models.resnet18(pretrained=True)

    net.eval()
    net=net.to(device)
    return net

def data_input_init(xi):
    mean = [ 0.485, 0.456, 0.406 ]
    std = [ 0.229, 0.224, 0.225 ]
    transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean = mean,
                         std = std)])

    return (mean,std,transform)

def proj_lp(v, xi, p):
    # Project on the lp ball centered at 0 and of radius xi
    if p==np.inf:
        v=torch.clamp(v,-xi,xi)
    else:
        v=v * min(1, xi/(torch.norm(v,p)+0.00001))
    return v

def get_fooling_rate(data_list,v,model, device):
    f = data_input_init(0)[2]
    num_images = len(data_list)

    fooled=0.0

    for name in tqdm(data_list):
        image = Image.open(name)
        image = tf(image)
        image = image.unsqueeze(0)
        image = image.to(device)
        _, pred = torch.max(model(image),1)
        _, adv_pred = torch.max(model(image+v),1)
        if(pred!=adv_pred):
            fooled+=1

    # Compute the fooling rate
    fooling_rate = fooled/num_images
    print('Fooling Rate = ', fooling_rate)
    for param in model.parameters():
        param.requires_grad = False

    return fooling_rate,model

def universal_adversarial_perturbation(dataloader, model, device, xi=10, delta=0.2, max_iter_uni = 10, p=np.inf,
                                       num_classes=10, overshoot=0.02, max_iter_df=10,t_p = 0.2):
    """universal_adversarial_perturbation.

    Parameters
    ----------
    dataloader :
        dataloader
    model :
        target model
    device :
        device
    xi :
        controls the l_p magnitude of the perturbation
    delta :
        controls the desired fooling rate (default = 80% fooling rate)
    max_iter_uni :
        maximum number of iteration (default = 10*num_images)
    p :
        norm to be used (default = np.inf)
    num_classes :
        num_classes (default = 10)
    overshoot :
        to prevent vanishing updates (default = 0.02)
    max_iter_df :
        maximum number of iterations for deepfool (default = 10)
    t_p :
        truth percentage, for how many flipped labels in a batch. (default = 0.2)

    Returns
    -------
        the universal perturbation matrix.
    """
    time_start = time.time()
    mean, std,tf = data_input_init(xi)
    v = torch.zeros(1,3,224,224).to(device)
    v.requires_grad_()

    fooling_rate = 0.0
    num_images = len(dataloader)
    itr = 0

    while fooling_rate < 1-delta and itr < max_iter_uni:

        # Iterate over the dataset and compute the purturbation incrementally

        for i,(img, label) in enumerate(dataloader):
            _, pred = torch.max(model(img),1)
            _, adv_pred = torch.max(model(img+v),1)

            if(pred == adv_pred):
                perturb = deepfool(model, device)
                _ = perturb.generate(img+v, num_classed = num_classed, overshoot = overshoot, max_iter = max_iter_df)
                dr, iter = perturb.getpurb()
                if(iter<max_iter_df-1):
                    v = v + torch.from_numpy(dr).to(device)
                    v = proj_lp(v,xi,p)

            if(k%10==0):
                print('Norm of v: '+str(torch.norm(v).detach().cpu().numpy()))

        fooling_rate,model = get_fooling_rate(data_list,v,model, device)
        itr = itr + 1

    return v
