{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "b7d401cd-79dc-4086-8e62-62acfb8e6c40",
   "metadata": {},
   "source": [
    "# Download dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6c7de1c2-be6b-4c06-b3a7-e6cc88aab929",
   "metadata": {
    "collapsed": true,
    "jupyter": {
     "outputs_hidden": true
    },
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "import random\n",
    "import numpy as np\n",
    "from sklearn.ensemble import RandomForestClassifier\n",
    "from sklearn.feature_extraction.text import CountVectorizer\n",
    "from sklearn.metrics import accuracy_score\n",
    "from tensorflow.keras.datasets import imdb\n",
    "from tensorflow.keras.preprocessing.sequence import pad_sequences\n",
    "from directories import dicrectories\n",
    "from tools import tools\n",
    "\n",
    "def preprocess_text(text):\n",
    "    return text\n",
    "knowledge_directory = dicrectories.knowledge\n",
    "vectorizer_X = tools.read_pickle_data(\"IMDB/vectorizer_X.pickle\")\n",
    "\n",
    "# Load the IMDB dataset\n",
    "(X_train, y_train), (X_test, y_test) = imdb.load_data(num_words=20000)\n",
    "\n",
    "# Decode back to text\n",
    "word_index = imdb.get_word_index()\n",
    "index_word = {index: word for word, index in word_index.items()}\n",
    "\n",
    "def decode_review(encoded_review):\n",
    "    return ' '.join([index_word.get(i - 3, '?') for i in encoded_review])\n",
    "\n",
    "# Convert integer sequences back to text\n",
    "X_train_text = [decode_review(review) for review in X_train]\n",
    "X_test_text = [decode_review(review) for review in X_test]\n",
    "\n",
    "# Tokenize the training data (split words)\n",
    "tokenized_train = [review.split() for review in X_train_text]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8d4c0c90-7a87-4c60-845d-93bcee6b9e63",
   "metadata": {},
   "source": [
    "# Augment dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8522b8d7-f9c4-4052-8f3b-15e0bb9d770d",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# Augment training data by replacing words with their most similar words\n",
    "import pickle\n",
    "from tqdm import tqdm\n",
    "from concurrent.futures import ThreadPoolExecutor\n",
    "\n",
    "def augment_document_parallel(args):\n",
    "    doc, percent, label = args\n",
    "    return augment_document(doc, percent, label)\n",
    "    \n",
    "# Load all clauses into memory at the start\n",
    "clause_cache = {}\n",
    "for word, id in tqdm(vectorizer_X.vocabulary_.items(), desc=\"Loading clauses into cache\"):\n",
    "    file_path = dicrectories.pickle_by_id(knowledge_directory, id)\n",
    "    if dicrectories.pickle_exist(file_path):\n",
    "        clause_cache[id] = tools.read_pickle_data(file_path)\n",
    "\n",
    "def augment_document(doc, percent, label):   \n",
    "    tokens = doc.split()\n",
    "    num_words_to_change = len(tokens) * (percent / 100)\n",
    "    words_changed = 0\n",
    "    indices_to_change = set(random.sample(range(len(tokens)), int(num_words_to_change)))\n",
    "    for i in range(len(tokens)):\n",
    "        if i in indices_to_change:\n",
    "            word = tokens[i]\n",
    "            id = vectorizer_X.vocabulary_.get(word, None)\n",
    "            if id is None:\n",
    "                continue\n",
    "            top_features = [] \n",
    "            id = vectorizer_X.vocabulary_[word] \n",
    "\n",
    "            clauses = clause_cache.get(id)\n",
    "            if clauses is None:\n",
    "                continue          \n",
    "\n",
    "            if label == 1:\n",
    "                clauses_sorted = sorted(clauses, key=lambda x: x[0], reverse=True)[:10]\n",
    "            elif label == 0:\n",
    "                clauses_sorted = sorted(clauses, key=lambda x: x[0])[:10] \n",
    "\n",
    "            top_features = []\n",
    "            for clause in clauses_sorted:\n",
    "                for feature_id in clause[1]:\n",
    "                    top_features.append(vectorizer_X.get_feature_names_out()[feature_id])\n",
    "                    break\n",
    "            \n",
    "            if top_features:\n",
    "                tokens[i] = next(iter(top_features))\n",
    "                words_changed += 1\n",
    "    return ' '.join(tokens)\n",
    "\n",
    "# Augment the entire training set\n",
    "percent_to_change = 5\n",
    "\n",
    "with ThreadPoolExecutor() as executor:\n",
    "    # Wrap executor.map with tqdm to show progress\n",
    "    X_train_augmented = list(\n",
    "        tqdm(\n",
    "            executor.map(augment_document_parallel, \n",
    "                         [(doc, percent_to_change, y_train[r]) for r, doc in enumerate(X_train_text)]),\n",
    "            total=len(X_train_text),\n",
    "            desc=\"Augmenting documents\"\n",
    "        )\n",
    "    )"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0ca5487a-c351-4532-856b-6e2c3fcc2eee",
   "metadata": {},
   "source": [
    "# Classify dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "cd801caf-680d-4d12-986b-d0b8e040c0c2",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Use CountVectorizer to convert text to vectors\n",
    "vectorizer = CountVectorizer()\n",
    "X_train_vec = vectorizer.fit_transform(X_train_augmented)\n",
    "X_test_vec = vectorizer.transform(X_test_text)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "440e39e4-0748-4dce-8749-e48025674b3b",
   "metadata": {},
   "source": [
    "### RandomForestClassifier"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "0305a603-a99b-4c8b-9926-656b7553b043",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Test Accuracy: 0.7910\n"
     ]
    }
   ],
   "source": [
    "# Train the classifier on augmented data\n",
    "classifier = RandomForestClassifier(n_estimators=100, random_state=42)\n",
    "classifier.fit(X_train_vec, y_train)\n",
    "\n",
    "# Predict and evaluate on original test data\n",
    "y_pred = classifier.predict(X_test_vec)\n",
    "accuracy = accuracy_score(y_test, y_pred)\n",
    "print(f\"Test Accuracy: {accuracy:.4f}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1626ffbc-f88f-4983-9c6e-7bc98d870436",
   "metadata": {},
   "source": [
    "### LogisticRegression"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c61305be-45cf-4774-9666-0f60ebeacd5e",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.linear_model import LogisticRegression\n",
    "\n",
    "classifier = LogisticRegression(random_state=42)\n",
    "classifier.fit(X_train_vec, y_train)\n",
    "\n",
    "y_pred = classifier.predict(X_test_vec)\n",
    "accuracy = accuracy_score(y_test, y_pred)\n",
    "print(f\"Test Accuracy (Logistic Regression): {accuracy:.4f}\")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d5c5936b-98ba-4083-97f6-81bc9a4e73d6",
   "metadata": {},
   "source": [
    "### Naive Bayes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c57deef3-fd0a-445e-813b-d43b72927b3c",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.naive_bayes import MultinomialNB\n",
    "\n",
    "classifier = MultinomialNB()\n",
    "classifier.fit(X_train_vec, y_train)\n",
    "\n",
    "y_pred = classifier.predict(X_test_vec)\n",
    "accuracy = accuracy_score(y_test, y_pred)\n",
    "print(f\"Test Accuracy (Naive Bayes): {accuracy:.4f}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0d897bc3-14fb-4997-92bc-c553cb98f2ba",
   "metadata": {},
   "source": [
    "### SVM"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a331e491-ad3e-4426-b0f3-97ab0f68f63e",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.svm import LinearSVC\n",
    "\n",
    "classifier = LinearSVC(random_state=42)\n",
    "classifier.fit(X_train_vec, y_train)\n",
    "\n",
    "y_pred = classifier.predict(X_test_vec)\n",
    "accuracy = accuracy_score(y_test, y_pred)\n",
    "print(f\"Test Accuracy (SVM): {accuracy:.4f}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "55030f29-efc2-46d5-8d28-c13eca4fb68e",
   "metadata": {},
   "source": [
    "### MLP"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0caf8010-bc8a-4471-9268-60efdb76ee52",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.neural_network import MLPClassifier\n",
    "\n",
    "classifier = MLPClassifier(random_state=42)\n",
    "classifier.fit(X_train_vec, y_train)\n",
    "\n",
    "y_pred = classifier.predict(X_test_vec)\n",
    "accuracy = accuracy_score(y_test, y_pred)\n",
    "print(f\"Test Accuracy (MLP): {accuracy:.4f}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9c50967d-0431-4d62-906e-316aea590389",
   "metadata": {},
   "source": [
    "### TM"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5ae6d4f0-0bab-4738-b3f4-49364004e89c",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "from tmu.models.classification.vanilla_classifier import TMClassifier\n",
    "import numpy as np\n",
    "import pickle\n",
    "\n",
    "# Convert labels to one-hot encoding for Tsetlin Machine\n",
    "X_train_tm = np.array(X_train_vec.toarray(), dtype=np.uint32)\n",
    "Y_train_tm = y_train.astype(np.uint32)\n",
    "\n",
    "X_test_tm = np.array(X_test_vec.toarray(), dtype=np.uint32)\n",
    "Y_test_tm = y_test.astype(np.uint32)\n",
    "\n",
    "num_clauses = 1000\n",
    "T = 8000\n",
    "s = 2.0\n",
    "device = \"CPU\"\n",
    "weighted_clauses = True\n",
    "epochs = 10\n",
    "clause_drop_p = 0.75\n",
    "\n",
    "print(\"started\")\n",
    "tm = TMClassifier(num_clauses, T, s, platform=device, weighted_clauses=weighted_clauses,clause_drop_p=clause_drop_p)\n",
    "for epoch in range(epochs):\n",
    "    tm.fit(X_train_tm, Y_train_tm)\n",
    "    result = 100 * (tm.predict(X_test_tm) == Y_test_tm).mean()\n",
    "    print(f\"Accuracy: {result:.2f}\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
