import torch
def eot_pgd_attack(model, x, y, eps, alpha, iters=10, eot=1, text_tokens=None, norm='linf'):
    device = x.device
    B = x.size(0)
    print(x.min(), x.max())
    clip_mean = torch.tensor([0.48145466, 0.4578275, 0.40821073], device=device).view(1, 3, 1, 1)
    clip_std  = torch.tensor([0.26862954, 0.26130258, 0.27577711], device=device).view(1, 3, 1, 1)
  

    if norm == 'Linf':
        x_adv = x + torch.empty_like(x).uniform_(-eps, eps)
    elif norm == 'L2':
        delta = torch.randn_like(x).view(B, -1)
        delta = delta / (delta.norm(p=2, dim=1, keepdim=True) + 1e-12)
        delta = delta * torch.empty(B, 1, device=device).uniform_(0, eps)
        x_adv = (x + delta.view_as(x)).clamp(0, 1)
    else:
        raise ValueError("Unsupported norm type")

    x_adv = x_adv.detach().requires_grad_(True)

    for _ in range(iters):
        grad = 0
        for _ in range(eot):
            x_norm = (x_adv - clip_mean) / clip_std
            logits, _ = model(x_norm, text_tokens)
            loss = F.cross_entropy(logits, y)
            g = torch.autograd.grad(loss, x_adv, retain_graph=False)[0]
            grad += g.detach()

        grad /= eot

        with torch.no_grad():
            if norm == 'Linf':
                x_adv = x_adv + alpha * grad.sign()
                delta = torch.clamp(x_adv - x, min=-eps, max=eps)
            elif norm == 'L2':
                x_adv = x_adv + alpha * grad.sign()
                delta = x_adv - x
                delta_norm = torch.norm(delta.view(B, -1), p=2, dim=1)
                factor = eps / (delta_norm + 1e-12)
                factor = torch.min(factor, torch.ones_like(factor))
                delta = delta * factor.view(-1, 1, 1, 1)

            x_adv = (x + delta).clamp(0, 1).detach().requires_grad_(True)

    return x_adv.detach()