{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "018cf7e4",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2024-03-27 04:28:51.338556: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n",
      "To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "import tensorflow as tf\n",
    "import transformers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "1e6ab69e",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current\n",
      "                                 Dload  Upload   Total   Spent    Left  Speed\n",
      "100 11.1M  100 11.1M    0     0   371k      0  0:00:30  0:00:30 --:--:--  431k\n",
      "SNLI_Corpus/\n",
      "SNLI_Corpus/snli_1.0_dev.csv\n",
      "SNLI_Corpus/snli_1.0_train.csv\n",
      "SNLI_Corpus/snli_1.0_test.csv\n"
     ]
    }
   ],
   "source": [
    "!curl -LO https://raw.githubusercontent.com/MohamadMerchant/SNLI/master/data.tar.gz\n",
    "!tar -xvzf data.tar.gz"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "cb9c1eb3",
   "metadata": {},
   "outputs": [],
   "source": [
    "max_length = 128  # Maximum length of input sentence to the model.\n",
    "batch_size = 32\n",
    "epochs = 2\n",
    "\n",
    "# Labels in our dataset.\n",
    "labels = [\"contradiction\", \"entailment\", \"neutral\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "69390fb7",
   "metadata": {},
   "outputs": [],
   "source": [
    "# There are more than 550k samples in total; we will use 100k for this example.\n",
    "train_df = pd.read_csv(\"SNLI_Corpus/snli_1.0_train.csv\", nrows=100000)\n",
    "valid_df = pd.read_csv(\"SNLI_Corpus/snli_1.0_dev.csv\")\n",
    "test_df = pd.read_csv(\"SNLI_Corpus/snli_1.0_test.csv\")\n",
    "\n",
    "train_df.dropna(axis=0, inplace=True)\n",
    "\n",
    "train_df = (\n",
    "    train_df[train_df.similarity != \"-\"]\n",
    "    .sample(frac=1.0, random_state=42)\n",
    "    .reset_index(drop=True)\n",
    ")\n",
    "valid_df = (\n",
    "    valid_df[valid_df.similarity != \"-\"]\n",
    "    .sample(frac=1.0, random_state=42)\n",
    "    .reset_index(drop=True)\n",
    ")\n",
    "\n",
    "train_df[\"label\"] = train_df[\"similarity\"].apply(\n",
    "    lambda x: 0 if x == \"contradiction\" else 1 if x == \"entailment\" else 2\n",
    ")\n",
    "y_train = tf.keras.utils.to_categorical(train_df.label, num_classes=3)\n",
    "\n",
    "valid_df[\"label\"] = valid_df[\"similarity\"].apply(\n",
    "    lambda x: 0 if x == \"contradiction\" else 1 if x == \"entailment\" else 2\n",
    ")\n",
    "y_val = tf.keras.utils.to_categorical(valid_df.label, num_classes=3)\n",
    "\n",
    "test_df[\"label\"] = test_df[\"similarity\"].apply(\n",
    "    lambda x: 0 if x == \"contradiction\" else 1 if x == \"entailment\" else 2\n",
    ")\n",
    "y_test = tf.keras.utils.to_categorical(test_df.label, num_classes=3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "a2f6895b",
   "metadata": {},
   "outputs": [],
   "source": [
    "class BertSemanticDataGenerator(tf.keras.utils.Sequence):\n",
    "    \"\"\"Generates batches of data.\n",
    "\n",
    "    Args:\n",
    "        sentence_pairs: Array of premise and hypothesis input sentences.\n",
    "        labels: Array of labels.\n",
    "        batch_size: Integer batch size.\n",
    "        shuffle: boolean, whether to shuffle the data.\n",
    "        include_targets: boolean, whether to incude the labels.\n",
    "\n",
    "    Returns:\n",
    "        Tuples `([input_ids, attention_mask, `token_type_ids], labels)`\n",
    "        (or just `[input_ids, attention_mask, `token_type_ids]`\n",
    "         if `include_targets=False`)\n",
    "    \"\"\"\n",
    "\n",
    "    def __init__(\n",
    "        self,\n",
    "        sentence_pairs,\n",
    "        labels,\n",
    "        batch_size=batch_size,\n",
    "        shuffle=True,\n",
    "        include_targets=True,\n",
    "    ):\n",
    "        self.sentence_pairs = sentence_pairs\n",
    "        self.labels = labels\n",
    "        self.shuffle = shuffle\n",
    "        self.batch_size = batch_size\n",
    "        self.include_targets = include_targets\n",
    "        # Load our BERT Tokenizer to encode the text.\n",
    "        # We will use base-base-uncased pretrained model.\n",
    "        self.tokenizer = transformers.BertTokenizer.from_pretrained(\n",
    "            \"bert-base-uncased\", do_lower_case=True\n",
    "        )\n",
    "        self.indexes = np.arange(len(self.sentence_pairs))\n",
    "        self.on_epoch_end()\n",
    "\n",
    "    def __len__(self):\n",
    "        # Denotes the number of batches per epoch.\n",
    "        return len(self.sentence_pairs) // self.batch_size\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        # Retrieves the batch of index.\n",
    "        indexes = self.indexes[idx * self.batch_size : (idx + 1) * self.batch_size]\n",
    "        sentence_pairs = self.sentence_pairs[indexes]\n",
    "\n",
    "        # With BERT tokenizer's batch_encode_plus batch of both the sentences are\n",
    "        # encoded together and separated by [SEP] token.\n",
    "        encoded = self.tokenizer.batch_encode_plus(\n",
    "            sentence_pairs.tolist(),\n",
    "            add_special_tokens=True,\n",
    "            max_length=max_length,\n",
    "            return_attention_mask=True,\n",
    "            return_token_type_ids=True,\n",
    "            pad_to_max_length=True,\n",
    "            return_tensors=\"tf\",\n",
    "        )\n",
    "\n",
    "        # Convert batch of encoded features to numpy array.\n",
    "        input_ids = np.array(encoded[\"input_ids\"], dtype=\"int32\")\n",
    "        attention_masks = np.array(encoded[\"attention_mask\"], dtype=\"int32\")\n",
    "        token_type_ids = np.array(encoded[\"token_type_ids\"], dtype=\"int32\")\n",
    "\n",
    "        # Set to true if data generator is used for training/validation.\n",
    "        if self.include_targets:\n",
    "            labels = np.array(self.labels[indexes], dtype=\"int32\")\n",
    "            return [input_ids, attention_masks, token_type_ids], labels\n",
    "        else:\n",
    "            return [input_ids, attention_masks, token_type_ids]\n",
    "\n",
    "    def on_epoch_end(self):\n",
    "        # Shuffle indexes after each epoch if shuffle is set to True.\n",
    "        if self.shuffle:\n",
    "            np.random.RandomState(42).shuffle(self.indexes)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "05f53d7a",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)\n"
     ]
    }
   ],
   "source": [
    "# Create the model under a distribution strategy scope.\n",
    "strategy = tf.distribute.MirroredStrategy()\n",
    "\n",
    "with strategy.scope():\n",
    "    # Encoded token ids from BERT tokenizer.\n",
    "    input_ids = tf.keras.layers.Input(\n",
    "        shape=(max_length,), dtype=tf.int32, name=\"input_ids\"\n",
    "    )\n",
    "    # Attention masks indicates to the model which tokens should be attended to.\n",
    "    attention_masks = tf.keras.layers.Input(\n",
    "        shape=(max_length,), dtype=tf.int32, name=\"attention_masks\"\n",
    "    )\n",
    "    # Token type ids are binary masks identifying different sequences in the model.\n",
    "    token_type_ids = tf.keras.layers.Input(\n",
    "        shape=(max_length,), dtype=tf.int32, name=\"token_type_ids\"\n",
    "    )\n",
    "    # Loading pretrained BERT model.\n",
    "    bert_model = transformers.TFBertModel.from_pretrained(\"bert-base-uncased\")\n",
    "    # Freeze the BERT model to reuse the pretrained features without modifying them.\n",
    "    bert_model.trainable = False\n",
    "\n",
    "    bert_output = bert_model.bert(\n",
    "        input_ids, attention_mask=attention_masks, token_type_ids=token_type_ids\n",
    "    )\n",
    "    sequence_output = bert_output.last_hidden_state\n",
    "    pooled_output = bert_output.pooler_output\n",
    "    # Add trainable layers on top of frozen layers to adapt the pretrained features on the new data.\n",
    "    bi_lstm = tf.keras.layers.Bidirectional(\n",
    "        tf.keras.layers.LSTM(64, return_sequences=True)\n",
    "    )(sequence_output)\n",
    "    # Applying hybrid pooling approach to bi_lstm sequence output.\n",
    "    avg_pool = tf.keras.layers.GlobalAveragePooling1D()(bi_lstm)\n",
    "    max_pool = tf.keras.layers.GlobalMaxPooling1D()(bi_lstm)\n",
    "    concat = tf.keras.layers.concatenate([avg_pool, max_pool])\n",
    "    dropout = tf.keras.layers.Dropout(0.3)(concat)\n",
    "    output = tf.keras.layers.Dense(3, activation=\"softmax\")(dropout)\n",
    "    model = tf.keras.models.Model(\n",
    "        inputs=[input_ids, attention_masks, token_type_ids], outputs=output\n",
    "    )\n",
    "\n",
    "    model.compile(\n",
    "        optimizer=tf.keras.optimizers.Adam(),\n",
    "        loss=\"categorical_crossentropy\",\n",
    "        metrics=[\"acc\"],\n",
    "    )\n",
    "\n",
    "\n",
    "print(f\"Strategy: {strategy}\")\n",
    "model.summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8191f0c1",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
