import torch
import random
import numpy as np
from functools import partial
import torch.nn.functional as nnf
from torchvision import transforms as T

# A lot of the approaches here are inspired from the wonderful paper from O'Connor and Andreas 2021.
# https://github.com/lingo-mit/context-ablations

def get_text_perturb_fn(text_perturb_fn):
    if text_perturb_fn == "shuffle_nouns_and_adj":
        return shuffle_nouns_and_adj
    elif text_perturb_fn == "shuffle_allbut_nouns_and_adj":
        return shuffle_allbut_nouns_and_adj
    elif text_perturb_fn == "shuffle_within_trigrams":
        return shuffle_within_trigrams
    elif text_perturb_fn == "shuffle_all_words":
        return shuffle_all_words
    elif text_perturb_fn == "shuffle_trigrams":
        return shuffle_trigrams
    elif text_perturb_fn is None:
        return None
    else:
        print("Unknown text perturbation function: {}, returning None".format(text_perturb_fn))
        return None
    
    
def get_image_perturb_fn(image_perturb_fn):
    if image_perturb_fn == "shuffle_rows_4":
        return partial(shuffle_rows, n_rows=4)
    elif image_perturb_fn == "shuffle_patches_9":
        return partial(shuffle_patches, n_ratio=3)
    elif image_perturb_fn == "shuffle_cols_4":
        return partial(shuffle_columns, n_cols=4)
    elif image_perturb_fn is None:
        return None
    else:
        print("Unknown image perturbation function: {}, returning None".format(image_perturb_fn))
        return None
    
def shuffle_nouns_and_adj(ex):
    import spacy
    nlp = spacy.load("en_core_web_sm")
    doc = nlp(ex)
    tokens = [token.text for token in doc]
    text = np.array(tokens)
    noun_idx = [i for i, token in enumerate(doc) if token.tag_ in ['NN', 'NNS', 'NNP', 'NNPS']]
    ## Finding adjectives
    adjective_idx = [i for i, token in enumerate(doc) if token.tag_ in ['JJ', 'JJR', 'JJS']]
    ## Shuffle the nouns of the text 
    text[noun_idx] = np.random.permutation(text[noun_idx])
    ## Shuffle the adjectives of the text 
    text[adjective_idx] = np.random.permutation(text[adjective_idx])
    
    return " ".join(text)

    
def shuffle_all_words(ex):
    return " ".join(np.random.permutation(ex.split(" ")))


def shuffle_allbut_nouns_and_adj(ex):
    import spacy
    nlp = spacy.load("en_core_web_sm")
    doc = nlp(ex)
    tokens = [token.text for token in doc]
    text = np.array(tokens)
    noun_adj_idx = [i for i, token in enumerate(doc) if token.tag_ in ['NN', 'NNS', 'NNP', 'NNPS', 'JJ', 'JJR', 'JJS']]
    ## Finding adjectives
    
    else_idx = np.ones(text.shape[0])
    else_idx[noun_adj_idx] = 0
    
    else_idx = else_idx.astype(bool)
    ## Shuffle everything that are nouns or adjectives
    text[else_idx] = np.random.permutation(text[else_idx])
    return " ".join(text)
    

def get_trigrams(sentence):
    # Taken from https://github.com/lingo-mit/context-ablations/blob/478fb18a9f9680321f0d37dc999ea444e9287cc0/code/transformers/src/transformers/data/data_augmentation.py
    trigrams = []
    trigram = []
    for i in range(len(sentence)):
        trigram.append(sentence[i])
        if i % 3 == 2:
            trigrams.append(trigram[:])
            trigram = []
    if trigram:
        trigrams.append(trigram)
    return trigrams

def trigram_shuffle(sentence):
    trigrams = get_trigrams(sentence)
    for trigram in trigrams:
        random.shuffle(trigram)
    return " ".join([" ".join(trigram) for trigram in trigrams])


