import torch
import random
import matplotlib
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA
import numpy as np
from copy import deepcopy

from adjustText import adjust_text

from dataset import tokenize_line

matplotlib.rcParams['agg.path.chunksize'] = 10000

def plot_tsne(source_embs, sources,
              single_embs, single_words,
              neighbor_embs, neighbor_words,
              target_embs, targets,
              origin, full_path,
              num_labels=30):
    # num_labels: approximate number of neighbor labels to plot in the figure
    fig = plt.figure(figsize=(16,9))
    scale = 4.0
    neighbor_sparsity = num_labels / len(neighbor_words)
    texts = []
    labels = []
    def point(xy,label,point_color,size, text_color=None,sparsity=1.0):
        plt.scatter(xy[0], xy[1], c=point_color)
        if random.random() < sparsity and label and label not in labels:
            labels.append(label) # avoid plotting the same label twice
            split=label.split(' ')
            myscale = scale
            if len(split) > 5:
                label = split[0]+" ... "+split[-2]+" "+split[-1]
                myscale *= 0.7
            if text_color is not None:
                texts.append(
                    plt.annotate(
                        label, alpha=1, xy=(xy[0], xy[1]), xytext=(5, 2), c=text_color,
                        textcoords='offset points', ha='right', va='bottom', size=size*myscale))
            else:
                texts.append(
                    plt.annotate(
                        label, alpha=0.3, xy=(xy[0], xy[1]), xytext=(5, 2),
                        textcoords='offset points', ha='right', va='bottom', size=size*myscale))

    def arrow(src,dst,color):
        plt.arrow(x=src[0], y=src[1],
                  dx=dst[0]-src[0], dy=dst[1]-src[1], 
                  linestyle="--", color=color, 
                  length_includes_head=True)

    # neighbors
    for emb, word in zip(neighbor_embs, neighbor_words):
        if (word not in sources) and (word not in single_words) and (word not in targets):
            point(emb, word, "gray", 8, sparsity=neighbor_sparsity)

    # origin
    point(origin, "ORIGIN", "blue", 8, "blue")

    def split_by_origin(embs,origin):
        # in: [origin, emb11, ... emb1N, origin, emb21, ... emb2M]
        # out:[[origin, emb11,... emb1N],[origin, emb21, ..., emb2M]]
        acc=[]
        current=[]
        for emb in embs:
            if np.all(origin==emb): # does not care if the result is array or just True
                if len(current)>0: # first iteration
                    acc.append(current)
                current=[]
            current.append(emb)
        acc.append(current)
        return acc

    def draw_paths(embs,words,color,size,labelcolor=None):
        if labelcolor is None:
            labelcolor=color
        split_embs  = split_by_origin(embs,origin)
        split_words = split_by_origin(words,"")
        for _embs, _words in zip(split_embs,split_words):
            for i, emb, word in zip(range(len(_embs)),_embs,_words):
                if i in [1,2,len(_embs)-1]:
                    point(emb, word, color, size, text_color=labelcolor)
                else:
                    # do not print labels
                    point(emb, None, color, size, text_color=labelcolor)
                    
                if i != 0:
                    arrow(_embs[i-1],emb,color)

    # targets
    draw_paths(target_embs,targets,"green",8)
    # phrases
    draw_paths(source_embs,sources,"red",8,"brown")

    # single words
    for emb, word in zip(single_embs, single_words):
        if (word not in sources) and (word not in targets):
            point(emb, word, "orange", 8, "orange")

    # plt.legend(loc=4)
    # plt.title(full_path.split("/")[-1])
    plt.grid(False)
    plt.axis('off')
    adjust_text(texts)
    plt.savefig(full_path, format='png', dpi=150, bbox_inches='tight')


