import numpy as np 
import torch
import torch.nn as nn

import random
import copy
from torchvision import transforms
from PIL import Image
import torch.nn.functional as F

from enum import Enum
import scipy.stats as st

def set_mode_for_attack(model):
    """
    model with RNN should be in train mode to conduct attack.
    e.g.) GRU, LSTM
    """
    model.train()
    for m in model.modules():
        if m.__class__.__name__.startswith("GRU"):
            m.train()
            print(f"Set {m.__class__.__name__} in train mode.")


class UnsupPGD():
    """
    Unsupervised PGD
        Image attack aims to maximize the similarity between 
        the adversarial image and the original image.
            => image attack is independent of texts
    """
    def __init__(self, model, img_attacker, txt_attacker):
        self.model=model
        self.img_attacker = img_attacker
        self.txt_attacker = txt_attacker

    
    def attack(self, imgs, txts, device='cpu', max_length=30, txt_att_k=0, is_train=True, **kwargs):

        # 2. text-guided attack
        adv_txts = txts
        with torch.no_grad():
            txts_input = self.txt_attacker.tokenizer(
                adv_txts,
                padding="max_length",
                truncation=True,
                max_length=max_length,
                return_tensors="pt",
            ).to(device)
            txts_output = self.model.inference_text(txts_input)
            txt_supervisions = txts_output["text_feat"]
        adv_imgs = self.img_attacker.run_trades(self.model, imgs)
        
        # 3. random replace attack 
        if txt_att_k > 0:
            adv_txts_list = self.txt_attacker.random_replace_attack(self.model, txts)
            # randomly select 1 adversarial texts
            _adv_txts_list = []
            for adv_txts in adv_txts_list:
                if len(adv_txts) > 1:
                    _adv_txts_list.append(random.sample(adv_txts, 1))
                else:
                    # print('Warning: the number of adversarial texts is less than 1: ', len(adv_txts))
                    _adv_txts_list.append(adv_txts)
            adv_txts_list = _adv_txts_list
            adv_txts = [txt for txts in adv_txts_list for txt in txts]

        if is_train:
            assert len(adv_imgs) == len(adv_txts), (len(adv_imgs), len(adv_txts))

        return adv_imgs, adv_txts

      
class NormType(Enum):
    Linf = 0
    L2 = 1


def clamp_by_l2(x, max_norm):
    norm = torch.norm(x, dim=(1, 2, 3), p=2, keepdim=True)
    factor = torch.min(max_norm / norm, torch.ones_like(norm))
    return x * factor


def random_init(x, norm_type, epsilon):
    delta = torch.zeros_like(x)
    if norm_type == NormType.Linf:
        delta.data.uniform_(0.0, 1.0)
        delta.data = delta.data * epsilon
    elif norm_type == NormType.L2:
        delta.data.uniform_(0.0, 1.0)
        delta.data = delta.data - x
        delta.data = clamp_by_l2(delta.data, epsilon)
    return delta
          

class ImageAttacker:
    # PGD
    def __init__(
        self,
        preprocess,
        eps,
        steps,
        step_size,
        norm_type=NormType.Linf,
        random_init=True,
        cls=True,
        *args,
        **kwargs
    ):
        self.norm_type = norm_type
        self.random_init = random_init
        # self.epsilon = epsilon
        self.cls = cls

        self.preprocess = preprocess
        self.epsilon = eps
        self.steps = steps
        self.step_size = step_size

        self.bounding = kwargs.get("bounding")
        if self.bounding is None:
            self.bounding = (0, 1)

    def input_diversity(self, image):
        return image

    # def attack(self, image, num_iters):
    def attack(self, image):
        if self.random_init:
            self.delta = random_init(image, self.norm_type, self.epsilon)
        else:
            self.delta = torch.zeros_like(image)

        if hasattr(self, "kernel"):
            self.kernel = self.kernel.to(image.device)

        if hasattr(self, "grad"):
            self.grad = torch.zeros_like(image)

        # epsilon_per_iter = self.epsilon / num_iters * 1.25
        epsilon_per_iter = self.step_size
        num_iters = self.steps

        for i in range(num_iters):
            self.delta = self.delta.detach()
            self.delta.requires_grad = True

            image_diversity = self.input_diversity(image + self.delta)
            # plt.imshow(image_diversity.cpu().detach().numpy()[0].transpose(1, 2, 0))
            if self.preprocess is not None:
                image_diversity = self.preprocess(image_diversity)

            yield image_diversity

            grad = self.get_grad()
            grad = self.normalize(grad)
            self.delta = self.delta.data + epsilon_per_iter * grad

            # constraint 1: epsilon
            self.delta = self.project(self.delta, self.epsilon)
            # constraint 2: image range
            self.delta = torch.clamp(image + self.delta, *self.bounding) - image

        yield (image + self.delta).detach()

    def get_grad(self):
        self.grad = self.delta.grad.clone()
        return self.grad

    def project(self, delta, epsilon):
        if self.norm_type == NormType.Linf:
            return torch.clamp(delta, -epsilon, epsilon)
        elif self.norm_type == NormType.L2:
            return clamp_by_l2(delta, epsilon)

    def normalize(self, grad):
        if self.norm_type == NormType.Linf:
            return torch.sign(grad)
        elif self.norm_type == NormType.L2:
            return grad / torch.norm(grad, dim=(1, 2, 3), p=2, keepdim=True)

    def run_trades(self, net, image):
        set_mode_for_attack(net)
        
        with torch.no_grad():
            origin_output = net.inference_image(self.preprocess(image))
            if self.cls:
                # origin_embed = origin_output["image_embed"][:, 0, :].detach()
                origin_embed = origin_output["image_feat"][:, 0, :].detach()
            else:
                # origin_embed = origin_output["image_embed"].flatten(1).detach()
                origin_embed = origin_output["image_feat"].flatten(1).detach()

        criterion = torch.nn.KLDivLoss(reduction="batchmean")
        attacker = self.attack(image)

        num_iters = self.steps

        for i in range(num_iters):
            image_adv = next(attacker)
            adv_output = net.inference_image(image_adv)
            if self.cls:
                # adv_embed = adv_output["image_embed"][:, 0, :]
                adv_embed = adv_output["image_feat"][:, 0, :]
            else:
                # adv_embed = adv_output["image_embed"].flatten(1)
                adv_embed = adv_output["image_feat"].flatten(1)

            loss = criterion(
                adv_embed.log_softmax(dim=-1), origin_embed.softmax(dim=-1)
            )
            loss.backward()

        image_adv = next(attacker)
        return image_adv



