import datetime
import os

import matplotlib.pyplot as plt
import numpy as np
from scipy.cluster.hierarchy import dendrogram
from sklearn import metrics
from sklearn.cluster import AgglomerativeClustering, MiniBatchKMeans, KMeans
from sklearn.decomposition import TruncatedSVD
from sklearn.manifold import TSNE
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import Normalizer

from classifiers.tfidf import create_tfidf_vectorizer


def run_unsupervised_clustering(xx_train, yy_train, **kwargs):
    cluster_texts(xx_train, yy_train, **kwargs)


def cluster_texts(texts_list, y, **kwargs):
    categories = kwargs.get("categories")
    num_clusters = len(categories)
    num_seeds = 10
    max_iterations = 300
    labels_color_map = {
        0: 'green', 1: 'cyan', 2: 'blue', 3: 'orange', 4: 'purple',
        5: 'pink', 6: 'red', 7: 'olive', 8: 'darkred', 9: 'fuchsia', 10: 'blueviolet', 11: 'silver'
        , 12: 'gold', 13: 'navy', 14: 'deepskyblue', 15: 'peru', 16: 'lime', 17: 'darkcyan'
        , 18: 'yellow', 19: 'peachpuff', 20: 'springgreen', 21: 'slateblue', 22: 'mediumvioletred'
    }
    labels_letter_map = {
        0: '$A$', 1: '$U$', 2: '$S$', 3: '$T$', 4: '$M$',
        5: '$I$', 6: '$t$', 7: '$H$', 8: '$K$', 9: '$P$',
        10: '$B$', 11: '$S$', 12: '$J$', 13: '$V$', 14: '$i$',
        15: '$G$', 16: '$u$', 17: '$E$', 18: '$R$', 19: '$m$',
        20: '$F$', 21: '$h$', 22: '$C$'
    }
    s_categories = {
        0: 'Ar', 1: 'Uk', 2: 'Sp', 3: 'Ta', 4: 'Ml', 5: 'In', 6: 'Te', 7: 'Hu', 8: 'Ko', 9: 'Po', 10: 'Bu',
        11: 'Si', 12: 'Jp', 13: 'Vi', 14: 'It', 15: 'Ge', 16: 'Tu', 17: 'En', 18: 'Ro', 19: 'Ma', 20: 'Fr',
        21: 'Hi', 22: 'Ch'
    }
    pca_num_components = 3
    tsne_num_components = 3

    # texts_list = some array of strings for which TF-IDF is being computed

    # calculate tf-idf of texts
    # tf_idf_vectorizer = TfidfVectorizer(analyzer="word", use_idf=True, smooth_idf=True, ngram_range=(2, 3))
    # tf_idf_vectorizer = TfidfVectorizer(tokenizer=process_text,
    #                                     stop_words=stopwords.words('english'),
    #                                     max_df=0.5,
    #                                     min_df=0.1,
    #                                     lowercase=True)
    tf_idf_vectorizer = create_tfidf_vectorizer(**kwargs)
    tf_idf_matrix = tf_idf_vectorizer.fit_transform(texts_list)

    # create k-means model with custom config
    # clustering_model = KMeans(
    #     n_clusters=num_clusters,
    #     max_iter=max_iterations
    # )
    #
    # labels = clustering_model.fit_predict(tf_idf_matrix)
    # print_report(labels, y, 0, 0, **kwargs)
    # print labels

    # x = tf_idf_matrix.todense()

    # ----------------------------------------------------------------------------------------------------------------------

    # reduced_data = PCA(n_components=pca_num_components).fit_transform(x)
    # print reduced_data

    # fig = plt.figure()
    # ax = Axes3D(fig)
    # for index, instance in enumerate(reduced_data):
    #     print instance, index, labels[index]
    # pca_comp_1, pca_comp_2, pca_comp_3 = reduced_data[index]
    # color = labels_color_map[labels[index]]
    # ax.scatter(pca_comp_1, pca_comp_2, pca_comp_3, c=color)
    # ax.text(pca_comp_1, pca_comp_2, pca_comp_3, s_categories[y[index]], size=9, zorder=1, color='k')
    # plt.show()

    # ----------------------------------------------------------------------------------------------------------------------

    # diagram_svd_plot(categories, labels_color_map, labels_letter_map, num_clusters, tf_idf_matrix, y)
    diagram_svd_clustering_plot(categories, labels_color_map, labels_letter_map, num_clusters, tf_idf_vectorizer, tf_idf_matrix, y)
    # diagram_svd_varience_plot(tf_idf_matrix)
    # diagram_tsne_plot(categories, labels_color_map, labels_letter_map, num_clusters, tf_idf_matrix, y)

    # embeddings = TSNE(n_components=tsne_num_components)
    # Y = embeddings.fit_transform(x)
    # plt.scatter(Y[:, 0], Y[:, 1], cmap=plt.cm.Spectral)
    # plt.show()
    # pass
    # centers = km_model.cluster_centers_
    # plt.scatter(centers[:, 0], centers[:, 1], c='black', s=200, alpha=0.5)

    # ----------------------------------------------------------------------------------------------------------------------
    # ward = AgglomerativeClustering(n_clusters=num_clusters, linkage='ward', connectivity=connectivity)
    # model = AgglomerativeClustering(n_clusters=num_clusters)
    # model = model.fit(tf_idf_matrix.toarray())
    # plt.title('Hierarchical Clustering Dendrogram')
    # plot the top three levels of the dendrogram
    # plot_dendrogram(model, truncate_mode='level', p=3)
    # plt.xlabel("Number of points in node (or index of point if no parenthesis).")
    # plt.show()
    #
    # plt.hist(model.labels_, bins=num_clusters)
    # plt.show()