def tsne_transform(transform_model, *embs_list):
    """
    :param embs_list: List[List[np array]]. each embedding represent a phrase
    """
    if transform_model == "tsne":
        model = TSNE(perplexity=15, n_components=2, init='pca', n_iter=3500, random_state=42)
    else: #PCA
        model = PCA(n_components=2)
    all_embs = np.concatenate(embs_list,axis=0)
    # remove duplicate for the transform, and add them back after transform
    unique_embs, indices = np.unique(all_embs, axis=0, return_inverse=True)
    unique_embs = model.fit_transform(unique_embs)
    all_embs = unique_embs[indices]
    result_embs_list = []
    i = 0
    for embs in embs_list:
        result_embs_list.append(all_embs[i:i+len(embs)])
        i+=+len(embs)
    return result_embs_list


def get_neighbor_embs(embeddings, model, emb_dataset, k):
    """
    :param embeddings: List[torch tensor]. each embedding represent a phrase
    :param model: word embedding model
    :param k: number of neighbors to find for each phrase
    """
    device = model.embeddings[0].weight.device
    neighbor_embs = []
    neighbor_tokens = np.array([])
    for emb in embeddings:
        neighbor_tokens = np.append(neighbor_tokens, model.predict(max_k=k, pred_emb=torch.tensor(emb).unsqueeze(0).to(device)).cpu().numpy())
    neighbor_tokens = np.unique(neighbor_tokens).astype(int)
    neighbor_words = list(map(emb_dataset.idx_to_word, neighbor_tokens))
    neighbor_embs = [model.embed_sentence(torch.tensor([t]).long().to(device)).cpu().numpy() for t in neighbor_tokens.squeeze()]
    return neighbor_embs, neighbor_words


def prune_neighbors(neighbor_embs, neighbor_words, *embs_list):
    """
    :param neighbor_embs: List[np array] neighbor embeddings to prune
    :param neighbor_words: List[str] neighbor words to prune
    :param embs_list: List[List[np array]] rest of embeddings
    """
    all_embs = np.concatenate(embs_list,axis=0)
    all_embs = np.unique(all_embs, axis=0)
    x_mu, y_mu = np.mean(all_embs, axis=0)
    x_std, y_std = np.std(all_embs, axis=0)
    l = 3.0
    pruned_neighbor_embs = []
    pruned_neighbor_words = []
    for i,n in enumerate(neighbor_embs):
        if (x_mu-l*x_std < n[0]) and (n[0] < x_mu+l*x_std) and (y_mu-l*y_std < n[1]) and (n[1] < y_mu+l*y_std):
            pruned_neighbor_embs.append(n)
            pruned_neighbor_words.append(neighbor_words[i])

    print(f"{len(neighbor_embs)-len(pruned_neighbor_embs)} neighbors pruned")
    return pruned_neighbor_embs, pruned_neighbor_words


