{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "faab5ffc-d9a5-4aee-a765-60423a2c82a8",
   "metadata": {},
   "source": [
    "# SAMPLE DATA COLLECTION"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b92bcbfa-055a-465a-ba1a-35bfbb89c0c2",
   "metadata": {},
   "outputs": [],
   "source": [
    "#Sample data extraction\n",
    "\n",
    "from sklearn.datasets import fetch_20newsgroups\n",
    "from nltk import word_tokenize\n",
    "from sklearn.model_selection import train_test_split\n",
    "\n",
    "sportCategories = [\"rec.sport.baseball\", \"rec.sport.hockey\"]\n",
    "\n",
    "train = fetch_20newsgroups(subset=\"train\", categories=sportCategories)\n",
    "test = fetch_20newsgroups(subset=\"test\", categories=sportCategories)\n",
    "\n",
    "X_main, y_main = [], []\n",
    "X_test, y_test = [], []\n",
    "\n",
    "for line in range(len(train[\"data\"])):\n",
    "    X_main.append([j for j in word_tokenize(train[\"data\"][line].lower()) if j.isalpha()])\n",
    "    y_main.append(train[\"target\"][line])\n",
    "    \n",
    "for line in range(len(test[\"data\"])):\n",
    "    X_test.append([j for j in word_tokenize(test[\"data\"][line].lower()) if j.isalpha()])\n",
    "    y_test.append(test[\"target\"][line])\n",
    "    \n",
    "X_train, X_val, y_train, y_val = train_test_split(X_main, y_main, test_size=0.2) \n",
    "\n",
    "print(len(X_train), len(y_train), len(X_val), len(y_val))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4186909c-b6fc-4f61-a80a-73765ccb294c",
   "metadata": {},
   "source": [
    "# CLASSIFICATION"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6b771775-dfe0-43e7-9de5-bb310b57d58d",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "import tensorflow as tf\n",
    "from nltk.corpus import wordnet\n",
    "import torch\n",
    "import pickle\n",
    "import gensim\n",
    "import warnings\n",
    "from itertools import product\n",
    "warnings.filterwarnings('ignore')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "de55db3b-9f9c-4f08-8429-08022dcf5664",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle\n",
    "\n",
    "sport_train_X = pickle.load(open(\"/news classification/news_train_X\", \"rb\"))\n",
    "sport_train_y = pickle.load(open(\"/news classification/news_train_y\", \"rb\"))\n",
    "sport_val_X = pickle.load(open(\"/news classification/news_val_X\", \"rb\"))\n",
    "sport_val_y = pickle.load(open(\"news classification/news_val_y\", \"rb\"))\n",
    "sport_test_X = pickle.load(open(\"/news classification/news_test_X\", \"rb\"))\n",
    "sport_test_y = pickle.load(open(\"/news classification/news_test_y\", \"rb\"))\n",
    "\n",
    "religion_train_X = pickle.load(open(\"/news classification/religion_train_X\", \"rb\"))\n",
    "religion_train_y = pickle.load(open(\"/news classification/religion_train_y\", \"rb\"))\n",
    "religion_val_X = pickle.load(open(\"/news classification/religion_val_X\", \"rb\"))\n",
    "religion_val_y = pickle.load(open(\"/news classification/religion_val_y\", \"rb\"))\n",
    "religion_test_X = pickle.load(open(\"/news classification/religion_test_X\", \"rb\"))\n",
    "religion_test_y = pickle.load(open(\"/news classification/religion_test_y\", \"rb\"))\n",
    "\n",
    "computer_train_X = pickle.load(open(\"/news classification/computer_train_X\", \"rb\"))\n",
    "computer_train_y = pickle.load(open(\"/news classification/computer_train_y\", \"rb\"))\n",
    "computer_val_X = pickle.load(open(\"/news classification/computer_val_X\", \"rb\"))\n",
    "computer_val_y = pickle.load(open(\"/news classification/computer_val_y\", \"rb\"))\n",
    "computer_test_X = pickle.load(open(\"/news classification/computer_test_X\", \"rb\"))\n",
    "computer_test_y = pickle.load(open(\"/news classification/computer_test_y\", \"rb\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "36ef93e9-e250-476c-b957-4497506882cc",
   "metadata": {},
   "outputs": [],
   "source": [
    "#load all w2v models\n",
    "#syntactic representations\n",
    "w2v_syntModel = gensim.models.keyedvectors.load_word2vec_format(\"/word2vecPos.bin\", binary=True)\n",
    "w2v_syntModel_01 = gensim.models.keyedvectors.load_word2vec_format(\"//word2vecPos01.bin\", binary=True)\n",
    "w2v_syntModel_unNorm = gensim.models.keyedvectors.load_word2vec_format(\"/word2vecPos_unNorm.bin\", binary=True)\n",
    "\n",
    "#word2vec hierarchical overcomplete vectors \n",
    "w2v_2400_absolute = gensim.models.keyedvectors.load_word2vec_format(\"/chained_2400_absolute.bin\", binary=True)\n",
    "w2v_2400_interpretable = gensim.models.keyedvectors.load_word2vec_format(\"/chained_2400_interpretable.bin\", binary=True)\n",
    "w2v_2400_L2 = gensim.models.keyedvectors.load_word2vec_format(\"/chained_2400_L2.bin\", binary=True)\n",
    "\n",
    "#word2vec hierarchical weighted vectors \n",
    "w2v_300_absolute = gensim.models.keyedvectors.load_word2vec_format(\"/chained_300_absolute.bin\", binary=True)\n",
    "w2v_300_interpretable = gensim.models.keyedvectors.load_word2vec_format(\"/chained_300_interpretable.bin\", binary=True)\n",
    "w2v_300_L2 = gensim.models.keyedvectors.load_word2vec_format(\"/chained_300_L2.bin\", binary=True)\n",
    "\n",
    "#load all glove models\n",
    "#syntactic representations\n",
    "glove_syntModel = gensim.models.keyedvectors.load_word2vec_format(\"/glovePos.bin\", binary=True)\n",
    "glove_syntModel_01 = gensim.models.keyedvectors.load_word2vec_format(\"glovePos01.bin\", binary=True)\n",
    "glove_syntModel_unNorm = gensim.models.keyedvectors.load_word2vec_format(\"glovePos_unNorm.bin\", binary=True)\n",
    "\n",
    "#glove hierarchical overcomplete vectors\n",
    "glove_2400_absolute = gensim.models.keyedvectors.load_word2vec_format(\"/glove_chained_2400_absolute.bin\", binary=True)\n",
    "glove_2400_interpretable = gensim.models.keyedvectors.load_word2vec_format(\"/glove_chained_2400_interpretable.bin\", binary=True)\n",
    "glove_2400_L2 = gensim.models.keyedvectors.load_word2vec_format(\"/glove_chained_2400_L2.bin\", binary=True)\n",
    "\n",
    "#glove hierarchical weighted vectors\n",
    "glove_300_absolute = gensim.models.keyedvectors.load_word2vec_format(\"/glove_chained_300_absolute.bin\", binary=True)\n",
    "glove_300_interpretable = gensim.models.keyedvectors.load_word2vec_format(\"/glove_chained_300_interpretable.bin\", binary=True)\n",
    "glove_300_L2 = gensim.models.keyedvectors.load_word2vec_format(\"/glove_chained_300_L2.bin\", binary=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b535b974-5fd9-480c-8215-56671067e148",
   "metadata": {},
   "outputs": [],
   "source": [
    "modelDict = {\"word2vec\":word2vec,\n",
    "            \"w2v_ov_absolute\":w2v_2400_absolute,\n",
    "            \"w2v_ov_interpretable\":w2v_2400_interpretable,\n",
    "            \"w2v_ov_L2\":w2v_2400_L2,\n",
    "            \"w2v_we_absolute\":w2v_300_absolute,\n",
    "            \"w2v_we_interpretable\":w2v_300_interpretable,\n",
    "            \"w2v_we_L2\":w2v_300_L2,\n",
    "            \"glove\":glove,\n",
    "            \"glove_ov_absolute\":glove_2400_absolute,\n",
    "            \"glove_ov_interpretable\":glove_2400_interpretable,\n",
    "            \"glove_ov_L2\":glove_2400_L2,\n",
    "            \"glove_we_absolute\":glove_300_absolute,\n",
    "            \"glove_we_interpretable\":glove_300_interpretable,\n",
    "            \"glove_we_L2\":glove_300_L2,}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9e325c94-d432-4914-8e09-7e1248370fd0",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.datasets import fetch_20newsgroups\n",
    "from nltk import word_tokenize\n",
    "from sklearn.model_selection import train_test_split\n",
    "from sklearn.neural_network import MLPClassifier\n",
    "from sklearn.svm import SVC\n",
    "from sklearn.naive_bayes import GaussianNB\n",
    "from sklearn.ensemble import RandomForestClassifier\n",
    "\n",
    "sportCategories = [\"rec.sport.baseball\", \"rec.sport.hockey\"]\n",
    "religionCategories = [\"alt.atheism\", \"soc.religion.christian\"]\n",
    "computerCategories = [\"comp.sys.ibm.pc.hardware\", \"comp.sys.mac.hardware\"]\n",
    "\n",
    "classifiers = [\n",
    "        SVC(kernel=\"linear\", C=0.025, class_weight='balanced'),\n",
    "        SVC(kernel=\"linear\", C=0.1, class_weight='balanced'),\n",
    "        SVC(kernel=\"linear\", C=5, class_weight='balanced'),\n",
    "        SVC(kernel=\"linear\", C=10, class_weight='balanced'),\n",
    "        SVC(kernel=\"linear\", C=50, class_weight='balanced'),\n",
    "        SVC(kernel=\"linear\", C=100, class_weight='balanced'),\n",
    "        SVC(kernel=\"linear\", C=500, class_weight='balanced'),\n",
    "        SVC(kernel=\"linear\", C=1000, class_weight='balanced'),\n",
    "        SVC(kernel=\"linear\", C=0.25, class_weight='balanced'),\n",
    "        SVC(gamma=2, C=0.1, class_weight='balanced'),\n",
    "        SVC(gamma=2, C=0.25, class_weight='balanced'),\n",
    "        SVC(C=0.1, class_weight='balanced'),\n",
    "        SVC(C=5, class_weight='balanced'),\n",
    "        SVC(C=10, class_weight='balanced'),\n",
    "        SVC(C=50, class_weight='balanced'),\n",
    "        SVC(C=100, class_weight='balanced'),\n",
    "        SVC(C=500, class_weight='balanced'),\n",
    "        SVC(C=1000, class_weight='balanced'),\n",
    "        SVC(class_weight='balanced'),\n",
    "        MLPClassifier(alpha=1),\n",
    "        GaussianNB(),\n",
    "        RandomForestClassifier()]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9646377a-dc6d-4911-8e02-60b0a171abf7",
   "metadata": {},
   "outputs": [],
   "source": [
    "def getFeat(tokenList, model):\n",
    "    ret = np.zeros(300*8) if \"_ov_\" in model else np.zeros(300)\n",
    "    cnt = 0\n",
    "    \n",
    "    for token in tokenList:\n",
    "        embeddingModel = modelDict[model]\n",
    "        if token in embeddingModel:\n",
    "            ret += embeddingModel[token]\n",
    "            cnt+=1               \n",
    "    if cnt:\n",
    "        ret = ret/cnt\n",
    "    return ret.astype(\"float32\")\n",
    "\n",
    "def getFeatures(embeddingName, trainList, testList):\n",
    "    features = []\n",
    "    for ind, X in enumerate(trainList):\n",
    "        features.append([getFeat(i,embeddingName) for i in X])\n",
    "    labels = testList\n",
    "    return features, labels"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7147ccbf-b7b9-4998-afff-322d9acb37b3",
   "metadata": {},
   "outputs": [],
   "source": [
    "best_val, best_test = 0.0, 0.0\n",
    "\n",
    "def evaluate(X_values, y_values, clf, ind, embeddingName, categ):\n",
    "    global best_val, best_test \n",
    "    \n",
    "    clf.fit(X_values[0], y_values[0])\n",
    "    flag = False\n",
    "    if len(X_values[1]) > 0:\n",
    "        val_score = clf.score(X_values[1], y_values[1])\n",
    "        if val_score >= best_val:\n",
    "            flag = True\n",
    "            best_val = val_score\n",
    "    test_score = clf.score(X_values[2], y_values[2])\n",
    "    if flag:\n",
    "        if test_score > best_test:\n",
    "            best_test = test_score\n",
    "    print(\"\\n clf:{}\\t val score:{}\\t test score:{}\".format(ind, round(val_score, 4), \n",
    "                                                            round(test_score, 4)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "725c8563-7711-4275-8474-1caa653837b4",
   "metadata": {},
   "outputs": [],
   "source": [
    "sport = [[sport_train_X, sport_val_X, sport_test_X], [sport_train_y, sport_val_y, sport_test_y]]\n",
    "religion = [[religion_train_X, religion_val_X, religion_test_X], [religion_train_y, religion_val_y, religion_test_y]]\n",
    "computer = [[computer_train_X, computer_val_X, computer_test_X], [computer_train_y, computer_val_y, computer_test_y]]\n",
    "\n",
    "newsDict = {\"sport\":sport,\n",
    "            \"religion\":religion,\n",
    "            \"computer\":computer,}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1cba30fc-504b-4282-9945-4b69f07f2436",
   "metadata": {},
   "outputs": [],
   "source": [
    "def getResult(categoryList, embeddingList):    \n",
    "    global best_val, best_test    \n",
    "    for category in categoryList:\n",
    "        categName = [\"sport\", \"religion\", \"computer\"][categoryList.index(category)]\n",
    "        trainData, testData = newsDict[categName]\n",
    "\n",
    "        for embeddingName in embeddingList:\n",
    "            best_val, best_test = 0, 0\n",
    "            X_values, y_values = getFeatures(embeddingName, trainData, testData)\n",
    "            val, test = [], []\n",
    "            print(\"\\n-----------------{} results-------------------\".format(embeddingName))\n",
    "            for ind_test, clf in enumerate(classifiers):\n",
    "                evaluate(X_values, y_values, clf, ind_test, embeddingName, categName)\n",
    "            if embeddingName == embeddingList[-1]:\n",
    "                dataReady = False\n",
    "            print(\"--------------------------------------------------------------------\")\n",
    "            print(\"{}\\t{}\\tval score:{}\\ttest score:{}\".format(categName, embeddingName, \n",
    "                                                               round(best_val,4), round(best_test, 4)))\n",
    "            print(\"--------------------------------------------------------------------\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5d59bb93-8100-490f-8fd7-746ce4f0949e",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "getResult([sportCategories, religionCategories, computerCategories], list(modelDict.keys()))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