def diagram_svd_plot(categories, labels_color_map, labels_letter_map, num_clusters, tf_idf_matrix, y):
    # SVD plot
    print("svd reduction...")
    svd_reduced = TruncatedSVD(n_components=2, random_state=0).fit_transform(tf_idf_matrix)

    print("plotting...")
    svd_by_lang = {}
    for svd, target in zip(svd_reduced, y):
        svd_list = []
        if categories[target][0] in svd_by_lang:
            svd_list = svd_by_lang[categories[target][0]]
        svd_list.append(svd)
        svd_by_lang[categories[target][0]] = svd_list

    plt.figure(figsize=(10, 10))
    plt.axes()
    for i, lang in enumerate(svd_by_lang):
        plt.scatter(np.array(svd_by_lang[lang])[:, 0], np.array(svd_by_lang[lang])[:, 1], marker='x',
                    c=labels_color_map[i], label=lang)
    plt.title('SVD')
    plt.xlabel('dimension 1')
    plt.ylabel('dimension 2')
    plt.legend()

    plt.savefig("images/" + os.path.basename(__file__) + "_svd_"
                + datetime.datetime.now().strftime('%s') + ".png", dpi=(250), bbox_inches='tight')
    # plt.show()
    print("DONE!")


def diagram_svd_clustering_plot(categories, labels_color_map, labels_letter_map, num_clusters, tf_idf_vectorizer, tf_idf_matrix, y):
    lsa_enabled = True
    # SVD plot
    print("svd reduction 2n...")
    svd_2n = TruncatedSVD(n_components=2, random_state=0, n_iter=5).fit_transform(tf_idf_matrix)

    print("svd reduction 1k...")
    X = tf_idf_matrix
    if lsa_enabled:
        svd_10k = TruncatedSVD(n_components=100, random_state=0)
        normalizer = Normalizer(copy=False)
        lsa = make_pipeline(svd_10k, normalizer)
        X = lsa.fit_transform(tf_idf_matrix)

    print("cluster prediction...")
    km = KMeans(n_clusters=num_clusters, init='k-means++', n_init=1)
    yPred = km.fit_predict(X)

    print("Homogeneity: %0.3f" % metrics.homogeneity_score(y, km.labels_))
    print("Completeness: %0.3f" % metrics.completeness_score(y, km.labels_))
    print("V-measure: %0.3f" % metrics.v_measure_score(y, km.labels_))
    print("Adjusted Rand-Index: %.3f"
          % metrics.adjusted_rand_score(y, km.labels_))
    print("Silhouette Coefficient: %0.3f"
          % metrics.silhouette_score(X, km.labels_, sample_size=1000))

    print("Top terms per cluster:")

    if lsa_enabled:
        original_space_centroids = svd_10k.inverse_transform(km.cluster_centers_)
        order_centroids = original_space_centroids.argsort()[:, ::-1]
    else:
        order_centroids = km.cluster_centers_.argsort()[:, ::-1]

    terms = tf_idf_vectorizer.get_feature_names()
    for i in range(num_clusters):
        print("Cluster %d:" % i, end='')
        for ind in order_centroids[i, :10]:
            print(' %s' % terms[ind], end='')
        print()

    # svd_reduced_10k = TruncatedSVD(n_components=10000, random_state=0, n_iter=5).fit_transform(tf_idf_matrix)


    # cluster_model = AgglomerativeClustering(n_clusters=num_clusters)
    # yPred = cluster_model.fit_predict(svd_reduced_10k)

    print("plotting...")
    svd_by_lang = {}
    for svd, target in zip(svd_2n, y):
        svd_list = []
        if categories[target][0] in svd_by_lang:
            svd_list = svd_by_lang[categories[target][0]]
        svd_list.append(svd)
        svd_by_lang[categories[target][0]] = svd_list

    svd_by_pred_lang = {}
    for svd, target in zip(svd_2n, yPred):
        svd_list = []
        if categories[target][0] in svd_by_pred_lang:
            svd_list = svd_by_pred_lang[categories[target][0]]
        svd_list.append(svd)
        svd_by_pred_lang[categories[target][0]] = svd_list

    plt.figure(figsize=(20, 10))
    plt.subplot(1, 2, 1)
    # plt.axes()
    for i, lang in enumerate(sorted(svd_by_lang)):
        plt.scatter(np.array(svd_by_lang[lang])[:, 0], np.array(svd_by_lang[lang])[:, 1], marker='x',
                    c=labels_color_map[i], label=lang)
    plt.title('Actual Classes')
    plt.xlabel('dimension 1')
    plt.ylabel('dimension 2')
    plt.legend()

    plt.subplot(1, 2, 2)
    for i, lang in enumerate(sorted(svd_by_pred_lang)):
        plt.scatter(np.array(svd_by_pred_lang[lang])[:, 0], np.array(svd_by_pred_lang[lang])[:, 1], marker='x',
                    c=labels_color_map[i], label=lang)
    plt.title('k-Means Clustering using TruncatedSVD (n_components=100, n_iter=5)')
    plt.xlabel('dimension 1')
    plt.legend()

    plt.savefig("images/" + os.path.basename(__file__) + "_svd_"
                + datetime.datetime.now().strftime('%s') + ".png", dpi=(250), bbox_inches='tight')
    # plt.show()
    print("DONE!")

