{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "0168ea3e",
   "metadata": {},
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "import tensorflow_datasets as tfds\n",
    "import keras\n",
    "import keras_nlp\n",
    "import os"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "id": "15df8a8a",
   "metadata": {},
   "outputs": [],
   "source": [
    "snli_train = tfds.load(\"snli\", split=\"train[:20%]\", data_dir = '../tensorflow_datasets/')\n",
    "snli_val = tfds.load(\"snli\", split=\"validation\", data_dir = '../tensorflow_datasets/')\n",
    "snli_test = tfds.load(\"snli\", split=\"test\", data_dir = '../tensorflow_datasets/')\n",
    "\n",
    "def filter_labels(sample):\n",
    "    return sample[\"label\"] >= 0\n",
    "\n",
    "def split_labels(sample):\n",
    "    x = (sample[\"hypothesis\"], sample[\"premise\"])\n",
    "    y = sample[\"label\"]\n",
    "    return x, y\n",
    "\n",
    "ds_train = (snli_train.filter(filter_labels).map(split_labels, num_parallel_calls=tf.data.AUTOTUNE).batch(16))\n",
    "ds_val = (snli_val.filter(filter_labels).map(split_labels, num_parallel_calls=tf.data.AUTOTUNE).batch(16))\n",
    "ds_test = (snli_test.filter(filter_labels).map(split_labels, num_parallel_calls=tf.data.AUTOTUNE).batch(16))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "7691e767",
   "metadata": {},
   "outputs": [],
   "source": [
    "bert_classifier = keras_nlp.models.BertClassifier.from_preset(\n",
    "    \"bert_tiny_en_uncased\", num_classes=3\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1b4aee89",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/100\n",
      "   6867/Unknown \u001b[1m62s\u001b[0m 7ms/step - loss: 0.8607 - sparse_categorical_accuracy: 0.6026"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2024-03-26 14:51:26.867264: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n",
      "\t [[{{node IteratorGetNext}}]]\n",
      "2024-03-26 14:51:26.867326: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n",
      "\t [[{{node IteratorGetNext}}]]\n",
      "\t [[IteratorGetNext/_6]]\n",
      "2024-03-26 14:51:26.867407: I tensorflow/core/framework/local_rendezvous.cc:422] Local rendezvous recv item cancelled. Key hash: 17060731988147034527\n",
      "2024-03-26 14:51:26.867418: I tensorflow/core/framework/local_rendezvous.cc:422] Local rendezvous recv item cancelled. Key hash: 12498018463839324721\n",
      "2024-03-26 14:51:26.867427: I tensorflow/core/framework/local_rendezvous.cc:422] Local rendezvous recv item cancelled. Key hash: 6495130065868410671\n",
      "2024-03-26 14:51:26.867462: I tensorflow/core/framework/local_rendezvous.cc:422] Local rendezvous recv item cancelled. Key hash: 11843840089427515282\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\u001b[1m6867/6867\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m67s\u001b[0m 7ms/step - loss: 0.8607 - sparse_categorical_accuracy: 0.6026 - val_loss: 0.5836 - val_sparse_categorical_accuracy: 0.7604\n",
      "Epoch 2/100\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2024-03-26 14:51:31.123512: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n",
      "\t [[{{node IteratorGetNext}}]]\n",
      "2024-03-26 14:51:31.123592: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n",
      "\t [[{{node IteratorGetNext}}]]\n",
      "\t [[IteratorGetNext/_2]]\n",
      "2024-03-26 14:51:31.123651: I tensorflow/core/framework/local_rendezvous.cc:422] Local rendezvous recv item cancelled. Key hash: 11843840089427515282\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\u001b[1m6858/6867\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m━\u001b[0m \u001b[1m0s\u001b[0m 5ms/step - loss: 0.6515 - sparse_categorical_accuracy: 0.7282"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2024-03-26 14:52:05.700815: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n",
      "\t [[{{node IteratorGetNext}}]]\n",
      "2024-03-26 14:52:05.700884: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n",
      "\t [[{{node IteratorGetNext}}]]\n",
      "\t [[IteratorGetNext/_2]]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\u001b[1m6867/6867\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m38s\u001b[0m 5ms/step - loss: 0.6515 - sparse_categorical_accuracy: 0.7282 - val_loss: 0.5412 - val_sparse_categorical_accuracy: 0.7815\n",
      "Epoch 3/100\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2024-03-26 14:52:08.808901: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n",
      "\t [[{{node IteratorGetNext}}]]\n",
      "\t [[IteratorGetNext/_4]]\n",
      "2024-03-26 14:52:08.808970: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n",
      "\t [[{{node IteratorGetNext}}]]\n",
      "2024-03-26 14:52:08.809034: I tensorflow/core/framework/local_rendezvous.cc:422] Local rendezvous recv item cancelled. Key hash: 11843840089427515282\n",
      "2024-03-26 14:52:08.809051: I tensorflow/core/framework/local_rendezvous.cc:422] Local rendezvous recv item cancelled. Key hash: 12498018463839324721\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\u001b[1m6867/6867\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 5ms/step - loss: 0.5901 - sparse_categorical_accuracy: 0.7588"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2024-03-26 14:52:43.420835: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n",
      "\t [[{{node IteratorGetNext}}]]\n",
      "2024-03-26 14:52:43.420913: I tensorflow/core/framework/local_rendezvous.cc:422] Local rendezvous recv item cancelled. Key hash: 11843840089427515282\n",
      "2024-03-26 14:52:43.420922: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n",
      "\t [[{{node IteratorGetNext}}]]\n",
      "\t [[IteratorGetNext/_8]]\n",
      "2024-03-26 14:52:43.421183: I tensorflow/core/framework/local_rendezvous.cc:422] Local rendezvous recv item cancelled. Key hash: 17060731988147034527\n",
      "2024-03-26 14:52:43.421222: I tensorflow/core/framework/local_rendezvous.cc:422] Local rendezvous recv item cancelled. Key hash: 12498018463839324721\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\u001b[1m6867/6867\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m37s\u001b[0m 5ms/step - loss: 0.5901 - sparse_categorical_accuracy: 0.7588 - val_loss: 0.5218 - val_sparse_categorical_accuracy: 0.7925\n",
      "Epoch 4/100\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2024-03-26 14:52:45.881909: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n",
      "\t [[{{node IteratorGetNext}}]]\n",
      "2024-03-26 14:52:45.881965: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n",
      "\t [[{{node IteratorGetNext}}]]\n",
      "\t [[IteratorGetNext/_2]]\n",
      "2024-03-26 14:52:45.882013: I tensorflow/core/framework/local_rendezvous.cc:422] Local rendezvous recv item cancelled. Key hash: 11843840089427515282\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\u001b[1m6864/6867\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m━\u001b[0m \u001b[1m0s\u001b[0m 5ms/step - loss: 0.5472 - sparse_categorical_accuracy: 0.7815"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2024-03-26 14:53:20.080055: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n",
      "\t [[{{node IteratorGetNext}}]]\n",
      "\t [[IteratorGetNext/_4]]\n",
      "2024-03-26 14:53:20.080133: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n",
      "\t [[{{node IteratorGetNext}}]]\n",
      "2024-03-26 14:53:20.080252: I tensorflow/core/framework/local_rendezvous.cc:422] Local rendezvous recv item cancelled. Key hash: 12498018463839324721\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\u001b[1m6867/6867\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m37s\u001b[0m 5ms/step - loss: 0.5472 - sparse_categorical_accuracy: 0.7815 - val_loss: 0.5167 - val_sparse_categorical_accuracy: 0.7934\n",
      "Epoch 5/100\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2024-03-26 14:53:22.541906: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n",
      "\t [[{{node IteratorGetNext}}]]\n",
      "2024-03-26 14:53:22.541975: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n",
      "\t [[{{node IteratorGetNext}}]]\n",
      "\t [[IteratorGetNext/_2]]\n",
      "2024-03-26 14:53:22.542024: I tensorflow/core/framework/local_rendezvous.cc:422] Local rendezvous recv item cancelled. Key hash: 11843840089427515282\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\u001b[1m6863/6867\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m━\u001b[0m \u001b[1m0s\u001b[0m 5ms/step - loss: 0.5099 - sparse_categorical_accuracy: 0.7975"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2024-03-26 14:53:57.214352: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n",
      "\t [[{{node IteratorGetNext}}]]\n",
      "2024-03-26 14:53:57.214425: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n",
      "\t [[{{node IteratorGetNext}}]]\n",
      "\t [[IteratorGetNext/_2]]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\u001b[1m6867/6867\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m37s\u001b[0m 5ms/step - loss: 0.5099 - sparse_categorical_accuracy: 0.7975 - val_loss: 0.5095 - val_sparse_categorical_accuracy: 0.8005\n",
      "Epoch 6/100\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2024-03-26 14:53:59.664372: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n",
      "\t [[{{node IteratorGetNext}}]]\n",
      "2024-03-26 14:53:59.664446: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n",
      "\t [[{{node IteratorGetNext}}]]\n",
      "\t [[IteratorGetNext/_6]]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\u001b[1m6867/6867\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 5ms/step - loss: 0.4779 - sparse_categorical_accuracy: 0.8144"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2024-03-26 14:54:34.254760: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n",
      "\t [[{{node IteratorGetNext}}]]\n",
      "2024-03-26 14:54:34.254852: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n",
      "\t [[{{node IteratorGetNext}}]]\n",
      "\t [[IteratorGetNext/_8]]\n",
      "2024-03-26 14:54:34.254934: I tensorflow/core/framework/local_rendezvous.cc:422] Local rendezvous recv item cancelled. Key hash: 11843840089427515282\n",
      "2024-03-26 14:54:34.254968: I tensorflow/core/framework/local_rendezvous.cc:422] Local rendezvous recv item cancelled. Key hash: 17060731988147034527\n",
      "2024-03-26 14:54:34.254979: I tensorflow/core/framework/local_rendezvous.cc:422] Local rendezvous recv item cancelled. Key hash: 12498018463839324721\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\u001b[1m6867/6867\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m37s\u001b[0m 5ms/step - loss: 0.4779 - sparse_categorical_accuracy: 0.8144 - val_loss: 0.5164 - val_sparse_categorical_accuracy: 0.8010\n",
      "Epoch 7/100\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2024-03-26 14:54:36.734742: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n",
      "\t [[{{node IteratorGetNext}}]]\n",
      "2024-03-26 14:54:36.734795: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n",
      "\t [[{{node IteratorGetNext}}]]\n",
      "\t [[IteratorGetNext/_2]]\n",
      "2024-03-26 14:54:36.734860: I tensorflow/core/framework/local_rendezvous.cc:422] Local rendezvous recv item cancelled. Key hash: 11843840089427515282\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\u001b[1m6861/6867\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m━\u001b[0m \u001b[1m0s\u001b[0m 5ms/step - loss: 0.4517 - sparse_categorical_accuracy: 0.8265"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2024-03-26 14:55:11.181993: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n",
      "\t [[{{node IteratorGetNext}}]]\n",
      "2024-03-26 14:55:11.182093: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n",
      "\t [[{{node IteratorGetNext}}]]\n",
      "\t [[IteratorGetNext/_6]]\n",
      "2024-03-26 14:55:11.182245: I tensorflow/core/framework/local_rendezvous.cc:422] Local rendezvous recv item cancelled. Key hash: 17060731988147034527\n",
      "2024-03-26 14:55:11.182259: I tensorflow/core/framework/local_rendezvous.cc:422] Local rendezvous recv item cancelled. Key hash: 12498018463839324721\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\u001b[1m6867/6867\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m37s\u001b[0m 5ms/step - loss: 0.4517 - sparse_categorical_accuracy: 0.8265 - val_loss: 0.5310 - val_sparse_categorical_accuracy: 0.7989\n",
      "Epoch 8/100\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2024-03-26 14:55:13.632124: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n",
      "\t [[{{node IteratorGetNext}}]]\n",
      "2024-03-26 14:55:13.632198: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n",
      "\t [[{{node IteratorGetNext}}]]\n",
      "\t [[IteratorGetNext/_2]]\n",
      "2024-03-26 14:55:13.632262: I tensorflow/core/framework/local_rendezvous.cc:422] Local rendezvous recv item cancelled. Key hash: 11843840089427515282\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\u001b[1m6862/6867\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m━\u001b[0m \u001b[1m0s\u001b[0m 5ms/step - loss: 0.4282 - sparse_categorical_accuracy: 0.8371"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2024-03-26 14:55:47.721348: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n",
      "\t [[{{node IteratorGetNext}}]]\n",
      "2024-03-26 14:55:47.721442: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n",
      "\t [[{{node IteratorGetNext}}]]\n",
      "\t [[IteratorGetNext/_4]]\n",
      "2024-03-26 14:55:47.721539: I tensorflow/core/framework/local_rendezvous.cc:422] Local rendezvous recv item cancelled. Key hash: 12498018463839324721\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\u001b[1m6867/6867\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m37s\u001b[0m 5ms/step - loss: 0.4282 - sparse_categorical_accuracy: 0.8371 - val_loss: 0.5317 - val_sparse_categorical_accuracy: 0.8067\n",
      "Epoch 9/100\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2024-03-26 14:55:50.209389: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n",
      "\t [[{{node IteratorGetNext}}]]\n",
      "2024-03-26 14:55:50.209442: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n",
      "\t [[{{node IteratorGetNext}}]]\n",
      "\t [[IteratorGetNext/_6]]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\u001b[1m6864/6867\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m━\u001b[0m \u001b[1m0s\u001b[0m 5ms/step - loss: 0.4038 - sparse_categorical_accuracy: 0.8472"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2024-03-26 14:56:24.623601: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n",
      "\t [[{{node IteratorGetNext}}]]\n",
      "2024-03-26 14:56:24.623707: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n",
      "\t [[{{node IteratorGetNext}}]]\n",
      "\t [[IteratorGetNext/_4]]\n",
      "2024-03-26 14:56:24.623834: I tensorflow/core/framework/local_rendezvous.cc:422] Local rendezvous recv item cancelled. Key hash: 12498018463839324721\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\u001b[1m6867/6867\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m37s\u001b[0m 5ms/step - loss: 0.4038 - sparse_categorical_accuracy: 0.8472 - val_loss: 0.5299 - val_sparse_categorical_accuracy: 0.8045\n",
      "Epoch 10/100\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2024-03-26 14:56:27.006417: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n",
      "\t [[{{node IteratorGetNext}}]]\n",
      "2024-03-26 14:56:27.006466: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n",
      "\t [[{{node IteratorGetNext}}]]\n",
      "\t [[IteratorGetNext/_2]]\n",
      "2024-03-26 14:56:27.006526: I tensorflow/core/framework/local_rendezvous.cc:422] Local rendezvous recv item cancelled. Key hash: 11843840089427515282\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\u001b[1m6858/6867\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m━\u001b[0m \u001b[1m0s\u001b[0m 5ms/step - loss: 0.3826 - sparse_categorical_accuracy: 0.8564"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2024-03-26 14:57:01.094075: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n",
      "\t [[{{node IteratorGetNext}}]]\n",
      "2024-03-26 14:57:01.094153: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n",
      "\t [[{{node IteratorGetNext}}]]\n",
      "\t [[IteratorGetNext/_8]]\n",
      "2024-03-26 14:57:01.094214: I tensorflow/core/framework/local_rendezvous.cc:422] Local rendezvous recv item cancelled. Key hash: 11843840089427515282\n",
      "2024-03-26 14:57:01.094228: I tensorflow/core/framework/local_rendezvous.cc:422] Local rendezvous recv item cancelled. Key hash: 17060731988147034527\n",
      "2024-03-26 14:57:01.094239: I tensorflow/core/framework/local_rendezvous.cc:422] Local rendezvous recv item cancelled. Key hash: 12498018463839324721\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\u001b[1m6867/6867\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m37s\u001b[0m 5ms/step - loss: 0.3826 - sparse_categorical_accuracy: 0.8564 - val_loss: 0.5496 - val_sparse_categorical_accuracy: 0.8068\n",
      "Epoch 11/100\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2024-03-26 14:57:03.537706: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n",
      "\t [[{{node IteratorGetNext}}]]\n",
      "2024-03-26 14:57:03.537759: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n",
      "\t [[{{node IteratorGetNext}}]]\n",
      "\t [[IteratorGetNext/_4]]\n",
      "2024-03-26 14:57:03.537825: I tensorflow/core/framework/local_rendezvous.cc:422] Local rendezvous recv item cancelled. Key hash: 11843840089427515282\n",
      "2024-03-26 14:57:03.537841: I tensorflow/core/framework/local_rendezvous.cc:422] Local rendezvous recv item cancelled. Key hash: 12498018463839324721\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\u001b[1m6860/6867\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m━\u001b[0m \u001b[1m0s\u001b[0m 5ms/step - loss: 0.3619 - sparse_categorical_accuracy: 0.8645"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2024-03-26 14:57:37.054383: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n",
      "\t [[{{node IteratorGetNext}}]]\n",
      "2024-03-26 14:57:37.054526: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n",
      "\t [[{{node IteratorGetNext}}]]\n",
      "\t [[IteratorGetNext/_2]]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\u001b[1m6867/6867\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m37s\u001b[0m 5ms/step - loss: 0.3619 - sparse_categorical_accuracy: 0.8645 - val_loss: 0.5673 - val_sparse_categorical_accuracy: 0.8031\n",
      "Epoch 12/100\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2024-03-26 14:57:40.276747: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n",
      "\t [[{{node IteratorGetNext}}]]\n",
      "2024-03-26 14:57:40.276807: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n",
      "\t [[{{node IteratorGetNext}}]]\n",
      "\t [[IteratorGetNext/_2]]\n",
      "2024-03-26 14:57:40.276867: I tensorflow/core/framework/local_rendezvous.cc:422] Local rendezvous recv item cancelled. Key hash: 11843840089427515282\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\u001b[1m6862/6867\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m━\u001b[0m \u001b[1m0s\u001b[0m 5ms/step - loss: 0.3403 - sparse_categorical_accuracy: 0.8715"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2024-03-26 14:58:14.123251: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n",
      "\t [[{{node IteratorGetNext}}]]\n",
      "2024-03-26 14:58:14.123363: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n",
      "\t [[{{node IteratorGetNext}}]]\n",
      "\t [[IteratorGetNext/_8]]\n",
      "2024-03-26 14:58:14.123439: I tensorflow/core/framework/local_rendezvous.cc:422] Local rendezvous recv item cancelled. Key hash: 11843840089427515282\n",
      "2024-03-26 14:58:14.123467: I tensorflow/core/framework/local_rendezvous.cc:422] Local rendezvous recv item cancelled. Key hash: 17060731988147034527\n",
      "2024-03-26 14:58:14.123479: I tensorflow/core/framework/local_rendezvous.cc:422] Local rendezvous recv item cancelled. Key hash: 12498018463839324721\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\u001b[1m6867/6867\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m36s\u001b[0m 5ms/step - loss: 0.3403 - sparse_categorical_accuracy: 0.8715 - val_loss: 0.5733 - val_sparse_categorical_accuracy: 0.8025\n",
      "Epoch 13/100\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2024-03-26 14:58:16.666000: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n",
      "\t [[{{node IteratorGetNext}}]]\n",
      "2024-03-26 14:58:16.666067: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n",
      "\t [[{{node IteratorGetNext}}]]\n",
      "\t [[IteratorGetNext/_2]]\n",
      "2024-03-26 14:58:16.666137: I tensorflow/core/framework/local_rendezvous.cc:422] Local rendezvous recv item cancelled. Key hash: 11843840089427515282\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\u001b[1m6863/6867\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m━\u001b[0m \u001b[1m0s\u001b[0m 5ms/step - loss: 0.3252 - sparse_categorical_accuracy: 0.8795"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2024-03-26 14:58:50.016614: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n",
      "\t [[{{node IteratorGetNext}}]]\n",
      "2024-03-26 14:58:50.016715: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n",
      "\t [[{{node IteratorGetNext}}]]\n",
      "\t [[IteratorGetNext/_8]]\n",
      "2024-03-26 14:58:50.016873: I tensorflow/core/framework/local_rendezvous.cc:422] Local rendezvous recv item cancelled. Key hash: 11843840089427515282\n",
      "2024-03-26 14:58:50.016884: I tensorflow/core/framework/local_rendezvous.cc:422] Local rendezvous recv item cancelled. Key hash: 17060731988147034527\n",
      "2024-03-26 14:58:50.016893: I tensorflow/core/framework/local_rendezvous.cc:422] Local rendezvous recv item cancelled. Key hash: 12498018463839324721\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\u001b[1m6867/6867\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m36s\u001b[0m 5ms/step - loss: 0.3252 - sparse_categorical_accuracy: 0.8795 - val_loss: 0.6018 - val_sparse_categorical_accuracy: 0.7996\n",
      "Epoch 14/100\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2024-03-26 14:58:52.469657: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n",
      "\t [[{{node IteratorGetNext}}]]\n",
      "2024-03-26 14:58:52.469720: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n",
      "\t [[{{node IteratorGetNext}}]]\n",
      "\t [[IteratorGetNext/_2]]\n",
      "2024-03-26 14:58:52.469783: I tensorflow/core/framework/local_rendezvous.cc:422] Local rendezvous recv item cancelled. Key hash: 11843840089427515282\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\u001b[1m6863/6867\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m━\u001b[0m \u001b[1m0s\u001b[0m 5ms/step - loss: 0.3101 - sparse_categorical_accuracy: 0.8856"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2024-03-26 14:59:25.336007: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n",
      "\t [[{{node IteratorGetNext}}]]\n",
      "2024-03-26 14:59:25.336088: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n",
      "\t [[{{node IteratorGetNext}}]]\n",
      "\t [[IteratorGetNext/_8]]\n",
      "2024-03-26 14:59:25.336171: I tensorflow/core/framework/local_rendezvous.cc:422] Local rendezvous recv item cancelled. Key hash: 11843840089427515282\n",
      "2024-03-26 14:59:25.336178: I tensorflow/core/framework/local_rendezvous.cc:422] Local rendezvous recv item cancelled. Key hash: 17060731988147034527\n",
      "2024-03-26 14:59:25.336185: I tensorflow/core/framework/local_rendezvous.cc:422] Local rendezvous recv item cancelled. Key hash: 12498018463839324721\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\u001b[1m6867/6867\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m35s\u001b[0m 5ms/step - loss: 0.3101 - sparse_categorical_accuracy: 0.8856 - val_loss: 0.6025 - val_sparse_categorical_accuracy: 0.8024\n",
      "Epoch 15/100\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2024-03-26 14:59:27.758422: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n",
      "\t [[{{node IteratorGetNext}}]]\n",
      "2024-03-26 14:59:27.758477: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n",
      "\t [[{{node IteratorGetNext}}]]\n",
      "\t [[IteratorGetNext/_2]]\n",
      "2024-03-26 14:59:27.758593: I tensorflow/core/framework/local_rendezvous.cc:422] Local rendezvous recv item cancelled. Key hash: 11843840089427515282\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\u001b[1m6863/6867\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m━\u001b[0m \u001b[1m0s\u001b[0m 5ms/step - loss: 0.2962 - sparse_categorical_accuracy: 0.8916"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2024-03-26 15:00:00.912645: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n",
      "\t [[{{node IteratorGetNext}}]]\n",
      "2024-03-26 15:00:00.912727: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n",
      "\t [[{{node IteratorGetNext}}]]\n",
      "\t [[IteratorGetNext/_8]]\n",
      "2024-03-26 15:00:00.912761: I tensorflow/core/framework/local_rendezvous.cc:422] Local rendezvous recv item cancelled. Key hash: 11843840089427515282\n",
      "2024-03-26 15:00:00.912774: I tensorflow/core/framework/local_rendezvous.cc:422] Local rendezvous recv item cancelled. Key hash: 17060731988147034527\n",
      "2024-03-26 15:00:00.912790: I tensorflow/core/framework/local_rendezvous.cc:422] Local rendezvous recv item cancelled. Key hash: 12498018463839324721\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\u001b[1m6867/6867\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m36s\u001b[0m 5ms/step - loss: 0.2962 - sparse_categorical_accuracy: 0.8916 - val_loss: 0.6272 - val_sparse_categorical_accuracy: 0.7992\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2024-03-26 15:00:03.329242: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n",
      "\t [[{{node IteratorGetNext}}]]\n",
      "2024-03-26 15:00:03.329292: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n",
      "\t [[{{node IteratorGetNext}}]]\n",
      "\t [[IteratorGetNext/_6]]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Making directory ../results/PI_Explainability/bert_snli/run_2\n",
      "Epoch 1/100\n",
      "   6862/Unknown \u001b[1m34s\u001b[0m 5ms/step - loss: 0.3612 - sparse_categorical_accuracy: 0.8653"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2024-03-26 15:00:38.571712: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n",
      "\t [[{{node IteratorGetNext}}]]\n",
      "2024-03-26 15:00:38.571813: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n",
      "\t [[{{node IteratorGetNext}}]]\n",
      "\t [[IteratorGetNext/_2]]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\u001b[1m6867/6867\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m37s\u001b[0m 5ms/step - loss: 0.3612 - sparse_categorical_accuracy: 0.8653 - val_loss: 0.5636 - val_sparse_categorical_accuracy: 0.8065\n",
      "Epoch 2/100\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2024-03-26 15:00:41.011111: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n",
      "\t [[{{node IteratorGetNext}}]]\n",
      "2024-03-26 15:00:41.011177: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n",
      "\t [[{{node IteratorGetNext}}]]\n",
      "\t [[IteratorGetNext/_4]]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\u001b[1m6864/6867\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m━\u001b[0m \u001b[1m0s\u001b[0m 5ms/step - loss: 0.3437 - sparse_categorical_accuracy: 0.8720"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2024-03-26 15:01:14.176647: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n",
      "\t [[{{node IteratorGetNext}}]]\n",
      "2024-03-26 15:01:14.176713: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n",
      "\t [[{{node IteratorGetNext}}]]\n",
      "\t [[IteratorGetNext/_2]]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\u001b[1m6867/6867\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m35s\u001b[0m 5ms/step - loss: 0.3437 - sparse_categorical_accuracy: 0.8720 - val_loss: 0.5857 - val_sparse_categorical_accuracy: 0.8001\n",
      "Epoch 3/100\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2024-03-26 15:01:16.524773: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n",
      "\t [[{{node IteratorGetNext}}]]\n",
      "2024-03-26 15:01:16.524851: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n",
      "\t [[{{node IteratorGetNext}}]]\n",
      "\t [[IteratorGetNext/_2]]\n",
      "2024-03-26 15:01:16.524911: I tensorflow/core/framework/local_rendezvous.cc:422] Local rendezvous recv item cancelled. Key hash: 11843840089427515282\n",
      "2024-03-26 15:01:16.524924: I tensorflow/core/framework/local_rendezvous.cc:422] Local rendezvous recv item cancelled. Key hash: 17060731988147034527\n",
      "2024-03-26 15:01:16.524933: I tensorflow/core/framework/local_rendezvous.cc:422] Local rendezvous recv item cancelled. Key hash: 6495130065868410671\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\u001b[1m6577/6867\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m━\u001b[0m \u001b[1m1s\u001b[0m 5ms/step - loss: 0.3325 - sparse_categorical_accuracy: 0.8766"
     ]
    }
   ],
   "source": [
    "for run in range(5):\n",
    "    tf.keras.utils.set_random_seed(run+10) # set random seed for Python, NumPy, and TensorFlow\n",
    "    exp_name = f'../results/PI_Explainability/bert_snli/run_{run+1}'\n",
    "    if not os.path.exists(exp_name):\n",
    "        print(\"Making directory\", exp_name)\n",
    "        os.makedirs(exp_name)\n",
    "        \n",
    "    early_stop = tf.keras.callbacks.EarlyStopping(monitor='val_sparse_categorical_accuracy', patience=5, restore_best_weights=True, mode='max')\n",
    "    bert_classifier.fit(ds_train, validation_data=val_ds, epochs=100, callbacks=[early_stop,])\n",
    "    bert_classifier.save(exp_name+'/trained_model.keras')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "04bfa397",
   "metadata": {},
   "source": [
    "### PMI Estimation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e8a4dd23",
   "metadata": {},
   "outputs": [],
   "source": [
    "critic = 'separable'\n",
    "estimator = 'density_ratio_fitting'\n",
    "epochs = 100\n",
    "\n",
    "run = 0 \n",
    "tf.keras.utils.set_random_seed(run+10) # set random seed for Python, NumPy, and TensorFlow\n",
    "exp_name = f'../results/PI_Explainability/bert_snli/run_{run+1}/calibration/pmi/{critic}_{estimator}'\n",
    "if not os.path.exists(exp_name):\n",
    "    print(\"Making directory\", exp_name)\n",
    "    os.makedirs(exp_name)\n",
    "    \n",
    "model = tf.keras.models.load_model(f'../results/PI_Explainability/bert_snli/run_{run+1}/trained_model.keras')\n",
    "int_model = tf.keras.Model(inputs=model.inputs, outputs=model.layers[-3].output['pooled_output'])\n",
    "\n",
    "print(f'Training PMI model ({critic}, {estimator})...')\n",
    "ds_activity = ds_train.batch(16).map(lambda x, y: (int_model(x), y)).cache().prefetch(tf.data.AUTOTUNE)\n",
    "train_critic_model(ds_activity, critic=critic, estimator=estimator, epochs=epochs, save_path=f'{exp_name}/pmi_model')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "id": "6d0d9c83",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<KerasTensor shape=(None, 128), dtype=float32, sparse=False, name=keras_tensor_79>"
      ]
     },
     "execution_count": 35,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5cb3cc0d",
   "metadata": {},
   "outputs": [],
   "source": [
    "critic = 'separable'\n",
    "estimator = 'density_ratio_fitting'\n",
    "epochs = 100\n",
    "\n",
    "for run in range(5):\n",
    "    print(f'Run: {run+1}')\n",
    "    tf.keras.utils.set_random_seed(run+10) # set random seed for Python, NumPy, and TensorFlow\n",
    "    exp_name = f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}/calibration/pmi/{critic}_{estimator}'\n",
    "    if not os.path.exists(exp_name):\n",
    "        print(\"Making directory\", exp_name)\n",
    "        os.makedirs(exp_name)\n",
    "\n",
    "    ds_train, ds_val, ds_test, ds_info = load_dataset(cfg)\n",
    "    n_classes = ds_info.features['label'].num_classes\n",
    "    ds_train = preprocess_dataset(ds_train, cfg, n_classes, resize=True, normalize=True, onehot=True)\n",
    "    ds_val = preprocess_dataset(ds_val, cfg, n_classes, resize=True, normalize=True, onehot=True)\n",
    "    ds_test = preprocess_dataset(ds_test, cfg, n_classes, resize=True, normalize=True, onehot=True)\n",
    "\n",
    "    model = tf.keras.models.load_model(f'../results/PI_Explainability/{model_name}_{dataset_name}/run_{run+1}/trained_model.keras')\n",
    "    int_model = tf.keras.Model(inputs=model.inputs, outputs=model.layers[-2].output)\n",
    "\n",
    "    ##############################################################\n",
    "    #\n",
    "    # Train PMI Model\n",
    "    #\n",
    "    # #############################################################\n",
    "\n",
    "    print(f'Training PMI model ({critic}, {estimator})...')\n",
    "    ds_activity = ds_train.batch(cfg['batch_size']).map(lambda x, y: (int_model(x), y)).cache().prefetch(tf.data.AUTOTUNE)\n",
    "    train_critic_model(ds_activity, critic=critic, estimator=estimator, epochs=epochs, save_path=f'{exp_name}/pmi_model')\n",
    "\n",
    "    ##############################################################\n",
    "    #\n",
    "    # Compute PMI for all validation and test samples\n",
    "    #\n",
    "    # #############################################################\n",
    "\n",
    "    pmi_model = tf.keras.models.load_model(f'{exp_name}/pmi_model')\n",
    "\n",
    "    print(f'Computing PMI for all validation samples and for all classes...')\n",
    "    pmi_class = []\n",
    "    for k in range(n_classes):\n",
    "        ds_activity = ds_val.batch(cfg['batch_size']).map(lambda x, y: (int_model(x), tf.one_hot(tf.fill([tf.shape(x)[0]], k), depth=n_classes))).cache().prefetch(tf.data.AUTOTUNE)\n",
    "        pmi_list = []\n",
    "        for (x_batch, y_batch) in ds_activity:\n",
    "            pmi = neural_pmi(x_batch, y_batch, pmi_model, estimator=estimator)\n",
    "            pmi_list += np.array(pmi).tolist()\n",
    "        pmi_class.append(pmi_list)\n",
    "    np.save(f'{exp_name}/pmi_class_val.npy', np.array(pmi_class).T)\n",
    "\n",
    "    print(f'Computing PMI for all test samples and for all classes...')\n",
    "    pmi_class = []\n",
    "    for k in range(n_classes):\n",
    "        ds_activity = ds_test.batch(cfg['batch_size']).map(lambda x, y: (int_model(x), tf.one_hot(tf.fill([tf.shape(x)[0]], k), depth=n_classes))).cache().prefetch(tf.data.AUTOTUNE)\n",
    "        pmi_list = []\n",
    "        for (x_batch, y_batch) in ds_activity:\n",
    "            pmi = neural_pmi(x_batch, y_batch, pmi_model, estimator=estimator)\n",
    "            pmi_list += np.array(pmi).tolist()\n",
    "        pmi_class.append(pmi_list)\n",
    "    np.save(f'{exp_name}/pmi_class_test.npy', np.array(pmi_class).T)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
