{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "caa0de87-fdf9-4e6c-8613-066975cad0bc",
   "metadata": {},
   "source": [
    "# Download dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "4be141fb-c90d-40c2-b7c6-1dd90761e3e4",
   "metadata": {
    "collapsed": true,
    "jupyter": {
     "outputs_hidden": true
    },
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2024-09-26 08:02:13.238982: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n",
      "2024-09-26 08:02:13.259489: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n",
      "2024-09-26 08:02:13.265460: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n",
      "2024-09-26 08:02:13.282246: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n",
      "To enable the following instructions: AVX2 AVX512F FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n",
      "2024-09-26 08:02:14.283520: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\n"
     ]
    }
   ],
   "source": [
    "import random\n",
    "import numpy as np\n",
    "import fasttext\n",
    "from sklearn.ensemble import RandomForestClassifier\n",
    "from sklearn.feature_extraction.text import CountVectorizer\n",
    "from sklearn.metrics import accuracy_score\n",
    "from tensorflow.keras.datasets import imdb\n",
    "from tensorflow.keras.preprocessing.sequence import pad_sequences\n",
    "\n",
    "# Load the IMDB dataset\n",
    "(X_train, y_train), (X_test, y_test) = imdb.load_data(num_words=10000)\n",
    "\n",
    "# Decode back to text\n",
    "word_index = imdb.get_word_index()\n",
    "index_word = {index: word for word, index in word_index.items()}\n",
    "\n",
    "def decode_review(encoded_review):\n",
    "    return ' '.join([index_word.get(i - 3, '?') for i in encoded_review])\n",
    "\n",
    "# Convert integer sequences back to text\n",
    "X_train_text = [decode_review(review) for review in X_train]\n",
    "X_test_text = [decode_review(review) for review in X_test]\n",
    "\n",
    "# Tokenize the training data (split words)\n",
    "tokenized_train = [review.split() for review in X_train_text]\n",
    "# with open(\"imdb_train.txt\", \"w\") as f:\n",
    "#     for doc in X_train_text:\n",
    "#         f.write(doc + \"\\n\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "60aa5259-4199-4c97-bff2-e49fbcb9f9bd",
   "metadata": {},
   "source": [
    "# Augment dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2e043df0-979a-4eb6-b469-3c3cb0e9598f",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "import tensorflow_hub as hub\n",
    "import numpy as np\n",
    "import random\n",
    "from tqdm import tqdm\n",
    "import sys\n",
    "import os\n",
    "\n",
    "\n",
    "# Suppress TensorFlow logs (set TF_CPP_MIN_LOG_LEVEL to '3' to suppress everything)\n",
    "os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'\n",
    "warnings.filterwarnings(\"ignore\")\n",
    "tf.get_logger().setLevel('INFO')\n",
    "\n",
    "# Load the ELMo model\n",
    "elmo_model = hub.load(\"https://tfhub.dev/google/elmo/3\")\n",
    "\n",
    "# Function to compute ELMo embeddings for a sentence\n",
    "def get_elmo_embeddings(sentences):\n",
    "    embeddings = elmo_model.signatures['default'](tf.constant(sentences))[\"elmo\"]\n",
    "    return embeddings\n",
    "    \n",
    "def get_word_embeddings(doc):\n",
    "    tokens = doc.split()\n",
    "    embeddings = get_elmo_embeddings(tokens).numpy()\n",
    "    return dict(zip(tokens, embeddings))\n",
    "\n",
    "# Function to compute cosine similarity\n",
    "def cosine_similarity(vec1, vec2):\n",
    "    vec1 = np.squeeze(vec1)  # This will convert (1, 1024) to (1024,)\n",
    "    vec2 = np.squeeze(vec2)  # This will convert (1, 1024) to (1024,)\n",
    "    dot_product = np.dot(vec1, vec2)\n",
    "    norm_vec1 = np.linalg.norm(vec1)\n",
    "    norm_vec2 = np.linalg.norm(vec2)\n",
    "    return dot_product / (norm_vec1 * norm_vec2)\n",
    "    \n",
    "def augment_document(doc, label, percent=5):\n",
    "    tokens = doc.split()\n",
    "    num_words_to_change = len(tokens) * (percent / 100)\n",
    "    words_changed = 0\n",
    "    indices_to_change = set(random.sample(range(len(tokens)), int(num_words_to_change)))\n",
    "\n",
    "    # Precompute embeddings for all words in the document\n",
    "    word_embeddings = get_word_embeddings(doc)\n",
    "\n",
    "    new_tokens = []\n",
    "\n",
    "    for i, word in enumerate(tokens):\n",
    "        if i in indices_to_change:\n",
    "            word_embedding = word_embeddings[word]\n",
    "            \n",
    "            # Compute cosine similarities for the word with all other words in the document\n",
    "            similarities = []\n",
    "            for j, other_word in enumerate(tokens):\n",
    "                if j != i:  # Do not compare with itself\n",
    "                    other_word_embedding = word_embeddings[other_word]\n",
    "                    similarity = cosine_similarity(word_embedding, other_word_embedding)\n",
    "                    similarities.append((similarity, other_word))\n",
    "\n",
    "            # Sort by similarity and choose based on label\n",
    "            if label == 1:  # For positive labels, get the most similar\n",
    "                most_similar = sorted(similarities, key=lambda x: x[0], reverse=True)[:5]\n",
    "                chosen_word = random.choice([w for _, w in most_similar])\n",
    "            else:  # For negative labels, get the most dissimilar\n",
    "                least_similar = min(similarities, key=lambda x: x[0])  # Get the tuple with the smallest similarity score\n",
    "                chosen_word = least_similar[1]  # Select the word with the least similarity\n",
    "\n",
    "            new_tokens.append(chosen_word)\n",
    "            words_changed += 1\n",
    "        else:\n",
    "            new_tokens.append(word)\n",
    "\n",
    "        if words_changed >= num_words_to_change:\n",
    "            break\n",
    "\n",
    "    return ' '.join(new_tokens)\n",
    "    \n",
    "# Augment the entire training set\n",
    "percent_to_change = 5\n",
    "X_train_augmented = []\n",
    "\n",
    "for r, doc in enumerate(tqdm(X_train_text)):\n",
    "    augmented_doc = augment_document(doc, y_train[r], percent_to_change)\n",
    "    X_train_augmented.append(augmented_doc)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9af9c6e0-b248-4239-9771-356d7209ab62",
   "metadata": {},
   "source": [
    "# Classify dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "a4bff02e-d669-4d08-adab-48426af1f7a5",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# Use CountVectorizer to convert text to vectors\n",
    "vectorizer = CountVectorizer()\n",
    "X_train_vec = vectorizer.fit_transform(X_train_augmented)\n",
    "X_test_vec = vectorizer.transform(X_test_text)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7460ce85-15a2-4a56-b53f-a4a71cf2897d",
   "metadata": {},
   "source": [
    "### RandomForestClassifier"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6c456a7b-666b-4c2e-a1ea-1e33da666f95",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Train the classifier on augmented data\n",
    "classifier = RandomForestClassifier(n_estimators=100, random_state=42)\n",
    "classifier.fit(X_train_vec, y_train)\n",
    "\n",
    "# Predict and evaluate on original test data\n",
    "y_pred = classifier.predict(X_test_vec)\n",
    "accuracy = accuracy_score(y_test, y_pred)\n",
    "print(f\"Test Accuracy: {accuracy:.4f}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f73e435c-a864-4d1d-b105-283f39a8dfbc",
   "metadata": {},
   "source": [
    "### LogisticRegression"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2727f318-f2c4-4c93-9616-585b4d956908",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.linear_model import LogisticRegression\n",
    "\n",
    "classifier = LogisticRegression(random_state=42)\n",
    "classifier.fit(X_train_vec, y_train)\n",
    "\n",
    "y_pred = classifier.predict(X_test_vec)\n",
    "accuracy = accuracy_score(y_test, y_pred)\n",
    "print(f\"Test Accuracy (Logistic Regression): {accuracy:.4f}\")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "457413ed-dbb9-4a8f-bbe0-44cf299bc4d3",
   "metadata": {},
   "source": [
    "### Naive Bayes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c544a8d0-ce68-49cb-8c12-4c3e7b2e62bb",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.naive_bayes import MultinomialNB\n",
    "\n",
    "classifier = MultinomialNB()\n",
    "classifier.fit(X_train_vec, y_train)\n",
    "\n",
    "y_pred = classifier.predict(X_test_vec)\n",
    "accuracy = accuracy_score(y_test, y_pred)\n",
    "print(f\"Test Accuracy (Naive Bayes): {accuracy:.4f}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e7588d57-0c05-46ce-bef6-a97b99d8e61d",
   "metadata": {},
   "source": [
    "### SVM"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "449b7dec-aeac-4419-b3c0-d121c8cee59f",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.svm import LinearSVC\n",
    "\n",
    "classifier = LinearSVC(random_state=42)\n",
    "classifier.fit(X_train_vec, y_train)\n",
    "\n",
    "y_pred = classifier.predict(X_test_vec)\n",
    "accuracy = accuracy_score(y_test, y_pred)\n",
    "print(f\"Test Accuracy (SVM): {accuracy:.4f}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e90fede3-56c9-4ffa-ba88-3d4bc061ac0a",
   "metadata": {},
   "source": [
    "### MLP"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3b7ed7b3-6a3a-4a97-ad67-841f300ed52d",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.neural_network import MLPClassifier\n",
    "\n",
    "classifier = MLPClassifier(random_state=42)\n",
    "classifier.fit(X_train_vec, y_train)\n",
    "\n",
    "y_pred = classifier.predict(X_test_vec)\n",
    "accuracy = accuracy_score(y_test, y_pred)\n",
    "print(f\"Test Accuracy (MLP): {accuracy:.4f}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "37cc8102-219a-4746-999d-8277296bc999",
   "metadata": {},
   "source": [
    "### TM"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "87999304-113e-4d01-be9c-390793513d95",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "from tmu.models.classification.vanilla_classifier import TMClassifier\n",
    "import numpy as np\n",
    "import pickle\n",
    "\n",
    "# Convert labels to one-hot encoding for Tsetlin Machine\n",
    "X_train_tm = np.array(X_train_vec.toarray(), dtype=np.uint32)\n",
    "Y_train_tm = y_train.astype(np.uint32)\n",
    "\n",
    "X_test_tm = np.array(X_test_vec.toarray(), dtype=np.uint32)\n",
    "Y_test_tm = y_test.astype(np.uint32)\n",
    "\n",
    "num_clauses = 1000\n",
    "T = 8000\n",
    "s = 2.0\n",
    "device = \"CPU\"\n",
    "weighted_clauses = True\n",
    "epochs = 10\n",
    "clause_drop_p = 0.75\n",
    "\n",
    "print(\"started\")\n",
    "tm = TMClassifier(num_clauses, T, s, platform=device, weighted_clauses=weighted_clauses,clause_drop_p=clause_drop_p)\n",
    "for epoch in range(epochs):\n",
    "    tm.fit(X_train_tm, Y_train_tm)\n",
    "    result = 100 * (tm.predict(X_test_tm) == Y_test_tm).mean()\n",
    "    print(f\"Accuracy: {result:.2f}\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
