{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "import json\n",
    "from glob import glob\n",
    "pd.set_option('display.max_colwidth', None)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Set the path for the finetuned models\n",
    "\n",
    "Please place all of the finetuned models for each configuration into a separate folder. Set the finetuned model folder location in the cell below.\n",
    "\n",
    "This script will automatically grab all models saved in this location. In order to only test on the best saved model checkpoint, make sure to delete the saved model for the last checkpoint for the models in the location.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "USE_VSA = True\n",
    "TOKENIZER_LOCATION = './tokenizer-mimic-iv-icd-final'\n",
    "TESTSET_LOCATION = './data/Finetuning_Testset'\n",
    "FINETUNED_MODEL_LOCATION = \"[SET PATH HERE]/*/*/checkpoint-*/trainer_state.json\"\n",
    "FINETUNED_MODEL_VAL_SAVE = \"./val_results.csv\"\n",
    "FINETUNED_MODEL_TEST_SAVE = \"./test_results.csv\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "file_name = FINETUNED_MODEL_LOCATION\n",
    "\n",
    "file_model_type = file_name.split('/')[5]\n",
    "filepaths = glob(file_name)\n",
    "file_model_type\n",
    "print(filepaths)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer_state_dict = {}\n",
    "for file in filepaths:\n",
    "    with open(file) as f:\n",
    "        data = json.load(f)\n",
    "        trainer_state_dict[file] = {'json': data, 'vsa_type': file_model_type}\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "results = {}\n",
    "\n",
    "params = [\"attention_probs_dropout_prob\", \"hidden_dropout_prob\", \"learning_rate\", \"num_train_epochs\", \"per_device_train_batch_size\", \"weight_decay\"]\n",
    "metrics = [\"eval_accuracy\", \"eval_loss\"]\n",
    "\n",
    "for file in trainer_state_dict:\n",
    "    results.setdefault(\"file\", []).append(file)\n",
    "    for metric in metrics:\n",
    "        val = trainer_state_dict[file]['json']['log_history'][-1][metric]\n",
    "        results.setdefault(metric, []).append(val)\n",
    "\n",
    "    # Grab last train loss\n",
    "    for log in reversed(trainer_state_dict[file]['json']['log_history']):\n",
    "        if log.get('loss') is not None:\n",
    "            results.setdefault('train_loss', []).append(log['loss'])\n",
    "            break\n",
    "\n",
    "        \n",
    "    current_epoch = trainer_state_dict[file]['json']['epoch']\n",
    "    results.setdefault('current_epoch', []).append(current_epoch)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "results_df = pd.DataFrame(results)\n",
    "cols = ['file', 'eval_accuracy', 'train_loss', 'current_epoch', 'eval_loss']\n",
    "results_df = results_df[cols]\n",
    "results_df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "results_df.to_csv(FINETUNED_MODEL_VAL_SAVE)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\"\n",
    "import sys\n",
    "import argparse\n",
    "import time\n",
    "\n",
    "from datasets import Dataset\n",
    "from models.medair_models.bertha.modelling_bertha import BerthaForSequenceClassification\n",
    "from models.medair_models.bertha.configuration_bertha import BerthaConfig\n",
    "from torch.optim import Adam\n",
    "from transformers import get_constant_schedule, Trainer, TrainingArguments, AutoTokenizer\n",
    "from evaluate import load\n",
    "from collators import DataCollatorForMultiLabelsDiseasePrediction\n",
    "\n",
    "import torch\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "import random\n",
    "import pickle"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from collections import Counter\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "import pickle\n",
    "import random\n",
    "random.seed(0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "accuracy_metric = load('accuracy')\n",
    "precision_metric = load('precision')\n",
    "recall_metric = load('recall')\n",
    "f1_metric = load('f1')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_LOCATION)\n",
    "dataset_test = Dataset.load_from_disk(TESTSET_LOCATION)\n",
    "test_dataloader = torch.utils.data.DataLoader(dataset_test, batch_size=1, collate_fn=DataCollatorForMultiLabelsDiseasePrediction(tokenizer))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "disease_finetuning_models_list = pd.read_csv(FINETUNED_MODEL_VAL_SAVE)\n",
    "disease_finetuning_models_list.reset_index()\n",
    "pd.set_option('display.max_colwidth', None)\n",
    "disease_finetuning_models_list = disease_finetuning_models_list.reset_index(drop=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tqdm import tqdm\n",
    "from torch import nn\n",
    "def compute_metrics_all(i):\n",
    "    torch.cuda.manual_seed(0)\n",
    "    torch.cuda.manual_seed_all(0)\n",
    "    random.seed(0)\n",
    "    np.random.seed(0)\n",
    "    model = BerthaForSequenceClassification.from_pretrained(disease_finetuning_models_list['file'][i][:-18], problem_type=\"multi_label_classification\", num_labels = 22).to('cuda')\n",
    "    if USE_VSA:\n",
    "        emb = model.bertha.embeddings.word_embeddings.embedding_gen()\n",
    "        model.bertha.embeddings.word_embeddings.weight = nn.Parameter(emb, requires_grad=(True))\n",
    "        model.bertha.embeddings.word_embeddings.use_vsa = False\n",
    "        model.bertha.embeddings.word_embeddings.freeze_embeddings = False\n",
    "    all_predictions = []\n",
    "    all_prediction_scores = []\n",
    "    all_active_labels = []\n",
    "    all_labels = []\n",
    "    all_loss = []\n",
    "    auroc_counter = 0\n",
    "    auroc_total_ovr = 0\n",
    "    auroc_total_ovo = 0\n",
    "    model.eval()\n",
    "    for batch in tqdm(test_dataloader):\n",
    "        with torch.no_grad():\n",
    "            batch = {k: v.to(\"cuda\") for k, v in batch.items() if k != 'days_until_death'}\n",
    "\n",
    "            inputs = batch\n",
    "            # print(inputs)\n",
    "            # break\n",
    "            inputs.pop('special_tokens_mask')\n",
    "            inputs.pop('subject_id')\n",
    "            outputs = model(**inputs)\n",
    "            logits = outputs[\"logits\"]\n",
    "            loss = outputs[\"loss\"]\n",
    "            labels = inputs[\"labels\"]\n",
    "            \n",
    "            preds = (logits>0).int()\n",
    "\n",
    "            active_labels = labels.flatten()\n",
    "            active_preds = preds.flatten()\n",
    "            \n",
    "            all_predictions.extend(active_preds)\n",
    "            all_labels.append(active_labels)\n",
    "            all_active_labels.extend(active_labels)\n",
    "            all_loss.append(loss.item())\n",
    "    \n",
    "    accuracy = accuracy_metric.compute(predictions=all_predictions, references=all_active_labels)\n",
    "    precision = precision_metric.compute(predictions=all_predictions, references=all_active_labels, average='macro')\n",
    "    recall = recall_metric.compute(predictions=all_predictions, references=all_active_labels, average='macro')\n",
    "    f1 = f1_metric.compute(predictions=all_predictions, references=all_active_labels, average='macro')\n",
    "\n",
    "    all_loss = sum(all_loss)/len(all_loss)\n",
    "\n",
    "    return [accuracy['accuracy'], precision['precision'], recall['recall'], f1['f1'], all_loss]\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "accuracy_list = []\n",
    "precision_list = []\n",
    "recall_list = []\n",
    "f1_list = []\n",
    "auroc_list = []\n",
    "loss_list = []"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for i in range(len(disease_finetuning_models_list)):\n",
    "    metrics = compute_metrics_all(i)\n",
    "    accuracy_list.append(metrics[0])\n",
    "    precision_list.append(metrics[1])\n",
    "    recall_list.append(metrics[2])\n",
    "    f1_list.append(metrics[3])\n",
    "    loss_list.append(metrics[4])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df = pd.DataFrame({'accuracy': accuracy_list, 'precision':precision_list, 'recall':recall_list, 'f1':f1_list,\"model\":disease_finetuning_models_list['file'], \"loss\":loss_list})\n",
    "df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df.to_csv(FINETUNED_MODEL_TEST_SAVE)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
