import torch.nn.functional as F
import torch
import numpy as np


device = torch.device('cuda')
device1 = torch.device('cpu')
class LinfPGDAttack:
    def __init__(self,model,  epsilon,alpha, k,random_start=False,batch_size=64):
        self.model = model
        self.epsilon = epsilon
        self.alpha = alpha
        self.t = k
        self.randstart = random_start
        self.batch = batch_size

    def pertub(self, input, edges , label, mask):

        if self.epsilon ==0:
            return input

        all_num = input.size(0)
        dimension = input.size(1)

        num = all_num // self.batch

        if self.randstart:
            x_ad = (input.to(device1) + torch.randn(input.size()).uniform_(-self.epsilon, self.epsilon)).to(device)
        else:
            x_ad = input


        # xad_final = torch.zeros(input.size(),dtype=torch.float32)
        x_temp = x_ad
        label_temp = label

        for i in range(int(self.t)):
            x_temp.requires_grad = True

            out = self.model(x_temp.to(device), edges.to(device))

            loss = F.nll_loss(out[mask],label_temp[mask].to(device)).to(device)
            self.model.zero_grad()
            loss.backward()
            x_temp_grad = x_temp.grad.data


            x_temp = x_temp + self.alpha * x_temp_grad.sign()
            x_temp = (torch.clamp(x_temp.to(device1), min = input.to(device1) - self.epsilon, max = input.to(device1) + self.epsilon).detach_()).to(device)


        xad_final = x_temp

        return xad_final




