{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "ea081933-7e2e-40cc-b83f-b22dcbde4966",
   "metadata": {},
   "source": [
    "# Electrocardiogram Analysis using ECG-FM\n",
    "\n",
    "The electrocardiogram (ECG) is a low-cost, non-invasive diagnostic test that has been ubiquitous in the assessment and management of cardiovascular disease for decades. ECG-FM is a pretrained, open foundation model for ECG analysis.\n",
    "\n",
    "In this tutorial, we will introduce how to perform inference for multi-label classification using a finetuned ECG-FM model. Specifically, we will take a model finetuned on the [PhysioNet 2021 v1.0.3 dataset](https://physionet.org/content/challenge-2021/1.0.3/) and perform inference on a sample of the [CODE-15% v1.0.0 dataset](https://zenodo.org/records/4916206/) to show how to adapt the predictions to a new set of labels.\n",
    "\n",
    "## Overview\n",
    "0. Installation\n",
    "1. Prepare checkpoints\n",
    "2. Prepare data\n",
    "3. Run inference\n",
    "4. Interpret results"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ab1e91ad-e3de-45ac-be76-50551cced4e1",
   "metadata": {},
   "source": [
    "## 0. Installation\n",
    "\n",
    "ECG-FM was developed in collaboration with the [fairseq_signals](https://github.com/Jwoo5/fairseq-signals) framework, which implements a collection of deep learning methods for ECG analysis.\n",
    "\n",
    "Clone [fairseq_signals](https://github.com/Jwoo5/fairseq-signals) and refer to the requirements and installation section in the top-level README. After following those steps, install `pandas` and make the environment accessible within this notebook by running:\n",
    "```\n",
    "python3 -m pip install --user pandas\n",
    "python3 -m pip install --user --upgrade jupyterlab ipywidgets ipykernel\n",
    "python3 -m ipykernel install --user --name ecg_fm\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "43635fe2-aa31-4aed-8f19-da356d3c0177",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import pandas as pd\n",
    "import torch\n",
    "\n",
    "from fairseq_signals.utils.store import MemmapReader"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6ca90c90-9bf6-40fb-b2e6-0bb4821c5e27",
   "metadata": {},
   "outputs": [],
   "source": [
    "root = os.getcwd()\n",
    "root"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "86de85f5-8266-4e41-85d3-d2dace48e35b",
   "metadata": {},
   "outputs": [],
   "source": [
    "fairseq_signals_root = # TODO\n",
    "fairseq_signals_root = fairseq_signals_root.rstrip('/')\n",
    "fairseq_signals_root"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "79e95866-bf84-4ebf-9197-ab11e013d87a",
   "metadata": {},
   "source": [
    "## 1. Prepare checkpoints"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3e8ac5b8-0a81-4b84-b589-7d756b21c780",
   "metadata": {},
   "source": [
    "### Download checkpoints\n",
    "\n",
    "The checkpoints are available on [HuggingFace](https://huggingface.co/wanglab/ecg-fm-preprint). Alternatively, they can be downloaded using the below commands.\n",
    "\n",
    "**Disclaimer: These models are different from those reported in our arXiv paper.** These BERT-Base sized models were trained purely on public data sources due to privacy concerns surrounding UHN-ECG data and patient identification. Validation for the final models will be available upon full publication."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "898c9c4e-e908-4637-8b03-9103417d5201",
   "metadata": {},
   "outputs": [],
   "source": [
    "from huggingface_hub import hf_hub_download\n",
    "\n",
    "_ = hf_hub_download(\n",
    "    repo_id='wanglab/ecg-fm-preprint',\n",
    "    filename='physionet_finetuned.pt',\n",
    "    local_dir=os.path.join(root, 'ckpts'),\n",
    ")\n",
    "_ = hf_hub_download(\n",
    "    repo_id='wanglab/ecg-fm-preprint',\n",
    "    filename='physionet_finetuned.yaml',\n",
    "    local_dir=os.path.join(root, 'ckpts'),\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "29bae135-55dd-4775-b24e-68aa9ee7fd73",
   "metadata": {},
   "outputs": [],
   "source": [
    "assert os.path.isfile(os.path.join(root, 'ckpts/physionet_finetuned.pt'))\n",
    "assert os.path.isfile(os.path.join(root, 'ckpts/physionet_finetuned.yaml'))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ad3c12fa-6197-41bd-9d95-a6a02893fc8b",
   "metadata": {},
   "source": [
    "## 2. Prepare data\n",
    "\n",
    "The model being used was finetuned on the [PhysioNet 2021 v1.0.3 dataset](https://physionet.org/content/challenge-2021/1.0.3/). To simplify this tutorial, we have processed a sample of 10 ECGs (14 5s segments) from the [CODE-15% v1.0.0 dataset](https://zenodo.org/records/4916206/) so that we may demonstrate how to adapt the predictions to a new set of labels.\n",
    "\n",
    "If looking to perform inference on a full dataset (or using your own dataset), refer to the flexible, end-to-end, multi-source data preprocessing pipeline described [here](https://github.com/Jwoo5/fairseq-signals/tree/master/scripts/preprocess/ecg). Its README is useful for understanding how the data is organized. There are preprocessing scripts implemented for several datasets."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8a6e9a4e-c23b-4f9b-aa82-0a7831042dc3",
   "metadata": {},
   "source": [
    "### Update manifest"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "75624573-d1c5-4ab6-8dab-f682bb1c349f",
   "metadata": {},
   "source": [
    "The segmented split must be saved with absolute file paths, so we will update the current relative file paths accordingly."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a5b99a8f-6eea-4c60-9c27-c2ef388c4a95",
   "metadata": {},
   "outputs": [],
   "source": [
    "segmented_split = pd.read_csv(\n",
    "    os.path.join(root, 'data/code_15/segmented_split_incomplete.csv'),\n",
    "    index_col='idx',\n",
    ")\n",
    "segmented_split['path'] = (root + '/data/code_15/segmented/') + segmented_split['path']\n",
    "segmented_split.to_csv(os.path.join(root, 'data/code_15/segmented_split.csv'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "61710580-87fc-43fe-9d7a-0465ebed46ea",
   "metadata": {},
   "outputs": [],
   "source": [
    "assert os.path.isfile(os.path.join(root, 'data/code_15/segmented_split.csv'))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b71fabac-3150-433d-9fe8-c559ec71022a",
   "metadata": {},
   "source": [
    "Run the follow commands togenerate the `test.tsv` file used for inference."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "30575700-2503-4052-a7a9-58adc26e634c",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(f\"\"\"cd {fairseq_signals_root}/scripts/preprocess\n",
    "python manifests.py \\\\\n",
    "    --split_file_paths \"{root}/data/code_15/segmented_split.csv\" \\\\\n",
    "    --save_dir \"{root}/data/manifests/code_15_subset10/\"\n",
    "\"\"\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b1f13589-c5ff-4a04-bd42-73bdfc807f72",
   "metadata": {},
   "outputs": [],
   "source": [
    "assert os.path.isfile(os.path.join(root, 'data/manifests/code_15_subset10/test.tsv'))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "dfb153be-9d87-47f3-8855-3f5b947e53a3",
   "metadata": {},
   "source": [
    "## 3. Run inference\n",
    "\n",
    "Inside our environment, we can run the following command using hydra's command line interface to extract the logits for each segment. There must be an available GPU."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7fe363b7-8d1b-4831-91eb-9b326fed1593",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(f\"\"\"fairseq-hydra-inference \\\\\n",
    "    task.data=\"{root}/data/manifests/code_15_subset10/\" \\\\\n",
    "    common_eval.path=\"{root}/ckpts/physionet_finetuned.pt\" \\\\\n",
    "    common_eval.results_path=\"{root}/outputs\" \\\\\n",
    "    model.num_labels=26 \\\\\n",
    "    dataset.valid_subset=\"test\" \\\\\n",
    "    dataset.batch_size=10 \\\\\n",
    "    dataset.num_workers=3 \\\\\n",
    "    dataset.disable_validation=false \\\\\n",
    "    distributed_training.distributed_world_size=1 \\\\\n",
    "    distributed_training.find_unused_parameters=True \\\\\n",
    "    --config-dir \"{root}/ckpts\" \\\\\n",
    "    --config-name physionet_finetuned\n",
    "\"\"\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "033dfa89-d249-410b-b5ab-39737686b4b9",
   "metadata": {},
   "outputs": [],
   "source": [
    "assert os.path.isfile(os.path.join(root, 'outputs/outputs_test.npy'))\n",
    "assert os.path.isfile(os.path.join(root, 'outputs/outputs_test_header.pkl'))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3bb473cb-87e9-4a30-8a49-5eda954d9519",
   "metadata": {},
   "source": [
    "## 4. Interpret results\n",
    "\n",
    "The logits are ordered same as the samples in the manifest and labels in the label definition."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "635f2f4b-fc76-4b09-87f4-82265e13b6a1",
   "metadata": {},
   "source": [
    "### Get predictions on PhysioNet 2021 labels"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f9b2b3fb-9563-425e-9b7e-abf374ef0483",
   "metadata": {},
   "outputs": [],
   "source": [
    "physionet2021_label_def = pd.read_csv(\n",
    "    os.path.join(root, 'data/physionet2021/labels/label_def.csv'),\n",
    "     index_col='name',\n",
    ")\n",
    "physionet2021_label_names = physionet2021_label_def.index\n",
    "physionet2021_label_def"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "83e228ae-a72b-492d-ba0c-be2a667c6a7d",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load the array of computed logits\n",
    "logits = MemmapReader.from_header(\n",
    "    os.path.join(root, 'outputs/outputs_test.npy')\n",
    ")[:]\n",
    "logits.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2630a917-ec9c-41f3-8587-9d94625c4d7b",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Construct predictions from logits\n",
    "pred = pd.DataFrame(\n",
    "    torch.sigmoid(torch.tensor(logits)).numpy(),\n",
    "    columns=physionet2021_label_names,\n",
    ")\n",
    "\n",
    "# Join in sample information\n",
    "pred = segmented_split.reset_index().join(pred, how='left').set_index('idx')\n",
    "pred"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3079ec7e-a1a1-409e-87e6-8a36a6e63dc1",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Perform a (crude) thresholding of 0.5 for all labels\n",
    "pred_thresh = pred.copy()\n",
    "pred_thresh[physionet2021_label_names] = pred_thresh[physionet2021_label_names] > 0.5\n",
    "\n",
    "# Construct a readable column of predicted labels for each sample\n",
    "pred_thresh['labels'] = pred_thresh[physionet2021_label_names].apply(\n",
    "    lambda row: ', '.join(row.index[row]),\n",
    "    axis=1,\n",
    ")\n",
    "pred_thresh['labels']"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7cfbe91f-6b85-47dd-9cd5-65a484e75074",
   "metadata": {},
   "source": [
    "### Map predictions to CODE-15 labels"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0f64888e-01a9-4995-b034-b1923715e1f3",
   "metadata": {},
   "outputs": [],
   "source": [
    "code_15_label_def = pd.read_csv(\n",
    "    os.path.join(root, 'data/code_15/labels/label_def.csv'),\n",
    "     index_col='name',\n",
    ")\n",
    "code_15_label_names = code_15_label_def.index\n",
    "code_15_label_def"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "58b7e20f-e2bb-4b00-89c4-ec01cc542464",
   "metadata": {},
   "outputs": [],
   "source": [
    "label_mapping = {\n",
    "    'CRBBB|RBBB': 'RBBB',\n",
    "    'CLBBB|LBBB': 'LBBB',\n",
    "    'SB': 'SB',\n",
    "    'STach': 'ST',\n",
    "    'AF': 'AF',\n",
    "}\n",
    "\n",
    "physionet2021_label_def['name_mapped'] = physionet2021_label_def.index.map(label_mapping)\n",
    "physionet2021_label_def"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bc33563b-48e7-46c9-a5e0-0007820381e9",
   "metadata": {},
   "outputs": [],
   "source": [
    "pred_mapped = pred.copy()\n",
    "pred_mapped.drop(set(physionet2021_label_names) - set(label_mapping.keys()), axis=1, inplace=True)\n",
    "pred_mapped.rename(label_mapping, axis=1, inplace=True)\n",
    "pred_mapped"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f8b6d8b3-5ccd-4564-aab9-31e0617c021e",
   "metadata": {},
   "outputs": [],
   "source": [
    "pred_thresh_mapped = pred_thresh.copy()\n",
    "pred_thresh_mapped.drop(set(physionet2021_label_names) - set(label_mapping.keys()), axis=1, inplace=True)\n",
    "pred_thresh_mapped.rename(label_mapping, axis=1, inplace=True)\n",
    "pred_thresh_mapped['predicted'] = pred_thresh_mapped[label_mapping.values()].apply(\n",
    "    lambda row: ', '.join(row.index[row]),\n",
    "    axis=1,\n",
    ")\n",
    "pred_thresh_mapped"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c6031ae3-7ea0-4d9a-8859-8ca9a52020ea",
   "metadata": {},
   "source": [
    "### Compare predicted CODE-15 to actual"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b917e367-f36b-4cde-889d-8f9baff8436f",
   "metadata": {},
   "outputs": [],
   "source": [
    "code_15_labels = pd.read_csv(os.path.join(root, 'data/code_15/labels/labels.csv'), index_col='idx')\n",
    "code_15_labels['actual'] = code_15_labels[label_mapping.values()].apply(\n",
    "    lambda row: ', '.join(row.index[row]),\n",
    "    axis=1,\n",
    ")\n",
    "code_15_labels"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "45505b74-0fd0-4898-9dca-61ce256277ad",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Visualize predicted and actual labels side-by-side\n",
    "pred_thresh_mapped[['predicted']].join(code_15_labels[['actual']], how='left')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d7c6b21a-2d1e-453e-ab8f-dc2015dd80a0",
   "metadata": {},
   "source": [
    "# 5. Extra - Load models"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ed6cece6-a6e0-46c5-b022-7bdeb9ce68ec",
   "metadata": {},
   "source": [
    "Outside of the scripts/hydra client, models can be easily loaded as shown below:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e8ce3df7-d9cd-4211-b85a-71436b1987af",
   "metadata": {},
   "outputs": [],
   "source": [
    "from fairseq_signals.models import build_model_from_checkpoint"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "38b6ee04-cceb-427b-9aee-53cb622feb7e",
   "metadata": {},
   "outputs": [],
   "source": [
    "model_finetuned = build_model_from_checkpoint(\n",
    "    checkpoint_path=os.path.join(root, 'ckpts/physionet_finetuned.pt')\n",
    ")\n",
    "model_finetuned"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e2a04486-b151-4b06-9b71-9fa9dd3d0c70",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Run if the pretrained model hasn't already been downloaded\n",
    "from huggingface_hub import hf_hub_download\n",
    "\n",
    "_ = hf_hub_download(\n",
    "    repo_id='wanglab/ecg-fm-preprint',\n",
    "    filename='mimic_iv_ecg_physionet_pretrained.pt',\n",
    "    local_dir=os.path.join(root, 'ckpts'),\n",
    ")\n",
    "_ = hf_hub_download(\n",
    "    repo_id='wanglab/ecg-fm-preprint',\n",
    "    filename='mimic_iv_ecg_physionet_pretrained.yaml',\n",
    "    local_dir=os.path.join(root, 'ckpts'),\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8afcc57b-cf99-4538-bac9-82f3381bcef4",
   "metadata": {},
   "outputs": [],
   "source": [
    "model_pretrained = build_model_from_checkpoint(\n",
    "    checkpoint_path=os.path.join(root, 'ckpts/mimic_iv_ecg_physionet_pretrained.pt')\n",
    ")\n",
    "model_pretrained"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "84f6b766-ed35-4bc2-9f7a-d6c98d6bc6aa",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "ecg_fm",
   "language": "python",
   "name": "ecg_fm"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
