import numpy as np 
import torch
import torch.nn as nn

import random
from random import shuffle
import copy
from torchvision import transforms
from PIL import Image
import torch.nn.functional as F

xent_loss = torch.nn.CrossEntropyLoss()

class MMA():
    def __init__(
        self, 
        model, 
        img_attacker,
        tokenizer,
        attack_config,
        ):
        self.model=model
        self.tokenizer=tokenizer

        # Step1: prepare texts for text supervision in image attack
        self.is_aug_txt = attack_config.get("is_aug_txt", False)
        # self.aug_alpha = attack_config.get("aug_alpha", 0.3)
        # self.is_use_gt_caps = attack_config.get("is_use_gt_caps", False)
        self.alpha_sr = attack_config.get("alpha_sr", 0.1)
        self.alpha_ri = attack_config.get("alpha_ri", 0.1)
        self.alpha_rs = attack_config.get("alpha_rs", 0.1)
        self.p_rd = attack_config.get("p_rd", 0.1)
        # number of texts supervision for image attack
        self.txt_sup_k = attack_config.get("txt_sup_k", 1)

        # Step2: image attack
        scales_list = [
            [],
            [0.75,1.25],
            [0.5,0.75,1.25,1.5],
        ]
        self.scale_ver = attack_config["scale_ver"] # specify!
        self.img_att_scales = scales_list[self.scale_ver]
        self.img_attacker = img_attacker

        # Step3: augment txt
        self.is_txt_aug = attack_config.get("is_txt_aug", False)
        self.txt_aug = attack_config.get("txt_aug", "sr")
        self.txt_aug_alpha = attack_config.get("txt_aug_alpha", 0.1)
        assert self.txt_aug in ["sr", "ri", "rs", "rd", "rand", "all"], self.txt_aug

        print("MMA attack config:")
        print(f"[img_att_scale: {self.img_att_scales}, txt_sup_k: {self.txt_sup_k}]")
        # print(f"- is_use_gt_caps: {self.is_use_gt_caps}")
        print(f"- alpha_sr: {self.alpha_sr}")
        print(f"- alpha_ri: {self.alpha_ri}")
        print(f"- alpha_rs: {self.alpha_rs}")
        print(f"- p_rd: {self.p_rd}")
        # print(f"- is_txt_aug: {self.is_txt_aug}")
        # print(f"- txt_aug: {self.txt_aug}")
        print(f"- is_aug_txt: {self.is_aug_txt}")

    
    def attack(self, imgs, txts, txt2img, device='cpu', max_length=30, 
               scales=None, 
               img_attack_loss="sim", 
               gt_caps_list=None,
               is_train=False,
               **kwargs):
        if scales is None:
            scales = self.img_att_scales
            # raise ValueError("scales is None")

        # print("is gt caps list None?", gt_caps_list is None)

        # for sentence in txts:
        #     print(sentence)

        ######## Step 1: text attack ########
        adv_txts_list = None
        if gt_caps_list is not None:
            # print("if gt_caps_list is not None:")
            # assert gt_caps_list is not None
            # len(gt_caps_list) = caps_k,  len(gt_caps_list[0]) = B
            # for i in range(len(gt_caps_list)):
            #     print(f"len(gt_caps_list[{i}]): {len(gt_caps_list[i])}")
           
            # adv_txts = []
            B = len(imgs)
            adv_txts = [
                gt_caps_list[j][i] 
                for i in range(B) for j in range(len(gt_caps_list))
            ]
            adv_txt2img = []
            for i in range(B):
                for _ in range(len(gt_caps_list)):
                    adv_txt2img.append(i)
            # print("len adv_txts", len(adv_txts))
            # print("len adv_txt2img", len(adv_txt2img))
        else:
            if self.txt_sup_k == 1:
                # print("if self.txt_sup_k == 1:")
                adv_txts = txts
                adv_txt2img = txt2img
            elif self.txt_sup_k > 1:
                # print("elif self.txt_sup_k > 1:")
                adv_txts_list = [
                    eda(
                        sentence, 
                        alpha_sr=self.alpha_sr, 
                        alpha_ri=self.alpha_ri, 
                        alpha_rs=self.alpha_rs, 
                        p_rd=self.p_rd, 
                        num_aug=self.txt_sup_k - 1
                    )
                    for sentence in txts
                ]
                adv_txts = [txt for txts in adv_txts_list for txt in txts]
                adv_txt2img = [i for i in range(len(txts)) for _ in range(self.txt_sup_k)]

        ######## Step 1.5: aug text ########
        if self.is_aug_txt:
            raise NotImplementedError
            # for sentence in adv_txts:
            #     print(sentence)
            pre_adv_txts = copy.deepcopy(adv_txts)

            adv_txts = [
                eda(
                    sentence, 
                    alpha_sr=self.alpha_sr, 
                    alpha_ri=self.alpha_ri, 
                    alpha_rs=self.alpha_rs, 
                    p_rd=self.p_rd, 
                    num_aug=1
                )[0]
                for sentence in adv_txts
            ]
            if random.random() < 0.01:
                print("pre_adv_txts[0] -> adv_txts[0]")
                print(pre_adv_txts[0])
                print(adv_txts[0])
            # print("adv_txts[0]", adv_txts[0])

        ######## Step 2. text-guided attack ########
        if self.img_attacker.alpha_sup > 0:
            with torch.no_grad():
                # txts_input = self.tokenizer(
                #     adv_txts,
                #     padding="max_length",
                #     truncation=True,
                #     max_length=max_length,
                #     return_tensors="pt",
                # ).to(device)
                # txts_output = self.model.inference_text(txts_input)
                # txt_supervisions = txts_output["text_feat"]
                B = imgs.shape[0]
                txt_B = len(adv_txts)
                len_iter = txt_B // B
                txt_supervisions = None
                for i in range(len_iter):
                    txts_input = self.tokenizer(
                        adv_txts[i*B:i*B+B],
                        padding="max_length",
                        truncation=True,
                        max_length=max_length,
                        return_tensors="pt",
                    ).to(device)
                    this_txts_output = self.model.inference_text(txts_input)
                    if txt_supervisions is None:
                        txt_supervisions = this_txts_output["text_feat"]
                    else:
                        txt_supervisions = torch.cat(
                            [txt_supervisions, this_txts_output["text_feat"]], 0
                        )
        else:
              txt_supervisions = None
        adv_imgs = self.img_attacker.attack(
            self.model, imgs, adv_txt2img, device, 
            scales=scales, txt_embeds=txt_supervisions, loss_metric=img_attack_loss)
        
        ######## Step 3. select 1 text for each image ########
        if self.is_txt_aug:
            raise NotImplementedError
            if self.txt_aug == "rand":
                m = np.eye(4) * 0.1
                i = random.choice(range(4))
                alpha_sr, alpha_ri, alpha_rs, p_rd = m[i]
            elif self.txt_aug == "all":
                alpha_sr, alpha_ri, alpha_rs, p_rd = self.txt_aug_alpha, self.txt_aug_alpha, self.txt_aug_alpha, self.txt_aug_alpha
            else:
                alpha_sr = self.txt_aug_alpha if self.txt_aug == "sr" else 0
                alpha_ri = self.txt_aug_alpha if self.txt_aug == "ri" else 0
                alpha_rs = self.txt_aug_alpha if self.txt_aug == "rs" else 0
                p_rd = self.txt_aug_alpha if self.txt_aug == "rd" else 0

            if gt_caps_list is not None:
                # [000001111122222,...]
                adv_txts_list = [
                    eda(
                        gt_caps_list[j][i], 
                        alpha_sr=alpha_sr,
                        alpha_ri=alpha_ri,
                        alpha_rs=alpha_rs,
                        p_rd=p_rd,
                        num_aug=1
                    )
                    for i in range(B)
                    for j in range(len(gt_caps_list))
                ]
                adv_txts = [txts[0] for txts in adv_txts_list]
                return adv_imgs, adv_txts
            else:
                adv_txts_list = [
                    eda(
                        sentence, 
                        alpha_sr=alpha_sr,
                        alpha_ri=alpha_ri,
                        alpha_rs=alpha_rs,
                        p_rd=p_rd,
                        num_aug=1
                    )
                    for sentence in txts
                ]
                adv_txts = [txts[0] for txts in adv_txts_list]
        else:
            if gt_caps_list is not None:
                # all captions for each image. len(adv_imgs) * 5 == len(adv_txts)
                pass
            else:
                adv_txts = txts
        
        if is_train:
            assert len(adv_imgs) == len(adv_txts)
        return adv_imgs, adv_txts

                

