# https://github.com/Zoky-2020/SGA?tab=readme-ov-file

import numpy as np
import torch
import torch.nn as nn

import random
import copy
from torchvision import transforms
from PIL import Image
import torch.nn.functional as F

xent_loss = torch.nn.CrossEntropyLoss()


def set_mode_for_attack(model):
    """
    model with RNN should be in train mode to conduct attack.
    e.g.) GRU, LSTM
    """
    model.train()
    for m in model.modules():
        if m.__class__.__name__.startswith("GRU"):
            m.train()
            print(f"Set {m.__class__.__name__} in train mode.")


class SupPGD:
    def __init__(self, model, img_attacker, tokenizer):
        self.model = model
        self.img_attacker = img_attacker
        self.tokenizer = tokenizer

    def attack(
        self,
        imgs,
        txts,
        txt2img,
        device="cpu",
        max_length=30,
        scales=None,
        txt_att_k=0,
        img_attack_loss="sim",
        is_train=True,
        **kwargs,
    ):

        # text-guided attack
        with torch.no_grad():
            txts_input = self.tokenizer(
                txts,
                padding="max_length",
                truncation=True,
                max_length=max_length,
                return_tensors="pt",
            ).to(device)
            txts_output = self.model.inference_text(txts_input)
            txt_supervisions = txts_output["text_feat"]
        adv_imgs, final_loss_matrix = self.img_attacker.txt_guided_attack(
            self.model,
            imgs,
            txt2img,
            device,
            scales=scales,
            txt_embeds=txt_supervisions,
            loss_metric=img_attack_loss,
        )

        if is_train:
            assert len(adv_imgs) == len(txts), print(len(adv_imgs), len(txts))

        return adv_imgs, txts


class ImageAttacker:
    def __init__(self, normalization, eps=2 / 255, steps=10, step_size=0.5 / 255):
        self.normalization = normalization
        self.eps = eps
        self.steps = steps
        self.step_size = step_size

    def loss_func(self, model, adv_imgs_embeds, txts_embeds, txt2img, loss_metric="sim"):
        """
        similarity loss

        return loss, loss_matrix

        loss_metric:
            sim: default
                maximize the distance between correct pairs
            clip:
                maximize the clip training loss
        """
        assert loss_metric in ["sim", "clip", "clip_i"], loss_metric

        device = adv_imgs_embeds.device

        bi = adv_imgs_embeds.shape[0]
        bj = txts_embeds.shape[0]
        it_labels = torch.zeros(bi, bj).to(device)
        for i in range(len(txt2img)):
            it_labels[txt2img[i], i] = 1

        if loss_metric == "clip":
            image_features = adv_imgs_embeds
            text_features = txts_embeds
            # normalized features
            image_features = image_features / image_features.norm(dim=-1, keepdim=True)
            text_features = text_features / text_features.norm(dim=-1, keepdim=True)

            # cosine similarity as logits
            logit_scale = model.logit_scale.exp()
            logits_per_image = logit_scale * image_features @ text_features.t()
            logits_per_text = logits_per_image.t()

            # loss
            loss = (xent_loss(logits_per_image, it_labels) + xent_loss(logits_per_text, it_labels)) / 2
            return loss, -(logits_per_image * it_labels)
        
        elif loss_metric == "clip_i":
            image_features = adv_imgs_embeds
            text_features = txts_embeds
            # normalized features
            image_features = image_features / image_features.norm(dim=-1, keepdim=True)
            text_features = text_features / text_features.norm(dim=-1, keepdim=True)

            # cosine similarity as logits
            logit_scale = model.logit_scale.exp()
            logits_per_image = logit_scale * image_features @ text_features.t()
            logits_per_text = logits_per_image.t()

            # loss
            loss = xent_loss(logits_per_image, it_labels)
            return loss, -(logits_per_image * it_labels)

        elif loss_metric == "sim":
            it_sim_matrix = adv_imgs_embeds @ txts_embeds.T
            # loss_IaTcpos = -(it_sim_matrix * it_labels).sum(-1).mean()
            # loss = loss_IaTcpos
            loss_matrix = -(it_sim_matrix * it_labels)
            loss = loss_matrix.sum(-1).mean()

        return loss, loss_matrix

    def txt_guided_attack(
        self, model, imgs, txt2img, device, scales=None, txt_embeds=None, loss_metric="sim"
    ):

        model.eval()
        set_mode_for_attack(model)

        b, _, _, _ = imgs.shape

        if scales is None:
            scales_num = 1
        else:
            scales_num = len(scales) + 1

        adv_imgs = imgs.detach() + torch.from_numpy(
            np.random.uniform(-self.eps, self.eps, imgs.shape)
        ).float().to(device)
        adv_imgs = torch.clamp(adv_imgs, 0.0, 1.0)

        for i in range(self.steps):
            adv_imgs.requires_grad_()
            scaled_imgs = self.get_scaled_imgs(adv_imgs, scales, device)

            if self.normalization is not None:
                adv_imgs_output = model.inference_image(self.normalization(scaled_imgs))
            else:
                adv_imgs_output = model.inference_image(scaled_imgs)

            adv_imgs_embeds = adv_imgs_output["image_feat"]
            model.zero_grad()
            with torch.enable_grad():
                loss_list = []
                loss = torch.tensor(0.0, dtype=torch.float32).to(device)
                for i in range(scales_num):
                    loss_item, loss_matrix = self.loss_func(
                        model,
                        adv_imgs_embeds[i * b : i * b + b],
                        txt_embeds,
                        txt2img,
                        loss_metric=loss_metric,
                    )
                    loss_list.append(loss_item.item())
                    loss += loss_item
            loss.backward()

            grad = adv_imgs.grad
            grad = grad / torch.mean(torch.abs(grad), dim=(1, 2, 3), keepdim=True)

            perturbation = self.step_size * grad.sign()
            adv_imgs = adv_imgs.detach() + perturbation
            adv_imgs = torch.min(torch.max(adv_imgs, imgs - self.eps), imgs + self.eps)
            adv_imgs = torch.clamp(adv_imgs, 0.0, 1.0)

            # final loss matrix
            final_loss_matrix = loss_matrix

        return adv_imgs, final_loss_matrix.detach().cpu().numpy()

    def get_scaled_imgs(self, imgs, scales=None, device="cuda"):
        if scales is None:
            return imgs

        ori_shape = (imgs.shape[-2], imgs.shape[-1])

        reverse_transform = transforms.Resize(ori_shape, interpolation=transforms.InterpolationMode.BICUBIC)
        result = []
        for ratio in scales:
            scale_shape = (int(ratio * ori_shape[0]), int(ratio * ori_shape[1]))
            scale_transform = transforms.Resize(
                scale_shape, interpolation=transforms.InterpolationMode.BICUBIC
            )
            scaled_imgs = imgs + torch.from_numpy(np.random.normal(0.0, 0.05, imgs.shape)).float().to(
                device
            )
            scaled_imgs = scale_transform(scaled_imgs)
            scaled_imgs = torch.clamp(scaled_imgs, 0.0, 1.0)

            reversed_imgs = reverse_transform(scaled_imgs)

            result.append(reversed_imgs)

        return torch.cat(
            [
                imgs,
            ]
            + result,
            0,
        )
