import torch
import torch.nn.functional as F

def fgsm_attack(model, data, target, epsilon, norm):
    device = "cuda" if torch.cuda.is_available() else "cpu"
    data, target = data.to(device), target.to(device)
    data.requires_grad = True

    output = model(data)
    init_pred = output.max(0, keepdim=True)[1] # get the index of the max log-probability
    # print(output, init_pred)

    # If the initial prediction is wrong, don't bother attacking, just move on
    if init_pred.item() != target.item():
        return data

    # Calculate the loss
    loss = F.nll_loss(output, target)

    # Zero all existing gradients
    model.zero_grad()

    # Calculate gradients of model in backward pass
    loss.backward()

    # Collect ``datagrad``
    data_grad = data.grad.data

    # print(data)
    if norm == 2:
        data_grad = data_grad / torch.norm(data_grad, p=2)
    elif norm == float('inf'):
        data_grad = data_grad.sign()
    else:
        assert False, f"Only L2 and Linf norms are supported. Got {norm}"
    # print(data_grad)

    # Create the perturbed data by adjusting each element of the input data
    perturbed_data = data + epsilon*data_grad
    # Adding clipping to maintain [0,1] range
    perturbed_data = torch.clamp(perturbed_data, 0, 1)
    # Return the perturbed data
    return perturbed_data


def perturb_dataset(model, X, y, epsilon, norm = 2):
    perturbed_X = []
    for i in range(X.shape[0]):
        data = X[i]
        target = y[i]
        perturbed_data = fgsm_attack(model, data, target, epsilon, norm)
        perturbed_X.append(perturbed_data)

    return torch.stack(perturbed_X)
