import torch
from tqdm import trange 
import torch.nn.functional as F
class BPDA:
    def __init__(self, get_logit, attack_steps=200, eps=0.5, step_size=0.007, target=None, eot=20):
        self.target = target
        self.clamp = (0,1)
        self.eps = eps
        self.step_size = step_size
        self.get_logit = get_logit
        self.attack_steps = attack_steps
        self.eot = eot

    def _random_init(self, x):
        x = x + (torch.rand(x.size(), dtype=x.dtype, device=x.device) - 0.5) * 2 * self.eps
        x = torch.clamp(x, *self.clamp)
        return x

    def forward(self, x, y):
        
        x_adv = x.detach().clone()
        for _ in trange(self.attack_steps):
            grad = torch.zeros_like(x_adv)
            for _ in range(self.eot):
                with torch.no_grad():
                    preprocessed_x = self.get_logit.get_img_logits(x_adv)[0]
                preprocessed_x.requires_grad = True
                logits = self.get_logit.classify(preprocessed_x)
                loss = F.cross_entropy(logits, y, reduction="sum")
                grad += torch.autograd.grad(loss, [preprocessed_x])[0].detach()
                x_adv = x_adv.detach()

            grad /= self.eot
            grad = grad.sign()
            x_adv = x_adv + self.step_size * grad
            x_adv = x + torch.clamp(x_adv - x, min=-self.eps, max=self.eps)
            x_adv = torch.clamp(x_adv, *self.clamp).detach()
        return x_adv