{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "# autoreload\n",
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import sys\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "from sklearn.preprocessing import StandardScaler\n",
    "from tqdm import tqdm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "mimic_iv_path = \"/cis/home/charr165/Documents/physionet.org/mimiciv/2.2\"\n",
    "mm_dir = \"/cis/home/charr165/Documents/multimodal\"\n",
    "\n",
    "output_dir = os.path.join(mm_dir, \"preprocessing\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "include_notes = True\n",
    "include_cxr = True\n",
    "standard_scale = True\n",
    "include_missing = True\n",
    "include_ecg = True"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "# ireg_vitals_ts_df = pd.read_pickle(os.path.join(output_dir, \"ts_vitals_icu.pkl\"))\n",
    "# imputed_vitals = pd.read_pickle(os.path.join(output_dir, \"imputed_ts_vitals_icu.pkl\"))\n",
    "\n",
    "ireg_vitals_ts_df = pd.read_pickle(os.path.join(output_dir, \"ts_labs_vitals_icu.pkl\"))\n",
    "imputed_vitals = pd.read_pickle(os.path.join(output_dir, \"imputed_ts_labs_vitals_icu.pkl\"))\n",
    "\n",
    "ireg_vitals_ts_df = ireg_vitals_ts_df[ireg_vitals_ts_df['timedelta'] >= 0]\n",
    "imputed_vitals = imputed_vitals[imputed_vitals['timedelta'] >= 0]\n",
    "\n",
    "ireg_vitals_ts_df = ireg_vitals_ts_df[ireg_vitals_ts_df['timedelta'] <= 48]\n",
    "imputed_vitals = imputed_vitals[imputed_vitals['timedelta'] <= 48]\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "if include_notes:\n",
    "    notes_df = pd.read_pickle(os.path.join(output_dir, \"icu_notes_text_embeddings.pkl\"))\n",
    "    # notes_df = pd.read_pickle(os.path.join(output_dir, \"notes_text.pkl\"))\n",
    "    notes_df = notes_df[notes_df['stay_id'].notnull()]\n",
    "\n",
    "    notes_df = notes_df[notes_df['icu_time_delta'] >= 0]\n",
    "    notes_df = notes_df[notes_df['icu_time_delta'] <= 48]\n",
    "\n",
    "if include_cxr:\n",
    "    cxr_df = pd.read_pickle(os.path.join(output_dir, \"cxr_embeddings_icu.pkl\"))\n",
    "    cxr_df = cxr_df[cxr_df['icu_time_delta'] >= 0]\n",
    "    cxr_df = cxr_df[cxr_df['icu_time_delta'] <= 48]\n",
    "\n",
    "if include_ecg:\n",
    "    ecg_df = pd.read_pickle(os.path.join(output_dir, \"ecg_embeddings_icu.pkl\"))\n",
    "    ecg_df = ecg_df[ecg_df['icu_time_delta'] >= 0]\n",
    "    ecg_df = ecg_df[ecg_df['icu_time_delta'] <= 48]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "icustays_df = pd.read_csv(os.path.join(mimic_iv_path, \"icu\", \"icustays.csv\"), low_memory=False)\n",
    "icustays_df['intime'] = pd.to_datetime(icustays_df['intime'])\n",
    "icustays_df['outtime'] = pd.to_datetime(icustays_df['outtime'])\n",
    "\n",
    "icustays_df = icustays_df[icustays_df['los'] >= 2]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "valid_stay_ids = icustays_df['stay_id'].unique()\n",
    "\n",
    "ireg_vitals_ts_df = ireg_vitals_ts_df[ireg_vitals_ts_df['stay_id'].isin(valid_stay_ids)]\n",
    "imputed_vitals = imputed_vitals[imputed_vitals['stay_id'].isin(valid_stay_ids)]\n",
    "\n",
    "if include_notes:\n",
    "    notes_df = notes_df[notes_df['stay_id'].isin(valid_stay_ids)]\n",
    "\n",
    "if include_cxr:\n",
    "    cxr_df = cxr_df[cxr_df['stay_id'].isin(valid_stay_ids)]\n",
    "\n",
    "if include_ecg:\n",
    "    ecg_df = ecg_df[ecg_df['stay_id'].isin(valid_stay_ids)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "admissions_df = pd.read_csv(os.path.join(mimic_iv_path, \"hosp\", \"admissions.csv\"))\n",
    "admissions_df = admissions_df.rename(columns={\"hospital_expire_flag\": \"died\"})\n",
    "admissions_df = admissions_df[[\"subject_id\", \"hadm_id\", \"died\"]]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Number of stays with vitals: 35129\n",
      "Number of stays with either TS or notes: 35131\n",
      "Number of stays with either TS, notes, cxr: 35131\n",
      "Number of stays with either TS, notes, cxr, ecg: 35131\n"
     ]
    }
   ],
   "source": [
    "if not include_missing:\n",
    "    unique_stays = ireg_vitals_ts_df['stay_id'].unique()\n",
    "    print(f\"Number of stays with vitals: {len(unique_stays)}\")\n",
    "\n",
    "    if include_notes:\n",
    "        unique_stays = np.intersect1d(unique_stays, notes_df['stay_id'].unique())\n",
    "        print(f\"Number of stays with notes: {len(unique_stays)}\")\n",
    "\n",
    "    if include_cxr:\n",
    "        unique_stays = np.intersect1d(unique_stays, cxr_df['stay_id'].unique())\n",
    "        print(f\"Number of stays with cxr: {len(unique_stays)}\")\n",
    "\n",
    "    if include_ecg:\n",
    "        unique_stays = np.intersect1d(unique_stays, ecg_df['stay_id'].unique())\n",
    "        print(f\"Number of stays with ecg: {len(unique_stays)}\")        \n",
    "else:\n",
    "    unique_stays = ireg_vitals_ts_df['stay_id'].unique()\n",
    "    print(f\"Number of stays with vitals: {len(unique_stays)}\")\n",
    "\n",
    "    if include_notes:\n",
    "        # Get stays with either TS or notes\n",
    "        unique_stays = np.union1d(unique_stays, notes_df['stay_id'].unique())\n",
    "        print(f\"Number of stays with either TS or notes: {len(unique_stays)}\")\n",
    "    \n",
    "    if include_cxr:\n",
    "        unique_stays = np.union1d(unique_stays, cxr_df['stay_id'].unique())\n",
    "        print(f\"Number of stays with either TS, notes, cxr: {len(unique_stays)}\")\n",
    "\n",
    "    if include_ecg:\n",
    "        unique_stays = np.union1d(unique_stays, ecg_df['stay_id'].unique())\n",
    "        print(f\"Number of stays with either TS, notes, cxr, ecg: {len(unique_stays)}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create train, val, test splits\n",
    "np.random.seed(0)\n",
    "np.random.shuffle(unique_stays)\n",
    "train_stays = unique_stays[:int(0.7*len(unique_stays))]\n",
    "val_stays = unique_stays[int(0.7*len(unique_stays)):int(0.85*len(unique_stays))]\n",
    "test_stays = unique_stays[int(0.85*len(unique_stays)):]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_ireg_ts_df = ireg_vitals_ts_df[ireg_vitals_ts_df['stay_id'].isin(train_stays)].copy()\n",
    "train_imputed_df = imputed_vitals[imputed_vitals['stay_id'].isin(train_stays)].copy()\n",
    "\n",
    "cols = train_ireg_ts_df.columns.tolist()\n",
    "cols = [col for col in cols if col not in ['subject_id', 'hadm_id', 'stay_id', 'timedelta']]\n",
    "\n",
    "if standard_scale:\n",
    "    for col in cols:\n",
    "        scaler = StandardScaler()\n",
    "        scaler.fit(train_ireg_ts_df[[col]])\n",
    "        ireg_vitals_ts_df[col] = scaler.transform(ireg_vitals_ts_df[[col]])\n",
    "\n",
    "        scaler = StandardScaler()\n",
    "        scaler.fit(train_imputed_df[[col]])\n",
    "        imputed_vitals[col] = scaler.transform(imputed_vitals[[col]])\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  9%|▉         | 2194/24591 [00:24<04:07, 90.48it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "error!\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 24591/24591 [04:27<00:00, 92.09it/s]\n",
      "100%|██████████| 5270/5270 [00:57<00:00, 92.00it/s]\n",
      " 32%|███▏      | 1686/5270 [00:18<00:38, 94.06it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "error!\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 5270/5270 [00:57<00:00, 92.22it/s]\n"
     ]
    }
   ],
   "source": [
    "def get_stay_list(stays):\n",
    "    stays_list = []\n",
    "\n",
    "    for curr_stay in tqdm(stays):\n",
    "        curr_stay_ireg = ireg_vitals_ts_df[ireg_vitals_ts_df['stay_id'] == curr_stay].copy()\n",
    "        curr_stay_imputed = imputed_vitals[imputed_vitals['stay_id'] == curr_stay].copy()\n",
    "        curr_icustay = icustays_df[icustays_df['stay_id'] == curr_stay].copy()\n",
    "\n",
    "        try:\n",
    "            curr_hadm_id = curr_stay_ireg['hadm_id'].iloc[0]\n",
    "            died = admissions_df[admissions_df['hadm_id'] == curr_hadm_id]['died'].iloc[0]\n",
    "        except:\n",
    "            print(\"error!\")\n",
    "            continue\n",
    "\n",
    "        intime = icustays_df[icustays_df['stay_id'] == curr_stay]['intime'].iloc[0]\n",
    "        outtime = icustays_df[icustays_df['stay_id'] == curr_stay]['outtime'].iloc[0]\n",
    "        icu_time_delta = (outtime - intime).total_seconds() / 3600\n",
    "\n",
    "        if died == 1:\n",
    "            pass\n",
    "\n",
    "        if (icu_time_delta < 96) & (died == 0):\n",
    "            label = 1\n",
    "        else:\n",
    "            label = 0\n",
    "\n",
    "        if len(curr_stay_ireg) == 0:\n",
    "            continue\n",
    "\n",
    "        if include_notes:\n",
    "            curr_stay_notes = notes_df[notes_df['stay_id'] == curr_stay].copy()\n",
    "\n",
    "        if include_cxr:\n",
    "            curr_stay_cxr = cxr_df[cxr_df['stay_id'] == curr_stay].copy()\n",
    "\n",
    "        curr_stay_dict = {}\n",
    "        curr_stay_dict['name'] = curr_stay_ireg['subject_id'].iloc[0]\n",
    "        curr_stay_dict['hadm_id'] = curr_stay_ireg['hadm_id'].iloc[0]\n",
    "        curr_stay_dict['stay_id'] = curr_stay\n",
    "        curr_stay_dict['ts_tt'] = curr_stay_ireg['timedelta'].values\n",
    "\n",
    "        curr_stay_ireg.drop(columns=['subject_id', 'hadm_id', 'stay_id', 'timedelta'], inplace=True)\n",
    "        ireg_ts_mask = curr_stay_ireg.notnull()\n",
    "        curr_stay_ireg.fillna(0, inplace=True)\n",
    "        curr_stay_dict['irg_ts'] = curr_stay_ireg.values\n",
    "        curr_stay_dict['irg_ts_mask'] = ireg_ts_mask.values.astype(int)\n",
    "\n",
    "        curr_stay_imputed.drop(columns=['subject_id', 'hadm_id', 'stay_id', 'timedelta'], inplace=True)\n",
    "        curr_stay_dict['reg_ts'] = curr_stay_imputed.values\n",
    "\n",
    "        if include_notes:\n",
    "            if len(curr_stay_notes) == 0:\n",
    "                curr_stay_dict['text_data'] = []\n",
    "                curr_stay_dict['text_time_to_end'] = []\n",
    "                curr_stay_dict['text_embeddings'] = []\n",
    "                curr_stay_dict['text_missing'] = 1\n",
    "            else:\n",
    "                curr_stay_dict['text_data'] = curr_stay_notes['text'].tolist()\n",
    "                curr_stay_dict['text_time_to_end'] = curr_stay_notes['icu_time_delta'].values\n",
    "                curr_stay_dict['text_embeddings'] = [emb[0][0] for emb in curr_stay_notes['biobert_embeddings']]\n",
    "                curr_stay_dict['text_missing'] = 0\n",
    "\n",
    "        if include_cxr:\n",
    "            if len(curr_stay_cxr) == 0:\n",
    "                curr_stay_dict['cxr_feats'] = []\n",
    "                curr_stay_dict['cxr_time'] = []\n",
    "                curr_stay_dict['cxr_missing'] = 1\n",
    "            else:\n",
    "                curr_stay_dict['cxr_feats'] = curr_stay_cxr['densefeatures'].tolist()\n",
    "                curr_stay_dict['cxr_time'] = curr_stay_cxr['icu_time_delta'].values\n",
    "                curr_stay_dict['cxr_missing'] = 0\n",
    "\n",
    "        if include_ecg:\n",
    "            curr_stay_ecg = ecg_df[ecg_df['stay_id'] == curr_stay].copy()\n",
    "            if len(curr_stay_ecg) == 0:\n",
    "                curr_stay_dict['ecg_feats'] = []\n",
    "                curr_stay_dict['ecg_time'] = []\n",
    "                curr_stay_dict['ecg_missing'] = 1\n",
    "            else:\n",
    "                curr_stay_dict['ecg_feats'] = curr_stay_ecg['embeddings'].tolist()\n",
    "                curr_stay_dict['ecg_time'] = curr_stay_ecg['icu_time_delta'].values\n",
    "                curr_stay_dict['ecg_missing'] = 0\n",
    "\n",
    "        curr_stay_dict['label'] = label\n",
    "        stays_list.append(curr_stay_dict)\n",
    "\n",
    "    return stays_list\n",
    "\n",
    "train_stays_list = get_stay_list(train_stays)\n",
    "val_stays_list = get_stay_list(val_stays)\n",
    "test_stays_list = get_stay_list(test_stays)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Saving train stays to /cis/home/charr165/Documents/multimodal/preprocessing/train_los-48-cxr-notes-ecg-missingInd_stays.pkl\n",
      "Saving val stays to /cis/home/charr165/Documents/multimodal/preprocessing/val_los-48-cxr-notes-ecg-missingInd_stays.pkl\n",
      "Saving test stays to /cis/home/charr165/Documents/multimodal/preprocessing/test_los-48-cxr-notes-ecg-missingInd_stays.pkl\n"
     ]
    }
   ],
   "source": [
    "# Save the data\n",
    "import pickle\n",
    "\n",
    "restrict_48_hours = True\n",
    "\n",
    "base_name = \"los\"\n",
    "if restrict_48_hours:\n",
    "    base_name += \"-48\"\n",
    "else:\n",
    "    base_name += \"-all\"\n",
    "\n",
    "if include_cxr:\n",
    "    if include_notes:\n",
    "        base_name += \"-cxr-notes\"\n",
    "    else:\n",
    "        base_name += \"-cxr\"\n",
    "\n",
    "if include_ecg:\n",
    "    base_name += \"-ecg\"\n",
    "\n",
    "if include_missing:\n",
    "    base_name += \"-missingInd\"\n",
    "\n",
    "f_path = os.path.join(output_dir, f\"train_{base_name}_stays.pkl\")\n",
    "with open(f_path, 'wb') as f:\n",
    "    print(f\"Saving train stays to {f_path}\")\n",
    "    pickle.dump(train_stays_list, f)\n",
    "\n",
    "f_path = os.path.join(output_dir, f\"val_{base_name}_stays.pkl\")\n",
    "with open(f_path, 'wb') as f:\n",
    "    print(f\"Saving val stays to {f_path}\")\n",
    "    pickle.dump(val_stays_list, f)\n",
    "\n",
    "f_path = os.path.join(output_dir, f\"test_{base_name}_stays.pkl\")\n",
    "with open(f_path, 'wb') as f:\n",
    "    print(f\"Saving test stays to {f_path}\")\n",
    "    pickle.dump(test_stays_list, f)\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": ".venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
