import numpy as np 
import torch
import torch.nn as nn
import random
import copy
from torchvision import transforms
from PIL import Image
import torch.nn.functional as F
from attacks.step import LinfStep, L2Step
import torchvision.transforms.functional as TF

STEPS = {
    'Linf': LinfStep,
    'L2': L2Step,
}

class SGAttacker():
    def __init__(self, model_inf, model_infi, model_inft, img_attacker, txt_attacker, dc_model = None, optimizer = None, model_infi1 = None):
        self.model_inf=model_inf
        self.model_infi=model_infi
        self.model_infi1=model_infi1
        self.model_inft=model_inft
        self.img_attacker = img_attacker
        self.txt_attacker = txt_attacker

    
    def attack7(self, imgs, txts, txt2img, scales_num, args, device='cpu', max_length=30, scales=None, **kwargs):
    
        with torch.no_grad():
            txts_input = self.txt_attacker.tokenizer(txts, padding='max_length', truncation=True, max_length=max_length, return_tensors="pt")#.to(device)
            txts_input.input_ids = txts_input.input_ids.to(device)
            txts_input.attention_mask = txts_input.attention_mask.to(device)
            txts_output = self.model_inft(txts_input.input_ids,txts_input.attention_mask)
            txt_supervisions = txts_output['text_feat']
            
        adv_imgs = self.img_attacker.txt_guided_attack(self.model_infi, imgs, txt2img, scales_num, args, device,
                                                       scales=scales, txt_embeds = txt_supervisions)
                            
        return adv_imgs, txts

    def attack8(self, imgs, txts, txt2img, scales_num, args, device='cpu', max_length=30, scales=None, **kwargs):
    
        # original state
        with torch.no_grad():
            origin_img_output = self.model_infi(self.img_attacker.normalization(imgs))
            img_supervisions = origin_img_output['image_feat'][txt2img] 
        adv_txts = self.txt_attacker.img_guided_attack2(self.model_inft, txts, img_embeds=img_supervisions)

        with torch.no_grad():
            txts_input = self.txt_attacker.tokenizer(adv_txts, padding='max_length', truncation=True, max_length=max_length, return_tensors="pt")#.to(device)
            txts_input.input_ids = txts_input.input_ids.to(device)
            txts_input.attention_mask = txts_input.attention_mask.to(device)
            txts_output = self.model_inft(txts_input.input_ids,txts_input.attention_mask)
            txt_supervisions = txts_output['text_feat']
        adv_imgs = self.img_attacker.txt_guided_attack(self.model_infi, imgs, txt2img, scales_num, args, device, 
                                                       scales=scales, txt_embeds = txt_supervisions)
                            
        return adv_imgs, adv_txts