class ImageAttacker():
    def __init__(self, normalization, eps=2/255, steps=10, step_size=0.5/255, 
                 alpha_unsup=0.0, alpha_sup=1.0):
        self.normalization = normalization
        self.eps = eps
        self.steps = steps 
        self.step_size = step_size 
        
        self.kl_loss = nn.KLDivLoss(reduction="batchmean")
        self.alpha_unsup = alpha_unsup
        self.alpha_sup = alpha_sup

        print("ImageAttacker config:")
        print(f"- eps: {self.eps}")
        print(f"- steps: {self.steps}")
        print(f"- step_size: {self.step_size}")
        print(f"- alpha_unsup: {self.alpha_unsup}")
        print(f"- alpha_sup: {self.alpha_sup}")

    def unsup_loss(self, adv_embeds, origin_embeds):
        loss = self.kl_loss(
            adv_embeds.log_softmax(dim=-1), origin_embeds.softmax(dim=-1)
        )
        return loss

    def sup_loss(self, model, adv_imgs_embeds, txts_embeds, txt2img, loss_metric="sim"):  
        """
        return supervised loss (text-guided loss)

        loss_metric:
            sim: maximize the distance between correct pairs
            clip: maximize the clip training loss
        """
        assert loss_metric in ["sim", "clip"], loss_metric
        
        device = adv_imgs_embeds.device    

        bi = adv_imgs_embeds.shape[0]
        bj = txts_embeds.shape[0]
        it_labels = torch.zeros(bi, bj).to(device)
        # print(len(txt2img), txt2img[-1])
        for i in range(len(txt2img)):
            it_labels[txt2img[i], i]=1
        
        if loss_metric == "clip":
            image_features = adv_imgs_embeds
            text_features = txts_embeds
            # normalized features
            image_features = image_features / image_features.norm(dim=-1, keepdim=True)
            text_features = text_features / text_features.norm(dim=-1, keepdim=True)

            # cosine similarity as logits
            logit_scale = model.logit_scale.exp()
            logits_per_image = logit_scale * image_features @ text_features.t()
            logits_per_text = logits_per_image.t()

            # loss
            loss = (xent_loss(logits_per_image, it_labels) + xent_loss(logits_per_text, it_labels)) / 2
            return loss, -(logits_per_image * it_labels)
        
        elif loss_metric == "sim":
            it_sim_matrix = adv_imgs_embeds @ txts_embeds.T
            # loss_IaTcpos = -(it_sim_matrix * it_labels).sum(-1).mean()
            # loss = loss_IaTcpos
            loss_matrix = -(it_sim_matrix * it_labels)
            loss = loss_matrix.sum(-1).mean()
        
        return loss
    
    def loss_func(self, model, origin_embeds, adv_imgs_embeds, txts_embeds, txt2img, loss_metric="sim"):
        """
        return unsupervised loss and supervised loss
        """
        # print("origin_embeds", origin_embeds.shape)
        # print("adv_imgs_embeds", adv_imgs_embeds.shape)
        # print("txts_embeds", txts_embeds.shape)
        loss = 0
        if self.alpha_unsup > 0:
            unsup_loss = self.unsup_loss(adv_imgs_embeds, origin_embeds)
            loss += self.alpha_unsup * unsup_loss
        if self.alpha_sup > 0:
            sup_loss = self.sup_loss(model, adv_imgs_embeds, txts_embeds, txt2img, loss_metric=loss_metric)
            loss += self.alpha_sup * sup_loss
        # unsup_loss = self.unsup_loss(adv_imgs_embeds, origin_embeds)
        # sup_loss = self.sup_loss(model, adv_imgs_embeds, txts_embeds, txt2img, loss_metric=loss_metric)
        # loss = self.alpha_unsup * unsup_loss + self.alpha_sup * sup_loss
        return loss
    

    def attack(self, model, imgs, txt2img, device, scales=None, txt_embeds=None, loss_metric="sim"):
        
        # model.eval()
       
        b, _, _, _ = imgs.shape

        # forward
        origin_embeds = None
        if self.alpha_unsup > 0:
            if self.normalization is not None:
                imgs_output = model.inference_image(self.normalization(imgs))
            else:
                imgs_output = model.inference_image(imgs)
            origin_embeds = imgs_output['image_feat'].detach() # detach!
        
        if scales is None:
            scales_num = 1
        else:
            scales_num = len(scales) +1

        adv_imgs = imgs.detach() + torch.from_numpy(np.random.uniform(-self.eps, self.eps, imgs.shape)).float().to(device)
        adv_imgs = torch.clamp(adv_imgs, 0.0, 1.0)

        for i in range(self.steps):
            adv_imgs.requires_grad_()
            scaled_imgs = self.get_scaled_imgs(adv_imgs, scales, device)        
        
            if self.normalization is not None:
                adv_imgs_output = model.inference_image(self.normalization(scaled_imgs))
            else:
                adv_imgs_output = model.inference_image(scaled_imgs)
                
            adv_imgs_embeds = adv_imgs_output['image_feat']
            model.zero_grad()
            with torch.enable_grad():
                loss_list = []
                loss = torch.tensor(0.0, dtype=torch.float32).to(device)
                for i in range(scales_num):
                    loss_item = self.loss_func(
                        model, origin_embeds, adv_imgs_embeds[i*b:i*b+b], txt_embeds, txt2img,
                        loss_metric=loss_metric
                    )
                    loss_list.append(loss_item.item())
                    loss += loss_item
            loss.backward()
            
            grad = adv_imgs.grad 
            grad = grad / torch.mean(torch.abs(grad), dim=(1,2,3), keepdim=True)           
            
            perturbation = self.step_size * grad.sign()
            adv_imgs = adv_imgs.detach() + perturbation
            adv_imgs = torch.min(torch.max(adv_imgs, imgs - self.eps), imgs + self.eps)
            adv_imgs = torch.clamp(adv_imgs, 0.0, 1.0)
        
        return adv_imgs


    def get_scaled_imgs(self, imgs, scales=None, device='cuda'):
        if scales is None:
            return imgs

        ori_shape = (imgs.shape[-2], imgs.shape[-1])
        
        reverse_transform = transforms.Resize(ori_shape,
                                interpolation=transforms.InterpolationMode.BICUBIC)
        result = []
        for ratio in scales:
            scale_shape = (int(ratio*ori_shape[0]), 
                                  int(ratio*ori_shape[1]))
            scale_transform = transforms.Resize(scale_shape,
                                  interpolation=transforms.InterpolationMode.BICUBIC)
            scaled_imgs = imgs + torch.from_numpy(np.random.normal(0.0, 0.05, imgs.shape)).float().to(device)
            scaled_imgs = scale_transform(scaled_imgs)
            scaled_imgs = torch.clamp(scaled_imgs, 0.0, 1.0)
            
            reversed_imgs = reverse_transform(scaled_imgs)
            
            result.append(reversed_imgs)
        
        return torch.cat([imgs,]+result, 0)