def shuffle_within_trigrams(ex):
    import nltk
    tokens = nltk.word_tokenize(ex)
    shuffled_ex = trigram_shuffle(tokens)
    return shuffled_ex


def shuffle_trigrams(ex):
    import nltk
    tokens = nltk.word_tokenize(ex)
    trigrams = get_trigrams(tokens)
    random.shuffle(trigrams)
    shuffled_ex = " ".join([" ".join(trigram) for trigram in trigrams])
    return shuffled_ex


def _handle_image_4shuffle(x):
    return_image = False
    if not isinstance(x, torch.Tensor):
        # print(f"x is not a tensor: {type(x)}. Trying to handle but fix this or I'll annoy you with this log")
        t = torch.tensor(np.array(x)).unsqueeze(dim=0).float()
        t = t.permute(0, 3, 1, 2)
        return_image = True
        return t, return_image
    if len(x.shape) != 4:
        #print("You did not send a tensor of shape NxCxWxH. Unsqueezing not but fix this or I'll annoy you with this log")
        return x.unsqueeze(dim=0), return_image
    else:
        # Good boi
        return x, return_image
        

def shuffle_rows(x, n_rows=7):
    """
    Shuffle the rows of the image tensor where each row has a size of 14 pixels.
    Tensor is of shape N x C x W x H
    """
    x, return_image = _handle_image_4shuffle(x)
    patch_size = x.shape[-2]//n_rows
    u = nnf.unfold(x, kernel_size=(patch_size, x.shape[-1]), stride=patch_size, padding=0)
    # permute the patches of each image in the batch
    pu = torch.cat([b_[:, torch.randperm(b_.shape[-1])][None,...] for b_ in u], dim=0)
    # fold the permuted patches back together
    f = nnf.fold(pu, x.shape[-2:], kernel_size=(patch_size, x.shape[-1]), stride=patch_size, padding=0)
    
    image = f.squeeze() # C W H
    if return_image:
        return T.ToPILImage()(image.type(torch.uint8))
    else:
        return image


def shuffle_columns(x, n_cols=7):
    """
    Shuffle the columns of the image tensor where we'll have n_cols columns.
    Tensor is of shape N x C x W x H
    """
    x, return_image = _handle_image_4shuffle(x)
    patch_size = x.shape[-1]//n_cols
    u = nnf.unfold(x, kernel_size=(x.shape[-2], patch_size), stride=patch_size, padding=0)
    # permute the patches of each image in the batch
    pu = torch.cat([b_[:, torch.randperm(b_.shape[-1])][None,...] for b_ in u], dim=0)
    # fold the permuted patches back together
    f = nnf.fold(pu, x.shape[-2:], kernel_size=(x.shape[-2], patch_size), stride=patch_size, padding=0)
    image = f.squeeze() # C W H
    if return_image:
        return T.ToPILImage()(image.type(torch.uint8))
    else:
        return image



def shuffle_patches(x, n_ratio=4):
    """
    Shuffle the rows of the image tensor where each row has a size of 14 pixels.
    Tensor is of shape N x C x W x H
    """
    x, return_image = _handle_image_4shuffle(x)
    patch_size_x = x.shape[-2]//n_ratio
    patch_size_y = x.shape[-1]//n_ratio
    u = nnf.unfold(x, kernel_size=(patch_size_x, patch_size_y), stride=(patch_size_x, patch_size_y), padding=0)
    # permute the patches of each image in the batch
    pu = torch.cat([b_[:, torch.randperm(b_.shape[-1])][None,...] for b_ in u], dim=0)
    # fold the permuted patches back together
    f = nnf.fold(pu, x.shape[-2:], kernel_size=(patch_size_x, patch_size_y), stride=(patch_size_x, patch_size_y), padding=0)
    image = f.squeeze() # C W H
    if return_image:
        return T.ToPILImage()(image.type(torch.uint8))
    else:
        return image