"""Adversarial attacks on neural networks.
"""

import torch
import torch.nn.functional as F

import functions.datasets as datasets

def fgsm_attack(batch, model, eps, data='cifar10'):
    images, labels = batch
    images = images.requires_grad_(True)
    output = model(images)
    loss = F.cross_entropy(output, labels)
    loss.backward()
    adv_images = images + eps*images.grad.sign()
    return (adv_images, labels)