filter_words = ['a', 'about', 'above', 'across', 'after', 'afterwards', 'again', 'against', 'ain', 'all', 'almost',
                'alone', 'along', 'already', 'also', 'although', 'am', 'among', 'amongst', 'an', 'and', 'another',
                'any', 'anyhow', 'anyone', 'anything', 'anyway', 'anywhere', 'are', 'aren', "aren't", 'around', 'as',
                'at', 'back', 'been', 'before', 'beforehand', 'behind', 'being', 'below', 'beside', 'besides',
                'between', 'beyond', 'both', 'but', 'by', 'can', 'cannot', 'could', 'couldn', "couldn't", 'd', 'didn',
                "didn't", 'doesn', "doesn't", 'don', "don't", 'down', 'due', 'during', 'either', 'else', 'elsewhere',
                'empty', 'enough', 'even', 'ever', 'everyone', 'everything', 'everywhere', 'except', 'first', 'for',
                'former', 'formerly', 'from', 'hadn', "hadn't", 'hasn', "hasn't", 'haven', "haven't", 'he', 'hence',
                'her', 'here', 'hereafter', 'hereby', 'herein', 'hereupon', 'hers', 'herself', 'him', 'himself', 'his',
                'how', 'however', 'hundred', 'i', 'if', 'in', 'indeed', 'into', 'is', 'isn', "isn't", 'it', "it's",
                'its', 'itself', 'just', 'latter', 'latterly', 'least', 'll', 'may', 'me', 'meanwhile', 'mightn',
                "mightn't", 'mine', 'more', 'moreover', 'most', 'mostly', 'must', 'mustn', "mustn't", 'my', 'myself',
                'namely', 'needn', "needn't", 'neither', 'never', 'nevertheless', 'next', 'no', 'nobody', 'none',
                'noone', 'nor', 'not', 'nothing', 'now', 'nowhere', 'o', 'of', 'off', 'on', 'once', 'one', 'only',
                'onto', 'or', 'other', 'others', 'otherwise', 'our', 'ours', 'ourselves', 'out', 'over', 'per',
                'please', 's', 'same', 'shan', "shan't", 'she', "she's", "should've", 'shouldn', "shouldn't", 'somehow',
                'something', 'sometime', 'somewhere', 'such', 't', 'than', 'that', "that'll", 'the', 'their', 'theirs',
                'them', 'themselves', 'then', 'thence', 'there', 'thereafter', 'thereby', 'therefore', 'therein',
                'thereupon', 'these', 'they', 'this', 'those', 'through', 'throughout', 'thru', 'thus', 'to', 'too',
                'toward', 'towards', 'under', 'unless', 'until', 'up', 'upon', 'used', 've', 'was', 'wasn', "wasn't",
                'we', 'were', 'weren', "weren't", 'what', 'whatever', 'when', 'whence', 'whenever', 'where',
                'whereafter', 'whereas', 'whereby', 'wherein', 'whereupon', 'wherever', 'whether', 'which', 'while',
                'whither', 'who', 'whoever', 'whole', 'whom', 'whose', 'why', 'with', 'within', 'without', 'won',
                "won't", 'would', 'wouldn', "wouldn't", 'y', 'yet', 'you', "you'd", "you'll", "you're", "you've",
                'your', 'yours', 'yourself', 'yourselves', '.', '-', 'a the', '/', '?', 'some', '"', ',', 'b', '&', '!',
                '@', '%', '^', '*', '(', ')', "-", '-', '+', '=', '<', '>', '|', ':', ";", '～', '·']
filter_words = set(filter_words)
    

class TextAttacker():
    def __init__(self, tokenizer):
        self.tokenizer = tokenizer

    def random_replace_attack(self, net, texts):
        """
        random replace attack.
        returns list of masked texts. 
            len(final_adverse) = len(texts)
            len(final_adverse[i]) = len(texts[i]) - len(filter_words[i])

            final_adverse = [
                [
                    I am a [UNK] person,
                    I am a good [UNK]
                ],
                ...
            ]
        """
        final_adverse = []
        for text in texts:
            masked_text = []
            words = text.split(' ') # list of words
            final_words = copy.deepcopy(words)
            for i in range(len(words)):
                # only replace words that are not in filter_words
                if words[i] in filter_words:
                    continue
                masked_text.append(' '.join(final_words[0:i] + ['[UNK]'] + final_words[i + 1:]))
            final_adverse.append(masked_text)
        return final_adverse
 