{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "caa0de87-fdf9-4e6c-8613-066975cad0bc",
   "metadata": {},
   "source": [
    "# Download dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "4be141fb-c90d-40c2-b7c6-1dd90761e3e4",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "import random\n",
    "import numpy as np\n",
    "from gensim.models import Word2Vec\n",
    "from sklearn.ensemble import RandomForestClassifier\n",
    "from sklearn.feature_extraction.text import CountVectorizer\n",
    "from sklearn.metrics import accuracy_score\n",
    "from tensorflow.keras.datasets import imdb\n",
    "from tensorflow.keras.preprocessing.sequence import pad_sequences\n",
    "\n",
    "# Load the IMDB dataset\n",
    "(X_train, y_train), (X_test, y_test) = imdb.load_data(num_words=20000)\n",
    "\n",
    "# Decode back to text\n",
    "word_index = imdb.get_word_index()\n",
    "index_word = {index: word for word, index in word_index.items()}\n",
    "\n",
    "def decode_review(encoded_review):\n",
    "    return ' '.join([index_word.get(i - 3, '?') for i in encoded_review])\n",
    "\n",
    "# Convert integer sequences back to text\n",
    "X_train_text = [decode_review(review) for review in X_train]\n",
    "X_test_text = [decode_review(review) for review in X_test]\n",
    "\n",
    "# Tokenize the training data (split words)\n",
    "tokenized_train = [review.split() for review in X_train_text]\n",
    "\n",
    "w2v_model = Word2Vec(\n",
    "    sentences=tokenized_train,\n",
    "    vector_size=100,  # Dimensionality of embeddings\n",
    "    window=5,  # Context window size\n",
    "    min_count=1,  # Minimum word count to include in vocabulary\n",
    "    workers=4,  # Number of parallel threads\n",
    "    sg=1  # Skip-gram model\n",
    ")\n",
    "\n",
    "# Save the trained Word2Vec model\n",
    "w2v_model.save(\"custom_word2vec.model\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "60aa5259-4199-4c97-bff2-e49fbcb9f9bd",
   "metadata": {},
   "source": [
    "# Augment dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "c0309a2d-3371-4bbe-b71b-42ee9f6b1a05",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 25000/25000 [05:30<00:00, 75.67it/s] \n"
     ]
    }
   ],
   "source": [
    "from tqdm import tqdm\n",
    "\n",
    "def augment_document(doc, label, w2v_model, percent=5):\n",
    "    tokens = doc.split()\n",
    "    num_words_to_change = len(tokens) * (percent / 100)\n",
    "    words_changed = 0\n",
    "    indices_to_change = set(random.sample(range(len(tokens)), int(num_words_to_change)))\n",
    "  \n",
    "    new_tokens = []\n",
    "    for i, word in enumerate(tokens):\n",
    "        if i in indices_to_change and word in w2v_model.wv:\n",
    "            if label == 1:\n",
    "                similar_words = w2v_model.wv.most_similar(word, topn=5)\n",
    "                chosen_word = random.choice(similar_words)[0]  # Choose one similar word randomly\n",
    "            else:\n",
    "                similar_words = w2v_model.wv.most_similar(word, topn=400)[::-1]\n",
    "                chosen_word = random.choice(similar_words)[0]  # Choose one dissimilar word randomly\n",
    "                # similar_words = w2v_model.wv.most_similar(word, topn=1000)[::-1]  # Get the most similar words and reverse the order\n",
    "                # chosen_word = similar_words[-1][0]  # Choose the last word (most dissimilar)\n",
    "    \n",
    "            # print(chosen_word)\n",
    "            new_tokens.append(chosen_word)\n",
    "            words_changed += 1\n",
    "        else:\n",
    "            new_tokens.append(word)\n",
    "  \n",
    "        if words_changed >= num_words_to_change:\n",
    "            break\n",
    "  \n",
    "    return ' '.join(new_tokens)\n",
    "\n",
    "# Augment the entire training set\n",
    "percent_to_change = 5\n",
    "X_train_augmented = []\n",
    "\n",
    "for r, doc in enumerate(tqdm(X_train_text)):\n",
    "    augmented_doc = augment_document(doc, y_train[r], w2v_model, percent_to_change)\n",
    "    X_train_augmented.append(augmented_doc)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9af9c6e0-b248-4239-9771-356d7209ab62",
   "metadata": {},
   "source": [
    "# Classify dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "a4bff02e-d669-4d08-adab-48426af1f7a5",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# Use CountVectorizer to convert text to vectors\n",
    "vectorizer = CountVectorizer()\n",
    "X_train_vec = vectorizer.fit_transform(X_train_augmented)\n",
    "X_test_vec = vectorizer.transform(X_test_text)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7460ce85-15a2-4a56-b53f-a4a71cf2897d",
   "metadata": {},
   "source": [
    "### RandomForestClassifier"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "6c456a7b-666b-4c2e-a1ea-1e33da666f95",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Test Accuracy: 0.7682\n"
     ]
    }
   ],
   "source": [
    "# Train the classifier on augmented data\n",
    "classifier = RandomForestClassifier(n_estimators=100, random_state=42)\n",
    "classifier.fit(X_train_vec, y_train)\n",
    "\n",
    "# Predict and evaluate on original test data\n",
    "y_pred = classifier.predict(X_test_vec)\n",
    "accuracy = accuracy_score(y_test, y_pred)\n",
    "print(f\"Test Accuracy: {accuracy:.4f}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f73e435c-a864-4d1d-b105-283f39a8dfbc",
   "metadata": {},
   "source": [
    "### LogisticRegression"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "2727f318-f2c4-4c93-9616-585b4d956908",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Test Accuracy (Logistic Regression): 0.8019\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/opt/conda/lib/python3.11/site-packages/sklearn/linear_model/_logistic.py:460: ConvergenceWarning: lbfgs failed to converge (status=1):\n",
      "STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.\n",
      "\n",
      "Increase the number of iterations (max_iter) or scale the data as shown in:\n",
      "    https://scikit-learn.org/stable/modules/preprocessing.html\n",
      "Please also refer to the documentation for alternative solver options:\n",
      "    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression\n",
      "  n_iter_i = _check_optimize_result(\n"
     ]
    }
   ],
   "source": [
    "from sklearn.linear_model import LogisticRegression\n",
    "\n",
    "classifier = LogisticRegression(random_state=42)\n",
    "classifier.fit(X_train_vec, y_train)\n",
    "\n",
    "y_pred = classifier.predict(X_test_vec)\n",
    "accuracy = accuracy_score(y_test, y_pred)\n",
    "print(f\"Test Accuracy (Logistic Regression): {accuracy:.4f}\")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "457413ed-dbb9-4a8f-bbe0-44cf299bc4d3",
   "metadata": {},
   "source": [
    "### Naive Bayes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "c544a8d0-ce68-49cb-8c12-4c3e7b2e62bb",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Test Accuracy (Naive Bayes): 0.8256\n"
     ]
    }
   ],
   "source": [
    "from sklearn.naive_bayes import MultinomialNB\n",
    "\n",
    "classifier = MultinomialNB()\n",
    "classifier.fit(X_train_vec, y_train)\n",
    "\n",
    "y_pred = classifier.predict(X_test_vec)\n",
    "accuracy = accuracy_score(y_test, y_pred)\n",
    "print(f\"Test Accuracy (Naive Bayes): {accuracy:.4f}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e7588d57-0c05-46ce-bef6-a97b99d8e61d",
   "metadata": {},
   "source": [
    "### SVM"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "449b7dec-aeac-4419-b3c0-d121c8cee59f",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/opt/conda/lib/python3.11/site-packages/sklearn/svm/_classes.py:32: FutureWarning: The default value of `dual` will change from `True` to `'auto'` in 1.5. Set the value of `dual` explicitly to suppress the warning.\n",
      "  warnings.warn(\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Test Accuracy (SVM): 0.8037\n"
     ]
    }
   ],
   "source": [
    "from sklearn.svm import LinearSVC\n",
    "\n",
    "classifier = LinearSVC(random_state=42)\n",
    "classifier.fit(X_train_vec, y_train)\n",
    "\n",
    "y_pred = classifier.predict(X_test_vec)\n",
    "accuracy = accuracy_score(y_test, y_pred)\n",
    "print(f\"Test Accuracy (SVM): {accuracy:.4f}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e90fede3-56c9-4ffa-ba88-3d4bc061ac0a",
   "metadata": {},
   "source": [
    "### MLP"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "3b7ed7b3-6a3a-4a97-ad67-841f300ed52d",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Test Accuracy (MLP): 0.8398\n"
     ]
    }
   ],
   "source": [
    "from sklearn.neural_network import MLPClassifier\n",
    "\n",
    "classifier = MLPClassifier(random_state=42)\n",
    "classifier.fit(X_train_vec, y_train)\n",
    "\n",
    "y_pred = classifier.predict(X_test_vec)\n",
    "accuracy = accuracy_score(y_test, y_pred)\n",
    "print(f\"Test Accuracy (MLP): {accuracy:.4f}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "37cc8102-219a-4746-999d-8277296bc999",
   "metadata": {},
   "source": [
    "### TM"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "87999304-113e-4d01-be9c-390793513d95",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "started\n",
      "2024-11-21 12:21:09,415 - tmu.models.base - WARNING - CUDA not installed, using CPU clause bank\n",
      "Accuracy: 65.16\n",
      "Accuracy: 78.80\n",
      "Accuracy: 79.38\n",
      "Accuracy: 77.31\n",
      "Accuracy: 76.09\n",
      "Accuracy: 69.36\n",
      "Accuracy: 70.12\n",
      "Accuracy: 72.34\n",
      "Accuracy: 73.50\n",
      "Accuracy: 73.95\n"
     ]
    }
   ],
   "source": [
    "from tmu.models.classification.vanilla_classifier import TMClassifier\n",
    "import numpy as np\n",
    "import pickle\n",
    "\n",
    "# Convert labels to one-hot encoding for Tsetlin Machine\n",
    "X_train_tm = np.array(X_train_vec.toarray(), dtype=np.uint32)\n",
    "Y_train_tm = y_train.astype(np.uint32)\n",
    "\n",
    "X_test_tm = np.array(X_test_vec.toarray(), dtype=np.uint32)\n",
    "Y_test_tm = y_test.astype(np.uint32)\n",
    "\n",
    "# with open('imdb_tm_classifer.pickle', 'wb') as f:\n",
    "#     pickle.dump((X_train, Y_train, X_test, Y_test), f)\n",
    "        \n",
    "# saved = open(\"imdb_tm_classifer.pickle\", \"rb\")\n",
    "# X_train, Y_train, X_test, Y_test = pickle.load(saved)\n",
    "# saved.close()\n",
    "\n",
    "num_clauses = 1000\n",
    "T = 8000\n",
    "s = 2.0\n",
    "device = \"GPU\"\n",
    "weighted_clauses = True\n",
    "epochs = 10\n",
    "clause_drop_p = 0.75\n",
    "\n",
    "print(\"started\")\n",
    "tm = TMClassifier(num_clauses, T, s, platform=device, weighted_clauses=weighted_clauses,clause_drop_p=clause_drop_p)\n",
    "for epoch in range(epochs):\n",
    "    tm.fit(X_train_tm, Y_train_tm)\n",
    "    result = 100 * (tm.predict(X_test_tm) == Y_test_tm).mean()\n",
    "    print(f\"Accuracy: {result:.2f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d4ffda05-86f9-45fc-935e-d75a948bba33",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
