import os
import torch
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
import torchattacks


# def attack_collate(batch):
#     images, labels = zip(*batch)
#     atk = torchattacks.PGD(model, eps=8/255, alpha=2/255, steps=4)
#     adv_images = atk(images, labels)


class Attack(object):
    """Impose attack to data 

    Args:
        attack_type (str): default is pgd
        eps (float): maximum perturbation
        alpha (float): step size
        steps (int): number of steps
    """

    def __init__(self, model, attack_type="pgd", 
            eps=8/255, alpha=2/255, steps=4, n_class=10
        ):
        # assert isinstance(output_size, (int, tuple))
        if attack_type == "fgsm":
            self.atk = torchattacks.FGSM(model, eps=eps)
        elif attack_type == "nifgsm":
            self.atk = torchattacks.NIFGSM(model, eps=eps, alpha=alpha, steps=steps)
        elif attack_type == "fab":
            self.atk = torchattacks.fab.FAB(model, eps=eps, steps=steps)
        elif attack_type == "jitter":
            self.atk = torchattacks.Jitter(model, eps=eps, alpha=alpha, steps=steps)
        elif attack_type == "autoattack":
            self.atk = torchattacks.AutoAttack(model, eps=eps, n_classes=n_class)
        elif attack_type == "square":
            self.atk = torchattacks.Square(model, eps=eps)
        else:
            self.atk = torchattacks.PGD(model, eps=eps, alpha=alpha, steps=steps)


    def __call__(self, data, targets):
        adv_images = self.atk(data, targets)
        return adv_images
