{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "9d5a74b0-0747-4768-a4b5-b79a15e83e42",
   "metadata": {},
   "source": [
    "This notebook is designed to produce saliency maps for ECG transformer classification models in the [fairseq-signals](https://github.com/Jwoo5/fairseq-signals) repository.\n",
    "\n",
    "Before running this notebook, there are some precursor steps to be taken:\n",
    "1. Compute `saliency_{split}.npy` files using `fairseq-hydra-validate` with the `common_eval.extract=[saliency]` command-line argument\n",
    "2. Run the `saliency.py` script to generate a `attn_max_{split}.npy` file\n",
    "\n",
    "Here is an example command for step 1, assuming [this preprocessing procedure](https://github.com/Jwoo5/fairseq-signals/blob/master/scripts/preprocess/ecg/README.md) was followed:\n",
    "```\n",
    "FAIRSEQ_ROOT=\"TODO\"\n",
    "MANIFEST_DIR=\"TODO\"\n",
    "LABEL_DIR=\"TODO\"\n",
    "OUTPUT_DIR=\"TODO\"\n",
    "CHECKPOINT_NUM=\"TODO\"\n",
    "\n",
    "CHECKPOINT=\"$OUTPUT_DIR/checkpoint$CHECKPOINT_NUM.pt\"\n",
    "NUM_LABELS=$(($(wc -l < \"$LABEL_DIR/label_def.csv\") - 1))\n",
    "\n",
    "fairseq-hydra-validate \\\n",
    "    task.data=$MANIFEST_DIR \\\n",
    "    common_eval.path=$CHECKPOINT \\\n",
    "    common_eval.extract=[saliency] \\\n",
    "    common_eval.results_path=$OUTPUT_DIR \\\n",
    "    model.num_labels=$NUM_LABELS \\\n",
    "    dataset.valid_subset=test \\\n",
    "    dataset.batch_size=256 \\\n",
    "    dataset.num_workers=10 \\\n",
    "    dataset.disable_validation=false \\\n",
    "    distributed_training.distributed_world_size=1 \\\n",
    "    distributed_training.find_unused_parameters=True \\\n",
    "    +task.label_file=$LABEL_DIR/y.npy \\\n",
    "    --config-dir $FAIRSEQ_ROOT/examples/w2v_cmsc/config/finetuning/ecg_transformer \\\n",
    "    --config-name diagnosis\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "56c38d98-7334-48f5-8832-063277b266de",
   "metadata": {},
   "source": [
    "# Setup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f9a5f570-2f9a-4eff-81e2-036db50debf8",
   "metadata": {},
   "outputs": [],
   "source": [
    "from typing import Tuple\n",
    "import os\n",
    "import yaml\n",
    "\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "\n",
    "from scipy.io import loadmat\n",
    "from scipy.ndimage import map_coordinates\n",
    "\n",
    "from fairseq_signals.utils.file import extract_filename\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "\n",
    "def blend_colors_hex(start_color: str, end_color: str, activations: np.ndarray) -> np.ndarray:\n",
    "    \"\"\"\n",
    "    Blends between two colors based on an array of blend factors.\n",
    "\n",
    "    Parameters\n",
    "    ----------\n",
    "    start_color : str\n",
    "        Hexadecimal color code for the start color.\n",
    "    end_color : str\n",
    "        Hexadecimal color code for the end color.\n",
    "    activations : np.ndarray\n",
    "        An array of blend factors where 0 corresponds to the start color and 1 to the end color.\n",
    "\n",
    "    Returns\n",
    "    -------\n",
    "    np.ndarray\n",
    "        An array of hexadecimal color codes resulting from the blends.\n",
    "\n",
    "    Raises\n",
    "    ------\n",
    "    ValueError\n",
    "        If any of the input blend factors are not within the range [0, 1].\n",
    "    \"\"\"\n",
    "    if np.any((activations < 0) | (activations > 1)):\n",
    "        raise ValueError(\"All blend factors must be between 0 and 1.\")\n",
    "\n",
    "    # Convert hexadecimal to RGB\n",
    "    def hex_to_rgb(hex_color: str) -> Tuple[int]:\n",
    "        return tuple(int(hex_color[i: i+2], 16) for i in (1, 3, 5))\n",
    "\n",
    "    # Get RGB tuples\n",
    "    start_rgb = np.array(hex_to_rgb(start_color))\n",
    "    end_rgb = np.array(hex_to_rgb(end_color))\n",
    "\n",
    "    # Blend RGB values\n",
    "    blended_rgb = np.outer(1 - activations, start_rgb) + np.outer(activations, end_rgb)\n",
    "\n",
    "    # Convert blended RGB back to hex codes\n",
    "    return blended_rgb / 255\n",
    "\n",
    "def colored_line_segments(data: np.ndarray, colors: np.ndarray, ax=None, **kwargs):\n",
    "    \"\"\"\n",
    "    Plots line segments based on the provided data points, with each segment\n",
    "    colored according to the corresponding color specification in `colors`.\n",
    "\n",
    "    Parameters\n",
    "    ----------\n",
    "    data : np.ndarray\n",
    "        Array of y-values for the line segments.\n",
    "    colors : np.ndarray\n",
    "        Array of colors, each color applied to the corresponding line segment\n",
    "        between points i and i+1.\n",
    "\n",
    "    Raises\n",
    "    ------\n",
    "    ValueError\n",
    "        If the `colors` array does not have exactly one less element than the `data` array,\n",
    "        as each segment needs a unique color.\n",
    "\n",
    "    Returns\n",
    "    -------\n",
    "    None\n",
    "    \"\"\"\n",
    "    if len(colors) != len(data) - 1:\n",
    "        raise ValueError(\"Colors array must have one fewer elements than data array.\")\n",
    "\n",
    "    if ax is None:\n",
    "        for i in range(len(data) - 1):\n",
    "            plt.plot([i, i + 1], [data[i], data[i + 1]], color=colors[i], **kwargs)\n",
    "    else:\n",
    "        for i in range(len(data) - 1):\n",
    "            ax.plot([i, i + 1], [data[i], data[i + 1]], color=colors[i], **kwargs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "06de5192-a2bb-47ca-bcd9-15a0f0ceed2f",
   "metadata": {},
   "outputs": [],
   "source": [
    "manifest_path = '...' # Multi-source 'manifest.csv' filepath\n",
    "run_directory = '...' # Directory with 'config.yaml' from training, as well as 'attn_max_{split}.npy' file\n",
    "segmented_dir = '...' # Directory of segmented files (raw signal values over which attention coloring is laid)\n",
    "\n",
    "split = 'test'\n",
    "sample_size = 2500\n",
    "sample_rate = 500\n",
    "lead = 'II'"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4758c399-6cb1-4a87-8d86-d2b049a77b69",
   "metadata": {},
   "source": [
    "# Load"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ebf26da3-6d03-4ab8-94ec-831b8d0c1603",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load attention output weight max values\n",
    "attn_max = np.load(os.path.join(run_directory, f'attn_max_{split}.npy'))\n",
    "attn_max.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e4b810d4-168c-48d4-905f-88c3f8647714",
   "metadata": {},
   "outputs": [],
   "source": [
    "manifest = pd.read_csv(manifest_path, low_memory=False)\n",
    "manifest.rename(columns={\n",
    "    'sample_rate': 'sample_rate_org',\n",
    "    'sample_size': 'sample_size_org',\n",
    "}, inplace=True)\n",
    "manifest"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7a7cca05-5347-46bb-9442-0c2e68cc1bbc",
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(os.path.join(run_directory, 'config.yaml'), \"r\") as f:\n",
    "    config = yaml.safe_load(f)\n",
    "\n",
    "manifest_dir = config['task']['data']\n",
    "label_file = config['task']['label_file']\n",
    "\n",
    "# Incorporate sample index and original sampling sizes/rates\n",
    "meta = pd.read_csv(os.path.join(manifest_dir, f'{split}.tsv'), sep='\\t', index_col='Unnamed: 0')\n",
    "meta = meta[meta.columns[0]].rename('sample_size')\n",
    "meta.index.name = 'file'\n",
    "meta = meta.reset_index()\n",
    "meta['save_file'] = extract_filename(meta['file']).replace(\"_\\d+\\.mat$\", '.mat', regex=True)\n",
    "meta = meta.merge(\n",
    "    manifest[['save_file', 'idx', 'sample_size_org', 'sample_rate_org']],\n",
    "    on='save_file',\n",
    "    how='left',\n",
    ")\n",
    "\n",
    "# Incorporate attn_max\n",
    "meta['attn_max'] = list(attn_max)\n",
    "\n",
    "# Incorporate labels\n",
    "if config['task']['label_file'] is not None:\n",
    "    label_dir = os.path.dirname(config['task']['label_file'])\n",
    "    label_def = pd.read_csv(os.path.join(label_dir, \"label_def.csv\"), index_col='name')\n",
    "    y = np.load(config['task']['label_file'])\n",
    "\n",
    "    # Align labels with manifest\n",
    "    labels = y[meta[\"idx\"].values]\n",
    "\n",
    "    # Convert into DataFrame format\n",
    "    labels_pd = pd.DataFrame(\n",
    "        labels,\n",
    "        columns=label_def.index,\n",
    "    ).astype(bool)\n",
    "    labels_pd.index.name = 'idx'\n",
    "    meta = pd.concat([meta, labels_pd], axis=1)\n",
    "\n",
    "meta"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3c5824e6-6068-48d7-934b-21c5427965dc",
   "metadata": {},
   "source": [
    "# Filter"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "db454265-bf00-4990-a3ac-3c8c2d3c7d67",
   "metadata": {},
   "outputs": [],
   "source": [
    "meta_filtered = meta[meta['Sinus rhythm']].sample(3).copy()\n",
    "meta_filtered"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9e443182-661f-4ccd-8694-6ff122526f4e",
   "metadata": {},
   "source": [
    "# Prepare plots"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cefa5b7d-dc24-4da8-b5e2-112dd9e5e17b",
   "metadata": {},
   "outputs": [],
   "source": [
    "meta_filtered['lead'] = ['I', 'II', 'III', 'aVR', 'aVL', 'aVF', 'V1', 'V2', 'V3', 'V4', 'V5', 'V6'].index(lead)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "55eda8cb-bd09-43c4-8624-e01dcfb65bbf",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load the original signal values\n",
    "meta_filtered['seg_path'] = segmented_dir.rstrip('/') + '/' + meta_filtered['file']\n",
    "assert meta_filtered['seg_path'].apply(os.path.isfile).all()\n",
    "\n",
    "meta_filtered['feats'] = meta_filtered.apply(\n",
    "    lambda row: loadmat(row['seg_path'])['feats'][row['lead']],\n",
    "    axis=1,\n",
    ")\n",
    "meta_filtered['sample_size_extracted'] = meta_filtered['feats'].apply(\n",
    "    lambda feats: feats.shape[0]\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aa2bde98-7053-45f3-be1c-6a6dce0a379b",
   "metadata": {},
   "outputs": [],
   "source": [
    "def prep_saliency_values(row):\n",
    "    attn_max = row['attn_max']\n",
    "\n",
    "    # Resample to original sample size\n",
    "    new_dims = [\n",
    "        np.linspace(0, original_length-1, new_length) \\\n",
    "        for original_length, new_length in \\\n",
    "        zip(attn_max.shape, (row['sample_size_extracted'] - 1,))\n",
    "    ]\n",
    "    coords = np.meshgrid(*new_dims, indexing='ij')\n",
    "    attn_max = map_coordinates(attn_max, coords)\n",
    "\n",
    "    # Min-max normalization\n",
    "    attn_max = attn_max - attn_max.min()\n",
    "    attn_max = attn_max/attn_max.max()\n",
    "\n",
    "    return attn_max\n",
    "\n",
    "meta_filtered['saliency_prepped'] = meta_filtered.apply(prep_saliency_values, axis=1)\n",
    "meta_filtered['colors'] = meta_filtered['saliency_prepped'].apply(lambda sal: blend_colors_hex('#0047AB', '#DC143C', sal))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4607fc08-8c27-46f4-bc9e-0c3e94226688",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Plot samples\n",
    "for i, (_, row) in enumerate(meta_filtered.iterrows()):\n",
    "    fig = plt.figure(i, figsize=(20, 2))\n",
    "    fig.tight_layout()\n",
    "    plt.axis('off')\n",
    "    colored_line_segments(row['feats'], row['colors'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d755a745-2631-454d-a0ea-f620e55d3ed1",
   "metadata": {},
   "outputs": [],
   "source": [
    "label_str = pd.DataFrame(\n",
    "    np.argwhere(meta_filtered[label_def.index]).tolist()\n",
    ").set_index(0)[1].map({i: val for i, val in enumerate(label_def.index)}).groupby(\n",
    "    level=0,\n",
    ").agg('\\n'.join)\n",
    "label_str.index = label_str.index.map({i: ind for i, ind in enumerate(meta_filtered.index)})\n",
    "meta_filtered['label_str'] = label_str\n",
    "meta_filtered['label_str']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b9c02cbf-89da-4305-b845-6bba78386f1d",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Plot with true labels on the right-hand side\n",
    "for i, (_, row) in enumerate(meta_filtered.iterrows()):\n",
    "    fig = plt.figure(i, figsize=(20, 2))\n",
    "    fig.tight_layout()\n",
    "    plt.axis('off')\n",
    "    plt.subplots_adjust(right=0.9)\n",
    "    plt.figtext(\n",
    "        0.9,\n",
    "        0.5,\n",
    "        row['label_str'],\n",
    "        verticalalignment='center',\n",
    "        horizontalalignment='left',\n",
    "    )\n",
    "    colored_line_segments(row['feats'], row['colors'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "30f39b94-3ad5-4dfa-b446-32c761b1d567",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "ecg_env",
   "language": "python",
   "name": "ecg_env"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