class ImageAttacker():
    def __init__(self, normalization, eps=2/255, steps=10, step_size=0.5/255, eps1=2/255, steps1=10, step_size1=1/255):
        self.normalization = normalization
        self.eps = eps
        self.steps = steps 
        self.step_size = step_size 

        self.eps1 = eps1
        self.steps1 = steps1
        self.step_size1 = step_size1
        
        self.criterion = torch.nn.KLDivLoss(reduction='batchmean')
        

    def loss_func(self, adv_imgs_embeds, txts_embeds, txt2img):  
        device = adv_imgs_embeds.device    

        it_sim_matrix = adv_imgs_embeds @ txts_embeds.T
        it_labels = torch.zeros(it_sim_matrix.shape).to(device)
        
        for i in range(len(txt2img)):
            it_labels[txt2img[i], i]=1
        
        loss_IaTcpos = -(it_sim_matrix * it_labels).sum(-1).mean()
        loss = loss_IaTcpos
        
        return loss
        
    def loss_func1(self, adv_imgs_embeds, imgs_embeds):  
        device = adv_imgs_embeds.device    

        it_sim_matrix = adv_imgs_embeds @ imgs_embeds.T
        it_labels = torch.eye(it_sim_matrix.size(0)).to(device)
        
        loss_IaTcpos = (it_sim_matrix * it_labels).sum(-1).mean()
        loss = loss_IaTcpos
        
        return loss


    def txt_guided_attack(self, model, imgs, txt2img, scales_num, args, device, scales=None, txt_embeds=None):
        
        model.eval()
       
        b, _, h, w = imgs.shape
        
        delta = torch.zeros_like(imgs)
        delta = delta.to(device)
        
        orig_delta = delta.clone().detach()
        step = STEPS[args.constraint](orig_delta, self.eps, self.step_size)
        
        with torch.no_grad():
            imgs_output = model(self.normalization(imgs))
            imgs_embeds = imgs_output['image_feat'].detach()
            imgs_block_embeds = imgs_output['image_block_embed']
            imgs_block_embeds_1 = imgs_output['image_block_embed1']
            
        ori_shape = (imgs.shape[-2], imgs.shape[-1])

        num_iter = 10
        sub_iter = int(self.steps/num_iter)
        patch_size =16
  
        for i in range(num_iter):
            
            ratio = random.choice(scales)
            ratio1 = random.choice(scales)
            
            scale_shape = (int(1*ori_shape[0]), 
                                  int(ratio*ori_shape[1]))
            scale_transform = transforms.Resize(scale_shape,
                                  interpolation=transforms.InterpolationMode.BICUBIC)
        
            scale_shape1 = (int(ratio1*ori_shape[0]), 
                                  int(1*ori_shape[1]))
            scale_transform1 = transforms.Resize(scale_shape1,
                                  interpolation=transforms.InterpolationMode.BICUBIC)

        
            for j in range(sub_iter):#self.steps
                
                delta = delta.clone().detach().requires_grad_(True) 
                
                adv_imgs = imgs.detach() + delta
                adv_imgs = torch.clamp(adv_imgs, 0.0, 1.0)
                
                # Define the vertical translation distance (e.g., 50 pixels downward)
                translate_x = random.randint(4, 12) 
                translate_y = random.randint(4, 12) 
            
                # Perform cyclic vertical translation using torch.roll
                translated_images = torch.roll(adv_imgs, shifts=translate_x, dims=2)
                translated_images = torch.clamp(translated_images, 0.0, 1.0)
            
                translated_images1 = torch.roll(adv_imgs, shifts=translate_y, dims=3)
                translated_images1 = torch.clamp(translated_images1, 0.0, 1.0) 

                
                
                scaled_imgs = scale_transform(adv_imgs)
                scaled_imgs1 = scale_transform1(adv_imgs)
                left = 0 #
                top = 0 #
            
                # Crop the scaled image back to the original shape
                crop_x1 = translate_x 
                crop_x2 = int(w*(ratio-1))
                crop_x3 = int(w*(ratio-1) + translate_x)
                
                translated_imagesx1 = torch.roll(scaled_imgs, shifts=crop_x1, dims=2)
                translated_imagesx1 = torch.clamp(translated_imagesx1, 0.0, 1.0)
                
                translated_imagesx2 = torch.roll(scaled_imgs, shifts=crop_x2, dims=2)
                translated_imagesx2 = torch.clamp(translated_imagesx2, 0.0, 1.0)
                
                translated_imagesx3 = torch.roll(scaled_imgs, shifts=crop_x3, dims=2)
                translated_imagesx3 = torch.clamp(translated_imagesx3, 0.0, 1.0)
                
                folded_img = TF.crop(scaled_imgs, top, left, ori_shape[0], ori_shape[1])
                folded_img1 = TF.crop(translated_imagesx1, top, left, ori_shape[0], ori_shape[1])
                folded_img2 = TF.crop(translated_imagesx2, top, left, ori_shape[0], ori_shape[1])
                folded_img3 = TF.crop(translated_imagesx3, top, left, ori_shape[0], ori_shape[1])

                
                # Crop the scaled image back to the original shape
                crop_y1 = translate_y
                crop_y2 = int(h*(ratio1-1))
                crop_y3 = int(h*(ratio1-1) + crop_y1) #5*16 + 24
                
                translated_imagesy1 = torch.roll(scaled_imgs1, shifts=crop_y1, dims=3)
                translated_imagesy1 = torch.clamp(translated_imagesy1, 0.0, 1.0)
                
                translated_imagesy2 = torch.roll(scaled_imgs1, shifts=crop_y2, dims=3)
                translated_imagesy2 = torch.clamp(translated_imagesy2, 0.0, 1.0)
                
                translated_imagesy3 = torch.roll(scaled_imgs1, shifts=crop_y3, dims=3)
                translated_imagesy3 = torch.clamp(translated_imagesy3, 0.0, 1.0)
                
                folded_imgy = TF.crop(scaled_imgs1, top, left , ori_shape[0], ori_shape[1])
                folded_imgy1 = TF.crop(translated_imagesy1, top, left , ori_shape[0], ori_shape[1])
                folded_imgy2 = TF.crop(translated_imagesy2, top, left , ori_shape[0], ori_shape[1])
                folded_imgy3 = TF.crop(translated_imagesy3, top, left , ori_shape[0], ori_shape[1])
                
            
                if self.normalization is not None:
                    adv_imgs_output = model(self.normalization(adv_imgs))
                    adv_imgs_output1 = model(self.normalization(translated_images))
                    adv_imgs_output2 = model(self.normalization(translated_images1))
                
                    roll_imgs_output = model(self.normalization(folded_img))
                    roll_imgs_output1 = model(self.normalization(folded_img1))
                    roll_imgs_output2 = model(self.normalization(folded_img2))
                    roll_imgs_output3 = model(self.normalization(folded_img3))
                    
                    rolly_imgs_output = model(self.normalization(folded_imgy))
                    rolly_imgs_output1 = model(self.normalization(folded_imgy1))
                    rolly_imgs_output2 = model(self.normalization(folded_imgy2))
                    rolly_imgs_output3 = model(self.normalization(folded_imgy3))
                    
                else:
                    adv_imgs_output = model(adv_imgs)
                
                adv_imgs_embeds = adv_imgs_output['image_feat']
                adv_imgs_block_embeds = adv_imgs_output['image_block_embed']
                adv_imgs_block_embeds_1 = adv_imgs_output['image_block_embed1']
            
                adv_imgs_embeds1 = adv_imgs_output1['image_feat']
                adv_imgs_block_embeds1 = adv_imgs_output1['image_block_embed']
                adv_imgs_block_embeds1_1 = adv_imgs_output1['image_block_embed1']
            
                adv_imgs_embeds2 = adv_imgs_output2['image_feat']
                adv_imgs_block_embeds2 = adv_imgs_output2['image_block_embed']
                adv_imgs_block_embeds2_1 = adv_imgs_output2['image_block_embed1']
            
                roll_imgs_embeds = roll_imgs_output['image_feat']
                roll_imgs_block_embeds = roll_imgs_output['image_block_embed']
                roll_imgs_block_embeds_1 = roll_imgs_output['image_block_embed1']
            
                roll_imgs_embeds1 = roll_imgs_output1['image_feat']
                roll_imgs_block_embeds1 = roll_imgs_output1['image_block_embed']
                roll_imgs_block_embeds1_1 = roll_imgs_output1['image_block_embed1']
            
                roll_imgs_embeds2 = roll_imgs_output2['image_feat']
                roll_imgs_block_embeds2 = roll_imgs_output2['image_block_embed']
                roll_imgs_block_embeds2_1 = roll_imgs_output2['image_block_embed1']
            
                roll_imgs_embeds3 = roll_imgs_output3['image_feat']
                roll_imgs_block_embeds3 = roll_imgs_output3['image_block_embed']
                roll_imgs_block_embeds3_1 = roll_imgs_output3['image_block_embed1']
                
                rolly_imgs_embeds = rolly_imgs_output['image_feat']
                rolly_imgs_block_embeds = rolly_imgs_output['image_block_embed']
                rolly_imgs_block_embeds_1 = rolly_imgs_output['image_block_embed1']
                
                rolly_imgs_embeds1 = rolly_imgs_output1['image_feat']
                rolly_imgs_block_embeds1 = rolly_imgs_output1['image_block_embed']
                rolly_imgs_block_embeds1_1 = rolly_imgs_output1['image_block_embed1']
            
                rolly_imgs_embeds2 = rolly_imgs_output2['image_feat']
                rolly_imgs_block_embeds2 = rolly_imgs_output2['image_block_embed']
                rolly_imgs_block_embeds2_1 = rolly_imgs_output2['image_block_embed1']
            
                rolly_imgs_embeds3 = rolly_imgs_output3['image_feat']
                rolly_imgs_block_embeds3 = rolly_imgs_output3['image_block_embed']
                rolly_imgs_block_embeds3_1 = rolly_imgs_output3['image_block_embed1']
                
                
                f = []
                f_adv = []
                f_adv1 = []
                f_adv2 = []
                f_adv3 = []
                f_adv4 = []
                
                
                f_1 = []
                f_adv_1 = []
                f_adv1_1 = []
                f_adv2_1 = []
                
                f_r_adv = []
                f_r_adv1 = []
                f_r_adv2 = []
                f_r_adv3 = []
                
                f_r_adv_1 = []
                f_r_adv1_1 = []
                f_r_adv2_1 = []
                f_r_adv3_1 = []
                
                f_ry_adv = []
                f_ry_adv1 = []
                f_ry_adv2 = []
                f_ry_adv3 = []
                
                f_ry_adv_1 = []
                f_ry_adv1_1 = []
                f_ry_adv2_1 = []
                f_ry_adv3_1 = []
                
                for i, (feats, feats_adv, feats_adv1, feats_adv2) in enumerate(zip(imgs_block_embeds, adv_imgs_block_embeds, adv_imgs_block_embeds1, adv_imgs_block_embeds2)):
                    f.append(feats)
                    f_adv.append(feats_adv)
                    f_adv1.append(feats_adv1)
                    f_adv2.append(feats_adv2)
                                 
                for i, (feats_1, feats_adv_1, feats_adv1_1, feats_adv2_1) in enumerate(zip(imgs_block_embeds_1, adv_imgs_block_embeds_1, adv_imgs_block_embeds1_1, adv_imgs_block_embeds2_1)):
                    f_1.append(feats_1)
                    f_adv_1.append(feats_adv_1)
                    f_adv1_1.append(feats_adv1_1)
                    f_adv2_1.append(feats_adv2_1)
                    
                for i, (feats_r_adv, feats_r_adv1, feats_r_adv2, feats_r_adv3) in enumerate(zip(roll_imgs_block_embeds, roll_imgs_block_embeds1, roll_imgs_block_embeds2, roll_imgs_block_embeds3)):

                    f_r_adv.append(feats_r_adv)
                    f_r_adv1.append(feats_r_adv1)
                    f_r_adv2.append(feats_r_adv2)
                    f_r_adv3.append(feats_r_adv3)
                
                for i, (feats_r_adv_1, feats_r_adv1_1, feats_r_adv2_1, feats_r_adv3_1) in enumerate(zip(roll_imgs_block_embeds_1, roll_imgs_block_embeds1_1, roll_imgs_block_embeds2_1, roll_imgs_block_embeds3_1)):

                    f_r_adv_1.append(feats_r_adv_1)
                    f_r_adv1_1.append(feats_r_adv1_1)
                    f_r_adv2_1.append(feats_r_adv2_1)
                    f_r_adv3_1.append(feats_r_adv3_1)
                
                for i, (feats_ry_adv, feats_ry_adv1, feats_ry_adv2, feats_ry_adv3) in enumerate(zip(rolly_imgs_block_embeds, rolly_imgs_block_embeds1, rolly_imgs_block_embeds2, rolly_imgs_block_embeds3)):

                    f_ry_adv.append(feats_ry_adv)
                    f_ry_adv1.append(feats_ry_adv1)
                    f_ry_adv2.append(feats_ry_adv2)
                    f_ry_adv3.append(feats_ry_adv3)
                
                for i, (feats_ry_adv_1, feats_ry_adv1_1, feats_ry_adv2_1, feats_ry_adv3_1) in enumerate(zip(rolly_imgs_block_embeds_1, rolly_imgs_block_embeds1_1, rolly_imgs_block_embeds2_1, rolly_imgs_block_embeds3_1)):

                    f_ry_adv_1.append(feats_ry_adv_1)
                    f_ry_adv1_1.append(feats_ry_adv1_1)
                    f_ry_adv2_1.append(feats_ry_adv2_1)
                    f_ry_adv3_1.append(feats_ry_adv3_1)
                
                if args.source_image_encoder == "ViT-B/16" or args.source_image_encoder == "ViT-L/14" or args.source_image_encoder == "ViT-B/32" :
                    f=torch.cat(f,dim=0).detach()
                    f_adv = torch.cat(f_adv, dim=0)
                    f_adv1 = torch.cat(f_adv1, dim=0)
                    f_adv2 = torch.cat(f_adv2, dim=0)
                    
                    f_r_adv  = torch.cat(f_r_adv, dim=0)
                    f_r_adv1 = torch.cat(f_r_adv1,dim=0)
                    f_r_adv2 = torch.cat(f_r_adv2, dim=0)
                    f_r_adv3 = torch.cat(f_r_adv3, dim=0)

                    f_ry_adv = torch.cat(f_ry_adv,dim=0)
                    f_ry_adv1 = torch.cat(f_ry_adv1,dim=0)
                    f_ry_adv2 = torch.cat(f_ry_adv2, dim=0)
                    f_ry_adv3 = torch.cat(f_ry_adv3, dim=0)
                
                    f_1=torch.cat(f_1,dim=0).detach()
                    f_adv_1 = torch.cat(f_adv_1, dim=0)
                    f_adv1_1 = torch.cat(f_adv1_1, dim=0)
                    f_adv2_1 = torch.cat(f_adv2_1, dim=0)
            
                    f_r_adv_1  = torch.cat(f_r_adv_1, dim=0)
                    f_r_adv1_1 = torch.cat(f_r_adv1_1,dim=0)
                    f_r_adv2_1 = torch.cat(f_r_adv2_1, dim=0)
                    f_r_adv3_1 = torch.cat(f_r_adv3_1, dim=0)

                    f_ry_adv_1 = torch.cat(f_ry_adv_1,dim=0)
                    f_ry_adv1_1 = torch.cat(f_ry_adv1_1,dim=0)
                    f_ry_adv2_1 = torch.cat(f_ry_adv2_1, dim=0)
                    f_ry_adv3_1 = torch.cat(f_ry_adv3_1, dim=0)
                
                
                model.zero_grad()
                with torch.enable_grad():
                
                    if args.source_image_encoder == "ViT-B/16" or args.source_image_encoder == "ViT-L/14" or args.source_image_encoder == "ViT-B/32":                   
                    
                        cos = nn.CosineSimilarity(dim=1, eps=1e-6)
                        cos1 = nn.CosineSimilarity(dim=1, eps=1e-6)

                        loss1 = (torch.sum(-cos1(f, f_adv))/f.shape[1]) + (torch.sum(-cos1(f, f_adv1))/f.shape[1]) + (torch.sum(-cos1(f, f_adv2))/f.shape[1]) #+ (torch.sum(-cos1(f, f_adv3))/f.shape[1]) + (torch.sum(-cos1(f, f_adv4))/f.shape[1])
                        loss2 = (torch.sum(-cos1(f, f_r_adv))/f.shape[1])  + (torch.sum(-cos1(f, f_r_adv1))/f.shape[1]) + (torch.sum(-cos1(f, f_r_adv2))/f.shape[1])+ (torch.sum(-cos1(f, f_r_adv3))/f.shape[1]) 
                        loss3 = (torch.sum(-cos1(f, f_ry_adv))/f.shape[1]) + (torch.sum(-cos1(f, f_ry_adv1))/f.shape[1]) + (torch.sum(-cos1(f, f_ry_adv2))/f.shape[1])+ (torch.sum(-cos1(f, f_ry_adv3))/f.shape[1]) 
                        
                        loss1_1 = (torch.sum(-cos1(f_1, f_adv_1))/f_1.shape[1]) + (torch.sum(-cos1(f_1, f_adv1_1))/f_1.shape[1]) + (torch.sum(-cos1(f_1, f_adv2_1))/f_1.shape[1])
                        loss2_1 = (torch.sum(-cos1(f_1, f_r_adv_1))/f_1.shape[1])  + (torch.sum(-cos1(f_1, f_r_adv1_1))/f_1.shape[1]) + (torch.sum(-cos1(f_1, f_r_adv2_1))/f_1.shape[1])+ (torch.sum(-cos1(f_1, f_r_adv3_1))/f_1.shape[1]) 
                        loss3_1 = (torch.sum(-cos1(f_1, f_ry_adv_1))/f_1.shape[1]) + (torch.sum(-cos1(f_1, f_ry_adv1_1))/f_1.shape[1]) + (torch.sum(-cos1(f_1, f_ry_adv2_1))/f_1.shape[1])+ (torch.sum(-cos1(f_1, f_ry_adv3_1))/f_1.shape[1])
                        
                        loss_ii = (torch.sum(-cos1(imgs_embeds, adv_imgs_embeds))/imgs_embeds.shape[0]) + (torch.sum(-cos1(imgs_embeds, adv_imgs_embeds1))/imgs_embeds.shape[0])+ (torch.sum(-cos1(imgs_embeds, adv_imgs_embeds2))/imgs_embeds.shape[0])
                        loss_ii1 = (torch.sum(-cos1(imgs_embeds, roll_imgs_embeds))/imgs_embeds.shape[0]) + (torch.sum(-cos1(imgs_embeds, roll_imgs_embeds1))/imgs_embeds.shape[0])+ (torch.sum(-cos1(imgs_embeds, roll_imgs_embeds2))/imgs_embeds.shape[0])+ (torch.sum(-cos1(imgs_embeds, roll_imgs_embeds3))/imgs_embeds.shape[0])
                        loss_ii2 = (torch.sum(-cos1(imgs_embeds, rolly_imgs_embeds))/imgs_embeds.shape[0]) + (torch.sum(-cos1(imgs_embeds, rolly_imgs_embeds1))/imgs_embeds.shape[0])+ (torch.sum(-cos1(imgs_embeds, rolly_imgs_embeds2))/imgs_embeds.shape[0])+ (torch.sum(-cos1(imgs_embeds, rolly_imgs_embeds3))/imgs_embeds.shape[0])
                        loss_it = self.loss_func(adv_imgs_embeds, txt_embeds, txt2img) + self.loss_func(adv_imgs_embeds1, txt_embeds, txt2img) + self.loss_func(adv_imgs_embeds2, txt_embeds, txt2img)
                        loss_it1 = self.loss_func(roll_imgs_embeds, txt_embeds, txt2img) + self.loss_func(roll_imgs_embeds1, txt_embeds, txt2img) + self.loss_func(roll_imgs_embeds2, txt_embeds, txt2img) + self.loss_func(roll_imgs_embeds3, txt_embeds, txt2img)
                        loss_it2 = self.loss_func(rolly_imgs_embeds, txt_embeds, txt2img) + self.loss_func(rolly_imgs_embeds1, txt_embeds, txt2img) + self.loss_func(rolly_imgs_embeds2, txt_embeds, txt2img) + self.loss_func(rolly_imgs_embeds3, txt_embeds, txt2img)
                        
                        loss= - loss1 - loss2 - loss3 - loss1_1 - loss2_1- loss3_1 - loss_ii - loss_ii1- loss_ii2 - loss_it - loss_it1 - loss_it2
            
                grad = torch.autograd.grad(loss, [delta])[0]
                with torch.no_grad():
                    delta = step.step(delta, grad)
                    delta = step.project(delta)
        
        print(delta.max())
        final_imgs = imgs.detach() + delta
        final_imgs = torch.clamp(final_imgs, 0.0, 1.0)
        
        return final_imgs

