import numpy as np 
import torch
import torch.nn as nn

import random
from random import shuffle
import copy
from torchvision import transforms
from PIL import Image
import torch.nn.functional as F

xent_loss = torch.nn.CrossEntropyLoss()

class MultiAttacker():
    def __init__(
        self, 
        model, 
        img_attacker,
        txt_attacker,
        tokenizer,
        ):
        self.model=model
        self.tokenizer=tokenizer

        self.txt_attacker = txt_attacker
        self.img_attacker = img_attacker

    
    def attack(self, imgs, txts, txt2img, device='cpu', max_length=30, 
               scales=None, 
               img_attack_loss="sim", 
               gt_caps_list=None,
               is_train=False,
               **kwargs):

        ######## Step 1: text attack ########
        try:
            adv_txts = self.txt_attacker.attack(self.model, txts, k=10)
        except:
            # EDA
            adv_txts = self.txt_attacker.attack(txts)

        ######## Step 2. (text-guided) img attack ########
        if self.img_attacker.alpha_sup > 0:
            with torch.no_grad():
                B = imgs.shape[0]
                txt_B = len(adv_txts)
                len_iter = txt_B // B
                txt_supervisions = None
                for i in range(len_iter):
                    txts_input = self.tokenizer(
                        adv_txts[i*B:i*B+B],
                        padding="max_length",
                        truncation=True,
                        max_length=max_length,
                        return_tensors="pt",
                    ).to(device)
                    this_txts_output = self.model.inference_text(txts_input)
                    if txt_supervisions is None:
                        txt_supervisions = this_txts_output["text_feat"]
                    else:
                        txt_supervisions = torch.cat(
                            [txt_supervisions, this_txts_output["text_feat"]], 0
                        )
        else:
              txt_supervisions = None
        adv_imgs = self.img_attacker.attack(
            self.model, imgs, txt2img, device, 
            scales=scales, txt_embeds=txt_supervisions, loss_metric=img_attack_loss)
        
        if is_train:
            assert len(adv_imgs) == len(adv_txts)
        return adv_imgs, adv_txts

                

