{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "import wfdb\n",
    "import os\n",
    "from sklearn.model_selection import train_test_split\n",
    "from matplotlib import pyplot as plt\n",
    "import seaborn as snss\n",
    "from pprint import pprint\n",
    "from tqdm import tqdm\n",
    "import sys\n",
    "sys.path.append(\"../finetune/\")\n",
    "sys.path.append(\"../utils\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# set your meta path of mimic-ecg\n",
    "meta_path = 'your_path/mimic-iv-ecg-diagnostic-electrocardiogram-matched-subset-1.0'\n",
    "report_csv = pd.read_csv(f'{meta_path}/machine_measurements.csv', low_memory=False)\n",
    "record_csv = pd.read_csv(f'{meta_path}/record_list.csv', low_memory=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def process_report(row):\n",
    "    # Select the relevant columns and filter out NaNs\n",
    "    report = row[['report_0', 'report_1', 'report_2', 'report_3', 'report_4', \n",
    "                  'report_5', 'report_6', 'report_7', 'report_8', 'report_9', \n",
    "                  'report_10', 'report_11', 'report_12', 'report_13', 'report_14', \n",
    "                  'report_15', 'report_16', 'report_17']].dropna()\n",
    "    # Concatenate the report\n",
    "    report = '. '.join(report)\n",
    "    # Replace and preprocess text\n",
    "    report = report.replace('EKG', 'ECG').replace('ekg', 'ecg')\n",
    "    report = report.strip(' ***').strip('*** ').strip('***').strip('=-').strip('=')\n",
    "    # Convert to lowercase\n",
    "    report = report.lower()\n",
    "\n",
    "    # concatenate the report if the report length is not 0\n",
    "    total_report = ''\n",
    "    if len(report.split()) != 0:\n",
    "        total_report = report\n",
    "        total_report = total_report.replace('\\n', ' ')\n",
    "        total_report = total_report.replace('\\r', ' ')\n",
    "        total_report = total_report.replace('\\t', ' ')\n",
    "        total_report += '.'\n",
    "    if len(report.split()) == 0:\n",
    "        total_report = 'empty'\n",
    "    # Calculate the length of the report in words\n",
    "    return len(report.split()), total_report\n",
    "\n",
    "tqdm.pandas()\n",
    "report_csv['report_length'], report_csv['total_report'] = zip(*report_csv.progress_apply(process_report, axis=1))\n",
    "# Filter out reports with less than 4 words\n",
    "report_csv = report_csv[report_csv['report_length'] >= 4]\n",
    "\n",
    "# you should get 771693 here\n",
    "print(report_csv.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "report_csv.reset_index(drop=True, inplace=True)\n",
    "record_csv = record_csv[record_csv['study_id'].isin(report_csv['study_id'])]\n",
    "record_csv.reset_index(drop=True, inplace=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# build an empty numpy array to store the data, we use int16 to save the space\n",
    "temp_npy = np.zeros((len(record_csv), 12, 5000), dtype=np.int16)\n",
    "\n",
    "for p in tqdm(record_csv['path']):\n",
    "    # read the data\n",
    "    ecg_path = os.path.join(meta_path, p)\n",
    "    record = wfdb.rdsamp(ecg_path)[0]\n",
    "    record = record.T\n",
    "    # replace the nan with the neighbor 5 value mean\n",
    "    # detect nan in each lead\n",
    "    if np.isnan(record).sum() == 0 and np.isinf(record).sum() == 0:\n",
    "        # normalize to 0-1\n",
    "        record = (record - record.min()) / (record.max() - record.min())\n",
    "        # scale the data\n",
    "        record *= 1000\n",
    "        # convert to int16\n",
    "        record = record.astype(np.int16)\n",
    "        # store the data\n",
    "        temp_npy[record_csv[record_csv['path'] == p].index[0]] = record[:, :5000]\n",
    "\n",
    "    else:\n",
    "        if np.isinf(record).sum() == 0:\n",
    "            for i in range(record.shape[0]):\n",
    "                nan_idx = np.where(np.isnan(record[:, i]))[0]\n",
    "                for idx in nan_idx:\n",
    "                    record[idx, i] = np.mean(record[max(0, idx-6):min(idx+6, record.shape[0]), i])\n",
    "        if np.isnan(record).sum() == 0:\n",
    "            for i in range(record.shape[0]):\n",
    "                inf_idx = np.where(np.isinf(record[:, i]))[0]\n",
    "                for idx in inf_idx:\n",
    "                    record[idx, i] = np.mean(record[max(0, idx-6):min(idx+6, record.shape[0]), i])\n",
    "\n",
    "        # normalize to 0-1\n",
    "        record = (record - record.min()) / (record.max() - record.min())\n",
    "        # scale the data\n",
    "        record *= 1000\n",
    "        # convert to int16\n",
    "        record = record.astype(np.int16)\n",
    "        # store the data\n",
    "        temp_npy[record_csv[record_csv['path'] == p].index[0]] = record[:, :5000]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# split to train and val\n",
    "train_npy, train_csv, val_npy, val_csv = train_test_split(temp_npy, report_csv, test_size=0.02, random_state=42)\n",
    "\n",
    "train_csv.reset_index(drop=True, inplace=True)\n",
    "val_csv.reset_index(drop=True, inplace=True)\n",
    "\n",
    "# save to your path\n",
    "np.save(\"your_path_train.npy\", train_npy)\n",
    "np.save(\"your_path_val.npy\", val_npy)\n",
    "train_csv.to_csv(\"your_path_train.csv\", index=False)\n",
    "val_csv.to_csv(\"your_path_val.csv\", index=False)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "chen",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.19"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