filter_words = ['a', 'about', 'above', 'across', 'after', 'afterwards', 'again', 'against', 'ain', 'all', 'almost',
                'alone', 'along', 'already', 'also', 'although', 'am', 'among', 'amongst', 'an', 'and', 'another',
                'any', 'anyhow', 'anyone', 'anything', 'anyway', 'anywhere', 'are', 'aren', "aren't", 'around', 'as',
                'at', 'back', 'been', 'before', 'beforehand', 'behind', 'being', 'below', 'beside', 'besides',
                'between', 'beyond', 'both', 'but', 'by', 'can', 'cannot', 'could', 'couldn', "couldn't", 'd', 'didn',
                "didn't", 'doesn', "doesn't", 'don', "don't", 'down', 'due', 'during', 'either', 'else', 'elsewhere',
                'empty', 'enough', 'even', 'ever', 'everyone', 'everything', 'everywhere', 'except', 'first', 'for',
                'former', 'formerly', 'from', 'hadn', "hadn't", 'hasn', "hasn't", 'haven', "haven't", 'he', 'hence',
                'her', 'here', 'hereafter', 'hereby', 'herein', 'hereupon', 'hers', 'herself', 'him', 'himself', 'his',
                'how', 'however', 'hundred', 'i', 'if', 'in', 'indeed', 'into', 'is', 'isn', "isn't", 'it', "it's",
                'its', 'itself', 'just', 'latter', 'latterly', 'least', 'll', 'may', 'me', 'meanwhile', 'mightn',
                "mightn't", 'mine', 'more', 'moreover', 'most', 'mostly', 'must', 'mustn', "mustn't", 'my', 'myself',
                'namely', 'needn', "needn't", 'neither', 'never', 'nevertheless', 'next', 'no', 'nobody', 'none',
                'noone', 'nor', 'not', 'nothing', 'now', 'nowhere', 'o', 'of', 'off', 'on', 'once', 'one', 'only',
                'onto', 'or', 'other', 'others', 'otherwise', 'our', 'ours', 'ourselves', 'out', 'over', 'per',
                'please', 's', 'same', 'shan', "shan't", 'she', "she's", "should've", 'shouldn', "shouldn't", 'somehow',
                'something', 'sometime', 'somewhere', 'such', 't', 'than', 'that', "that'll", 'the', 'their', 'theirs',
                'them', 'themselves', 'then', 'thence', 'there', 'thereafter', 'thereby', 'therefore', 'therein',
                'thereupon', 'these', 'they', 'this', 'those', 'through', 'throughout', 'thru', 'thus', 'to', 'too',
                'toward', 'towards', 'under', 'unless', 'until', 'up', 'upon', 'used', 've', 'was', 'wasn', "wasn't",
                'we', 'were', 'weren', "weren't", 'what', 'whatever', 'when', 'whence', 'whenever', 'where',
                'whereafter', 'whereas', 'whereby', 'wherein', 'whereupon', 'wherever', 'whether', 'which', 'while',
                'whither', 'who', 'whoever', 'whole', 'whom', 'whose', 'why', 'with', 'within', 'without', 'won',
                "won't", 'would', 'wouldn', "wouldn't", 'y', 'yet', 'you', "you'd", "you'll", "you're", "you've",
                'your', 'yours', 'yourself', 'yourselves', '.', '-', 'a the', '/', '?', 'some', '"', ',', 'b', '&', '!',
                '@', '%', '^', '*', '(', ')', "-", '-', '+', '=', '<', '>', '|', ':', ";", '～', '·']