class ImageAttacker():
    def __init__(self, normalization, eps=2/255, steps=10, step_size=0.5/255, 
                 alpha_unsup=0.0, alpha_sup=1.0):
        self.normalization = normalization
        self.eps = eps
        self.steps = steps 
        self.step_size = step_size 
        
        self.kl_loss = nn.KLDivLoss(reduction="batchmean")
        self.alpha_unsup = alpha_unsup
        self.alpha_sup = alpha_sup

        print("ImageAttacker config:")
        print(f"- eps: {self.eps}")
        print(f"- steps: {self.steps}")
        print(f"- step_size: {self.step_size}")
        print(f"- alpha_unsup: {self.alpha_unsup}")
        print(f"- alpha_sup: {self.alpha_sup}")

    def unsup_loss(self, adv_embeds, origin_embeds):
        loss = self.kl_loss(
            adv_embeds.log_softmax(dim=-1), origin_embeds.softmax(dim=-1)
        )
        return loss

    def sup_loss(self, model, adv_imgs_embeds, txts_embeds, txt2img, loss_metric="sim"):  
        """
        return supervised loss (text-guided loss)

        loss_metric:
            sim: maximize the distance between correct pairs
            clip: maximize the clip training loss
        """
        assert loss_metric in ["sim", "clip"], loss_metric
        
        device = adv_imgs_embeds.device    

        bi = adv_imgs_embeds.shape[0]
        bj = txts_embeds.shape[0]
        it_labels = torch.zeros(bi, bj).to(device)
        # print(len(txt2img), txt2img[-1])
        for i in range(len(txt2img)):
            it_labels[txt2img[i], i]=1
        
        if loss_metric == "clip":
            image_features = adv_imgs_embeds
            text_features = txts_embeds
            # normalized features
            image_features = image_features / image_features.norm(dim=-1, keepdim=True)
            text_features = text_features / text_features.norm(dim=-1, keepdim=True)

            # cosine similarity as logits
            logit_scale = model.logit_scale.exp()
            logits_per_image = logit_scale * image_features @ text_features.t()
            logits_per_text = logits_per_image.t()

            # loss
            loss = (xent_loss(logits_per_image, it_labels) + xent_loss(logits_per_text, it_labels)) / 2
            return loss, -(logits_per_image * it_labels)
        
        elif loss_metric == "sim":
            it_sim_matrix = adv_imgs_embeds @ txts_embeds.T
            # loss_IaTcpos = -(it_sim_matrix * it_labels).sum(-1).mean()
            # loss = loss_IaTcpos
            loss_matrix = -(it_sim_matrix * it_labels)
            loss = loss_matrix.sum(-1).mean()
        
        return loss
    
    def loss_func(self, model, origin_embeds, adv_imgs_embeds, txts_embeds, txt2img, loss_metric="sim"):
        """
        return unsupervised loss and supervised loss
        """
        # print("origin_embeds", origin_embeds.shape)
        # print("adv_imgs_embeds", adv_imgs_embeds.shape)
        # print("txts_embeds", txts_embeds.shape)
        loss = 0
        if self.alpha_unsup > 0:
            unsup_loss = self.unsup_loss(adv_imgs_embeds, origin_embeds)
            loss += self.alpha_unsup * unsup_loss
        if self.alpha_sup > 0:
            sup_loss = self.sup_loss(model, adv_imgs_embeds, txts_embeds, txt2img, loss_metric=loss_metric)
            loss += self.alpha_sup * sup_loss
        # unsup_loss = self.unsup_loss(adv_imgs_embeds, origin_embeds)
        # sup_loss = self.sup_loss(model, adv_imgs_embeds, txts_embeds, txt2img, loss_metric=loss_metric)
        # loss = self.alpha_unsup * unsup_loss + self.alpha_sup * sup_loss
        return loss
    

    def attack(self, model, imgs, txt2img, device, scales=None, txt_embeds=None, loss_metric="sim"):
        
        # model.eval()
       
        b, _, _, _ = imgs.shape

        # forward
        origin_embeds = None
        if self.alpha_unsup > 0:
            if self.normalization is not None:
                imgs_output = model.inference_image(self.normalization(imgs))
            else:
                imgs_output = model.inference_image(imgs)
            origin_embeds = imgs_output['image_feat'].detach() # detach!
        
        if scales is None:
            scales_num = 1
        else:
            scales_num = len(scales) +1

        adv_imgs = imgs.detach() + torch.from_numpy(np.random.uniform(-self.eps, self.eps, imgs.shape)).float().to(device)
        adv_imgs = torch.clamp(adv_imgs, 0.0, 1.0)

        for i in range(self.steps):
            adv_imgs.requires_grad_()
            scaled_imgs = self.get_scaled_imgs(adv_imgs, scales, device)        
        
            if self.normalization is not None:
                adv_imgs_output = model.inference_image(self.normalization(scaled_imgs))
            else:
                adv_imgs_output = model.inference_image(scaled_imgs)
                
            adv_imgs_embeds = adv_imgs_output['image_feat']
            model.zero_grad()
            with torch.enable_grad():
                loss_list = []
                loss = torch.tensor(0.0, dtype=torch.float32).to(device)
                for i in range(scales_num):
                    loss_item = self.loss_func(
                        model, origin_embeds, adv_imgs_embeds[i*b:i*b+b], txt_embeds, txt2img,
                        loss_metric=loss_metric
                    )
                    loss_list.append(loss_item.item())
                    loss += loss_item
            loss.backward()
            
            grad = adv_imgs.grad 
            grad = grad / torch.mean(torch.abs(grad), dim=(1,2,3), keepdim=True)           
            
            perturbation = self.step_size * grad.sign()
            adv_imgs = adv_imgs.detach() + perturbation
            adv_imgs = torch.min(torch.max(adv_imgs, imgs - self.eps), imgs + self.eps)
            adv_imgs = torch.clamp(adv_imgs, 0.0, 1.0)
        
        return adv_imgs


    def get_scaled_imgs(self, imgs, scales=None, device='cuda'):
        if scales is None:
            return imgs

        ori_shape = (imgs.shape[-2], imgs.shape[-1])
        
        reverse_transform = transforms.Resize(ori_shape,
                                interpolation=transforms.InterpolationMode.BICUBIC)
        result = []
        for ratio in scales:
            scale_shape = (int(ratio*ori_shape[0]), 
                                  int(ratio*ori_shape[1]))
            scale_transform = transforms.Resize(scale_shape,
                                  interpolation=transforms.InterpolationMode.BICUBIC)
            scaled_imgs = imgs + torch.from_numpy(np.random.normal(0.0, 0.05, imgs.shape)).float().to(device)
            scaled_imgs = scale_transform(scaled_imgs)
            scaled_imgs = torch.clamp(scaled_imgs, 0.0, 1.0)
            
            reversed_imgs = reverse_transform(scaled_imgs)
            
            result.append(reversed_imgs)
        
        return torch.cat([imgs,]+result, 0)



