{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "caa0de87-fdf9-4e6c-8613-066975cad0bc",
   "metadata": {},
   "source": [
    "# Download dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ce0bc6e2-5a83-4372-966b-c0928530ad5c",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "import random\n",
    "import numpy as np\n",
    "import fasttext\n",
    "from sklearn.ensemble import RandomForestClassifier\n",
    "from sklearn.feature_extraction.text import CountVectorizer\n",
    "from sklearn.metrics import accuracy_score\n",
    "from tensorflow.keras.datasets import imdb\n",
    "from tensorflow.keras.preprocessing.sequence import pad_sequences\n",
    "\n",
    "# Load the IMDB dataset\n",
    "(X_train, y_train), (X_test, y_test) = imdb.load_data(num_words=10000)\n",
    "\n",
    "# Decode back to text\n",
    "word_index = imdb.get_word_index()\n",
    "index_word = {index: word for word, index in word_index.items()}\n",
    "\n",
    "def decode_review(encoded_review):\n",
    "    return ' '.join([index_word.get(i - 3, '?') for i in encoded_review])\n",
    "\n",
    "# Convert integer sequences back to text\n",
    "X_train_text = [decode_review(review) for review in X_train]\n",
    "X_test_text = [decode_review(review) for review in X_test]\n",
    "\n",
    "# Tokenize the training data (split words)\n",
    "tokenized_train = [review.split() for review in X_train_text]\n",
    "with open(\"imdb_train.txt\", \"w\") as f:\n",
    "    for doc in X_train_text:\n",
    "        f.write(doc + \"\\n\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "60aa5259-4199-4c97-bff2-e49fbcb9f9bd",
   "metadata": {},
   "source": [
    "# Augment dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b6619eb3-0900-4889-a81b-ae23979a005d",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "from tqdm import tqdm\n",
    "# Train a Word2Vec model on your training data\n",
    "ft_model = fasttext.train_unsupervised(\"imdb_train.txt\", model='skipgram', dim=100, epoch=5, lr=0.05)\n",
    "def cosine_similarity(vec1, vec2):\n",
    "    dot_product = np.dot(vec1, vec2)\n",
    "    norm_vec1 = np.linalg.norm(vec1)\n",
    "    norm_vec2 = np.linalg.norm(vec2)\n",
    "    return dot_product / (norm_vec1 * norm_vec2)\n",
    "    \n",
    "def augment_document(doc, label, ft_model, percent=5):\n",
    "    tokens = doc.split()\n",
    "    num_words_to_change = len(tokens) * (percent / 100)\n",
    "    words_changed = 0\n",
    "    indices_to_change = set(random.sample(range(len(tokens)), int(num_words_to_change)))\n",
    "  \n",
    "    new_tokens = []\n",
    "    for i, word in enumerate(tokens):\n",
    "        if i in indices_to_change and word in ft_model.words:\n",
    "            if label == 1:\n",
    "                similar_words = ft_model.get_nearest_neighbors(word, k=5)  \n",
    "                similar_words = [w for _, w in similar_words]  \n",
    "                chosen_word = random.choice(similar_words)[0]  \n",
    "            else:\n",
    "                similar_words = ft_model.get_nearest_neighbors(word, k=400) \n",
    "                word_vector = ft_model.get_word_vector(word)\n",
    "                # Calculate cosine similarity for all nearest neighbors\n",
    "                word_similarity = [cosine_similarity(word_vector, ft_model.get_word_vector(w)) for _, w in similar_words]\n",
    "                # Sort words by similarity in ascending order (least similar first)\n",
    "                sorted_dissimilar_words = [w for _, w in sorted(zip(word_similarity, similar_words), key=lambda pair: pair[0], reverse=False)]\n",
    "                # Pick the top 5 least similar words\n",
    "                least_similar_words = sorted_dissimilar_words[:5]\n",
    "                # Randomly choose one least similar word\n",
    "                chosen_word = random.choice(least_similar_words)[1]\n",
    "            new_tokens.append(chosen_word)\n",
    "            words_changed += 1\n",
    "        else:\n",
    "            new_tokens.append(word)\n",
    "  \n",
    "        if words_changed >= num_words_to_change:\n",
    "            break\n",
    "  \n",
    "    return ' '.join(new_tokens)\n",
    "\n",
    "# Augment the entire training set\n",
    "percent_to_change = 5\n",
    "X_train_augmented = []\n",
    "\n",
    "for r, doc in enumerate(tqdm(X_train_text)):\n",
    "    augmented_doc = augment_document(doc, y_train[r], ft_model, percent_to_change)\n",
    "    X_train_augmented.append(augmented_doc)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9af9c6e0-b248-4239-9771-356d7209ab62",
   "metadata": {},
   "source": [
    "# Classify dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "a4bff02e-d669-4d08-adab-48426af1f7a5",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# Use CountVectorizer to convert text to vectors\n",
    "vectorizer = CountVectorizer()\n",
    "X_train_vec = vectorizer.fit_transform(X_train_augmented)\n",
    "X_test_vec = vectorizer.transform(X_test_text)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7460ce85-15a2-4a56-b53f-a4a71cf2897d",
   "metadata": {},
   "source": [
    "### RandomForestClassifier"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6c456a7b-666b-4c2e-a1ea-1e33da666f95",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Train the classifier on augmented data\n",
    "classifier = RandomForestClassifier(n_estimators=100, random_state=42)\n",
    "classifier.fit(X_train_vec, y_train)\n",
    "\n",
    "# Predict and evaluate on original test data\n",
    "y_pred = classifier.predict(X_test_vec)\n",
    "accuracy = accuracy_score(y_test, y_pred)\n",
    "print(f\"Test Accuracy: {accuracy:.4f}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f73e435c-a864-4d1d-b105-283f39a8dfbc",
   "metadata": {},
   "source": [
    "### LogisticRegression"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2727f318-f2c4-4c93-9616-585b4d956908",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.linear_model import LogisticRegression\n",
    "\n",
    "classifier = LogisticRegression(random_state=42)\n",
    "classifier.fit(X_train_vec, y_train)\n",
    "\n",
    "y_pred = classifier.predict(X_test_vec)\n",
    "accuracy = accuracy_score(y_test, y_pred)\n",
    "print(f\"Test Accuracy (Logistic Regression): {accuracy:.4f}\")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "457413ed-dbb9-4a8f-bbe0-44cf299bc4d3",
   "metadata": {},
   "source": [
    "### Naive Bayes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c544a8d0-ce68-49cb-8c12-4c3e7b2e62bb",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.naive_bayes import MultinomialNB\n",
    "\n",
    "classifier = MultinomialNB()\n",
    "classifier.fit(X_train_vec, y_train)\n",
    "\n",
    "y_pred = classifier.predict(X_test_vec)\n",
    "accuracy = accuracy_score(y_test, y_pred)\n",
    "print(f\"Test Accuracy (Naive Bayes): {accuracy:.4f}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e7588d57-0c05-46ce-bef6-a97b99d8e61d",
   "metadata": {},
   "source": [
    "### SVM"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "449b7dec-aeac-4419-b3c0-d121c8cee59f",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.svm import LinearSVC\n",
    "\n",
    "classifier = LinearSVC(random_state=42)\n",
    "classifier.fit(X_train_vec, y_train)\n",
    "\n",
    "y_pred = classifier.predict(X_test_vec)\n",
    "accuracy = accuracy_score(y_test, y_pred)\n",
    "print(f\"Test Accuracy (SVM): {accuracy:.4f}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e90fede3-56c9-4ffa-ba88-3d4bc061ac0a",
   "metadata": {},
   "source": [
    "### MLP"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3b7ed7b3-6a3a-4a97-ad67-841f300ed52d",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.neural_network import MLPClassifier\n",
    "\n",
    "classifier = MLPClassifier(random_state=42)\n",
    "classifier.fit(X_train_vec, y_train)\n",
    "\n",
    "y_pred = classifier.predict(X_test_vec)\n",
    "accuracy = accuracy_score(y_test, y_pred)\n",
    "print(f\"Test Accuracy (MLP): {accuracy:.4f}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "37cc8102-219a-4746-999d-8277296bc999",
   "metadata": {},
   "source": [
    "### TM"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "87999304-113e-4d01-be9c-390793513d95",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "from tmu.models.classification.vanilla_classifier import TMClassifier\n",
    "import numpy as np\n",
    "import pickle\n",
    "\n",
    "# Convert labels to one-hot encoding for Tsetlin Machine\n",
    "X_train_tm = np.array(X_train_vec.toarray(), dtype=np.uint32)\n",
    "Y_train_tm = y_train.astype(np.uint32)\n",
    "\n",
    "X_test_tm = np.array(X_test_vec.toarray(), dtype=np.uint32)\n",
    "Y_test_tm = y_test.astype(np.uint32)\n",
    "\n",
    "num_clauses = 1000\n",
    "T = 8000\n",
    "s = 2.0\n",
    "device = \"CPU\"\n",
    "weighted_clauses = True\n",
    "epochs = 10\n",
    "clause_drop_p = 0.75\n",
    "\n",
    "print(\"started\")\n",
    "tm = TMClassifier(num_clauses, T, s, platform=device, weighted_clauses=weighted_clauses,clause_drop_p=clause_drop_p)\n",
    "for epoch in range(epochs):\n",
    "    tm.fit(X_train_tm, Y_train_tm)\n",
    "    result = 100 * (tm.predict(X_test_tm) == Y_test_tm).mean()\n",
    "    print(f\"Accuracy: {result:.2f}\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