filter_words = ['a', 'about', 'above', 'across', 'after', 'afterwards', 'again', 'against', 'ain', 'all', 'almost',
                'alone', 'along', 'already', 'also', 'although', 'am', 'among', 'amongst', 'an', 'and', 'another',
                'any', 'anyhow', 'anyone', 'anything', 'anyway', 'anywhere', 'are', 'aren', "aren't", 'around', 'as',
                'at', 'back', 'been', 'before', 'beforehand', 'behind', 'being', 'below', 'beside', 'besides',
                'between', 'beyond', 'both', 'but', 'by', 'can', 'cannot', 'could', 'couldn', "couldn't", 'd', 'didn',
                "didn't", 'doesn', "doesn't", 'don', "don't", 'down', 'due', 'during', 'either', 'else', 'elsewhere',
                'empty', 'enough', 'even', 'ever', 'everyone', 'everything', 'everywhere', 'except', 'first', 'for',
                'former', 'formerly', 'from', 'hadn', "hadn't", 'hasn', "hasn't", 'haven', "haven't", 'he', 'hence',
                'her', 'here', 'hereafter', 'hereby', 'herein', 'hereupon', 'hers', 'herself', 'him', 'himself', 'his',
                'how', 'however', 'hundred', 'i', 'if', 'in', 'indeed', 'into', 'is', 'isn', "isn't", 'it', "it's",
                'its', 'itself', 'just', 'latter', 'latterly', 'least', 'll', 'may', 'me', 'meanwhile', 'mightn',
                "mightn't", 'mine', 'more', 'moreover', 'most', 'mostly', 'must', 'mustn', "mustn't", 'my', 'myself',
                'namely', 'needn', "needn't", 'neither', 'never', 'nevertheless', 'next', 'no', 'nobody', 'none',
                'noone', 'nor', 'not', 'nothing', 'now', 'nowhere', 'o', 'of', 'off', 'on', 'once', 'one', 'only',
                'onto', 'or', 'other', 'others', 'otherwise', 'our', 'ours', 'ourselves', 'out', 'over', 'per',
                'please', 's', 'same', 'shan', "shan't", 'she', "she's", "should've", 'shouldn', "shouldn't", 'somehow',
                'something', 'sometime', 'somewhere', 'such', 't', 'than', 'that', "that'll", 'the', 'their', 'theirs',
                'them', 'themselves', 'then', 'thence', 'there', 'thereafter', 'thereby', 'therefore', 'therein',
                'thereupon', 'these', 'they', 'this', 'those', 'through', 'throughout', 'thru', 'thus', 'to', 'too',
                'toward', 'towards', 'under', 'unless', 'until', 'up', 'upon', 'used', 've', 'was', 'wasn', "wasn't",
                'we', 'were', 'weren', "weren't", 'what', 'whatever', 'when', 'whence', 'whenever', 'where',
                'whereafter', 'whereas', 'whereby', 'wherein', 'whereupon', 'wherever', 'whether', 'which', 'while',
                'whither', 'who', 'whoever', 'whole', 'whom', 'whose', 'why', 'with', 'within', 'without', 'won',
                "won't", 'would', 'wouldn', "wouldn't", 'y', 'yet', 'you', "you'd", "you'll", "you're", "you've",
                'your', 'yours', 'yourself', 'yourselves', '.', '-', 'a the', '/', '?', 'some', '"', ',', 'b', '&', '!',
                '@', '%', '^', '*', '(', ')', "-", '-', '+', '=', '<', '>', '|', ':', ";", '～', '·']