filter_words = set(filter_words)
    

class TextAttacker():
    def __init__(self, ref_net, tokenizer, device, cls=True, max_length=30, number_perturbation=1, topk=10, threshold_pred_score=0.3, batch_size=32):
        self.ref_net = ref_net
        self.tokenizer = tokenizer
        self.max_length = max_length
        # epsilon_txt
        self.num_perturbation = number_perturbation
        self.threshold_pred_score = threshold_pred_score
        self.topk = topk
        self.batch_size = batch_size
        self.cls = cls
        self.device = device

    def img_guided_attack(self, net, texts, img_embeds = None):
        #device = self.ref_net.module.device
        
        #net=nn.DataParallel(net.to(self.device))
        
        #text_inputs = self.tokenizer(texts, padding='max_length', truncation=True, max_length=self.max_length, return_tensors='pt').to(self.device)
        text_inputs = self.tokenizer(texts, padding='max_length', truncation=True, max_length=self.max_length, return_tensors='pt')
        text_inputs.input_ids = text_inputs.input_ids.to(self.device)
        text_inputs.attention_mask = text_inputs.attention_mask.to(self.device)
        text_inputs.token_type_ids = text_inputs.token_type_ids.to(self.device)
        
        # substitutes
        mlm_logits = self.ref_net.module(text_inputs.input_ids, attention_mask=text_inputs.attention_mask).logits
        word_pred_scores_all, word_predictions = torch.topk(mlm_logits, self.topk, -1)  # seq-len k

        # original state
        origin_output = net(text_inputs.input_ids, text_inputs.attention_mask)
        #origin_output = net(text_inputs)
        
        if self.cls:
            origin_embeds = origin_output['text_feat'][:, 0, :].detach()
        else:
            origin_embeds = origin_output['text_feat'].flatten(1).detach()

        final_adverse = []
        
        for i, text in enumerate(texts):
            # word importance eval
            important_scores = self.get_important_scores(text, net, origin_embeds[i], self.batch_size, self.max_length)

            list_of_index = sorted(enumerate(important_scores), key=lambda x: x[1], reverse=True)

            words, sub_words, keys = self._tokenize(text)
            final_words = copy.deepcopy(words)
            change = 0

            for top_index in list_of_index:
                if change >= self.num_perturbation:
                    break

                tgt_word = words[top_index[0]]
                if tgt_word in filter_words:
                    continue
                if keys[top_index[0]][0] > self.max_length - 2:
                    continue

                substitutes = word_predictions[i, keys[top_index[0]][0]:keys[top_index[0]][1]]  # L, k
                word_pred_scores = word_pred_scores_all[i, keys[top_index[0]][0]:keys[top_index[0]][1]]

                substitutes = get_substitues(substitutes, self.tokenizer, self.ref_net, 1, word_pred_scores,
                                             self.threshold_pred_score)


                replace_texts = [' '.join(final_words)]
                available_substitutes = [tgt_word]
                for substitute_ in substitutes:
                    substitute = substitute_

                    if substitute == tgt_word:
                        continue  # filter out original word
                    if '##' in substitute:
                        continue  # filter out sub-word

                    if substitute in filter_words:
                        continue
                    '''
                    # filter out atonyms
                    if substitute in w2i and tgt_word in w2i:
                        if cos_mat[w2i[substitute]][w2i[tgt_word]] < 0.4:
                            continue
                    '''
                    temp_replace = copy.deepcopy(final_words)
                    temp_replace[top_index[0]] = substitute
                    available_substitutes.append(substitute)
                    replace_texts.append(' '.join(temp_replace))
                replace_text_input = self.tokenizer(replace_texts, padding='max_length', truncation=True, max_length=self.max_length, return_tensors='pt')#.to(self.device)
                #print(replace_text_input)
                replace_text_input.input_ids = replace_text_input.input_ids.to(self.device)
                replace_text_input.attention_mask = replace_text_input.attention_mask.to(self.device)
                replace_text_input.token_type_ids = replace_text_input.token_type_ids.to(self.device)
                
                replace_output = net(replace_text_input.input_ids, replace_text_input.attention_mask)
                #print(len(replace_output))
                if self.cls:
                    replace_embeds = replace_output['text_feat'][:, 0, :]
                else:
                    replace_embeds = replace_output['text_feat'].flatten(1)

                loss = self.loss_func(replace_embeds, img_embeds, i)
                candidate_idx = loss.argmax()

                final_words[top_index[0]] = available_substitutes[candidate_idx]

                if available_substitutes[candidate_idx] != tgt_word:
                    change += 1

            final_adverse.append(' '.join(final_words))

        return final_adverse

    def img_guided_attack2(self, net, texts, img_embeds = None):
        #device = self.ref_net.module.device
        
        #net=nn.DataParallel(net.to(self.device))
        
        #text_inputs = self.tokenizer(texts, padding='max_length', truncation=True, max_length=self.max_length, return_tensors='pt').to(self.device)
        text_inputs = self.tokenizer(texts, padding='max_length', truncation=True, max_length=self.max_length, return_tensors='pt')
        text_inputs.input_ids = text_inputs.input_ids.to(self.device)
        text_inputs.attention_mask = text_inputs.attention_mask.to(self.device)
        text_inputs.token_type_ids = text_inputs.token_type_ids.to(self.device)
        
        # substitutes
        mlm_logits = self.ref_net.module(text_inputs.input_ids, attention_mask=text_inputs.attention_mask).logits
        word_pred_scores_all, word_predictions = torch.topk(mlm_logits, self.topk, -1)  # seq-len k

        # original state
        origin_output = net(text_inputs.input_ids, text_inputs.attention_mask)
        #origin_output = net(text_inputs)
        
        if self.cls:
            origin_embeds = origin_output['text_feat'][:, 0, :].detach()
        else:
            origin_embeds = origin_output['text_feat'].flatten(1).detach()

        final_adverse = []
        criterion = torch.nn.KLDivLoss(reduction='none')        
        cos1 = nn.CosineSimilarity(dim=1, eps=1e-6)

        for i, text in enumerate(texts):
            # word importance eval
            important_scores = self.get_important_scores(text, net, origin_embeds[i], self.batch_size, self.max_length)

            list_of_index = sorted(enumerate(important_scores), key=lambda x: x[1], reverse=True)

            words, sub_words, keys = self._tokenize(text)
            final_words = copy.deepcopy(words)
            change = 0

            for top_index in list_of_index:
                if change >= self.num_perturbation:
                    break

                tgt_word = words[top_index[0]]
                if tgt_word in filter_words:
                    continue
                if keys[top_index[0]][0] > self.max_length - 2:
                    continue

                substitutes = word_predictions[i, keys[top_index[0]][0]:keys[top_index[0]][1]]  # L, k
                word_pred_scores = word_pred_scores_all[i, keys[top_index[0]][0]:keys[top_index[0]][1]]

                substitutes = get_substitues(substitutes, self.tokenizer, self.ref_net, 1, word_pred_scores,
                                             self.threshold_pred_score)


                replace_texts = [' '.join(final_words)]
                available_substitutes = [tgt_word]
                for substitute_ in substitutes:
                    substitute = substitute_

                    if substitute == tgt_word:
                        continue  # filter out original word
                    if '##' in substitute:
                        continue  # filter out sub-word

                    if substitute in filter_words:
                        continue
                    '''
                    # filter out atonyms
                    if substitute in w2i and tgt_word in w2i:
                        if cos_mat[w2i[substitute]][w2i[tgt_word]] < 0.4:
                            continue
                    '''
                    temp_replace = copy.deepcopy(final_words)
                    temp_replace[top_index[0]] = substitute
                    available_substitutes.append(substitute)
                    replace_texts.append(' '.join(temp_replace))
                replace_text_input = self.tokenizer(replace_texts, padding='max_length', truncation=True, max_length=self.max_length, return_tensors='pt')#.to(self.device)
                #print(replace_text_input)
                replace_text_input.input_ids = replace_text_input.input_ids.to(self.device)
                replace_text_input.attention_mask = replace_text_input.attention_mask.to(self.device)
                replace_text_input.token_type_ids = replace_text_input.token_type_ids.to(self.device)
                
                replace_output = net(replace_text_input.input_ids, replace_text_input.attention_mask)
                #print(len(replace_output))
                if self.cls:
                    replace_embeds = replace_output['text_feat'][:, 0, :]
                else:
                    replace_embeds = replace_output['text_feat'].flatten(1)
                
                
                #ori_embed = origin_embeds[i].repeat(len(replace_embeds), 1).detach()
                #print(ori_embed.shape)
                #print(replace_embeds.shape)
                #print(origin_embeds[i].detach().repeat(len(replace_embeds), 1).shape)
                
                #loss = criterion(replace_embeds.log_softmax(dim=-1), origin_embeds[i].softmax(dim=-1).repeat(len(replace_embeds), 1))
                #loss = loss.sum(dim=-1)
                
                
                loss = -replace_embeds.mul(origin_embeds[i].repeat(len(replace_embeds), 1)).sum(-1) -replace_embeds.mul(img_embeds[i].repeat(len(replace_embeds), 1)).sum(-1) 
                
                #loss = (torch.sum(-cos1(ori_embed, replace_embeds))/len(replace_embeds)) 
                #loss = -loss
                #loss = self.loss_func(replace_embeds, img_embeds, i)
                candidate_idx = loss.argmax()

                final_words[top_index[0]] = available_substitutes[candidate_idx]

                if available_substitutes[candidate_idx] != tgt_word:
                    change += 1

            final_adverse.append(' '.join(final_words))

        return final_adverse

    def loss_func(self, txt_embeds, img_embeds, label):
        loss_TaIcpos = -txt_embeds.mul(img_embeds[label].repeat(len(txt_embeds), 1)).sum(-1) 
        loss = loss_TaIcpos
        return loss


    def img_guided_attack1(self, net, texts, img_embeds = None):
        device = self.ref_net.device

        text_inputs = self.tokenizer(texts, padding='max_length', truncation=True, max_length=self.max_length, return_tensors='pt').to(device)

        # substitutes
        mlm_logits = self.ref_net(text_inputs.input_ids, attention_mask=text_inputs.attention_mask).logits
        word_pred_scores_all, word_predictions = torch.topk(mlm_logits, self.topk, -1)  # seq-len k

        # original state
        origin_output = net.inference_text(text_inputs)
        if self.cls:
            origin_embeds = origin_output['text_feat'][:, 0, :].detach()
        else:
            origin_embeds = origin_output['text_feat'].flatten(1).detach()

        final_adverse = []
        for i, text in enumerate(texts):
            # word importance eval
            important_scores = self.get_important_scores1(text, net, origin_embeds[i], self.batch_size, self.max_length)

            list_of_index = sorted(enumerate(important_scores), key=lambda x: x[1], reverse=True)

            words, sub_words, keys = self._tokenize(text)
            final_words = copy.deepcopy(words)
            change = 0

            for top_index in list_of_index:
                if change >= self.num_perturbation:
                    break

                tgt_word = words[top_index[0]]
                if tgt_word in filter_words:
                    continue
                if keys[top_index[0]][0] > self.max_length - 2:
                    continue

                substitutes = word_predictions[i, keys[top_index[0]][0]:keys[top_index[0]][1]]  # L, k
                word_pred_scores = word_pred_scores_all[i, keys[top_index[0]][0]:keys[top_index[0]][1]]

                substitutes = get_substitues(substitutes, self.tokenizer, self.ref_net, 1, word_pred_scores,
                                             self.threshold_pred_score)


                replace_texts = [' '.join(final_words)]
                available_substitutes = [tgt_word]
                for substitute_ in substitutes:
                    substitute = substitute_

                    if substitute == tgt_word:
                        continue  # filter out original word
                    if '##' in substitute:
                        continue  # filter out sub-word

                    if substitute in filter_words:
                        continue
                    '''
                    # filter out atonyms
                    if substitute in w2i and tgt_word in w2i:
                        if cos_mat[w2i[substitute]][w2i[tgt_word]] < 0.4:
                            continue
                    '''
                    temp_replace = copy.deepcopy(final_words)
                    temp_replace[top_index[0]] = substitute
                    available_substitutes.append(substitute)
                    replace_texts.append(' '.join(temp_replace))
                replace_text_input = self.tokenizer(replace_texts, padding='max_length', truncation=True, max_length=self.max_length, return_tensors='pt').to(device)
                replace_output = net.inference_text(replace_text_input)
                if self.cls:
                    replace_embeds = replace_output['text_feat'][:, 0, :]
                else:
                    replace_embeds = replace_output['text_feat'].flatten(1)

                loss = self.loss_func(replace_embeds, img_embeds, i)
                candidate_idx = loss.argmax()

                final_words[top_index[0]] = available_substitutes[candidate_idx]

                if available_substitutes[candidate_idx] != tgt_word:
                    change += 1

            final_adverse.append(' '.join(final_words))

        return final_adverse


    def get_important_scores1(self, text, net, origin_embeds, batch_size, max_length):
        device = origin_embeds.device

        masked_words = self._get_masked(text)
        masked_texts = [' '.join(words) for words in masked_words]  # list of text of masked words

        masked_embeds = []
        for i in range(0, len(masked_texts), batch_size):
            masked_text_input = self.tokenizer(masked_texts[i:i+batch_size], padding='max_length', truncation=True, max_length=max_length, return_tensors='pt').to(device)
            masked_output = net.inference_text(masked_text_input)
            if self.cls:
                masked_embed = masked_output['text_feat'][:, 0, :].detach()
            else:
                masked_embed = masked_output['text_feat'].flatten(1).detach()
            masked_embeds.append(masked_embed)
        masked_embeds = torch.cat(masked_embeds, dim=0)

        criterion = torch.nn.KLDivLoss(reduction='none')

        import_scores = criterion(masked_embeds.log_softmax(dim=-1), origin_embeds.softmax(dim=-1).repeat(len(masked_texts), 1))

        return import_scores.sum(dim=-1)


    def attack(self, net, texts):
        device = self.ref_net.module.device

        #text_inputs = self.tokenizer(texts, padding='max_length', truncation=True, max_length=self.max_length, return_tensors='pt').to(device)

        text_inputs = self.tokenizer(texts, padding='max_length', truncation=True, max_length=self.max_length, return_tensors='pt')
        text_inputs.input_ids = text_inputs.input_ids.to(self.device)
        text_inputs.attention_mask = text_inputs.attention_mask.to(self.device)
        text_inputs.token_type_ids = text_inputs.token_type_ids.to(self.device)

        # substitutes
        mlm_logits = self.ref_net.module(text_inputs.input_ids, attention_mask=text_inputs.attention_mask).logits
        word_pred_scores_all, word_predictions = torch.topk(mlm_logits, self.topk, -1)  # seq-len k

        # original state
        origin_output = net(text_inputs.input_ids)
        if self.cls:
            origin_embeds = origin_output['text_embed'][:, 0, :].detach()
        else:
            origin_embeds = origin_output['text_embed'].flatten(1).detach()



        criterion = torch.nn.KLDivLoss(reduction='none')
        final_adverse = []
        for i, text in enumerate(texts):
            # word importance eval
            important_scores = self.get_important_scores(text, net, origin_embeds[i], self.batch_size, self.max_length)

            list_of_index = sorted(enumerate(important_scores), key=lambda x: x[1], reverse=True)

            words, sub_words, keys = self._tokenize(text)
            final_words = copy.deepcopy(words)
            change = 0

            for top_index in list_of_index:
                if change >= self.num_perturbation:
                    break

                tgt_word = words[top_index[0]]
                if tgt_word in filter_words:
                    continue
                if keys[top_index[0]][0] > self.max_length - 2:
                    continue

                substitutes = word_predictions[i, keys[top_index[0]][0]:keys[top_index[0]][1]]  # L, k
                word_pred_scores = word_pred_scores_all[i, keys[top_index[0]][0]:keys[top_index[0]][1]]

                substitutes = get_substitues(substitutes, self.tokenizer, self.ref_net, 1, word_pred_scores,
                                             self.threshold_pred_score)


                replace_texts = [' '.join(final_words)]
                available_substitutes = [tgt_word]
                for substitute_ in substitutes:
                    substitute = substitute_

                    if substitute == tgt_word:
                        continue  # filter out original word
                    if '##' in substitute:
                        continue  # filter out sub-word

                    if substitute in filter_words:
                        continue
                    '''
                    # filter out atonyms
                    if substitute in w2i and tgt_word in w2i:
                        if cos_mat[w2i[substitute]][w2i[tgt_word]] < 0.4:
                            continue
                    '''
                    temp_replace = copy.deepcopy(final_words)
                    temp_replace[top_index[0]] = substitute
                    available_substitutes.append(substitute)
                    replace_texts.append(' '.join(temp_replace))
                replace_text_input = self.tokenizer(replace_texts, padding='max_length', truncation=True, max_length=self.max_length, return_tensors='pt').to(self.device)
                replace_output = net(replace_text_input)
                if self.cls:
                    replace_embeds = replace_output['text_embed'][:, 0, :]
                else:
                    replace_embeds = replace_output['text_embed'].flatten(1)

                loss = criterion(replace_embeds.log_softmax(dim=-1), origin_embeds[i].softmax(dim=-1).repeat(len(replace_embeds), 1))
                
                loss = loss.sum(dim=-1)
                candidate_idx = loss.argmax()

                final_words[top_index[0]] = available_substitutes[candidate_idx]

                if available_substitutes[candidate_idx] != tgt_word:
                    change += 1

            final_adverse.append(' '.join(final_words))

        return final_adverse

 
    def _tokenize(self, text):
        words = text.split(' ')

        sub_words = []
        keys = []
        index = 0
        for word in words:
            sub = self.tokenizer.tokenize(word)
            sub_words += sub
            keys.append([index, index + len(sub)])
            index += len(sub)

        return words, sub_words, keys

    def _get_masked(self, text):
        words = text.split(' ')
        len_text = len(words)
        masked_words = []
        for i in range(len_text):
            masked_words.append(words[0:i] + ['[UNK]'] + words[i + 1:])
        # list of words
        return masked_words

    def get_important_scores(self, text, net, origin_embeds, batch_size, max_length):
        device = origin_embeds.device
        
        masked_words = self._get_masked(text)
        masked_texts = [' '.join(words) for words in masked_words]  # list of text of masked words

        #print(text)
        #print(masked_texts)
        masked_embeds = []
        for i in range(0, len(masked_texts), batch_size):
            #print(i)
            #print(batch_size)
            #print('-------------------------------------------')
            masked_text_input = self.tokenizer(masked_texts[i:i+batch_size], padding='max_length', truncation=True, max_length=max_length, return_tensors='pt')#.to(device)
            #print(masked_text_input)
            masked_text_input.input_ids = masked_text_input.input_ids.to(self.device)
            masked_text_input.attention_mask = masked_text_input.attention_mask.to(self.device)
            masked_output = net(masked_text_input.input_ids, masked_text_input.attention_mask)
            #print(masked_output['text_feat'].shape[0])
            #print(masked_output['text_feat'][0])
            #print(masked_output['text_feat'][1])
            #print(masked_output['text_feat'][12])
            if self.cls:
                masked_embed = masked_output['text_feat'][:, 0, :].detach()
            else:
                masked_embed = masked_output['text_feat'].flatten(1).detach()
            masked_embeds.append(masked_embed)
        masked_embeds = torch.cat(masked_embeds, dim=0)

        criterion = torch.nn.KLDivLoss(reduction='none')

        import_scores = criterion(masked_embeds.log_softmax(dim=-1), origin_embeds.softmax(dim=-1).repeat(len(masked_texts), 1))

        return import_scores.sum(dim=-1)