filter_words = ['a', 'about', 'above', 'across', 'after', 'afterwards', 'again', 'against', 'ain', 'all', 'almost',
                'alone', 'along', 'already', 'also', 'although', 'am', 'among', 'amongst', 'an', 'and', 'another',
                'any', 'anyhow', 'anyone', 'anything', 'anyway', 'anywhere', 'are', 'aren', "aren't", 'around', 'as',
                'at', 'back', 'been', 'before', 'beforehand', 'behind', 'being', 'below', 'beside', 'besides',
                'between', 'beyond', 'both', 'but', 'by', 'can', 'cannot', 'could', 'couldn', "couldn't", 'd', 'didn',
                "didn't", 'doesn', "doesn't", 'don', "don't", 'down', 'due', 'during', 'either', 'else', 'elsewhere',
                'empty', 'enough', 'even', 'ever', 'everyone', 'everything', 'everywhere', 'except', 'first', 'for',
                'former', 'formerly', 'from', 'hadn', "hadn't", 'hasn', "hasn't", 'haven', "haven't", 'he', 'hence',
                'her', 'here', 'hereafter', 'hereby', 'herein', 'hereupon', 'hers', 'herself', 'him', 'himself', 'his',
                'how', 'however', 'hundred', 'i', 'if', 'in', 'indeed', 'into', 'is', 'isn', "isn't", 'it', "it's",
                'its', 'itself', 'just', 'latter', 'latterly', 'least', 'll', 'may', 'me', 'meanwhile', 'mightn',
                "mightn't", 'mine', 'more', 'moreover', 'most', 'mostly', 'must', 'mustn', "mustn't", 'my', 'myself',
                'namely', 'needn', "needn't", 'neither', 'never', 'nevertheless', 'next', 'no', 'nobody', 'none',
                'noone', 'nor', 'not', 'nothing', 'now', 'nowhere', 'o', 'of', 'off', 'on', 'once', 'one', 'only',
                'onto', 'or', 'other', 'others', 'otherwise', 'our', 'ours', 'ourselves', 'out', 'over', 'per',
                'please', 's', 'same', 'shan', "shan't", 'she', "she's", "should've", 'shouldn', "shouldn't", 'somehow',
                'something', 'sometime', 'somewhere', 'such', 't', 'than', 'that', "that'll", 'the', 'their', 'theirs',
                'them', 'themselves', 'then', 'thence', 'there', 'thereafter', 'thereby', 'therefore', 'therein',
                'thereupon', 'these', 'they', 'this', 'those', 'through', 'throughout', 'thru', 'thus', 'to', 'too',
                'toward', 'towards', 'under', 'unless', 'until', 'up', 'upon', 'used', 've', 'was', 'wasn', "wasn't",
                'we', 'were', 'weren', "weren't", 'what', 'whatever', 'when', 'whence', 'whenever', 'where',
                'whereafter', 'whereas', 'whereby', 'wherein', 'whereupon', 'wherever', 'whether', 'which', 'while',
                'whither', 'who', 'whoever', 'whole', 'whom', 'whose', 'why', 'with', 'within', 'without', 'won',
                "won't", 'would', 'wouldn', "wouldn't", 'y', 'yet', 'you', "you'd", "you'll", "you're", "you've",
                'your', 'yours', 'yourself', 'yourselves', '.', '-', 'a the', '/', '?', 'some', '"', ',', 'b', '&', '!',
                '@', '%', '^', '*', '(', ')', "-", '-', '+', '=', '<', '>', '|', ':', ";", '～', '·']
filter_words = set(filter_words)


########################################################################
########################################################################
# EDA: Easy Data Augmentation Techniques for Boosting Performance on Text Classification Tasks
# https://github.com/jasonwei20/eda_nlp
import nltk
nltk.download('wordnet')
from nltk.corpus import wordnet 

stop_words = filter_words

#cleaning up text
import re
def get_only_chars(line):

    clean_line = ""

    line = line.replace("’", "")
    line = line.replace("'", "")
    line = line.replace("-", " ") #replace hyphens with spaces
    line = line.replace("\t", " ")
    line = line.replace("\n", " ")
    line = line.lower()

    for char in line:
        if char in 'qwertyuiopasdfghjklzxcvbnm ':
            clean_line += char
        else:
            clean_line += ' '

    clean_line = re.sub(' +',' ',clean_line) #delete extra spaces
    if clean_line[0] == ' ':
        clean_line = clean_line[1:]
    return clean_line

########################################################################
# Synonym replacement
# Replace n words in the sentence with synonyms from wordnet
########################################################################

def synonym_replacement(words, n):
    new_words = words.copy()
    random_word_list = list(set([word for word in words if word not in stop_words]))
    random.shuffle(random_word_list)
    num_replaced = 0
    for random_word in random_word_list:
        synonyms = get_synonyms(random_word)
        if len(synonyms) >= 1:
            synonym = random.choice(list(synonyms))
            new_words = [synonym if word == random_word else word for word in new_words]
            #print("replaced", random_word, "with", synonym)
            num_replaced += 1
        if num_replaced >= n: #only replace up to n words
            break

    #this is stupid but we need it, trust me
    sentence = ' '.join(new_words)
    new_words = sentence.split(' ')

    return new_words

