{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "from ste_d2vc_scws_similarity import *\n",
    "#import sklearn\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "from sklearn.manifold import TSNE\n",
    "from sklearn.decomposition import PCA\n",
    "import seaborn as sns\n",
    "import matplotlib.pyplot as plt\n",
    "from sklearn.preprocessing import StandardScaler\n",
    "import matplotlib.cm as cm\n",
    "#from sklearn.model_selection import GridSearchCV ## for sklearn version 0.18.1\n",
    "from sklearn.grid_search import GridSearchCV ## for sklearn version 0.17.1\n",
    "#from sklearn.model_selection import  \n",
    "from sklearn.svm import SVC, LinearSVC\n",
    "from ggplot import *\n",
    "#import time #depricated\n",
    "#from pandas.lib import Timestamp # use this instead of import time.\n",
    "from pandas import Timestamp\n",
    "from KaggleWord2VecUtility import KaggleWord2VecUtility\n",
    "import sys\n",
    "import joblib\n",
    "from sklearn.feature_extraction.text import TfidfVectorizer\n",
    "#print(sklearn.__version__)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "##For_STE_wikiData_TEST\n",
    "url = 'evaluation_data/EXPERIMENTS_STE_TDE/STE/STE_WIKI/vectors_kk_dc_10000lines_s400_n8_w10_b0_oit15_init15_k10_count5_sample0.000100_input.txt'\n",
    "##For_TDE_wikiData_TEST\n",
    "#url = 'evaluation_data/EXPERIMENTS_STE_TDE/TDE/TDE_WIKI/vectors_kk_dc_10000lines_s400_n8_w10_b0_oit15_init15_k10_count5_sample0.000100_input.txt'\n",
    "id2topicword, topicword2id, id2emb, dtaset = getModeldata(url)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#drop first rows (info of numbr of rows and embedding size)\n",
    "#print(type(dtaset)) ##<class 'pandas.core.frame.DataFrame'>\n",
    "###droping first row of the dataframe\n",
    "normalize_data = dtaset.iloc[1:]\n",
    "###starting index from 0 instead of 1\n",
    "normalize_data.index =normalize_data.index-1\n",
    "normalize_data.head(4)\n",
    "#print(type(normalize_data)) ##<class 'pandas.core.frame.DataFrame'>\n",
    "normalize_data.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>1</th>\n",
       "      <th>2</th>\n",
       "      <th>3</th>\n",
       "      <th>4</th>\n",
       "      <th>5</th>\n",
       "      <th>6</th>\n",
       "      <th>7</th>\n",
       "      <th>8</th>\n",
       "      <th>9</th>\n",
       "      <th>10</th>\n",
       "      <th>...</th>\n",
       "      <th>391</th>\n",
       "      <th>392</th>\n",
       "      <th>393</th>\n",
       "      <th>394</th>\n",
       "      <th>395</th>\n",
       "      <th>396</th>\n",
       "      <th>397</th>\n",
       "      <th>398</th>\n",
       "      <th>399</th>\n",
       "      <th>400</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0.151448</td>\n",
       "      <td>0.340919</td>\n",
       "      <td>0.583997</td>\n",
       "      <td>-0.831073</td>\n",
       "      <td>-0.537173</td>\n",
       "      <td>-0.603222</td>\n",
       "      <td>-0.014596</td>\n",
       "      <td>-0.674210</td>\n",
       "      <td>-0.645246</td>\n",
       "      <td>-1.240574</td>\n",
       "      <td>...</td>\n",
       "      <td>-0.328991</td>\n",
       "      <td>-0.682700</td>\n",
       "      <td>0.469469</td>\n",
       "      <td>-0.314752</td>\n",
       "      <td>0.133564</td>\n",
       "      <td>0.997386</td>\n",
       "      <td>-0.440253</td>\n",
       "      <td>0.363071</td>\n",
       "      <td>0.083071</td>\n",
       "      <td>-1.663519</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>-0.222360</td>\n",
       "      <td>0.310732</td>\n",
       "      <td>-0.574816</td>\n",
       "      <td>0.120854</td>\n",
       "      <td>-0.101042</td>\n",
       "      <td>-0.026349</td>\n",
       "      <td>0.701724</td>\n",
       "      <td>-1.142572</td>\n",
       "      <td>-0.595960</td>\n",
       "      <td>-1.079364</td>\n",
       "      <td>...</td>\n",
       "      <td>1.230805</td>\n",
       "      <td>-0.883570</td>\n",
       "      <td>-0.745131</td>\n",
       "      <td>-0.737262</td>\n",
       "      <td>0.015814</td>\n",
       "      <td>0.487782</td>\n",
       "      <td>-0.171862</td>\n",
       "      <td>-1.136777</td>\n",
       "      <td>0.068124</td>\n",
       "      <td>-1.232696</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>1.150247</td>\n",
       "      <td>-0.513735</td>\n",
       "      <td>1.335180</td>\n",
       "      <td>-0.288587</td>\n",
       "      <td>-1.229484</td>\n",
       "      <td>0.142813</td>\n",
       "      <td>-0.041410</td>\n",
       "      <td>-1.326419</td>\n",
       "      <td>-1.181551</td>\n",
       "      <td>-1.424592</td>\n",
       "      <td>...</td>\n",
       "      <td>0.300751</td>\n",
       "      <td>-0.692620</td>\n",
       "      <td>0.785064</td>\n",
       "      <td>0.316055</td>\n",
       "      <td>0.392912</td>\n",
       "      <td>1.305295</td>\n",
       "      <td>0.842255</td>\n",
       "      <td>0.141570</td>\n",
       "      <td>-0.233156</td>\n",
       "      <td>-0.598106</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>1.079988</td>\n",
       "      <td>-0.270563</td>\n",
       "      <td>0.621979</td>\n",
       "      <td>-0.257028</td>\n",
       "      <td>-0.933015</td>\n",
       "      <td>-0.051577</td>\n",
       "      <td>0.295999</td>\n",
       "      <td>-0.881805</td>\n",
       "      <td>-0.035008</td>\n",
       "      <td>1.135281</td>\n",
       "      <td>...</td>\n",
       "      <td>0.228108</td>\n",
       "      <td>-0.179568</td>\n",
       "      <td>1.527568</td>\n",
       "      <td>-0.264951</td>\n",
       "      <td>0.543926</td>\n",
       "      <td>0.969544</td>\n",
       "      <td>0.867496</td>\n",
       "      <td>-0.453033</td>\n",
       "      <td>0.238750</td>\n",
       "      <td>-1.373506</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>4 rows × 400 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "        1         2         3         4         5         6         7    \\\n",
       "0  0.151448  0.340919  0.583997 -0.831073 -0.537173 -0.603222 -0.014596   \n",
       "1 -0.222360  0.310732 -0.574816  0.120854 -0.101042 -0.026349  0.701724   \n",
       "2  1.150247 -0.513735  1.335180 -0.288587 -1.229484  0.142813 -0.041410   \n",
       "3  1.079988 -0.270563  0.621979 -0.257028 -0.933015 -0.051577  0.295999   \n",
       "\n",
       "        8         9         10   ...       391       392       393       394  \\\n",
       "0 -0.674210 -0.645246 -1.240574  ... -0.328991 -0.682700  0.469469 -0.314752   \n",
       "1 -1.142572 -0.595960 -1.079364  ...  1.230805 -0.883570 -0.745131 -0.737262   \n",
       "2 -1.326419 -1.181551 -1.424592  ...  0.300751 -0.692620  0.785064  0.316055   \n",
       "3 -0.881805 -0.035008  1.135281  ...  0.228108 -0.179568  1.527568 -0.264951   \n",
       "\n",
       "        395       396       397       398       399       400  \n",
       "0  0.133564  0.997386 -0.440253  0.363071  0.083071 -1.663519  \n",
       "1  0.015814  0.487782 -0.171862 -1.136777  0.068124 -1.232696  \n",
       "2  0.392912  1.305295  0.842255  0.141570 -0.233156 -0.598106  \n",
       "3  0.543926  0.969544  0.867496 -0.453033  0.238750 -1.373506  \n",
       "\n",
       "[4 rows x 400 columns]"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data = normalize_data.drop(normalize_data.columns[0], axis=1)\n",
    "data.head(4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0    'also',#1\n",
       "1    'also',#2\n",
       "2    'also',#3\n",
       "3    'also',#4\n",
       "Name: 0, dtype: object"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#label =normalize_data.iloc[:, 0]\n",
    "# get column as series\n",
    "label = normalize_data[normalize_data.columns[0]]\n",
    "#label.ndim # to check the dimension of the data.\n",
    "label.head(4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[ 0.09660261,  0.50063954,  0.36068714, ...,  0.76314361,\n",
       "        -0.4028437 , -1.28414372],\n",
       "       [-0.25278155,  0.46224476, -0.54454094, ..., -0.62555034,\n",
       "        -0.42093591, -0.80433678],\n",
       "       [ 1.03014215, -0.58639308,  0.94748755, ...,  0.5580581 ,\n",
       "        -0.78561259, -0.09759484],\n",
       "       ...,\n",
       "       [-1.59915852,  2.11686437, -2.76221782, ..., -2.32994401,\n",
       "         0.51541456,  0.74575448],\n",
       "       [-1.9292148 ,  1.6167999 , -2.00726467, ..., -1.80230441,\n",
       "         1.83301122, -0.14854877],\n",
       "       [-1.05663018, -0.15461046, -1.66516604, ..., -0.0858714 ,\n",
       "        -1.20663938,  1.42937047]])"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data_1000 = data[0:1000]\n",
    "#data_1000\n",
    "#data_1000.head(4)\n",
    "#Data-preprocessing: Standardizing the data\n",
    "standardized_data = StandardScaler().fit_transform(data_1000)\n",
    "data_1000_rep = standardized_data[0:1000, :]\n",
    "data_1000_rep\n",
    "#data_1000_rep.ndim ##2-dimensions\n",
    "#data_1000_rep.shape ## (1000, 300)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "label_1000 = label[0:1000]\n",
    "#label_1000.shape ##(1000,)\n",
    "#label_1000\n",
    "labels = []\n",
    "for label_name in label_1000:\n",
    "    ## To represent all label\n",
    "    #labels.append(label_name)\n",
    "    ## To represent Selected Labels\n",
    "    #if label_name == 'band#0' or label_name == 'band#1' or label_name == 'band#2' or label_name == 'band#3' or label_name == 'band#4' or label_name == 'band#5' or label_name == 'band#6' or label_name == 'band#7' or label_name == 'band#8' or label_name == 'band#9':\n",
    "    #if label_name == 'band#0' or label_name == 'band#1' or label_name == 'band#2' or label_name == 'band#3' or label_name == 'band#4' or label_name == 'band#5' or label_name == 'band#6' or label_name == 'band#7' or label_name == 'band#8' or label_name == 'band#9' or label_name == 'album#0' or label_name == 'album#1' or label_name == 'album#2' or label_name == 'album#3' or label_name == 'album#4' or label_name == 'album#5' or label_name == 'album#6' or label_name == 'album#7' or label_name == 'album#8' or label_name == 'album#9' or label_name == 'music#0' or label_name == 'music#1' or label_name == 'music#2' or label_name == 'music#3' or label_name == 'music#4' or label_name == 'music#5' or label_name == 'music#6' or label_name == 'music#7' or label_name == 'music#8' or label_name == 'music#9' or label_name == 'show#0' or label_name == 'show#1' or label_name == 'show#2' or label_name == 'show#3' or label_name == 'show#4' or label_name == 'show#5' or label_name == 'show#6' or label_name == 'show#7' or label_name == 'show#8' or label_name == 'show#9' or label_name == 'party#0' or label_name == 'party#1' or label_name == 'party#2' or label_name == 'party#3' or label_name == 'party#4' or label_name == 'party#5' or label_name == 'party#6' or label_name == 'party#7' or label_name == 'party#8' or label_name == 'party#9': \n",
    "    ## Goverment and party\n",
    "    #if label_name == 'government#0' or label_name == 'government#1' or label_name == 'government#2' or label_name == 'government#3' or label_name == 'government#4' or label_name == 'government#5' or label_name == 'government#6' or label_name == 'government#7' or label_name == 'government#8' or label_name == 'government#9' or label_name == 'party#0' or label_name == 'party#1' or label_name == 'party#2' or label_name == 'party#3' or label_name == 'party#4' or label_name == 'party#5' or label_name == 'party#6' or label_name == 'party#7' or label_name == 'party#8' or label_name == 'party#9':\n",
    "    #if label_name == 'show#0' or label_name == 'show#1' or label_name == 'show#2' or label_name == 'show#3' or label_name == 'show#4' or label_name == 'show#5' or label_name == 'show#6' or label_name == 'show#7' or label_name == 'show#8' or label_name == 'show#9' or label_name == 'band#0' or label_name == 'band#1' or label_name == 'band#2' or label_name == 'band#3' or label_name == 'band#4' or label_name == 'band#5' or label_name == 'band#6' or label_name == 'band#7' or label_name == 'band#8' or label_name == 'band#9' or label_name == 'music#0' or label_name == 'music#1' or label_name == 'music#2' or label_name == 'music#3' or label_name == 'music#4' or label_name == 'music#5' or label_name == 'music#6' or label_name == 'music#7' or label_name == 'music#8' or label_name == 'music#9':\n",
    "    #if label_name == 'government#0' or label_name == 'government#1' or label_name == 'government#2' or label_name == 'government#3' or label_name == 'government#4' or label_name == 'government#5' or label_name == 'government#6' or label_name == 'government#7' or label_name == 'government#8' or label_name == 'government#9' or label_name == 'party#0' or label_name == 'party#1' or label_name == 'party#2' or label_name == 'party#3' or label_name == 'party#4' or label_name == 'party#5' or label_name == 'party#6' or label_name == 'party#7' or label_name == 'party#8' or label_name == 'party#9' or label_name == 'national#0' or label_name == 'national#1' or label_name == 'national#2' or label_name == 'national#3' or label_name == 'national#4' or label_name == 'national#5' or label_name == 'national#6' or label_name == 'national#7' or label_name == 'national#8' or label_name == 'national#9' or label_name == 'people#0' or label_name == 'people#1' or label_name == 'people#2' or label_name == 'people#3' or label_name == 'people#4' or label_name == 'people#5' or label_name == 'people#6' or label_name == 'people#7' or label_name == 'people#8' or label_name == 'people#9' or label_name == 'country#0' or label_name == 'country#1' or label_name == 'country#2' or label_name == 'country#3' or label_name == 'country#4' or label_name == 'country#5' or label_name == 'country#6' or label_name == 'country#7' or label_name == 'country#8' or label_name == 'country#9' or label_name == 'war#0' or label_name == 'war#1' or label_name == 'war#2' or label_name == 'war#3' or label_name == 'war#4' or label_name == 'war#5' or label_name == 'war#6' or label_name == 'war#7' or label_name == 'war#8' or label_name == 'war#9' or label_name == 'state#0' or label_name == 'state#1' or label_name == 'state#2' or label_name == 'state#3' or label_name == 'state#4' or label_name == 'state#5' or label_name == 'state#6' or label_name == 'state#7' or label_name == 'state#8' or label_name == 'state#9' or label_name == 'International#0' or label_name == 'International#1' or label_name == 'International#2' or label_name == 'International#3' or label_name == 'International#4' or label_name == 'International#5' or label_name == 'International#6' or label_name == 'International#7' or label_name == 'International#8' or label_name == 'International#9' or label_name == 'company#0' or label_name == 'company#1' or label_name == 'company#2' or label_name == 'company#3' or label_name == 'company#4' or label_name == 'company#5' or label_name == 'company#6' or label_name == 'company#7' or label_name == 'company#8' or label_name == 'company#9' or label_name == 'general#0' or label_name == 'general#1' or label_name == 'general#2' or label_name == 'general#3' or label_name == 'general#4' or label_name == 'general#5' or label_name == 'general#6' or label_name == 'general#7' or label_name == 'general#8' or label_name == 'general#9' or label_name == 'music#0' or label_name == 'music#1' or label_name == 'music#2' or label_name == 'music#3' or label_name == 'music#4' or label_name == 'music#5' or label_name == 'music#6' or label_name == 'music#7' or label_name == 'music#8' or label_name == 'music#9':\n",
    "    #if label_name == \"'government',#0\" or label_name == \"'government',#1\" or label_name == \"'government',#2\" or label_name == \"'government',#3\" or label_name == \"'government',#4\" or label_name == \"'government',#5\" or label_name == \"'government',#6\" or label_name == \"'government',#7\" or label_name == \"'government',#8\" or label_name == \"'government',#9\" or label_name == \"'party',#0\" or label_name == \"'party',#1\" or label_name == \"'party',#2\" or label_name == \"'party',#3\" or label_name == \"'party',#4\" or label_name == \"'party',#5\" or label_name == \"'party',#6\" or label_name == \"'party',#7\" or label_name == \"'party',#8\" or label_name == \"'party',#9\" or label_name == \"'national',#0\" or label_name == \"'national',#1\" or label_name == \"'national',#2\" or label_name == \"'national',#3\" or label_name == \"'national',#4\" or label_name == \"'national',#5\" or label_name == \"'national',#6\" or label_name == \"'national',#7\" or label_name == \"'national',#8\" or label_name == \"'national',#9\" or label_name == \"'people#',0\" or label_name == \"'people#',1\" or label_name == \"'people#',2\" or label_name == \"'people#',3\" or label_name == \"'people#',4\" or label_name == \"'people#',5\" or label_name == \"'people#',6\" or label_name == \"'people#',7\" or label_name == \"'people#',8\" or label_name == \"'people#',9\" or label_name == \"'country',#0\" or label_name == \"'country',#1\"  or label_name == \"'country',#2\"  or label_name == \"'country',#3\" or label_name == \"'country',#4\" or label_name == \"'country',#5\" or label_name == \"'country',#6\" or label_name == \"'country',#7\" or label_name == \"'country',#8\" or label_name == \"'country',#9\" or label_name == \"'war#',0\" or label_name == \"'war#',1\" or label_name == \"'war#',2\" or label_name == \"'war#',3\" or label_name == \"'war#',4\" or label_name == \"'war#',5\" or label_name == \"'war#',6\" or label_name == \"'war#',7\" or label_name == \"'war#',8\" or label_name == \"'war#',9\" or label_name == \"'state#',0\" or label_name == \"'state#',1\" or label_name == \"'state#',2\" or label_name == \"'state#',3\" or label_name == \"'state#',4\" or label_name == \"'state#',5\" or label_name == \"'state#',6\" or label_name == \"'state#',7\" or label_name == \"'state#',8\" or label_name == \"'state#',9\" or label_name == \"'International',#0\" or label_name == \"'International',#1\" or label_name == \"'International',#2\" or label_name == \"'International',#3\" or label_name == \"'International',#4\" or label_name == \"'International',#5\" or label_name == \"'International',#6\" or label_name == \"'International',#7\" or label_name == \"'International',#8\" or label_name == \"'International',#9\" or label_name == \"'company#',0\" or label_name == \"'company#',1\" or label_name == \"'company#',2\" or label_name == \"'company#',3\" or label_name == \"'company#',4\" or label_name == \"'company#',5\" or label_name == \"'company#',6\" or label_name == \"'company#',7\" or label_name == 'company#8' or label_name == \"'company#',9\" or label_name == \"'general',#0\" or label_name == \"'general',#1\" or label_name == \"'general',#2\" or label_name == \"'general',#3\" or label_name == \"'general',#4\" or label_name == \"'general',#5\" or label_name == \"'general',#6\" or label_name == \"'general',#7\" or label_name == \"'general',#8\" or label_name == \"'general',#9\" or label_name == \"'music',#0\" or label_name == \"'music',#1\" or label_name == \"'music',#2\" or label_name == \"'music',#3\" or label_name == \"'music',#4\" or label_name == \"'music',#5\" or label_name == \"'music',#6\" or label_name == \"'music',#7\" or label_name == \"'music',#8\" or label_name == \"'music',#9\":\n",
    "    #    labels.append(label_name)\n",
    "    #if label_name == \"'government',#0\" or label_name == \"'party',#2\" or label_name == \"'national',#2\" or label_name == \"'national',#3\" or label_name ==\"'republican',#2\" or label_name ==\"'republic',#9\":\n",
    "    #   labels.append(label_name)\n",
    "    #if label_name == \"'party',#5\" or label_name == \"'social',#9\" or label_name == \"'summers',#4\":\n",
    "       #labels.append(label_name)\n",
    "    #if label_name == \"'government',#0\" or label_name == \"'government',#1\" or label_name == \"'government',#2\" or label_name == \"'government',#3\" or label_name == \"'government',#4\" or label_name == \"'government',#5\" or label_name == \"'government',#6\" or label_name == \"'government',#7\" or label_name == \"'government',#8\" or label_name == \"'government',#9\" or label_name == \"'party',#0\" or label_name == \"'party',#1\" or label_name == \"'party',#2\" or label_name == \"'party',#3\" or label_name == \"'party',#4\" or label_name == \"'party',#5\" or label_name == \"'party',#6\" or label_name == \"'party',#7\" or label_name == \"'party',#8\" or label_name == \"'party',#9\" or label_name == \"'national',#0\" or label_name == \"'national',#1\" or label_name == \"'national',#2\" or label_name == \"'national',#3\" or label_name == \"'national',#4\" or label_name == \"'national',#5\" or label_name == \"'national',#6\" or label_name == \"'national',#7\" or label_name == \"'national',#8\" or label_name == \"'national',#9\" or label_name == \"'people#',0\" or label_name == \"'people#',1\" or label_name == \"'people#',2\" or label_name == \"'people#',3\" or label_name == \"'people#',4\" or label_name == \"'people#',5\" or label_name == \"'people#',6\" or label_name == \"'people#',7\" or label_name == \"'people#',8\" or label_name == \"'people#',9\":\n",
    "    #    labels.append(label_name)\n",
    "    #if label_name == \"'government',#0\" or label_name == \"'government',#1\" or label_name == \"'government',#2\" or label_name == \"'government',#3\" or label_name == \"'government',#4\" or label_name == \"'government',#5\" or label_name == \"'government',#6\" or label_name == \"'government',#7\" or label_name == \"'government',#8\" or label_name == \"'government',#9\" or label_name == \"'party',#0\" or label_name == \"'party',#1\" or label_name == \"'party',#2\" or label_name == \"'party',#3\" or label_name == \"'party',#4\" or label_name == \"'party',#5\" or label_name == \"'party',#6\" or label_name == \"'party',#7\" or label_name == \"'party',#8\" or label_name == \"'party',#9\" or label_name == \"'national',#0\" or label_name == \"'national',#1\" or label_name == \"'national',#2\" or label_name == \"'national',#3\" or label_name == \"'national',#4\" or label_name == \"'national',#5\" or label_name == \"'national',#6\" or label_name == \"'national',#7\" or label_name == \"'national',#8\" or label_name == \"'national',#9\":\n",
    "    #    labels.append(label_name)\n",
    "    if label_name == \"'government',#0\" or label_name == \"'government',#1\" or label_name == \"'government',#2\" or label_name == \"'government',#3\" or label_name == \"'government',#4\" or label_name == \"'government',#5\" or label_name == \"'government',#6\" or label_name == \"'government',#7\" or label_name == \"'government',#8\" or label_name == \"'government',#9\" or label_name == \"'party',#0\" or label_name == \"'party',#1\" or label_name == \"'party',#2\" or label_name == \"'party',#3\" or label_name == \"'party',#4\" or label_name == \"'party',#5\" or label_name == \"'party',#6\" or label_name == \"'party',#7\" or label_name == \"'party',#8\" or label_name == \"'party',#9\":\n",
    "        labels.append(label_name)\n",
    "  \n",
    "    else:\n",
    "        labels.append(\"\")\n",
    "labels"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#######################embeddings-withlabel-popup######################################################################\n",
    "tsne_model = TSNE(perplexity=40, n_components=2, init='pca', n_iter=2500, random_state=23)\n",
    "new_values = tsne_model.fit_transform(data_1000_rep)\n",
    "print(new_values.shape)\n",
    "markers = [\"d\", \"D\", \"o\", \"p\", \"1\", \"s\", \">\", \"x\", \"+\", \"*\"]\n",
    "x = []\n",
    "y = []\n",
    "for value in new_values:\n",
    "    x.append(value[0])\n",
    "    y.append(value[1])\n",
    "plt.figure(figsize=(16, 16))\n",
    "for i in range(len(x)):\n",
    "    if i < 10: # According to data in the output file we only taking taking and printing legend for first 10 rows.\n",
    "        g = i+1 # because regular i iteration start as i = 0 and  it doesn't match the colors of topics according to word labels.\n",
    "        if g == 10:\n",
    "            g = 0 # because\n",
    "        #for j in range(10): # for each of the 10 topics\n",
    "        mi = markers[i]\n",
    "        plt.scatter(x[i], y[i], label=\"Topic \"+str(g), marker=mi)\n",
    "    else:\n",
    "        j = i%10\n",
    "        mj = markers[j]\n",
    "        plt.scatter(x[i], y[i], marker=mj)\n",
    "    plt.annotate(labels[i], xy=(x[i], y[i]), xytext=(5, 2), textcoords='offset points', ha ='right', va='bottom')\n",
    "    ### if don't want labels comment out above line\n",
    "plt.legend()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "('word_id', 822)\n",
      "[\"'rome',#5\", \"'zia',#0\", \"'god',#7\", \"'clemenceau',#0\", \"'vol',#9\", \"'coup',#5\", \"'social',#9\", \"'oxford',#7\", \"'rabuka',#9\", \"'workers',#0\", \"'anti',#7\", \"'operation',#6\", \"'bhutan',#9\", \"'khrushchev',#8\", \"'solved',#7\", \"'security',#0\", \"'republic',#5\", \"'abdallahi',#5\", \"'republican',#0\", \"'buchman',#7\", \"'gongadze',#5\", \"'ppf',#9\", \"'rome',#9\", \"'namgyal',#9\", \"'policy',#0\", \"'islamic',#6\", \"'jewish',#7\", \"'japanese',#0\", \"'archer',#9\", \"'matoika',#7\", \"'social',#0\", \"'troops',#0\", \"'princess',#7\", \"'latvian',#5\", \"'java',#0\", \"'rabin',#2\", \"'hitler',#7\", \"'abbott',#5\", \"'security',#9\", \"'au',#4\", \"'hoa',#6\", \"'soviet',#5\", \"'leader',#9\", \"'democratic',#9\", \"'forces',#6\", \"'parliament',#5\", \"'boisclair',#4\", \"'bc',#5\", \"'benefits',#0\", \"'xii',#9\", \"'russia',#3\", \"'august',#5\", \"'taxes',#9\", \"'election',#9\", \"'abbott',#4\", \"'namgyal',#0\", \"'god',#2\", \"'zhirinovsky',#3\", \"'forces',#0\", \"'serve',#9\", \"'sentenced',#8\", \"'abdel',#5\", \"'elections',#7\", \"'laurier',#2\", \"'berlin',#0\", \"'desi',#9\", \"'ngawang',#9\", \"'government',#0\", \"'minister',#3\", \"'germany',#7\", \"'political',#0\", \"'dutch',#2\", \"'bhutan',#0\", \"'laurier',#4\", \"'operations',#6\", \"'davies',#3\", \"'counter',#6\", \"'princess',#9\", \"'government',#9\", \"'benefits',#9\", \"'prime',#3\", \"'germany',#2\", \"'khan',#1\", \"'said',#7\", \"'leader',#5\", \"'tibet',#9\", \"'pemuda',#0\", \"'anti',#0\", \"'atlantic',#3\", \"'dutch',#0\", \"'political',#9\", \"'cathedral',#3\", \"'indonesia',#0\", \"'muslim',#2\", \"'post',#9\", \"'democracy',#5\", \"'coup',#1\", \"'president',#1\", \"'aziz',#5\", \"'foulois',#8\"]\n"
     ]
    }
   ],
   "source": [
    "###########**************************similarity of words -block*******************###################################\n",
    "word_id = topicword2id[\"'party',#2\"]\n",
    "print(\"word_id\", word_id)\n",
    "num_nns = 100\n",
    "#show_nearest_neighbors(id2emb, id2topicword, word_id, num_nns)\n",
    "nearest_neighbors = show_nearest_neighbors(id2emb, id2topicword, word_id, num_nns)\n",
    "print(nearest_neighbors)\n",
    "#END-Similarity Block."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "### Retriving topic_emb dictionary: topic_emb[cluster_index][word]\n",
    "topic_emb = get_topic_emb()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "##retriving topic_embeddings: indx_wise_emb_topic[cluster_index]\n",
    "indx_wise_emb_topic = get_topic_indx_wise_emb_val()\n",
    "print(indx_wise_emb_topic)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "## Printing Spearman Correlation SCroe.\n",
    "Spearman_correlation_Score = quantitative_scws_df()\n",
    "print(Spearman_correlation_Score)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "ename": "TypeError",
     "evalue": "'builtin_function_or_method' object has no attribute '__getitem__'",
     "traceback": [
      "\u001b[0;31m\u001b[0m",
      "\u001b[0;31mTypeError\u001b[0mTraceback (most recent call last)",
      "\u001b[0;32m<ipython-input-5-c123cc559b28>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m      1\u001b[0m \u001b[0mtraindata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0;32mfor\u001b[0m \u001b[0mi\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mall\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m\"news\"\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m      3\u001b[0m     \u001b[0mtraindata\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\" \"\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mjoin\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mKaggleWord2VecUtility\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mreview_to_wordlist\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mall\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m\"news\"\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      4\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      5\u001b[0m \u001b[0mnum_cluster\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m10\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mTypeError\u001b[0m: 'builtin_function_or_method' object has no attribute '__getitem__'"
     ],
     "output_type": "error"
    }
   ],
   "source": [
    "traindata = []\n",
    "for i in range(0, len(all[\"news\"])):\n",
    "    traindata.append(\" \".join(KaggleWord2VecUtility.review_to_wordlist(all[\"news\"][i], True)))\n",
    "\n",
    "num_cluster = 10\n",
    "num_topwords = 10\n",
    "doc_freq, doc_cofreq = get_doccofrequency(traindata)\n",
    "## here \"prob_topic : all word under each topic index\" and  topic_centroid_prob_map : probability of each word in a topic\"\n",
    "print(\"doc_freq:\", doc_freq)\n",
    "print(\"doc_cofreq\", doc_cofreq)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train = pd.read_csv('evaluation_data/EXPERIMENTS_STE_TDE/TDE/TDE_20News/train_v2.tsv', header=0, delimiter=\"\\t\")\n",
    "traindata = []\n",
    "for i in range(0, len(all[\"news\"])):\n",
    "    traindata.append(\" \".join(KaggleWord2VecUtility.review_to_wordlist(all[\"news\"][i], True)))\n",
    "\n",
    "num_cluster = 10\n",
    "num_topwords = 10\n",
    "doc_freq, doc_cofreq = get_doccofrequency(traindata)\n",
    "## here \"prob_topic : all word under each topic index\" and  topic_centroid_prob_map : probability of each word in a topic\"\n",
    "vocab_dict = vocab_count_dict('evaluation_data/EXPERIMENTS_STE_TDE/TDE/TDE_20News/vocab.txt')\n",
    "topic_centroid_prob_map = get_probability_topic_vectors(vocab_dict, topic_emb, num_cluster)\n",
    "topic_pmi, overall_pmi, top10words_pmi = get_pmi(doc_cofreq, doc_freq, num_clusters, num_topwords, topic_centroid_prob_map)\n",
    "\n",
    "outfile = open(\"evaluation_data/EXPERIMENTS_STE_TDE/TDE/TDE_20News/pmi.txt\", \"w\")\n",
    "outfile.write('Overall_PMI:'+ str(overall_pmi))\n",
    "outfile.write(\"\\n\")\n",
    "for i in range(num_clusters):\n",
    "    for item in top10words_pmi[i]:\n",
    "        outfile.write(str(item))\n",
    "        outfile.write(\"\\n\")\n",
    "    outfile.write('Topic_PMI:'+str(topic_pmi[i]))\n",
    "    outfile.write(\"\\n\")\n",
    "    outfile.write(\"**********************************************************\")\n",
    "    outfile.write(\"\\n\")\n",
    "outfile.close()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 2",
   "language": "python",
   "name": "python2"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2.0
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}