{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "caa0de87-fdf9-4e6c-8613-066975cad0bc",
   "metadata": {},
   "source": [
    "# Download dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3624342a-4c8c-4398-a283-56520ca4007c",
   "metadata": {},
   "outputs": [],
   "source": [
    "import random\n",
    "import numpy as np\n",
    "import torch\n",
    "from transformers import BertTokenizer, BertForPreTraining, AdamW\n",
    "from torch.utils.data import DataLoader, Dataset\n",
    "from sklearn.metrics import accuracy_score\n",
    "from tensorflow.keras.datasets import imdb\n",
    "\n",
    "# Load the IMDB dataset\n",
    "(X_train, y_train), (X_test, y_test) = imdb.load_data(num_words=20000)\n",
    "\n",
    "# Decode back to text\n",
    "word_index = imdb.get_word_index()\n",
    "index_word = {index: word for word, index in word_index.items()}\n",
    "\n",
    "def decode_review(encoded_review):\n",
    "    return ' '.join([index_word.get(i - 3, '?') for i in encoded_review])\n",
    "\n",
    "# Convert integer sequences back to text\n",
    "X_train_text = [decode_review(review) for review in X_train]\n",
    "X_test_text = [decode_review(review) for review in X_test]\n",
    "\n",
    "# Initialize tokenizer and prepare dataset\n",
    "tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')\n",
    "# Tokenize the dataset\n",
    "train_encodings = tokenizer(X_train_text, truncation=True, padding=True, max_length=128)\n",
    "test_encodings = tokenizer(X_test_text, truncation=True, padding=True, max_length=128)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "34e0f0ab-5b0c-411f-9b98-3feef7f6a40c",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch.utils.data import Dataset, DataLoader\n",
    "from transformers import BertForSequenceClassification, AdamW, get_linear_schedule_with_warmup\n",
    "\n",
    "class IMDbDataset(Dataset):\n",
    "    def __init__(self, encodings, labels):\n",
    "        self.encodings = encodings\n",
    "        self.labels = labels\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.labels)\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}\n",
    "        item['labels'] = torch.tensor(self.labels[idx], dtype=torch.long)\n",
    "        return item\n",
    "\n",
    "# Create PyTorch datasets\n",
    "train_dataset = IMDbDataset(train_encodings, y_train)\n",
    "test_dataset = IMDbDataset(test_encodings, y_test)\n",
    "\n",
    "# Initialize BERT model for classification\n",
    "model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=2)\n",
    "model.train()\n",
    "\n",
    "# Set up optimizer and scheduler\n",
    "optimizer = AdamW(model.parameters(), lr=5e-5)\n",
    "\n",
    "train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)\n",
    "test_loader = DataLoader(test_dataset, batch_size=8, shuffle=False)\n",
    "\n",
    "device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')\n",
    "model.to(device)\n",
    "\n",
    "# Set up scheduler\n",
    "total_steps = len(train_loader) * 3  # 3 epochs\n",
    "scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0, num_training_steps=total_steps)\n",
    "\n",
    "# Training loop\n",
    "for epoch in range(3):  # Adjust the number of epochs as needed\n",
    "    total_loss = 0\n",
    "    \n",
    "    model.train()  # Ensure the model is in training mode\n",
    "    \n",
    "    for batch in train_loader:\n",
    "        optimizer.zero_grad()  # Clear gradients from the previous step\n",
    "        \n",
    "        # Move tensors to the correct device\n",
    "        input_ids = batch['input_ids'].to(device)\n",
    "        attention_mask = batch['attention_mask'].to(device)\n",
    "        labels = batch['labels'].to(device)\n",
    "        \n",
    "        # Forward pass\n",
    "        outputs = model(input_ids, attention_mask=attention_mask, labels=labels)\n",
    "        loss = outputs.loss\n",
    "        \n",
    "        # Backward pass\n",
    "        loss.backward()  # Compute gradients\n",
    "        \n",
    "        # Gradient clipping\n",
    "        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)\n",
    "        \n",
    "        optimizer.step()  # Update parameters\n",
    "        scheduler.step()  # Update learning rate scheduler\n",
    "        \n",
    "        # Accumulate loss\n",
    "        total_loss += loss.item()\n",
    "    \n",
    "    print(f'Epoch {epoch+1}, Loss: {total_loss/len(train_loader)}')\n",
    "\n",
    "# Save the fine-tuned model\n",
    "model.save_pretrained(\"my_pretrained_bert\")\n",
    "tokenizer.save_pretrained(\"my_pretrained_bert\")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "60aa5259-4199-4c97-bff2-e49fbcb9f9bd",
   "metadata": {},
   "source": [
    "# Augment dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "23a93686-ccb2-4d4e-a059-3e94e2f13b36",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "import random\n",
    "from transformers import BertTokenizer, BertModel\n",
    "from tqdm import tqdm\n",
    "from concurrent.futures import ThreadPoolExecutor\n",
    "\n",
    "# Load pre-trained BERT model and tokenizer\n",
    "tokenizer = BertTokenizer.from_pretrained('my_pretrained_bert')\n",
    "bert_model = BertModel.from_pretrained('my_pretrained_bert')\n",
    "\n",
    "def get_bert_embedding(word, bert_model, tokenizer):\n",
    "    inputs = tokenizer(word, return_tensors=\"pt\", truncation=True, padding=True)\n",
    "    with torch.no_grad():\n",
    "        outputs = bert_model(**inputs)\n",
    "    # Ignore special tokens and take the mean of token embeddings\n",
    "    hidden_states = outputs.last_hidden_state\n",
    "    input_ids = inputs['input_ids']\n",
    "    mask = input_ids != tokenizer.pad_token_id\n",
    "    word_embeddings = hidden_states[mask].mean(dim=0)  # Mean over token embeddings\n",
    "    return word_embeddings\n",
    "\n",
    "def cosine_similarity(vec1, vec2):\n",
    "    # Compute cosine similarity between two vectors\n",
    "    return torch.nn.functional.cosine_similarity(vec1.unsqueeze(0), vec2.unsqueeze(0)).item()\n",
    "\n",
    "def precompute_cosine_similarities(word_embeddings):\n",
    "    cosine_similarities = {}\n",
    "    words = list(word_embeddings.keys())\n",
    "    for i, word in enumerate(tqdm(words, desc=\"Precomputing Cosine Similarities\")):\n",
    "        for j in range(i+1, len(words)):\n",
    "            other_word = words[j]\n",
    "            similarity = cosine_similarity(word_embeddings[word], word_embeddings[other_word])\n",
    "            cosine_similarities[(word, other_word)] = similarity\n",
    "            cosine_similarities[(other_word, word)] = similarity\n",
    "    return cosine_similarities\n",
    "\n",
    "def get_similar_words(word, cosine_similarities, all_words, topn=5, dissimilar=False):\n",
    "    similarities = [(cosine_similarities.get((word, other_word), -1), other_word) for other_word in all_words]\n",
    "    similarities.sort(reverse=not dissimilar)\n",
    "    return [word for _, word in similarities[:topn]]\n",
    "\n",
    "def augment_document(doc, label, word_embeddings, cosine_similarities, all_words, percent=5):\n",
    "    tokens = doc.split()\n",
    "    num_words_to_change = int(len(tokens) * (percent / 100))\n",
    "    indices_to_change = set(random.sample(range(len(tokens)), num_words_to_change))\n",
    "\n",
    "    new_tokens = []\n",
    "    for i, word in enumerate(tokens):\n",
    "        if i in indices_to_change and word in word_embeddings:\n",
    "            similar_words = get_similar_words(word, cosine_similarities, all_words, topn=5 if label == 1 else 1000, dissimilar=(label == 0))\n",
    "            if similar_words:  # Check if similar words were found\n",
    "                chosen_word = random.choice(similar_words)\n",
    "                new_tokens.append(chosen_word)\n",
    "            else:\n",
    "                new_tokens.append(word)  # Keep the original word if no similar word found\n",
    "        else:\n",
    "            new_tokens.append(word)\n",
    "\n",
    "    return ' '.join(new_tokens)\n",
    "\n",
    "# Precompute word embeddings\n",
    "all_words = list(set([word for doc in X_train_text for word in doc.split()]))\n",
    "word_embeddings = {word: get_bert_embedding(word, bert_model, tokenizer) for word in tqdm(all_words, desc=\"Embedding Words\")}\n",
    "\n",
    "# Precompute cosine similarities\n",
    "cosine_similarities = precompute_cosine_similarities(word_embeddings)\n",
    "percent_to_change = 5\n",
    "# Parallelize document augmentation\n",
    "def augment_document_parallel(doc_index):\n",
    "    doc = X_train_text[doc_index]\n",
    "    label = y_train[doc_index]\n",
    "    return augment_document(doc, label, word_embeddings, cosine_similarities, all_words, percent_to_change)\n",
    "\n",
    "with ThreadPoolExecutor() as executor:\n",
    "    X_train_augmented = list(tqdm(executor.map(augment_document_parallel, range(len(X_train_text))), total=len(X_train_text)))\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9af9c6e0-b248-4239-9771-356d7209ab62",
   "metadata": {},
   "source": [
    "# Classify dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a4bff02e-d669-4d08-adab-48426af1f7a5",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "from sklearn.feature_extraction.text import CountVectorizer\n",
    "# Use CountVectorizer to convert text to vectors\n",
    "vectorizer = CountVectorizer()\n",
    "X_train_vec = vectorizer.fit_transform(X_train_augmented)\n",
    "X_test_vec = vectorizer.transform(X_test_text)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7460ce85-15a2-4a56-b53f-a4a71cf2897d",
   "metadata": {},
   "source": [
    "### RandomForestClassifier"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6c456a7b-666b-4c2e-a1ea-1e33da666f95",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.ensemble import RandomForestClassifier\n",
    "# Train the classifier on augmented data\n",
    "classifier = RandomForestClassifier(n_estimators=100, random_state=42)\n",
    "classifier.fit(X_train_vec, y_train)\n",
    "\n",
    "# Predict and evaluate on original test data\n",
    "y_pred = classifier.predict(X_test_vec)\n",
    "accuracy = accuracy_score(y_test, y_pred)\n",
    "print(f\"Test Accuracy: {accuracy:.4f}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f73e435c-a864-4d1d-b105-283f39a8dfbc",
   "metadata": {},
   "source": [
    "### LogisticRegression"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2727f318-f2c4-4c93-9616-585b4d956908",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.linear_model import LogisticRegression\n",
    "\n",
    "classifier = LogisticRegression(random_state=42)\n",
    "classifier.fit(X_train_vec, y_train)\n",
    "\n",
    "y_pred = classifier.predict(X_test_vec)\n",
    "accuracy = accuracy_score(y_test, y_pred)\n",
    "print(f\"Test Accuracy (Logistic Regression): {accuracy:.4f}\")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "457413ed-dbb9-4a8f-bbe0-44cf299bc4d3",
   "metadata": {},
   "source": [
    "### Naive Bayes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c544a8d0-ce68-49cb-8c12-4c3e7b2e62bb",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.naive_bayes import MultinomialNB\n",
    "\n",
    "classifier = MultinomialNB()\n",
    "classifier.fit(X_train_vec, y_train)\n",
    "\n",
    "y_pred = classifier.predict(X_test_vec)\n",
    "accuracy = accuracy_score(y_test, y_pred)\n",
    "print(f\"Test Accuracy (Naive Bayes): {accuracy:.4f}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e7588d57-0c05-46ce-bef6-a97b99d8e61d",
   "metadata": {},
   "source": [
    "### SVM"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "449b7dec-aeac-4419-b3c0-d121c8cee59f",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.svm import LinearSVC\n",
    "\n",
    "classifier = LinearSVC(random_state=42)\n",
    "classifier.fit(X_train_vec, y_train)\n",
    "\n",
    "y_pred = classifier.predict(X_test_vec)\n",
    "accuracy = accuracy_score(y_test, y_pred)\n",
    "print(f\"Test Accuracy (SVM): {accuracy:.4f}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e90fede3-56c9-4ffa-ba88-3d4bc061ac0a",
   "metadata": {},
   "source": [
    "### MLP"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3b7ed7b3-6a3a-4a97-ad67-841f300ed52d",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.neural_network import MLPClassifier\n",
    "\n",
    "classifier = MLPClassifier(random_state=42)\n",
    "classifier.fit(X_train_vec, y_train)\n",
    "\n",
    "y_pred = classifier.predict(X_test_vec)\n",
    "accuracy = accuracy_score(y_test, y_pred)\n",
    "print(f\"Test Accuracy (MLP): {accuracy:.4f}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "37cc8102-219a-4746-999d-8277296bc999",
   "metadata": {},
   "source": [
    "### TM"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "87999304-113e-4d01-be9c-390793513d95",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "from tmu.models.classification.vanilla_classifier import TMClassifier\n",
    "import numpy as np\n",
    "import pickle\n",
    "\n",
    "# Convert labels to one-hot encoding for Tsetlin Machine\n",
    "X_train_tm = np.array(X_train_vec.toarray(), dtype=np.uint32)\n",
    "Y_train_tm = y_train.astype(np.uint32)\n",
    "\n",
    "X_test_tm = np.array(X_test_vec.toarray(), dtype=np.uint32)\n",
    "Y_test_tm = y_test.astype(np.uint32)\n",
    "\n",
    "num_clauses = 1000\n",
    "T = 8000\n",
    "s = 2.0\n",
    "device = \"CPU\"\n",
    "weighted_clauses = True\n",
    "epochs = 10\n",
    "clause_drop_p = 0.75\n",
    "\n",
    "print(\"started\")\n",
    "tm = TMClassifier(num_clauses, T, s, platform=device, weighted_clauses=weighted_clauses,clause_drop_p=clause_drop_p)\n",
    "for epoch in range(epochs):\n",
    "    tm.fit(X_train_tm, Y_train_tm)\n",
    "    result = 100 * (tm.predict(X_test_tm) == Y_test_tm).mean()\n",
    "    print(f\"Accuracy: {result:.2f}\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
