import numpy as np
import torch
from attack_steps import L2Step, LinfStep
import matplotlib.pyplot as plt


def perturb(model, x, norm, eps, step_size, steps, random_start=False) -> torch.Tensor:
  """Perform PGD attack."""
  assert not model.training
  assert not x.requires_grad

  x0 = x.clone().detach()
  step_class = L2Step if norm == 'L2' else LinfStep
  step = step_class(eps=eps, orig_input=x0, step_size=step_size)

  if random_start:
    x = step.random_perturb(x)
  
  if steps == 0:
    return x

  for i in range(steps):
    x = x.clone().detach().requires_grad_(True)
    logits = model(x)
    loss = logits.sum()
    grad, = torch.autograd.grad(loss, [x])
    with torch.no_grad():
      x = step.step(x, grad)
      x = step.project(x)

  return x.clone().detach()
