{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "7750490f",
   "metadata": {},
   "source": [
    "# Predicting cause of death from PHMRC VA text narratives using BERT\n",
    "References: <br>\n",
    "https://github.com/theartificialguy/NLP-with-Deep-Learning/blob/master/BERT/Multi-Class%20classification%20TF-BERT/multi_class.ipynb <br>\n",
    "https://medium.com/@roshmitadey/understanding-language-modeling-from-n-grams-to-transformer-based-neural-models-d2bdf1532c6d"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "c4c1ae99",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2024-03-05 17:20:31.965218: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n",
      "To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n"
     ]
    }
   ],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "import random \n",
    "import matplotlib.pyplot as plt\n",
    "import tensorflow as tf\n",
    "import tensorflow_hub as hub\n",
    "import tensorflow_text as text\n",
    "import json\n",
    "from tqdm.auto import tqdm\n",
    "from transformers import BertTokenizer\n",
    "from nltk.tokenize import word_tokenize\n",
    "from nltk import pos_tag\n",
    "from nltk.corpus import stopwords\n",
    "from nltk.corpus import wordnet as wn\n",
    "from nltk.stem import WordNetLemmatizer\n",
    "from sklearn.preprocessing import LabelEncoder, OneHotEncoder\n",
    "from sklearn.feature_extraction.text import TfidfVectorizer\n",
    "from sklearn import model_selection, naive_bayes, svm\n",
    "from sklearn.metrics import accuracy_score, f1_score, classification_report\n",
    "from sklearn.utils.class_weight import compute_sample_weight\n",
    "from sklearn.neighbors import KNeighborsClassifier\n",
    "from sklearn.model_selection import train_test_split\n",
    "from collections import defaultdict\n",
    "from transformers import TFBertModel\n",
    "\n",
    "import warnings\n",
    "warnings.filterwarnings(\"ignore\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c692059d",
   "metadata": {},
   "source": [
    "## Load Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "2f36a070",
   "metadata": {},
   "outputs": [],
   "source": [
    "# GitHub CSV file URL\n",
    "url = 'https://raw.githubusercontent.com/avisokay/va_nlp/main/data/phmrc/phmrc_adult_tokenized.csv'\n",
    "df = pd.read_csv(url)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d7106a88",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Train_X, Test_X, Train_Y, Test_Y = model_selection.train_test_split(df['tags'],df['gs_cod'],test_size=0.3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "924c1100",
   "metadata": {},
   "outputs": [],
   "source": [
    "# # just age covariate\n",
    "# Test_X_covariates = df.loc[Test_X.index]['age_yr']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "01364194",
   "metadata": {},
   "outputs": [],
   "source": [
    "# # train\n",
    "# # Read in CSV files and store in dictionary\n",
    "# train_excluded_dict = {}\n",
    "# for region in regions:\n",
    "#     file_path = f'https://raw.githubusercontent.com/avisokay/va_nlp/main/data/train_test_val/train_ex_{region.lower()}.csv'\n",
    "#     train_excluded_dict[region] = pd.read_csv(file_path)\n",
    "    \n",
    "# # assign training data df names\n",
    "# train_ex_ap = train_excluded_dict['ap']\n",
    "# train_ex_dar = train_excluded_dict['dar']\n",
    "# train_ex_pemba = train_excluded_dict['pemba']\n",
    "# train_ex_mexico = train_excluded_dict['mexico']\n",
    "# train_ex_bohol = train_excluded_dict['bohol']\n",
    "# train_ex_up = train_excluded_dict['up']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cd9983d6",
   "metadata": {},
   "outputs": [],
   "source": [
    "# # test / val\n",
    "\n",
    "# # Dictionary to store DataFrames\n",
    "# test_dict = {}\n",
    "# val_dict = {}\n",
    "\n",
    "# # Read in test and validation CSV files and store in dictionaries\n",
    "# for region in regions:\n",
    "#     test_file_path = f'../../data/train_test_val/test_{region}.csv'\n",
    "#     val_file_path = f'../../data/train_test_val/val_{region}.csv'\n",
    "    \n",
    "#     test_dict[region] = pd.read_csv(test_file_path)\n",
    "#     val_dict[region] = pd.read_csv(val_file_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c8d621b3",
   "metadata": {},
   "outputs": [],
   "source": [
    "# # assign test and val data df names\n",
    "# test_ap = test_dict['ap']\n",
    "# test_dar = test_dict['dar']\n",
    "# test_pemba = test_dict['pemba']\n",
    "# test_mexico = test_dict['mexico']\n",
    "# test_bohol = test_dict['bohol']\n",
    "# test_up = test_dict['up']\n",
    "\n",
    "# val_ap = val_dict['ap']\n",
    "# val_dar = val_dict['dar']\n",
    "# val_pemba = val_dict['pemba']\n",
    "# val_mexico = val_dict['mexico']\n",
    "# val_bohol = val_dict['bohol']\n",
    "# val_up = val_dict['up']"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5c4c90e4",
   "metadata": {},
   "source": [
    "## RUN ON ONE SITE AT A TIME"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "3ddaee64",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array(['mexico', 'ap', 'up', 'dar', 'bohol', 'pemba'], dtype=object)"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df['site'].unique()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "92d45318",
   "metadata": {},
   "outputs": [],
   "source": [
    "site_excluded = 'pemba'\n",
    "\n",
    "Train_X = df[df['site'] != site_excluded]['tags']\n",
    "Test_X = df[df['site'] == site_excluded]['tags']\n",
    "Train_Y = df[df['site'] != site_excluded]['gs_cod']\n",
    "Test_Y = df[df['site'] == site_excluded]['gs_cod']\n",
    "Test_X_covariates = df[df['site'] == site_excluded]['age_yr']\n",
    "\n",
    "Test_X_covariates.to_csv(f'{site_excluded}_covariates.csv', index=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "f6f98d4b",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "76f71bf9be034ad7b248ebe74b2c939b",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of the PyTorch model were not used when initializing the TF 2.0 model TFBertModel: ['cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.weight', 'cls.predictions.bias', 'cls.seq_relationship.bias']\n",
      "- This IS expected if you are initializing TFBertModel from a PyTorch model trained on another task or with another architecture (e.g. initializing a TFBertForSequenceClassification model from a BertForPreTraining model).\n",
      "- This IS NOT expected if you are initializing TFBertModel from a PyTorch model that you expect to be exactly identical (e.g. initializing a TFBertForSequenceClassification model from a BertForSequenceClassification model).\n",
      "All the weights of TFBertModel were initialized from the PyTorch model.\n",
      "If your task is similar to the task the model of the checkpoint was trained on, you can already use TFBertModel for predictions without further training.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "324/324 [==============================] - 6950s 21s/step - loss: 1.2102 - accuracy: 0.5858 - val_loss: 1.0088 - val_accuracy: 0.6608\n",
      "INFO:tensorflow:Assets written to: ../../models/bert_cod_pemba/assets\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Assets written to: ../../models/bert_cod_pemba/assets\n"
     ]
    }
   ],
   "source": [
    "## Pre-processing\n",
    "\n",
    "tokenizer = BertTokenizer.from_pretrained('bert-base-cased')\n",
    "\n",
    "# create empty vectors for storing id and masks\n",
    "X_input_ids = np.zeros((len(Train_X), 256))\n",
    "X_attn_masks = np.zeros((len(Train_X), 256))\n",
    "\n",
    "# # for inputting whole dataframe\n",
    "# def generate_training_data(df, ids, masks, tokenizer):\n",
    "#     for i, text in tqdm(enumerate(df['narrative'])):\n",
    "#         tokenized_text = tokenizer.encode_plus(\n",
    "#             text,\n",
    "#             max_length=256, \n",
    "#             truncation=True, \n",
    "#             padding='max_length', \n",
    "#             add_special_tokens=True,\n",
    "#             return_tensors='tf'\n",
    "#         )\n",
    "#         ids[i, :] = tokenized_text.input_ids\n",
    "#         masks[i, :] = tokenized_text.attention_mask\n",
    "#     return ids, masks\n",
    "\n",
    "# for inputting a specific list of narratives\n",
    "def generate_training_data(narratives, ids, masks, tokenizer):\n",
    "    for i, text in tqdm(enumerate(narratives)):\n",
    "        tokenized_text = tokenizer.encode_plus(\n",
    "            text,\n",
    "            max_length=256, \n",
    "            truncation=True, \n",
    "            padding='max_length', \n",
    "            add_special_tokens=True,\n",
    "            return_tensors='tf'\n",
    "        )\n",
    "        ids[i, :] = tokenized_text.input_ids\n",
    "        masks[i, :] = tokenized_text.attention_mask\n",
    "    return ids, masks\n",
    "\n",
    "X_input_ids, X_attn_masks = generate_training_data(Train_X, X_input_ids, X_attn_masks, tokenizer)\n",
    "\n",
    "# create one-hot encoded target tensor from output classes\n",
    "nominal_vector = np.array(Train_Y).reshape(-1, 1)\n",
    "\n",
    "# Create an instance of OneHotEncoder\n",
    "encoder = OneHotEncoder(sparse=False)\n",
    "\n",
    "# Fit and transform the nominal vector\n",
    "labels = encoder.fit_transform(nominal_vector)\n",
    "\n",
    "# creating a data pipeline using tensorflow dataset utility, creates batches of data for easy training\n",
    "dataset = tf.data.Dataset.from_tensor_slices((X_input_ids, X_attn_masks, labels))\n",
    "\n",
    "# map function to return correct batch\n",
    "def SentimentDatasetMapFunction(input_ids, attn_masks, labels):\n",
    "    return {\n",
    "        'input_ids': input_ids,\n",
    "        'attention_mask': attn_masks\n",
    "    }, labels\n",
    "\n",
    "# converting to required format for tensorflow dataset \n",
    "dataset = dataset.map(SentimentDatasetMapFunction)\n",
    "\n",
    "# batch size, drop any left out tensor\n",
    "dataset = dataset.shuffle(10000).batch(16, drop_remainder=True) \n",
    "\n",
    "# for each 16 batch of data we will have len(df)//16 samples, 80/20 train test split\n",
    "p = 0.8\n",
    "train_size = int((len(Train_X)//16)*p) \n",
    "\n",
    "train_dataset = dataset.take(train_size)\n",
    "val_dataset = dataset.skip(train_size)\n",
    "\n",
    "## Build the BERT model\n",
    "\n",
    "model = TFBertModel.from_pretrained('bert-base-cased') # bert base model with pretrained weights\n",
    "\n",
    "# defining 2 input layers for input_ids and attn_masks\n",
    "layers = len(Train_Y.unique())\n",
    "\n",
    "input_ids = tf.keras.layers.Input(shape=(256,), name='input_ids', dtype='int32')\n",
    "attn_masks = tf.keras.layers.Input(shape=(256,), name='attention_mask', dtype='int32')\n",
    "\n",
    "bert_embds = model.bert(input_ids, attention_mask=attn_masks)[1] # 0 -> activation layer (3D), 1 -> pooled output layer (2D)\n",
    "intermediate_layer = tf.keras.layers.Dense(512, activation='relu', name='intermediate_layer')(bert_embds)\n",
    "output_layer = tf.keras.layers.Dense(layers, activation='softmax', name='output_layer')(intermediate_layer) # softmax -> calcs probs of classes\n",
    "\n",
    "cod_model = tf.keras.Model(inputs=[input_ids, attn_masks], outputs=output_layer)\n",
    "# cod_model.summary()\n",
    "\n",
    "# loss function, optimizer, and accuracy matrix \n",
    "optim = tf.keras.optimizers.legacy.Adam(learning_rate=1e-5, decay=1e-6)\n",
    "loss_func = tf.keras.losses.CategoricalCrossentropy()\n",
    "acc = tf.keras.metrics.CategoricalAccuracy('accuracy')\n",
    "\n",
    "cod_model.compile(optimizer=optim, loss=loss_func, metrics=[acc])\n",
    "\n",
    "## Train BERT model\n",
    "\n",
    "# LONG RUN TIME\n",
    "hist = cod_model.fit(\n",
    "    train_dataset,\n",
    "    validation_data=val_dataset,\n",
    "    epochs = 1 # can use more epochs (20-25) on better machine with good GPU\n",
    ")\n",
    "\n",
    "cod_model.save(f'../../models/bert_cod_{site_excluded}')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6223ed32",
   "metadata": {},
   "source": [
    "## Prediction with BERT model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "90f5c8ed",
   "metadata": {},
   "outputs": [],
   "source": [
    "# load model from memory\n",
    "cod_model = tf.keras.models.load_model('../../models/bert_cod_mexico/')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "39120e7c",
   "metadata": {},
   "outputs": [],
   "source": [
    "tokenizer = BertTokenizer.from_pretrained('bert-base-cased')\n",
    "\n",
    "unique_cod = list(np.unique(df['gs_cod'].values))\n",
    "\n",
    "def prepare_data(input_text, tokenizer):\n",
    "    token = tokenizer.encode_plus(\n",
    "        input_text,\n",
    "        max_length=256, \n",
    "        truncation=True, \n",
    "        padding='max_length', \n",
    "        add_special_tokens=True,\n",
    "        return_tensors='tf'\n",
    "    )\n",
    "    return {\n",
    "        'input_ids': tf.cast(token.input_ids, tf.float64),\n",
    "        'attention_mask': tf.cast(token.attention_mask, tf.float64)\n",
    "    }\n",
    "\n",
    "def make_prediction(model, processed_data, classes=unique_cod):\n",
    "    probs = model.predict(processed_data)[0]\n",
    "    return classes[np.argmax(probs)]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fdfbdbbb",
   "metadata": {},
   "source": [
    "## Tester"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "faa106bb",
   "metadata": {},
   "outputs": [],
   "source": [
    "# input_text = input('Enter death narrative here: ')\n",
    "# processed = prepare_data(input_text, tokenizer)\n",
    "# result = make_prediction(cod_model, processed_data=processed)\n",
    "# print(f\"Predicted cause of death: {result}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "71dd4070",
   "metadata": {},
   "source": [
    "## Predict with BERT"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "bd343419",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "c3df832e02ec489da7ce73ada26e246f",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/260 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1/1 [==============================] - 3s 3s/step\n",
      "1/1 [==============================] - 0s 290ms/step\n",
      "1/1 [==============================] - 0s 292ms/step\n",
      "1/1 [==============================] - 0s 297ms/step\n",
      "1/1 [==============================] - 0s 301ms/step\n",
      "1/1 [==============================] - 0s 294ms/step\n",
      "1/1 [==============================] - 0s 315ms/step\n",
      "1/1 [==============================] - 0s 295ms/step\n",
      "1/1 [==============================] - 0s 293ms/step\n",
      "1/1 [==============================] - 0s 293ms/step\n",
      "1/1 [==============================] - 0s 293ms/step\n",
      "1/1 [==============================] - 0s 293ms/step\n",
      "1/1 [==============================] - 0s 301ms/step\n",
      "1/1 [==============================] - 0s 297ms/step\n",
      "1/1 [==============================] - 0s 298ms/step\n",
      "1/1 [==============================] - 0s 292ms/step\n",
      "1/1 [==============================] - 0s 296ms/step\n",
      "1/1 [==============================] - 0s 290ms/step\n",
      "1/1 [==============================] - 0s 292ms/step\n",
      "1/1 [==============================] - 0s 294ms/step\n",
      "1/1 [==============================] - 0s 292ms/step\n",
      "1/1 [==============================] - 0s 296ms/step\n",
      "1/1 [==============================] - 0s 302ms/step\n",
      "1/1 [==============================] - 0s 315ms/step\n",
      "1/1 [==============================] - 0s 318ms/step\n",
      "1/1 [==============================] - 0s 326ms/step\n",
      "1/1 [==============================] - 0s 335ms/step\n",
      "1/1 [==============================] - 0s 350ms/step\n",
      "1/1 [==============================] - 0s 335ms/step\n",
      "1/1 [==============================] - 0s 331ms/step\n",
      "1/1 [==============================] - 0s 333ms/step\n",
      "1/1 [==============================] - 0s 323ms/step\n",
      "1/1 [==============================] - 0s 322ms/step\n",
      "1/1 [==============================] - 0s 319ms/step\n",
      "1/1 [==============================] - 0s 316ms/step\n",
      "1/1 [==============================] - 0s 317ms/step\n",
      "1/1 [==============================] - 0s 320ms/step\n",
      "1/1 [==============================] - 0s 318ms/step\n",
      "1/1 [==============================] - 0s 309ms/step\n",
      "1/1 [==============================] - 0s 312ms/step\n",
      "1/1 [==============================] - 0s 316ms/step\n",
      "1/1 [==============================] - 0s 314ms/step\n",
      "1/1 [==============================] - 0s 324ms/step\n",
      "1/1 [==============================] - 0s 310ms/step\n",
      "1/1 [==============================] - 0s 313ms/step\n",
      "1/1 [==============================] - 0s 313ms/step\n",
      "1/1 [==============================] - 0s 310ms/step\n",
      "1/1 [==============================] - 0s 323ms/step\n",
      "1/1 [==============================] - 0s 322ms/step\n",
      "1/1 [==============================] - 0s 328ms/step\n",
      "1/1 [==============================] - 0s 327ms/step\n",
      "1/1 [==============================] - 0s 353ms/step\n",
      "1/1 [==============================] - 0s 324ms/step\n",
      "1/1 [==============================] - 0s 322ms/step\n",
      "1/1 [==============================] - 0s 311ms/step\n",
      "1/1 [==============================] - 0s 300ms/step\n",
      "1/1 [==============================] - 0s 292ms/step\n",
      "1/1 [==============================] - 0s 296ms/step\n",
      "1/1 [==============================] - 0s 294ms/step\n",
      "1/1 [==============================] - 0s 294ms/step\n",
      "1/1 [==============================] - 0s 297ms/step\n",
      "1/1 [==============================] - 0s 304ms/step\n",
      "1/1 [==============================] - 0s 304ms/step\n",
      "1/1 [==============================] - 0s 309ms/step\n",
      "1/1 [==============================] - 0s 318ms/step\n",
      "1/1 [==============================] - 0s 318ms/step\n",
      "1/1 [==============================] - 0s 332ms/step\n",
      "1/1 [==============================] - 0s 338ms/step\n",
      "1/1 [==============================] - 0s 330ms/step\n",
      "1/1 [==============================] - 0s 331ms/step\n",
      "1/1 [==============================] - 0s 341ms/step\n",
      "1/1 [==============================] - 0s 330ms/step\n",
      "1/1 [==============================] - 0s 318ms/step\n",
      "1/1 [==============================] - 0s 308ms/step\n",
      "1/1 [==============================] - 0s 312ms/step\n",
      "1/1 [==============================] - 0s 316ms/step\n",
      "1/1 [==============================] - 0s 314ms/step\n",
      "1/1 [==============================] - 0s 312ms/step\n",
      "1/1 [==============================] - 0s 315ms/step\n",
      "1/1 [==============================] - 0s 310ms/step\n",
      "1/1 [==============================] - 0s 318ms/step\n",
      "1/1 [==============================] - 0s 324ms/step\n",
      "1/1 [==============================] - 0s 322ms/step\n",
      "1/1 [==============================] - 0s 317ms/step\n",
      "1/1 [==============================] - 0s 331ms/step\n",
      "1/1 [==============================] - 0s 323ms/step\n",
      "1/1 [==============================] - 0s 321ms/step\n",
      "1/1 [==============================] - 0s 326ms/step\n",
      "1/1 [==============================] - 0s 321ms/step\n",
      "1/1 [==============================] - 0s 322ms/step\n",
      "1/1 [==============================] - 0s 322ms/step\n",
      "1/1 [==============================] - 0s 322ms/step\n",
      "1/1 [==============================] - 0s 316ms/step\n",
      "1/1 [==============================] - 0s 320ms/step\n",
      "1/1 [==============================] - 0s 316ms/step\n",
      "1/1 [==============================] - 0s 313ms/step\n",
      "1/1 [==============================] - 0s 315ms/step\n",
      "1/1 [==============================] - 0s 312ms/step\n",
      "1/1 [==============================] - 0s 315ms/step\n",
      "1/1 [==============================] - 0s 319ms/step\n",
      "1/1 [==============================] - 0s 320ms/step\n",
      "1/1 [==============================] - 0s 323ms/step\n",
      "1/1 [==============================] - 0s 324ms/step\n",
      "1/1 [==============================] - 0s 325ms/step\n",
      "1/1 [==============================] - 0s 320ms/step\n",
      "1/1 [==============================] - 0s 323ms/step\n",
      "1/1 [==============================] - 0s 323ms/step\n",
      "1/1 [==============================] - 0s 325ms/step\n",
      "1/1 [==============================] - 0s 333ms/step\n",
      "1/1 [==============================] - 0s 322ms/step\n",
      "1/1 [==============================] - 0s 323ms/step\n",
      "1/1 [==============================] - 0s 326ms/step\n",
      "1/1 [==============================] - 0s 323ms/step\n",
      "1/1 [==============================] - 0s 316ms/step\n",
      "1/1 [==============================] - 0s 315ms/step\n",
      "1/1 [==============================] - 0s 315ms/step\n",
      "1/1 [==============================] - 0s 317ms/step\n",
      "1/1 [==============================] - 0s 310ms/step\n",
      "1/1 [==============================] - 0s 311ms/step\n",
      "1/1 [==============================] - 0s 320ms/step\n",
      "1/1 [==============================] - 0s 331ms/step\n",
      "1/1 [==============================] - 0s 324ms/step\n",
      "1/1 [==============================] - 0s 319ms/step\n",
      "1/1 [==============================] - 0s 321ms/step\n",
      "1/1 [==============================] - 0s 323ms/step\n",
      "1/1 [==============================] - 0s 320ms/step\n",
      "1/1 [==============================] - 0s 324ms/step\n",
      "1/1 [==============================] - 0s 322ms/step\n",
      "1/1 [==============================] - 0s 320ms/step\n",
      "1/1 [==============================] - 0s 321ms/step\n",
      "1/1 [==============================] - 0s 323ms/step\n",
      "1/1 [==============================] - 0s 321ms/step\n",
      "1/1 [==============================] - 0s 327ms/step\n",
      "1/1 [==============================] - 0s 322ms/step\n",
      "1/1 [==============================] - 0s 326ms/step\n",
      "1/1 [==============================] - 0s 325ms/step\n",
      "1/1 [==============================] - 0s 326ms/step\n",
      "1/1 [==============================] - 0s 327ms/step\n",
      "1/1 [==============================] - 0s 312ms/step\n",
      "1/1 [==============================] - 0s 312ms/step\n",
      "1/1 [==============================] - 0s 313ms/step\n",
      "1/1 [==============================] - 0s 316ms/step\n",
      "1/1 [==============================] - 0s 310ms/step\n",
      "1/1 [==============================] - 0s 317ms/step\n",
      "1/1 [==============================] - 0s 329ms/step\n",
      "1/1 [==============================] - 0s 320ms/step\n",
      "1/1 [==============================] - 0s 321ms/step\n",
      "1/1 [==============================] - 0s 320ms/step\n",
      "1/1 [==============================] - 0s 322ms/step\n",
      "1/1 [==============================] - 0s 329ms/step\n",
      "1/1 [==============================] - 0s 323ms/step\n",
      "1/1 [==============================] - 0s 324ms/step\n",
      "1/1 [==============================] - 0s 322ms/step\n",
      "1/1 [==============================] - 0s 321ms/step\n",
      "1/1 [==============================] - 0s 325ms/step\n",
      "1/1 [==============================] - 0s 325ms/step\n",
      "1/1 [==============================] - 0s 323ms/step\n",
      "1/1 [==============================] - 0s 321ms/step\n",
      "1/1 [==============================] - 0s 320ms/step\n",
      "1/1 [==============================] - 0s 321ms/step\n",
      "1/1 [==============================] - 0s 320ms/step\n",
      "1/1 [==============================] - 0s 320ms/step\n",
      "1/1 [==============================] - 0s 321ms/step\n",
      "1/1 [==============================] - 0s 321ms/step\n",
      "1/1 [==============================] - 0s 320ms/step\n",
      "1/1 [==============================] - 0s 319ms/step\n",
      "1/1 [==============================] - 0s 321ms/step\n",
      "1/1 [==============================] - 0s 321ms/step\n",
      "1/1 [==============================] - 0s 321ms/step\n",
      "1/1 [==============================] - 0s 324ms/step\n",
      "1/1 [==============================] - 0s 324ms/step\n",
      "1/1 [==============================] - 0s 317ms/step\n",
      "1/1 [==============================] - 0s 317ms/step\n",
      "1/1 [==============================] - 0s 320ms/step\n",
      "1/1 [==============================] - 0s 320ms/step\n",
      "1/1 [==============================] - 0s 312ms/step\n",
      "1/1 [==============================] - 0s 312ms/step\n",
      "1/1 [==============================] - 0s 318ms/step\n",
      "1/1 [==============================] - 0s 312ms/step\n",
      "1/1 [==============================] - 0s 314ms/step\n",
      "1/1 [==============================] - 0s 317ms/step\n",
      "1/1 [==============================] - 0s 322ms/step\n",
      "1/1 [==============================] - 0s 315ms/step\n",
      "1/1 [==============================] - 0s 328ms/step\n",
      "1/1 [==============================] - 0s 317ms/step\n",
      "1/1 [==============================] - 0s 319ms/step\n",
      "1/1 [==============================] - 0s 340ms/step\n",
      "1/1 [==============================] - 0s 337ms/step\n",
      "1/1 [==============================] - 0s 340ms/step\n",
      "1/1 [==============================] - 0s 331ms/step\n",
      "1/1 [==============================] - 0s 324ms/step\n",
      "1/1 [==============================] - 0s 308ms/step\n",
      "1/1 [==============================] - 0s 311ms/step\n",
      "1/1 [==============================] - 0s 314ms/step\n",
      "1/1 [==============================] - 0s 321ms/step\n",
      "1/1 [==============================] - 0s 312ms/step\n",
      "1/1 [==============================] - 0s 317ms/step\n",
      "1/1 [==============================] - 0s 313ms/step\n",
      "1/1 [==============================] - 0s 320ms/step\n",
      "1/1 [==============================] - 0s 322ms/step\n",
      "1/1 [==============================] - 0s 325ms/step\n",
      "1/1 [==============================] - 0s 324ms/step\n",
      "1/1 [==============================] - 0s 333ms/step\n",
      "1/1 [==============================] - 0s 334ms/step\n",
      "1/1 [==============================] - 0s 328ms/step\n",
      "1/1 [==============================] - 0s 336ms/step\n",
      "1/1 [==============================] - 0s 323ms/step\n",
      "1/1 [==============================] - 0s 321ms/step\n",
      "1/1 [==============================] - 0s 322ms/step\n",
      "1/1 [==============================] - 0s 321ms/step\n",
      "1/1 [==============================] - 0s 327ms/step\n",
      "1/1 [==============================] - 0s 335ms/step\n",
      "1/1 [==============================] - 0s 318ms/step\n",
      "1/1 [==============================] - 0s 313ms/step\n",
      "1/1 [==============================] - 0s 316ms/step\n",
      "1/1 [==============================] - 0s 315ms/step\n",
      "1/1 [==============================] - 0s 315ms/step\n",
      "1/1 [==============================] - 0s 324ms/step\n",
      "1/1 [==============================] - 0s 322ms/step\n",
      "1/1 [==============================] - 0s 330ms/step\n",
      "1/1 [==============================] - 0s 333ms/step\n",
      "1/1 [==============================] - 0s 334ms/step\n",
      "1/1 [==============================] - 0s 331ms/step\n",
      "1/1 [==============================] - 0s 337ms/step\n",
      "1/1 [==============================] - 0s 328ms/step\n",
      "1/1 [==============================] - 0s 325ms/step\n",
      "1/1 [==============================] - 0s 326ms/step\n",
      "1/1 [==============================] - 0s 327ms/step\n",
      "1/1 [==============================] - 0s 316ms/step\n",
      "1/1 [==============================] - 0s 315ms/step\n",
      "1/1 [==============================] - 0s 306ms/step\n",
      "1/1 [==============================] - 0s 319ms/step\n",
      "1/1 [==============================] - 0s 319ms/step\n",
      "1/1 [==============================] - 0s 321ms/step\n",
      "1/1 [==============================] - 0s 328ms/step\n",
      "1/1 [==============================] - 0s 326ms/step\n",
      "1/1 [==============================] - 0s 336ms/step\n",
      "1/1 [==============================] - 0s 330ms/step\n",
      "1/1 [==============================] - 0s 347ms/step\n",
      "1/1 [==============================] - 0s 345ms/step\n",
      "1/1 [==============================] - 0s 338ms/step\n",
      "1/1 [==============================] - 0s 328ms/step\n",
      "1/1 [==============================] - 0s 317ms/step\n",
      "1/1 [==============================] - 0s 321ms/step\n",
      "1/1 [==============================] - 0s 314ms/step\n",
      "1/1 [==============================] - 0s 313ms/step\n",
      "1/1 [==============================] - 0s 309ms/step\n",
      "1/1 [==============================] - 0s 318ms/step\n",
      "1/1 [==============================] - 0s 326ms/step\n",
      "1/1 [==============================] - 0s 320ms/step\n",
      "1/1 [==============================] - 0s 318ms/step\n",
      "1/1 [==============================] - 0s 323ms/step\n",
      "1/1 [==============================] - 0s 331ms/step\n",
      "1/1 [==============================] - 0s 329ms/step\n",
      "1/1 [==============================] - 0s 336ms/step\n",
      "1/1 [==============================] - 0s 334ms/step\n",
      "1/1 [==============================] - 0s 334ms/step\n",
      "1/1 [==============================] - 0s 326ms/step\n",
      "1/1 [==============================] - 0s 320ms/step\n",
      "1/1 [==============================] - 0s 317ms/step\n"
     ]
    }
   ],
   "source": [
    "# SLOW\n",
    "predictions_bert_text = []\n",
    "for narrative in tqdm(Test_X):\n",
    "    processed = prepare_data(narrative, tokenizer)\n",
    "    predictions_bert_text.append(make_prediction(cod_model, processed_data=processed))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9e598180",
   "metadata": {},
   "source": [
    "## Convert strings to embeddings using dictionary"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e192b1bf",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "1be451d5",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load the dictionary from the JSON file\n",
    "with open('../classic_nlp/cod_embeddings.json', 'r') as file:\n",
    "    loaded_dict = json.load(file)\n",
    "\n",
    "# Convert string keys back to integers\n",
    "cod_embeddings = {int(key): value for key, value in loaded_dict.items()}\n",
    "\n",
    "# Convert the list of strings to a list of integer embeddings using the dictionary\n",
    "predictions_bert_embedding = pd.Series([key for value in predictions_bert_text for key, string_value in cod_embeddings.items() if string_value == value])\n",
    "Test_Y_embedding = pd.Series([int(key) for item in Test_Y for key, value in loaded_dict.items() if value == item])\n",
    "\n",
    "# combine true labels, predictions, covariates into one df and save results\n",
    "results_df = pd.DataFrame({'Y': list(Test_Y_embedding), \n",
    "                           'Y_hat': list(predictions_bert_embedding),\n",
    "                           'X': list(Test_X_covariates)})\n",
    "\n",
    "# write out to csv\n",
    "results_df.to_csv(f'../../data/results/{site_excluded}_bert.csv', index=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7e157d7f",
   "metadata": {},
   "source": [
    "## Compute and Compare Accuracy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "16583e5b",
   "metadata": {},
   "outputs": [],
   "source": [
    "predictions_bert = pd.read_csv('predictions_bert.csv')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b53dd9fb",
   "metadata": {},
   "outputs": [],
   "source": [
    "predictions_classic = pd.read_csv('../classic_nlp/baseline_predictions.csv')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "68e8f0dd",
   "metadata": {},
   "outputs": [],
   "source": [
    "classic_predictions = predictions_classic.copy()\n",
    "predictions_classic['predictions_bert'] = predictions_bert"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "02ceba50",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Calculate accuracy and F1 score for each prediction column\n",
    "columns_to_evaluate = ['predictions_NB', 'predictions_SVM', 'predictions_KNN', 'predictions_bert']\n",
    "accuracy_scores = []\n",
    "f1_scores = []\n",
    "\n",
    "for column in columns_to_evaluate:\n",
    "    accuracy = accuracy_score(predictions_classic[column], predictions_classic['Test_Y'])\n",
    "    f1 = f1_score(predictions_classic[column], predictions_classic['Test_Y'] , average = 'weighted')\n",
    "    accuracy_scores.append(accuracy)\n",
    "    f1_scores.append(f1)\n",
    "\n",
    "# Plot the results\n",
    "fig, ax = plt.subplots()\n",
    "\n",
    "bar_width = 0.35\n",
    "index = np.arange(len(columns_to_evaluate))\n",
    "\n",
    "bar1 = ax.bar(index, accuracy_scores, bar_width, label='Accuracy')\n",
    "bar2 = ax.bar(index + bar_width, f1_scores, bar_width, label='F1 Score')\n",
    "\n",
    "ax.set_xlabel('Predictions')\n",
    "ax.set_ylabel('Scores')\n",
    "ax.set_title('Accuracy and F1 Score for Each Prediction')\n",
    "ax.set_xticks(index + bar_width / 2)\n",
    "ax.set_xticklabels(columns_to_evaluate)\n",
    "ax.legend()\n",
    "# Set y-axis range to 0-1\n",
    "ax.set_ylim(0, 1)\n",
    "\n",
    "# Add scores on top of each bar\n",
    "for i, (acc, f1) in enumerate(zip(accuracy_scores, f1_scores)):\n",
    "    ax.text(i, acc + 0.01, f'{acc:.2f}', ha='center', va='bottom', color='black', fontweight='bold')\n",
    "    ax.text(i + bar_width, f1 + 0.01, f'{f1:.2f}', ha='center', va='bottom', color='black', fontweight='bold')\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "05ca73e7",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1740e202",
   "metadata": {},
   "outputs": [],
   "source": [
    "test = pd.read_csv('../../src/gpt_nlp/gpt35_fewshot_text.csv')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1d152e8b",
   "metadata": {},
   "outputs": [],
   "source": [
    "test['0'].value_counts()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7a9303b1",
   "metadata": {},
   "outputs": [],
   "source": [
    "pd.Series(test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9d219c3c",
   "metadata": {},
   "outputs": [],
   "source": [
    "test"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "becd4541",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "50484596",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