filter_words = set(filter_words)


########################################################################
########################################################################
# EDA: Easy Data Augmentation Techniques for Boosting Performance on Text Classification Tasks
# https://github.com/jasonwei20/eda_nlp
import nltk
nltk.download('wordnet')
from nltk.corpus import wordnet 

stop_words = filter_words

#cleaning up text
import re
def get_only_chars(line):

    clean_line = ""

    line = line.replace("’", "")
    line = line.replace("'", "")
    line = line.replace("-", " ") #replace hyphens with spaces
    line = line.replace("\t", " ")
    line = line.replace("\n", " ")
    line = line.lower()

    for char in line:
        if char in 'qwertyuiopasdfghjklzxcvbnm ':
            clean_line += char
        else:
            clean_line += ' '

    clean_line = re.sub(' +',' ',clean_line) #delete extra spaces
    if clean_line[0] == ' ':
        clean_line = clean_line[1:]
    return clean_line

########################################################################
# Synonym replacement
# Replace n words in the sentence with synonyms from wordnet
########################################################################

def synonym_replacement(words, n):
    new_words = words.copy()
    random_word_list = list(set([word for word in words if word not in stop_words]))
    random.shuffle(random_word_list)
    num_replaced = 0
    for random_word in random_word_list:
        synonyms = get_synonyms(random_word)
        if len(synonyms) >= 1:
            synonym = random.choice(list(synonyms))
            new_words = [synonym if word == random_word else word for word in new_words]
            #print("replaced", random_word, "with", synonym)
            num_replaced += 1
        if num_replaced >= n: #only replace up to n words
            break

    #this is stupid but we need it, trust me
    sentence = ' '.join(new_words)
    new_words = sentence.split(' ')

    return new_words

