# This code is for the most part taken from https://github.com/Harry24k/adversarial-attacks-pytorch
import torch
import torch.nn as nn
from torchattacks.attack import Attack
import numpy as np


def hard_threshold(arr, k):
    flattened_arr = arr.view(arr.shape[0], -1)
    topk_indices = torch.topk(torch.abs(flattened_arr), k=k, dim=1).indices
    mask = torch.zeros_like(flattened_arr)
    mask.scatter_(1, topk_indices, flattened_arr.gather(1, topk_indices))
    reshaped_tensor = mask.view(arr.shape)
    return reshaped_tensor

class PGDGroupBB(Attack):

    def __init__(self, model, device=None, eps=8/255, alpha=2/255, steps=10, random_start=True, k=10, D=1., num_random_directions=100, miu= 0.01):
        super().__init__('PGDGroupBB', model, device)
        self.eps = eps
        self.alpha = alpha
        self.steps = steps
        self.random_start = random_start
        self.supported_mode = ['default', 'targeted']
        self.k = k
        self.D = D
        self.num_random_directions = num_random_directions
        self.miu = miu

    def forward(self, images, labels):
        r"""
        Overridden.
        """

        images = images.clone().detach().to(self.device)
        labels = labels.clone().detach().to(self.device)

        if self.targeted:
            target_labels = self.get_target_label(images, labels)

        loss = nn.NLLLoss()
        adv_images = images.clone().detach()

        if self.random_start:
            adv_images = adv_images + \
                torch.empty_like(adv_images).uniform_(-self.eps, self.eps)
            adv_images = torch.clamp(adv_images, min=0, max=1).detach()

        costs = []
        step_width = 8
        step_height = 8
        height = 32
        width = 32
        n_cells = int(height*width/(step_width*step_height))
        n_cells_horiz = int(height/step_height)
        assert n_cells == 16
        assert n_cells_horiz == 4

        groups = [[] for _ in range(n_cells)]
        for i in range(height):
            for j in range(width):
                groups[i//step_height * n_cells_horiz + j//step_width].append(i*width + j)

        num_divisions = 4
        H = 32
        W = 32
        division_h = H // num_divisions
        division_w = W // num_divisions

        for _ in range(self.steps):
            outputs = self.get_logits(torch.sigmoid(adv_images))
            # Calculate loss
            if self.targeted:
                cost = -loss(outputs, target_labels)
            else:
                cost = loss(outputs, labels)
            costs.append(cost.item())
            
            d = 3 * 32 * 32
            G_total = torch.zeros_like(adv_images)
            n_images = adv_images.shape[0]
            for j, adv_image in enumerate(adv_images):

                adv_image_reshaped = adv_image.unsqueeze(0).detach()
                adv_image_reshaped.requires_grad = True
                F = loss(self.get_logits(torch.sigmoid(adv_image_reshaped)), labels[j].unsqueeze(0))
                grad_unique = torch.autograd.grad(F, adv_image_reshaped,
                            retain_graph=False, create_graph=False)[0]
                for l in range(self.num_random_directions):
                    eps = torch.randn_like(adv_image)
                    eps /= torch.linalg.norm(eps)
                    perturbed_img = adv_image + self.miu * eps
                    F_eps = loss(self.get_logits(torch.sigmoid(perturbed_img.unsqueeze(0))), labels[j].unsqueeze(0))
                    G_current = d/self.miu * (F_eps - F)* eps
                    G_total[j] = G_total[j] + 1/(l+1) * (G_current - G_total[j])    # stable computation of the mean

            adv_images = adv_images.detach() + self.alpha*G_total

            delta = (adv_images - images).detach()

            delta = hard_threshold(delta, self.k).detach()


            for delt in delta:
            # Divide the image into num_divisions x num_divisions regions
                for i in range(num_divisions):
                    for j in range(num_divisions):
                        region = delt[:, i * division_h:(i + 1) * division_h, j * division_w:(j + 1) * division_w]
                        norm_region = torch.linalg.norm(region).detach()
                        if norm_region > self.D:
                            delt[:, i * division_h:(i + 1) * division_h, j * division_w:(j + 1) * division_w] *= self.D / norm_region

            adv_images = (images + delta).detach()
            adv_images_with_sigmoid = torch.sigmoid(adv_images).detach()

        results = {'adv_images': adv_images_with_sigmoid, 'delta': delta, 'costs': costs}
        return results
    


class PGDvanillal2BB(Attack):

    def __init__(self, model, device=None, eps=8/255, alpha=2/255, steps=10, random_start=True, k=10, D=1., num_random_directions=100, miu= 0.01):
        super().__init__('PGDvanillal2BB', model, device)
        self.eps = eps
        self.alpha = alpha
        self.steps = steps
        self.random_start = random_start
        self.supported_mode = ['default', 'targeted']
        self.k = k
        self.D = D
        self.num_random_directions = num_random_directions
        self.miu = miu

    def forward(self, images, labels):
        r"""
        Overridden.
        """

        images = images.clone().detach().to(self.device)
        labels = labels.clone().detach().to(self.device)

        if self.targeted:
            target_labels = self.get_target_label(images, labels)

        loss = nn.NLLLoss()
        adv_images = images.clone().detach()

        if self.random_start:
            adv_images = adv_images + \
                torch.empty_like(adv_images).uniform_(-self.eps, self.eps)
            adv_images = torch.clamp(adv_images, min=0, max=1).detach()

        costs = []


        for _ in range(self.steps):

            adv_images.requires_grad = True
            outputs = self.get_logits(torch.sigmoid(adv_images))

            # Calculate loss
            if self.targeted:
                cost = -loss(outputs, target_labels)
            else:
                cost = loss(outputs, labels)
            costs.append(cost.item())
            
            
            d = 3 * 32 * 32
            G_total = torch.zeros_like(adv_images)
            n_images = adv_images.shape[0]
            for j, adv_image in enumerate(adv_images):

                adv_image_reshaped = adv_image.unsqueeze(0).detach()
                adv_image_reshaped.requires_grad = True
                F = loss(self.get_logits(torch.sigmoid(adv_image_reshaped)), labels[j].unsqueeze(0))
                grad_unique = torch.autograd.grad(F, adv_image_reshaped,
                            retain_graph=False, create_graph=False)[0]
                for l in range(self.num_random_directions):
                    eps = torch.randn_like(adv_image)
                    eps /= torch.linalg.norm(eps)
                    perturbed_img = adv_image + self.miu * eps
                    F_eps = loss(self.get_logits(torch.sigmoid(perturbed_img.unsqueeze(0))), labels[j].unsqueeze(0))
                    G_current = d/self.miu * (F_eps - F)* eps
                    G_total[j] = G_total[j] + 1/(l+1) * (G_current - G_total[j])    # stable computation of the mean

            adv_images = adv_images.detach() + self.alpha*G_total

            delta = (adv_images - images).detach()

            delta = hard_threshold(delta, self.k).detach()
            for delt in delta:
                norm_delta = torch.linalg.norm(delt)
                if norm_delta > self.D:
                    delt  *= self.D / norm_delta


            adv_images = (images + delta).detach()
            adv_images_with_sigmoid = torch.sigmoid(adv_images).detach()

        results = {'adv_images': adv_images_with_sigmoid, 'delta': delta, 'costs': costs}
        return results
    

def max_l2_norm_in_regions(image):
    num_divisions = 4
    H = 32
    W = 32
    division_h = H // num_divisions
    division_w = W // num_divisions
    max_norm = -np.inf
    for i in range(num_divisions):
        for j in range(num_divisions):
            region = image[:, i * division_h:(i + 1) * division_h, j * division_w:(j + 1) * division_w]
            norm_region = torch.linalg.norm(region).detach()
            if norm_region > max_norm:
                max_norm = norm_region
    return max_norm