def plot_step_phrase(sources, targets, model, emb_dataset, k, transform_model):
    """
    :param step_phrase: List[List[str]], phrases to plot
    :param model: word embedding model
    :param emb_dataset: word embedding dataset (containing word to index)
    :param k: int, number of neighbors to add to each of the step phrase in plot
    :param target_word: str, target word in case of compositional phrase
    credit of many lines of code here goes to: 
    https://towardsdatascience.com/google-news-and-leo-tolstoy-visualizing-word2vec-
    word-embeddings-with-t-sne-11558d8bd4d
    """
    model = model.cpu()
    device = model.embeddings[0].weight.device
    for source,target in zip(sources,targets):
        # per figure
        source_embs = []
        target_embs = []
        single_words = set()
        single_embs = []
        source_phrases = []
        target_phrases = []
        
        # want to highlight:
        # all phrase embedding
        # embedding of all tokens appearing in the phrase
        # target embedding
        # neighboring embedding

        # embed each phrase into embedding space
        # note: source takes the value [[], [apple], [red,apple],...[red,red,red,apple],[],[apple],[green,apple],..]
        for source_words in source:
            source_tokens = list(map(emb_dataset.word_to_idx,source_words))
            if len(source_words)==0:
                if "hybrid" in model.hyper["model"].lower():
                    source_emb=np.zeros(model.hyper["embedding"]*2,dtype=np.float32)
                else:
                    source_emb=np.zeros(model.hyper["embedding"],dtype=np.float32)
            else:
                source_emb = model.embed_sentence(torch.tensor(source_tokens,dtype=torch.long).to(device)).cpu().detach().numpy()
                single_words.update(source_words)
            source_embs.append(source_emb)
            source_phrases.append(" ".join(source_words))
            
        # embed target words.
        # note: target takes the value [[], [target1], [], [target2]...]
        for target_words in target:
            target_tokens = list(map(emb_dataset.word_to_idx,target_words))
            if len(target_words)==0:
                if "hybrid" in model.hyper["model"].lower():
                    target_emb=np.zeros(model.hyper["embedding"]*2,dtype=np.float32)
                else:
                    target_emb=np.zeros(model.hyper["embedding"],dtype=np.float32)
            else:
                target_emb = model.embed_sentence(torch.tensor(target_tokens,dtype=torch.long).to(device)).cpu().detach().numpy()
                single_words.update(target_words)
            target_embs.append(target_emb)
            target_phrases.append(" ".join(target_words))

        single_words = list(single_words)
        # embed each word in the phrase
        for single_word in single_words:
            single_token = emb_dataset.word_to_idx(single_word)
            single_emb = model.embed_sentence(torch.tensor([single_token]).to(device)).cpu().detach().numpy()
            single_embs.append(single_emb)

        # Optional, for each word, add neighboring words into the batch
        neighbor_embs, neighbor_words = get_neighbor_embs(source_embs+single_embs+target_embs, model, emb_dataset, k)

        origin = np.zeros_like(target_emb)
        origin_embs = [origin]
        # build t-SNE model and convert to 2d
        source_embs, single_embs, neighbor_embs, target_embs, origin_embs = \
            tsne_transform(transform_model, source_embs, single_embs, neighbor_embs, target_embs, origin_embs)

        # remove far away neighbors
        neighbor_embs, neighbor_words = prune_neighbors(neighbor_embs, neighbor_words, source_embs, single_embs, target_embs, origin_embs)

        for num_labels in [30,40,20,50]:
            # CCC filesystem requirements. Smaller than the actual limits
            phrase_concat = " ".join([s for s in source_phrases+target_phrases if s != ""])
            phrase_concat = phrase_concat[:100]
            
            title = "-".join(["craw",transform_model,phrase_concat,str(num_labels)])
            title = title.replace(" ","_")
            try:
                plot_tsne(source_embs,   source_phrases,
                          single_embs,   single_words,
                          neighbor_embs, neighbor_words,
                          target_embs,   target_phrases,
                          origin_embs[0],
                          model.local(title+".png"),
                          num_labels=num_labels)
            except Exception as e:
                print("error while plotting",model.local(title+".png"),":",e)
                import stacktrace
                stacktrace.format(False)
                pass


def plot_phrases(phrase_target_pairs, model, emb_dataset, k, from_end=True, transform_model="tsne"):
    """
    :param phrases: (source,target,source,target,...)*, pairs of a source phrase and a target word
       source and target is a list of strings, each string is a word
    :param model: word embedding model
    :param emb_dataset: word embedding dataset (containing word to index)
    :param k: int, number of neighbors to add to each of the step phrase in plot
    :param modifier: str [pre, post], whether the modifier of the term comes before or after 
    """
    # build sources input from each compositional phrase
    per_figure_sources = []
    per_figure_targets = []
    for pairs in phrase_target_pairs:
        sources = []
        targets = []
        while len(pairs) > 0:
            source, target, *pairs = pairs
            if from_end:
                source = reversed(source)
            sources.append([])
            for word in source:
                sources.append([word]+sources[-1])
            targets.append([])
            for word in target:
                targets.append([word]+targets[-1])
        per_figure_sources.append(sources)
        per_figure_targets.append(targets)
    plot_step_phrase(per_figure_sources,
                     per_figure_targets,
                     model, emb_dataset, k, transform_model=transform_model)