def get_synonyms(word):
	synonyms = set()
	for syn in wordnet.synsets(word): 
		for l in syn.lemmas(): 
			synonym = l.name().replace("_", " ").replace("-", " ").lower()
			synonym = "".join([char for char in synonym if char in ' qwertyuiopasdfghjklzxcvbnm'])
			synonyms.add(synonym) 
	if word in synonyms:
		synonyms.remove(word)
	return list(synonyms)


########################################################################
# Random deletion
# Randomly delete words from the sentence with probability p
########################################################################

def random_deletion(words, p):

	#obviously, if there's only one word, don't delete it
	if len(words) == 1:
		return words

	#randomly delete words with probability p
	new_words = []
	for word in words:
		r = random.uniform(0, 1)
		if r > p:
			new_words.append(word)

	#if you end up deleting all words, just return a random word
	if len(new_words) == 0:
		rand_int = random.randint(0, len(words)-1)
		return [words[rand_int]]

	return new_words

########################################################################
# Random swap
# Randomly swap two words in the sentence n times
########################################################################

def random_swap(words, n):
	new_words = words.copy()
	for _ in range(n):
		new_words = swap_word(new_words)
	return new_words

def swap_word(new_words):
	random_idx_1 = random.randint(0, len(new_words)-1)
	random_idx_2 = random_idx_1
	counter = 0
	while random_idx_2 == random_idx_1:
		random_idx_2 = random.randint(0, len(new_words)-1)
		counter += 1
		if counter > 3:
			return new_words
	new_words[random_idx_1], new_words[random_idx_2] = new_words[random_idx_2], new_words[random_idx_1] 
	return new_words

