
import torch
import torch.nn as nn
import torch.nn.init as init
import torch.nn.functional as F
from torch.autograd import Variable
from torchsummary import summary
import sys
import numpy as np

import torchvision
import torchvision.transforms as transforms
from tensorboardX import SummaryWriter



class Adversarial_Attack:
    def __init__(self,):
        self.loss_fn=nn.CrossEntropyLoss()
    def project_back_L2(self,x,adv_x,eps):
        diff=adv_x-x
        diff_norm=torch.reshape(torch.norm(torch.reshape(diff,((-1,3*32*32))),2,1),(-1,1,1,1))
        diff_norm=torch.maximum(torch.ones_like(diff_norm),diff_norm/eps)
        diff=torch.div(diff, 1e-9+diff_norm)
        projected_x=x+diff
        return projected_x.clone().detach()
    def project_back_Linf(self,x,adv_x,eps):
        diff=adv_x-x
        diff=torch.clamp(diff,min=-eps,max=eps)
        projected_x=x+diff

        return projected_x.clone().detach()
    def ERM(self,source_model,data,y,eps,n_iteration,step_size):
        return data
    def FGM_L2(self,source_model,data,y,eps,n_iteration,step_size):
        for name, param in source_model.named_parameters(): 
            param.requires_grad = False
        source_model.eval()
        adv_data=data.clone().detach()
        adv_data.requires_grad=True
        prediction=source_model(adv_data)
        loss=self.loss_fn(prediction,y)
        loss.backward()
        grad=adv_data.grad
        grad=torch.div(grad, 1e-9+torch.reshape(torch.norm(torch.reshape(grad,((-1,3*32*32))),2,1),(-1,1,1,1)))
        adv_data=adv_data+eps*(grad)
        adv_data=self.project_back_L2(data,adv_data,eps)
        source_model.train()
        for name, param in source_model.named_parameters(): 
            param.requires_grad = True
        return adv_data

    def PGM_L2(self,source_model,data,y,eps,n_iteration,step_size):
        for name, param in source_model.named_parameters(): 
            param.requires_grad = False
        source_model.eval()
        adv_data=data.clone().detach()
        for i in range(n_iteration):
            adv_data.requires_grad=True
            prediction=source_model(adv_data)
            loss=self.loss_fn(prediction,y)
            loss.backward()
            grad=adv_data.grad
            grad=torch.div(grad, 1e-9+torch.reshape(torch.norm(torch.reshape(grad,((-1,3*32*32))),2,1),(-1,1,1,1)))
            adv_data=adv_data+step_size*(grad)
            adv_data=self.project_back_L2(data,adv_data,eps)
        source_model.train()
        for name, param in source_model.named_parameters(): 
            param.requires_grad = True
        return adv_data


    def FGM_Linf(self,source_model,data,y,eps,n_iteration,step_size):
        for name, param in source_model.named_parameters(): 
            param.requires_grad = False
        source_model.eval()
        adv_data=data.clone().detach()
        adv_data.requires_grad=True
        prediction=source_model(adv_data)
        loss=self.loss_fn(prediction,y)
        loss.backward()
        grad=adv_data.grad
        adv_data=adv_data+eps*(torch.sign(grad))
        adv_data=self.project_back_Linf(data,adv_data,eps)
        source_model.train()
        for name, param in source_model.named_parameters(): 
            param.requires_grad = True
        return adv_data

    def PGM_Linf(self,source_model,data,y,eps,n_iteration,step_size):
        for name, param in source_model.named_parameters(): 
            param.requires_grad = False
        source_model.eval()
        adv_data=data.clone().detach()
        for i in range(n_iteration):
            adv_data.requires_grad=True
            prediction=source_model(adv_data)
            loss=self.loss_fn(prediction,y)
            loss.backward()
            grad=adv_data.grad
            adv_data=adv_data+step_size*(torch.sign(grad))
            adv_data=self.project_back_Linf(data,adv_data,eps)
        source_model.train()
        for name, param in source_model.named_parameters(): 
            param.requires_grad = True
        return adv_data