def get_substitues(substitutes, tokenizer, mlm_model, use_bpe, substitutes_score=None, threshold=3.0):
    # substitues L,k
    # from this matrix to recover a word
    words = []
    sub_len, k = substitutes.size()  # sub-len, k

    if sub_len == 0:
        return words

    elif sub_len == 1:
        for (i, j) in zip(substitutes[0], substitutes_score[0]):
            if threshold != 0 and j < threshold:
                break
            words.append(tokenizer._convert_id_to_token(int(i)))
    else:
        if use_bpe == 1:
            words = get_bpe_substitues(substitutes, tokenizer, mlm_model)
        else:
            return words
    #
    # print(words)
    return words


def get_bpe_substitues(substitutes, tokenizer, mlm_model):
    # substitutes L, k
    try:
        device = mlm_model.module.device
    except:
        device = mlm_model.device
    #device = mlm_model.module.device
    #device = mlm_model.device
    substitutes = substitutes[0:12, 0:4]  # maximum BPE candidates

    # find all possible candidates

    all_substitutes = []
    for i in range(substitutes.size(0)):
        if len(all_substitutes) == 0:
            lev_i = substitutes[i]
            all_substitutes = [[int(c)] for c in lev_i]
        else:
            lev_i = []
            for all_sub in all_substitutes:
                for j in substitutes[i]:
                    lev_i.append(all_sub + [int(j)])
            all_substitutes = lev_i

    # all substitutes  list of list of token-id (all candidates)
    c_loss = nn.CrossEntropyLoss(reduction='none')
    word_list = []
    # all_substitutes = all_substitutes[:24]
    all_substitutes = torch.tensor(all_substitutes)  # [ N, L ]
    all_substitutes = all_substitutes[:24].to(device)
    # print(substitutes.size(), all_substitutes.size())
    N, L = all_substitutes.size()
    try:
        word_predictions = mlm_model(all_substitutes)[0]
    except:
        word_predictions = mlm_model.module(all_substitutes)[0]
    
    #word_predictions = mlm_model.module(all_substitutes)[0]
    #word_predictions = mlm_model(all_substitutes)[0]  # N L vocab-size
    ppl = c_loss(word_predictions.view(N * L, -1), all_substitutes.view(-1))  # [ N*L ]
    ppl = torch.exp(torch.mean(ppl.view(N, L), dim=-1))  # N
    _, word_list = torch.sort(ppl)
    word_list = [all_substitutes[i] for i in word_list]
    final_words = []
    for word in word_list:
        tokens = [tokenizer._convert_id_to_token(int(i)) for i in word]
        text = tokenizer.convert_tokens_to_string(tokens)
        final_words.append(text)
    return final_words