########################################################################
# Random insertion
# Randomly insert n words into the sentence
########################################################################

def random_insertion(words, n):
	new_words = words.copy()
	for _ in range(n):
		add_word(new_words)
	return new_words

def add_word(new_words):
	synonyms = []
	counter = 0
	while len(synonyms) < 1:
		random_word = new_words[random.randint(0, len(new_words)-1)]
		synonyms = get_synonyms(random_word)
		counter += 1
		if counter >= 10:
			return
	random_synonym = synonyms[0]
	random_idx = random.randint(0, len(new_words)-1)
	new_words.insert(random_idx, random_synonym)


########################################################################
# main data augmentation function
########################################################################

def eda(sentence, alpha_sr=0.1, alpha_ri=0.1, alpha_rs=0.1, p_rd=0.1, num_aug=9):
	
	sentence = get_only_chars(sentence)
	words = sentence.split(' ')
	words = [word for word in words if word is not '']
	num_words = len(words)
	
	augmented_sentences = []
	num_new_per_technique = int(num_aug/4)+1

	#sr
	if (alpha_sr > 0):
		n_sr = max(1, int(alpha_sr*num_words))
		for _ in range(num_new_per_technique):
			a_words = synonym_replacement(words, n_sr)
			augmented_sentences.append(' '.join(a_words))

	#ri
	if (alpha_ri > 0):
		n_ri = max(1, int(alpha_ri*num_words))
		for _ in range(num_new_per_technique):
			a_words = random_insertion(words, n_ri)
			augmented_sentences.append(' '.join(a_words))

	#rs
	if (alpha_rs > 0):
		n_rs = max(1, int(alpha_rs*num_words))
		for _ in range(num_new_per_technique):
			a_words = random_swap(words, n_rs)
			augmented_sentences.append(' '.join(a_words))

	#rd
	if (p_rd > 0):
		for _ in range(num_new_per_technique):
			a_words = random_deletion(words, p_rd)
			augmented_sentences.append(' '.join(a_words))

	augmented_sentences = [get_only_chars(sentence) for sentence in augmented_sentences]
	shuffle(augmented_sentences)

	#trim so that we have the desired number of augmented sentences
	if num_aug >= 1:
		augmented_sentences = augmented_sentences[:num_aug]
	else:
		keep_prob = num_aug / len(augmented_sentences)
		augmented_sentences = [s for s in augmented_sentences if random.uniform(0, 1) < keep_prob]

	#append the original sentence
	augmented_sentences.append(sentence)

	return augmented_sentences


if __name__=="__main__":
    sentence = "The quick brown fox jumps over the lazy dog"

    a = 0.3
    augmented_sentences = eda(sentence, alpha_sr=a, alpha_ri=a, alpha_rs=a, p_rd=a, num_aug=5)
    for s in augmented_sentences:
        print(s)