def diagram_svd_varience_plot(tf_idf_matrix):
    print("drawing svd varience diagram...")
    # SVD plot
    svd = TruncatedSVD(n_components=100, random_state=0)
    svd.fit_transform(tf_idf_matrix)
    precentage_var_explained = svd.explained_variance_ / sum(svd.explained_variance_)
    cum_var_explained = np.cumsum(precentage_var_explained)

    plt.figure(1, figsize=(6, 4))
    plt.clf()
    plt.plot(cum_var_explained, linewidth=2)
    plt.axis('tight')
    plt.grid()
    plt.xlabel('n_components')
    plt.ylabel('Cumulative explained variance')
    plt.savefig("images/" + os.path.basename(__file__) + "_svd_var_"
                + datetime.datetime.now().strftime('%s') + ".png", dpi=(250), bbox_inches='tight')
    print("DONE!")
    # plt.show()


def diagram_tsne_plot(categories, labels_color_map, labels_letter_map, num_clusters, tf_idf_matrix, y):
    print("drawing t-SNE diagram...")
    # t-SNE plot
    tfs_reduced = TruncatedSVD(n_components=num_clusters, random_state=0).fit_transform(tf_idf_matrix)
    tfs_embedded = TSNE(n_components=2, perplexity=50, verbose=2, n_iter=1000, n_jobs=-1).fit_transform(tfs_reduced)
    tfs_by_lang = {}
    for tfs, target in zip(tfs_embedded, y):
        tfs_list = []
        if categories[target][0] in tfs_by_lang:
            tfs_list = tfs_by_lang[categories[target][0]]
        tfs_list.append(tfs)
        tfs_by_lang[categories[target][0]] = tfs_list
    plt.figure(figsize=(10, 10))
    plt.axes()
    # ax = Axes3D(fig)
    for i, lang in enumerate(tfs_by_lang):
        # pca_comp_1 = np.array(tfs_by_lang[lang])[:, 0]
        # pca_comp_2 = np.array(tfs_by_lang[lang])[:, 1]
        # pca_comp_3 = np.array(tfs_by_lang[lang])[:, 2]
        # ax.scatter(pca_comp_1, pca_comp_2, pca_comp_3, c=labels_color_map[i], marker='x')
        # ax.text(pca_comp_1, pca_comp_2, pca_comp_3, s_categories[i], size=9, zorder=1, color='k')
        plt.scatter(np.array(tfs_by_lang[lang])[:, 0], np.array(tfs_by_lang[lang])[:, 1], marker='x',
                    c=labels_color_map[i], label=lang)
    plt.title('t-SNE (perplexity=50, n_iter=1000)')
    plt.xlabel('dimension 1')
    plt.ylabel('dimension 2')
    plt.legend()
    plt.savefig("images/" + os.path.basename(__file__) + "_tsne_"
                + datetime.datetime.now().strftime('%s') + ".png", dpi=(250), bbox_inches='tight')
    # plt.show()
    print("DONE!")


def plot_dendrogram(model, **kwargs):
    # Create linkage matrix and then plot the dendrogram

    # create the counts of samples under each node
    counts = np.zeros(model.children_.shape[0])
    n_samples = len(model.labels_)
    for i, merge in enumerate(model.children_):
        current_count = 0
        for child_idx in merge:
            if child_idx < n_samples:
                current_count += 1  # leaf node
            else:
                current_count += counts[child_idx - n_samples]
        counts[i] = current_count

    linkage_matrix = np.column_stack([model.children_, model.distances_, counts]).astype(float)

    # Plot the corresponding dendrogram
    dendrogram(linkage_matrix, **kwargs)
