# This code is for the most part taken from https://github.com/Harry24k/adversarial-attacks-pytorch
import torch
import torch.nn as nn
from torchattacks.attack import Attack
import numpy as np

def hard_threshold(arr, k):
    flattened_arr = arr.view(arr.shape[0], -1)
    topk_indices = torch.topk(torch.abs(flattened_arr), k=k, dim=1).indices
    mask = torch.zeros_like(flattened_arr)
    mask.scatter_(1, topk_indices, flattened_arr.gather(1, topk_indices))
    reshaped_tensor = mask.view(arr.shape)
    return reshaped_tensor


class PGDGroup(Attack):

    def __init__(self, model, device=None, eps=8/255, alpha=2/255, steps=10, random_start=True, k=10, D=1.):
        super().__init__('PGDGroup', model, device)
        self.eps = eps
        self.alpha = alpha
        self.steps = steps
        self.random_start = random_start
        self.supported_mode = ['default', 'targeted']
        self.k = k
        self.D = D

    def forward(self, images, labels):
        r"""
        Overridden.
        """

        images = images.clone().detach().to(self.device)
        labels = labels.clone().detach().to(self.device)

        if self.targeted:
            target_labels = self.get_target_label(images, labels)

        loss = nn.CrossEntropyLoss()
        adv_images = images.clone().detach()

        if self.random_start:
            adv_images = adv_images + \
                torch.empty_like(adv_images).uniform_(-self.eps, self.eps)
            adv_images = torch.clamp(adv_images, min=0, max=1).detach()

        costs = []

        step_width = 8
        step_height = 8
        height = 32
        width = 32
        n_cells = int(height*width/(step_width*step_height))
        n_cells_horiz = int(height/step_height)
        assert n_cells == 16
        assert n_cells_horiz == 4

        groups = [[] for _ in range(n_cells)]
        for i in range(height):
            for j in range(width):
                groups[i//step_height * n_cells_horiz + j//step_width].append(i*width + j)

        num_divisions = 4
        H = 32
        W = 32
        division_h = H // num_divisions
        division_w = W // num_divisions



        for _ in range(self.steps):
            adv_images.requires_grad = True
            outputs = self.get_logits(torch.sigmoid(adv_images))

            # Calculate loss
            if self.targeted:
                cost = -loss(outputs, target_labels)
            else:
                cost = loss(outputs, labels)
            costs.append(cost.item())

            # Update adversarial images
            grad = torch.autograd.grad(cost, adv_images,
                                       retain_graph=False, create_graph=False)[0]

            adv_images = adv_images.detach() + self.alpha*grad

            delta = (adv_images - images).detach()

            delta = hard_threshold(delta, self.k).detach()

            for delt in delta:
            # Divide the image into num_divisions x num_divisions regions
                for i in range(num_divisions):
                    for j in range(num_divisions):
                        region = delt[:, i * division_h:(i + 1) * division_h, j * division_w:(j + 1) * division_w]
                        norm_region = torch.linalg.norm(region).detach()
                        if norm_region > self.D:
                            delt[:, i * division_h:(i + 1) * division_h, j * division_w:(j + 1) * division_w] *= self.D / norm_region


            adv_images = (images + delta).detach()
            adv_images_with_sigmoid = torch.sigmoid(adv_images).detach()

        results = {'adv_images': adv_images_with_sigmoid, 'delta': delta, 'costs': costs}
        return results
    


class PGDvanillal2(Attack):

    def __init__(self, model, device=None, eps=8/255, alpha=2/255, steps=10, random_start=True, k=10, D=1.):
        super().__init__('PGDvanillal2', model, device)
        self.eps = eps
        self.alpha = alpha
        self.steps = steps
        self.random_start = random_start
        self.supported_mode = ['default', 'targeted']
        self.k = k
        self.D = D

    def forward(self, images, labels):
        r"""
        Overridden.
        """

        images = images.clone().detach().to(self.device)
        labels = labels.clone().detach().to(self.device)

        if self.targeted:
            target_labels = self.get_target_label(images, labels)

        loss = nn.CrossEntropyLoss()
        adv_images = images.clone().detach()

        if self.random_start:
            adv_images = adv_images + \
                torch.empty_like(adv_images).uniform_(-self.eps, self.eps)
            adv_images = torch.clamp(adv_images, min=0, max=1).detach()

        costs = []


        for _ in range(self.steps):
            adv_images.requires_grad = True
            outputs = self.get_logits(torch.sigmoid(adv_images))

            # Calculate loss
            if self.targeted:
                cost = -loss(outputs, target_labels)
            else:
                cost = loss(outputs, labels)
            costs.append(cost.item())

            # Update adversarial images
            grad = torch.autograd.grad(cost, adv_images,
                                       retain_graph=False, create_graph=False)[0]

            adv_images = adv_images.detach() + self.alpha*grad

            delta = (adv_images - images).detach()

            delta = hard_threshold(delta, self.k).detach()
            for delt in delta:
                # Divide the image into num_divisions x num_divisions regions
                norm_delta = torch.linalg.norm(delt)
                if norm_delta > self.D:
                    delt  *= self.D / norm_delta

            adv_images = (images + delta).detach()
            adv_images_with_sigmoid = torch.sigmoid(adv_images).detach()

        results = {'adv_images': adv_images_with_sigmoid, 'delta': delta, 'costs': costs}
        return results
    

def max_l2_norm_in_regions(image):
    num_divisions = 4
    H = 32
    W = 32
    division_h = H // num_divisions
    division_w = W // num_divisions
    max_norm = -np.inf
    for i in range(num_divisions):
        for j in range(num_divisions):
            region = image[:, i * division_h:(i + 1) * division_h, j * division_w:(j + 1) * division_w]
            norm_region = torch.linalg.norm(region).detach()
            if norm_region > max_norm:
                max_norm = norm_region
    return max_norm