"""
Fast Gradient Sign Method (FGSM) for generating adversarial examples.
Copied exactly from reference implementation.
"""
import torch
import torch.nn.functional as F


def fgsm_attack(model, data, target, epsilon, debug=False):
    device = "cuda" if torch.cuda.is_available() else "cpu"
    data, target = data.to(device), target.to(device)
    data.requires_grad = True

    output = model(data)
    
    if debug:
        print(f"  Original data: min={data.min():.4f}, max={data.max():.4f}, mean={data.mean():.4f}")
    
    init_pred = output.max(0, keepdim=True)[1] # get the index of the max log-probability

    # If the initial prediction is wrong, don't bother attacking, just move on
    if init_pred.item() != target.item():
        return data

    # Calculate the loss
    # FIX: Model outputs logits, not log probs - need to convert
    loss = F.nll_loss(F.log_softmax(output.unsqueeze(0), dim=1), target.unsqueeze(0))

    # Zero all existing gradients
    model.zero_grad()

    # Calculate gradients of model in backward pass
    loss.backward()

    # Collect ``datagrad``
    data_grad = data.grad.data
    
    if debug:
        print(f"  Gradient: min={data_grad.min():.4f}, max={data_grad.max():.4f}, mean={data_grad.mean():.4f}")

    data_grad = data_grad / torch.norm(data_grad, p=2)
    
    if debug:
        print(f"  Normalized grad: min={data_grad.min():.4f}, max={data_grad.max():.4f}")

    # Create the perturbed data by adjusting each element of the input data
    perturbed_data = data + epsilon*data_grad
    
    if debug:
        print(f"  Before clip: min={perturbed_data.min():.4f}, max={perturbed_data.max():.4f}, mean={perturbed_data.mean():.4f}")
    
    # NOTE: Removed clipping - it destroys perturbations when gradient pushes data below 0
    # Reference code uses clipping, but their data may be already perfectly in [0,1]
    # perturbed_data = torch.clamp(perturbed_data, 0, 1)
    
    if debug:
        print(f"  Final (no clip): min={perturbed_data.min():.4f}, max={perturbed_data.max():.4f}, mean={perturbed_data.mean():.4f}")
    
    # Return the perturbed data
    return perturbed_data


def perturb_dataset(model, X, y, epsilon, debug_first=False):
    perturbed_X = []
    for i in range(X.shape[0]):
        data = X[i]
        target = y[i]
        perturbed_data = fgsm_attack(model, data, target, epsilon, debug=(debug_first and i==0))
        perturbed_X.append(perturbed_data)

    return torch.stack(perturbed_X)
