{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "import json\n",
    "from glob import glob\n",
    "pd.set_option('display.max_colwidth', None)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Set the path for the finetuned models\n",
    "\n",
    "Please place all of the finetuned models for each configuration into a separate folder. Set the finetuned model folder location in the cell below.\n",
    "\n",
    "This script will automatically grab all models saved in this location. In order to only test on the best saved model checkpoint, make sure to delete the saved model for the last checkpoint for the models in the location."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "USE_VSA = True\n",
    "TOKENIZER_LOCATION = \"./tokenizer-mimic-iv-icd-final\"\n",
    "TESTSET_LOCATION = './data/Finetuning_Testset'\n",
    "FINETUNED_MODEL_LOCATION = \"[SET PATH HERE]/*/*/checkpoint-*/trainer_state.json\"\n",
    "FINETUNED_MODEL_VAL_SAVE = \"./val_results.csv\"\n",
    "FINETUNED_MODEL_TEST_SAVE = \"./test_results.csv\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "file_name = FINETUNED_MODEL_LOCATION\n",
    "\n",
    "file_model_type = file_name.split('/')[5]\n",
    "filepaths = glob(file_name)\n",
    "file_model_type"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer_state_dict = {}\n",
    "for file in filepaths:\n",
    "    with open(file) as f:\n",
    "        data = json.load(f)\n",
    "        trainer_state_dict[file] = {'json': data, 'vsa_type': file_model_type}\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "results = {}\n",
    "\n",
    "params = [\"attention_probs_dropout_prob\", \"hidden_dropout_prob\", \"learning_rate\", \"num_train_epochs\", \"per_device_train_batch_size\", \"weight_decay\"]\n",
    "metrics = [\"eval_accuracy\", \"eval_loss\"]\n",
    "\n",
    "for file in trainer_state_dict:\n",
    "    results.setdefault(\"file\", []).append(file)\n",
    "    for metric in metrics:\n",
    "        val = trainer_state_dict[file]['json']['log_history'][-1][metric]\n",
    "        results.setdefault(metric, []).append(val)\n",
    "\n",
    "    # Grab last train loss\n",
    "    for log in reversed(trainer_state_dict[file]['json']['log_history']):\n",
    "        if log.get('loss') is not None:\n",
    "            results.setdefault('train_loss', []).append(log['loss'])\n",
    "            break\n",
    "\n",
    "        \n",
    "    current_epoch = trainer_state_dict[file]['json']['epoch']\n",
    "    results.setdefault('current_epoch', []).append(current_epoch)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "results_df = pd.DataFrame(results)\n",
    "cols = ['file', 'eval_accuracy', 'train_loss', 'current_epoch', 'eval_loss']\n",
    "results_df = results_df[cols]\n",
    "results_df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "results_df.to_csv(FINETUNED_MODEL_VAL_SAVE)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\"\n",
    "\n",
    "from datasets import Dataset\n",
    "from models.medair_models.bertha.modelling_bertha import BerthaForSequenceClassification\n",
    "from models.medair_models.bertha.configuration_bertha import BerthaConfig\n",
    "from transformers import AutoTokenizer\n",
    "from evaluate import load\n",
    "from collators import DataCollatorForMortalityPrediction\n",
    "import torch\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "from evaluate import load\n",
    "import random\n",
    "\n",
    "pd.set_option('display.max_colwidth', None)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "accuracy_metric = load('accuracy')\n",
    "precision_metric = load('precision')\n",
    "recall_metric = load('recall')\n",
    "f1_metric = load('f1')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_LOCATION)\n",
    "dataset_test = Dataset.load_from_disk(TESTSET_LOCATION)\n",
    "dataset_test = dataset_test.remove_columns('special_tokens_mask')\n",
    "test_dataloader = torch.utils.data.DataLoader(dataset_test, batch_size=1, collate_fn=DataCollatorForMortalityPrediction(tokenizer, eol_threshold=180))\n",
    "sample_weight_value = 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mortality_finetuning_models_list = pd.read_csv(FINETUNED_MODEL_VAL_SAVE)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tqdm import tqdm\n",
    "from torch import nn\n",
    "\n",
    "def compute_metrics_all(i):\n",
    "    torch.cuda.manual_seed(0)\n",
    "    torch.cuda.manual_seed_all(0)\n",
    "    random.seed(0)\n",
    "    np.random.seed(0)\n",
    "    model = BerthaForSequenceClassification.from_pretrained(mortality_finetuning_models_list['file'][i][:-18], num_labels = 1).to('cuda')\n",
    "    if USE_VSA:\n",
    "        emb = model.bertha.embeddings.word_embeddings.embedding_gen()\n",
    "        model.bertha.embeddings.word_embeddings.weight = nn.Parameter(emb, requires_grad=(True))\n",
    "        model.bertha.embeddings.word_embeddings.use_vsa = False\n",
    "        model.bertha.embeddings.word_embeddings.freeze_embeddings = False\n",
    "    all_predictions = []\n",
    "    all_labels = []\n",
    "    all_loss = []\n",
    "    all_sample_weight = []\n",
    "    model.eval()\n",
    "    for batch in tqdm(test_dataloader):\n",
    "        with torch.no_grad():\n",
    "            batch = {k: v.to(\"cuda\") for k, v in batch.items()}\n",
    "\n",
    "            inputs = batch\n",
    "            labels = inputs.get(\"labels\").float()\n",
    "            outputs = model(**inputs)\n",
    "            logits = outputs.get(\"logits\").flatten()\n",
    "            loss = outputs.get(\"loss\")\n",
    "            preds = torch.round(torch.sigmoid(logits))\n",
    "            \n",
    "            all_predictions.extend(preds)\n",
    "            all_labels.append(labels)\n",
    "\n",
    "\n",
    "    accuracy = accuracy_metric.compute(predictions=all_predictions, references=all_labels)\n",
    "    precision = precision_metric.compute(predictions=all_predictions, references=all_labels, average='macro')\n",
    "    recall = recall_metric.compute(predictions=all_predictions, references=all_labels, average='macro')\n",
    "    f1 = f1_metric.compute(predictions=all_predictions, references=all_labels, average='macro')\n",
    "    \n",
    "    return [accuracy['accuracy'], precision['precision'], recall['recall'], f1['f1']]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "accuracy_list = []\n",
    "precision_list = []\n",
    "recall_list = []\n",
    "f1_list = []\n",
    "auroc_list = []\n",
    "for i in range(len(mortality_finetuning_models_list)):\n",
    "    metrics = compute_metrics_all(i)\n",
    "    accuracy_list.append(metrics[0])\n",
    "    precision_list.append(metrics[1])\n",
    "    recall_list.append(metrics[2])\n",
    "    f1_list.append(metrics[3])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df = pd.DataFrame({'accuracy': accuracy_list, 'precision':precision_list, 'recall':recall_list, 'f1':f1_list,\"model\":mortality_finetuning_models_list['file']})\n",
    "df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df.to_csv(FINETUNED_MODEL_TEST_SAVE)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
