{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# autoreload\n",
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import sys\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "from sklearn.preprocessing import StandardScaler\n",
    "from tqdm import tqdm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mimic_iv_path = \"/cis/home/charr165/Documents/physionet.org/mimiciv/2.2\"\n",
    "mm_dir = \"/cis/home/charr165/Documents/multimodal\"\n",
    "\n",
    "output_dir = os.path.join(mm_dir, \"preprocessing\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "include_notes = True\n",
    "include_cxr = True\n",
    "standard_scale = True\n",
    "include_missing = False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# ireg_vitals_ts_df = pd.read_pickle(os.path.join(output_dir, \"ts_vitals_icu.pkl\"))\n",
    "# imputed_vitals = pd.read_pickle(os.path.join(output_dir, \"imputed_ts_vitals_icu.pkl\"))\n",
    "\n",
    "ireg_vitals_ts_df = pd.read_pickle(os.path.join(output_dir, \"ts_labs_vitals_icu.pkl\"))\n",
    "imputed_vitals = pd.read_pickle(os.path.join(output_dir, \"imputed_ts_labs_vitals_icu.pkl\"))\n",
    "\n",
    "ireg_vitals_ts_df = ireg_vitals_ts_df[ireg_vitals_ts_df['timedelta'] >= 0]\n",
    "imputed_vitals = imputed_vitals[imputed_vitals['timedelta'] >= 0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if include_notes:\n",
    "    notes_df = pd.read_pickle(os.path.join(output_dir, \"icu_notes_text_embeddings.pkl\"))\n",
    "    # notes_df = pd.read_pickle(os.path.join(output_dir, \"notes_text.pkl\"))\n",
    "    notes_df = notes_df[notes_df['stay_id'].notnull()]\n",
    "\n",
    "    notes_df = notes_df[notes_df['icu_time_delta'] >= 0]\n",
    "\n",
    "if include_cxr:\n",
    "    cxr_df = pd.read_pickle(os.path.join(output_dir, \"cxr_embeddings_icu.pkl\"))\n",
    "    cxr_df = cxr_df[cxr_df['icu_time_delta'] >= 0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "icustays_df = pd.read_csv(os.path.join(mimic_iv_path, \"icu\", \"icustays.csv\"), low_memory=False)\n",
    "icustays_df['intime'] = pd.to_datetime(icustays_df['intime'])\n",
    "icustays_df['outtime'] = pd.to_datetime(icustays_df['outtime'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "valid_stay_ids = icustays_df['stay_id'].unique()\n",
    "\n",
    "ireg_vitals_ts_df = ireg_vitals_ts_df[ireg_vitals_ts_df['stay_id'].isin(valid_stay_ids)]\n",
    "imputed_vitals = imputed_vitals[imputed_vitals['stay_id'].isin(valid_stay_ids)]\n",
    "\n",
    "if include_notes:\n",
    "    notes_df = notes_df[notes_df['stay_id'].isin(valid_stay_ids)]\n",
    "\n",
    "if include_cxr:\n",
    "    cxr_df = cxr_df[cxr_df['stay_id'].isin(valid_stay_ids)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "admissions_df = pd.read_csv(os.path.join(mimic_iv_path, \"hosp\", \"admissions.csv\"))\n",
    "admissions_df = admissions_df.rename(columns={\"hospital_expire_flag\": \"died\"})\n",
    "admissions_df = admissions_df[[\"subject_id\", \"admittime\", \"dischtime\", \"hadm_id\", \"died\"]]\n",
    "admissions_df['admittime'] = pd.to_datetime(admissions_df['admittime'])\n",
    "admissions_df['dischtime'] = pd.to_datetime(admissions_df['dischtime'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "if not include_missing:\n",
    "    unique_stays = ireg_vitals_ts_df['stay_id'].unique()\n",
    "    print(f\"Number of stays with vitals: {len(unique_stays)}\")\n",
    "\n",
    "    if include_notes:\n",
    "        unique_stays = np.intersect1d(unique_stays, notes_df['stay_id'].unique())\n",
    "        print(f\"Number of stays with notes: {len(unique_stays)}\")\n",
    "\n",
    "    if include_cxr:\n",
    "        unique_stays = np.intersect1d(unique_stays, cxr_df['stay_id'].unique())\n",
    "        print(f\"Number of stays with cxr: {len(unique_stays)}\")\n",
    "else:\n",
    "    unique_stays = ireg_vitals_ts_df['stay_id'].unique()\n",
    "    print(f\"Number of stays with vitals: {len(unique_stays)}\")\n",
    "\n",
    "    if include_notes:\n",
    "        # Get stays with either TS or notes\n",
    "        unique_stays = np.union1d(unique_stays, notes_df['stay_id'].unique())\n",
    "        print(f\"Number of stays with either TS or notes: {len(unique_stays)}\")\n",
    "    \n",
    "    if include_cxr:\n",
    "        unique_stays = np.union1d(unique_stays, cxr_df['stay_id'].unique())\n",
    "        print(f\"Number of stays with either TS, notes, cxr: {len(unique_stays)}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create train, val, test splits\n",
    "np.random.seed(0)\n",
    "np.random.shuffle(unique_stays)\n",
    "train_stays = unique_stays[:int(0.7*len(unique_stays))]\n",
    "val_stays = unique_stays[int(0.7*len(unique_stays)):int(0.85*len(unique_stays))]\n",
    "test_stays = unique_stays[int(0.85*len(unique_stays)):]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_ireg_ts_df = ireg_vitals_ts_df[ireg_vitals_ts_df['stay_id'].isin(train_stays)].copy()\n",
    "train_imputed_df = imputed_vitals[imputed_vitals['stay_id'].isin(train_stays)].copy()\n",
    "\n",
    "cols = train_ireg_ts_df.columns.tolist()\n",
    "cols = [col for col in cols if col not in ['subject_id', 'hadm_id', 'stay_id', 'timedelta']]\n",
    "\n",
    "if standard_scale:\n",
    "    for col in cols:\n",
    "        scaler = StandardScaler()\n",
    "        scaler.fit(train_ireg_ts_df[[col]])\n",
    "        ireg_vitals_ts_df[col] = scaler.transform(ireg_vitals_ts_df[[col]])\n",
    "\n",
    "        scaler = StandardScaler()\n",
    "        scaler.fit(train_imputed_df[[col]])\n",
    "        imputed_vitals[col] = scaler.transform(imputed_vitals[[col]])\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "start_hr = 4\n",
    "def get_stay_list(stays):\n",
    "    stays_list = []\n",
    "\n",
    "    for curr_stay in tqdm(stays):\n",
    "\n",
    "        curr_stay_ireg_all = ireg_vitals_ts_df[ireg_vitals_ts_df['stay_id'] == curr_stay].copy()\n",
    "\n",
    "        max_time = int(np.max(curr_stay_ireg_all['timedelta']))\n",
    "\n",
    "        if max_time < start_hr:\n",
    "            continue\n",
    "\n",
    "        for hr in np.arange(start_hr, max_time + 1):\n",
    "            curr_stay_ireg = ireg_vitals_ts_df[ireg_vitals_ts_df['stay_id'] == curr_stay].copy()\n",
    "            curr_stay_imputed = imputed_vitals[imputed_vitals['stay_id'] == curr_stay].copy()\n",
    "\n",
    "            curr_stay_ireg = curr_stay_ireg[curr_stay_ireg['timedelta'] <= hr]\n",
    "            curr_stay_imputed = curr_stay_imputed[curr_stay_imputed['timedelta'] <= hr]\n",
    "\n",
    "            if include_notes:\n",
    "                curr_stay_notes = notes_df[notes_df['stay_id'] == curr_stay].copy()\n",
    "                curr_stay_notes = curr_stay_notes[curr_stay_notes['icu_time_delta'] <= hr]\n",
    "\n",
    "            if include_cxr:\n",
    "                curr_stay_cxr = cxr_df[cxr_df['stay_id'] == curr_stay].copy()\n",
    "                curr_stay_cxr = curr_stay_cxr[curr_stay_cxr['icu_time_delta'] <= hr]\n",
    "\n",
    "            if len(curr_stay_ireg) == 0:\n",
    "                continue\n",
    "\n",
    "            if (len(curr_stay_notes) == 0) and not include_missing:\n",
    "                continue\n",
    "                \n",
    "            if (len(curr_stay_cxr) == 0) and not include_missing:\n",
    "                continue\n",
    "\n",
    "            curr_admit_df = admissions_df[admissions_df['hadm_id'] == curr_stay_ireg['hadm_id'].iloc[0]].iloc[0].copy()\n",
    "            curr_icu_df = icustays_df[icustays_df['stay_id'] == curr_stay_ireg['stay_id'].iloc[0]].iloc[0].copy()\n",
    "\n",
    "            curr_time_to_end = (curr_icu_df['outtime'] - (curr_icu_df['intime'] + pd.Timedelta(hours=hr)))\n",
    "            time_to_end = curr_time_to_end.total_seconds() / 3600\n",
    "            # print(f\"({hr}, {max_time}, {time_to_end}, {curr_admit_df['died']})\", end=\"  \")\n",
    "            if (time_to_end <= 24) and (curr_admit_df['died'] == 1):\n",
    "                label = 1\n",
    "            else:\n",
    "                label = 0\n",
    "            \n",
    "            curr_stay_dict = {}\n",
    "            curr_stay_dict['name'] = curr_stay_ireg['subject_id'].iloc[0]\n",
    "            curr_stay_dict['hadm_id'] = curr_stay_ireg['hadm_id'].iloc[0]\n",
    "            curr_stay_dict['stay_id'] = curr_stay\n",
    "            curr_stay_dict['ts_tt'] = curr_stay_ireg['timedelta'].values\n",
    "\n",
    "            curr_stay_ireg.drop(columns=['subject_id', 'hadm_id', 'stay_id', 'timedelta'], inplace=True)\n",
    "            ireg_ts_mask = curr_stay_ireg.notnull()\n",
    "            curr_stay_ireg.fillna(0, inplace=True)\n",
    "            curr_stay_dict['irg_ts'] = curr_stay_ireg.values\n",
    "            curr_stay_dict['irg_ts_mask'] = ireg_ts_mask.values.astype(int)\n",
    "\n",
    "            curr_stay_imputed.drop(columns=['subject_id', 'hadm_id', 'stay_id', 'timedelta'], inplace=True)\n",
    "            curr_stay_dict['reg_ts'] = curr_stay_imputed.values\n",
    "\n",
    "            if include_notes:\n",
    "                if len(curr_stay_notes) == 0:\n",
    "                    curr_stay_dict['text_data'] = []\n",
    "                    curr_stay_dict['text_time_to_end'] = []\n",
    "                    curr_stay_dict['text_embeddings'] = []\n",
    "                    curr_stay_dict['text_missing'] = 1\n",
    "                else:\n",
    "                    curr_stay_dict['text_data'] = curr_stay_notes['text'].tolist()\n",
    "                    curr_stay_dict['text_time_to_end'] = curr_stay_notes['icu_time_delta'].values\n",
    "                    curr_stay_dict['text_embeddings'] = [emb[0][0] for emb in curr_stay_notes['biobert_embeddings']]\n",
    "                    curr_stay_dict['text_missing'] = 0\n",
    "\n",
    "            if include_cxr:\n",
    "                if len(curr_stay_cxr) == 0:\n",
    "                    curr_stay_dict['cxr_feats'] = []\n",
    "                    curr_stay_dict['cxr_time'] = []\n",
    "                    curr_stay_dict['cxr_missing'] = 1\n",
    "                else:\n",
    "                    curr_stay_dict['cxr_feats'] = curr_stay_cxr['densefeatures'].tolist()\n",
    "                    curr_stay_dict['cxr_time'] = curr_stay_cxr['icu_time_delta'].values\n",
    "                    curr_stay_dict['cxr_missing'] = 0\n",
    "            \n",
    "            curr_stay_dict['label'] = label\n",
    "\n",
    "            stays_list.append(curr_stay_dict)\n",
    "\n",
    "        # print('\\n')\n",
    "\n",
    "    return stays_list\n",
    "\n",
    "train_stays_list = get_stay_list(train_stays)\n",
    "val_stays_list = get_stay_list(val_stays)\n",
    "test_stays_list = get_stay_list(test_stays)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "n_notes = []\n",
    "for item in train_stays_list:\n",
    "    n_notes.append(len(item['text_data']))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Save the data\n",
    "import pickle\n",
    "\n",
    "base_name = \"decompensation\"\n",
    "if restrict_48_hours:\n",
    "    base_name += \"-48\"\n",
    "else:\n",
    "    base_name += \"-all\"\n",
    "\n",
    "if include_cxr:\n",
    "    if include_notes:\n",
    "        base_name += \"-cxr-notes\"\n",
    "    else:\n",
    "        base_name += \"-cxr\"\n",
    "\n",
    "if include_missing:\n",
    "    base_name += \"-missingInd\"\n",
    "\n",
    "f_path = os.path.join(output_dir, f\"train_{base_name}_stays.pkl\")\n",
    "with open(f_path, 'wb') as f:\n",
    "    print(f\"Saving train stays to {f_path}\")\n",
    "    pickle.dump(train_stays_list, f)\n",
    "\n",
    "f_path = os.path.join(output_dir, f\"val_{base_name}_stays.pkl\")\n",
    "with open(f_path, 'wb') as f:\n",
    "    print(f\"Saving val stays to {f_path}\")\n",
    "    pickle.dump(val_stays_list, f)\n",
    "\n",
    "f_path = os.path.join(output_dir, f\"test_{base_name}_stays.pkl\")\n",
    "with open(f_path, 'wb') as f:\n",
    "    print(f\"Saving test stays to {f_path}\")\n",
    "    pickle.dump(test_stays_list, f)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle\n",
    "import os\n",
    "\n",
    "f_name = os.path.join(output_dir, \"train_ihm-cxr_stays.pkl\")\n",
    "\n",
    "# Read in the pickled data\n",
    "with open(f_name, 'rb') as f:\n",
    "    train_stays_list = pickle.load(f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_stays_list[4]['ts_tt'].max()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": ".venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