def get_synonyms(word):
	synonyms = set()
	for syn in wordnet.synsets(word): 
		for l in syn.lemmas(): 
			synonym = l.name().replace("_", " ").replace("-", " ").lower()
			synonym = "".join([char for char in synonym if char in ' qwertyuiopasdfghjklzxcvbnm'])
			synonyms.add(synonym) 
	if word in synonyms:
		synonyms.remove(word)
	return list(synonyms)


########################################################################
# Random deletion
# Randomly delete words from the sentence with probability p
########################################################################

def random_deletion(words, p):

	#obviously, if there's only one word, don't delete it
	if len(words) == 1:
		return words

	#randomly delete words with probability p
	new_words = []
	for word in words:
		r = random.uniform(0, 1)
		if r > p:
			new_words.append(word)

	#if you end up deleting all words, just return a random word
	if len(new_words) == 0:
		rand_int = random.randint(0, len(words)-1)
		return [words[rand_int]]

	return new_words

########################################################################
# Random swap
# Randomly swap two words in the sentence n times
########################################################################

def random_swap(words, n):
	new_words = words.copy()
	for _ in range(n):
		new_words = swap_word(new_words)
	return new_words

def swap_word(new_words):
	random_idx_1 = random.randint(0, len(new_words)-1)
	random_idx_2 = random_idx_1
	counter = 0
	while random_idx_2 == random_idx_1:
		random_idx_2 = random.randint(0, len(new_words)-1)
		counter += 1
		if counter > 3:
			return new_words
	new_words[random_idx_1], new_words[random_idx_2] = new_words[random_idx_2], new_words[random_idx_1] 
	return new_words

########################################################################
# Random insertion
# Randomly insert n words into the sentence
########################################################################

def random_insertion(words, n):
	new_words = words.copy()
	for _ in range(n):
		add_word(new_words)
	return new_words

def add_word(new_words):
	synonyms = []
	counter = 0
	while len(synonyms) < 1:
		random_word = new_words[random.randint(0, len(new_words)-1)]
		synonyms = get_synonyms(random_word)
		counter += 1
		if counter >= 10:
			return
	random_synonym = synonyms[0]
	random_idx = random.randint(0, len(new_words)-1)
	new_words.insert(random_idx, random_synonym)


########################################################################
# main data augmentation function
########################################################################

def eda(sentence, alpha_sr=0.1, alpha_ri=0.1, alpha_rs=0.1, p_rd=0.1, num_aug=9):
	
	sentence = get_only_chars(sentence)
	words = sentence.split(' ')
	words = [word for word in words if word is not '']
	num_words = len(words)
	
	augmented_sentences = []
	num_new_per_technique = int(num_aug/4)+1

	#sr
	if (alpha_sr > 0):
		n_sr = max(1, int(alpha_sr*num_words))
		for _ in range(num_new_per_technique):
			a_words = synonym_replacement(words, n_sr)
			augmented_sentences.append(' '.join(a_words))

	#ri
	if (alpha_ri > 0):
		n_ri = max(1, int(alpha_ri*num_words))
		for _ in range(num_new_per_technique):
			a_words = random_insertion(words, n_ri)
			augmented_sentences.append(' '.join(a_words))

	#rs
	if (alpha_rs > 0):
		n_rs = max(1, int(alpha_rs*num_words))
		for _ in range(num_new_per_technique):
			a_words = random_swap(words, n_rs)
			augmented_sentences.append(' '.join(a_words))

	#rd
	if (p_rd > 0):
		for _ in range(num_new_per_technique):
			a_words = random_deletion(words, p_rd)
			augmented_sentences.append(' '.join(a_words))

	augmented_sentences = [get_only_chars(sentence) for sentence in augmented_sentences]
	shuffle(augmented_sentences)

	#trim so that we have the desired number of augmented sentences
	if num_aug >= 1:
		augmented_sentences = augmented_sentences[:num_aug]
	else:
		keep_prob = num_aug / len(augmented_sentences)
		augmented_sentences = [s for s in augmented_sentences if random.uniform(0, 1) < keep_prob]

	#append the original sentence
	augmented_sentences.append(sentence)

	return augmented_sentences


if __name__=="__main__":
    sentence = "The quick brown fox jumps over the lazy dog"

    a = 0.3
    augmented_sentences = eda(sentence, alpha_sr=a, alpha_ri=a, alpha_rs=a, p_rd=a, num_aug=5)
    for s in augmented_sentences:
        print(s)