from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.cluster import KMeans
from sklearn.metrics import adjusted_rand_score
import numpy as np
import numpy
from sklearn.manifold import TSNE
from matplotlib import pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from scipy.spatial.distance import cdist
from collections import Counter
import matplotlib.cm as cm
import sklearn

plt.rcParams["font.weight"] = "bold"
plt.rcParams["axes.labelweight"] = "bold"


texts = []
a_file = open("exp.txt", "r")
for line in a_file:
  texts.append(line.strip())
# vectorization of the texts
vectorizer = TfidfVectorizer(stop_words="english")
X_f = vectorizer.fit_transform(texts)

# used words (axis in our multi-dimensional space)
words = vectorizer.get_feature_names()


X = TSNE(n_components=2).fit_transform(X_f) # tsne embedding


# Plot clusters.
n_clusters=10
kmeanModel = KMeans(n_clusters).fit(X)
kmeanModel.fit(X)
labels = kmeanModel.labels_
fig, ax = plt.subplots()
fig.set_size_inches(21, 15)
vis_x = X[:,0]
vis_y = X[:,1]

colors = cm.nipy_spectral(labels.astype(float) / n_clusters)
ax.scatter(vis_x , vis_y,
                c=colors, edgecolor='k',s=600)
centers = kmeanModel.cluster_centers_
    # Draw white circles at cluster centers
ax.scatter(centers[:, 0], centers[:, 1], marker='o',
                c="white", alpha=1, s=1000, edgecolor='k')

for i, c in enumerate(centers):
        ax.scatter(c[0], c[1], marker='$%d$' % i, alpha=1,
                    s=800, edgecolor='k',c='k')
ax.tick_params(axis='both', which='major', labelsize=50)
ax.set_xlabel("Feature space for the 1st feature",fontsize=35)
ax.set_ylabel("Feature space for the 2nd feature",fontsize=35)
plt.show()

for i in range(n_clusters):
    idd = np.where(labels==i)
    sub = []
    for s in idd[0]:
        sub = sub + texts[s].split()
    split_it = sub
    vectorizer = TfidfVectorizer(stop_words="english")
    f = vectorizer.fit_transform(sub)
    words = vectorizer.get_feature_names()
    C = Counter()
    for w in split_it:
        if w in words:
            C[w] +=1
    print(C.most_common(12))
    print('\n \n')    