import random
from copy import deepcopy

import torch
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.cm as cm
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA

matplotlib.rcParams['agg.path.chunksize'] = 10000

import analogy

def plot_tsne(clusters, clusters_words, model, transform_model):
    plt.figure(figsize=(16,9))
    colors = cm.rainbow(np.linspace(0, 1, len(clusters)))
    for cluster,c in zip(clusters.items(), colors):
        category, embs = cluster
        emb = np.array(embs)
        plt.scatter(emb[:,0], emb[:,1], c=c, label=category)
        for i, emb, word in zip(range(len(embs)), embs, clusters_words[category]):
            plt.annotate(word, alpha=0.5, xy=(emb[0], emb[1]), xytext=(5, 2),
                    textcoords='offset points', ha='right', va='bottom', size=12)
    title = "analogy_cluster_"+transform_model+"_".join([c for c in list(clusters.keys())])
    plt.legend(loc=4)
    plt.title(title)
    plt.grid(True)
    plt.savefig(model.local(f"{title}.png"), format='png', dpi=150, bbox_inches='tight')

def t_sne_transform(clusters, transform_model):
    all_embs = []
    category_len = {}
    for category, cluster in clusters.items():
        all_embs.extend(cluster)
        category_len[category] = len(cluster)

    if transform_model == "tsne":  
        model = TSNE(perplexity=30, n_components=2, init='pca', n_iter=3500, random_state=42)
    else:
        model = PCA(n_components=2)
    all_embs = model.fit_transform(np.array(all_embs))
    for category, l in category_len.items():
        clusters[category] = all_embs[:l,:]
        all_embs = all_embs[l:,:]

    return clusters

def prune_embeddings(embs, words):
    """
    given list of embedings, prune away embeddings that are 3 stdev (mse) away
    from the mean (mse)
    """
    embs = np.array(embs)
    mu  = embs.mean(axis=0)
    std = embs.std(axis=0)
    std_norm = (std**2).sum()       # scalar
    mse = np.sum((embs-mu)**2, axis=1)
    pruned_embs = []
    pruned_words = []
    for i, error in enumerate(mse):
        if error < 3*std_norm:
            pruned_embs.append(embs[i])
            pruned_words.append(words[i])
    return pruned_embs, pruned_words

def plot_analogy_cluster(model, emb_dataset, k, categories=[], transform_model="tsne"):
    """
    :param model: word embedding mode
    :param emb_dataset: dataset containing word to index mapping
    :param k: (int), how many pairs of analogy from each categories
        to plot. If k = 0, plot all words in a category
    :param categories: list[str] what category to plot. if empty,
        plot all categories
    """

    analogy_dataset = analogy.analogy_dataset()
    if len(categories) == 0:
        categories = list(analogy_dataset.keys())
    clusters = {}
    random.seed(42)

    # sample words from analogy category
    for category in categories:
        clusters[category] = []
        if k == 0:
            choices = analogy_dataset[category]
        else:
            choices = random.choices(analogy_dataset[category], k=min(k,len(analogy_dataset[category])))
        for p in choices:
            clusters[category].extend(p)
        clusters[category] = list(set(clusters[category]))
    
    clusters_words = deepcopy(clusters)

    device = model.embeddings2.weight.device
    for category in categories:
        indices = np.array([emb_dataset.word_to_idx(word) for word in clusters[category]])
        clusters[category] = [model.embed_sentence(torch.tensor([w]).to(device)).cpu().numpy() for w in indices]
        len_before = len(clusters[category])
        clusters[category], clusters_words[category] = prune_embeddings(clusters[category], clusters_words[category])
        print(f"Category {category}: {len_before - len(clusters[category])} words pruned")

    try:
        clusters = t_sne_transform(clusters, transform_model)
        plot_tsne(clusters, clusters_words, model, transform_model)
    except Exception as e:
        print("error",e,"during tSNE clustering: ",*categories)

