{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# autoreload\n",
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import sys\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "from sklearn.preprocessing import StandardScaler\n",
    "from tqdm import tqdm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mimic_iv_path = \"/cis/home/charr165/Documents/physionet.org/mimiciv/2.2\"\n",
    "mm_dir = \"/cis/home/charr165/Documents/multimodal\"\n",
    "\n",
    "output_dir = os.path.join(mm_dir, \"preprocessing\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "include_notes = True\n",
    "include_cxr = True\n",
    "include_ecg = True\n",
    "standard_scale = True\n",
    "include_missing = True"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# ireg_vitals_ts_df = pd.read_pickle(os.path.join(output_dir, \"ts_vitals_icu.pkl\"))\n",
    "# imputed_vitals = pd.read_pickle(os.path.join(output_dir, \"imputed_ts_vitals_icu.pkl\"))\n",
    "\n",
    "ireg_vitals_ts_df = pd.read_pickle(os.path.join(output_dir, \"ts_labs_vitals_icu.pkl\"))\n",
    "imputed_vitals = pd.read_pickle(os.path.join(output_dir, \"imputed_ts_labs_vitals_icu.pkl\"))\n",
    "\n",
    "ireg_vitals_ts_df = ireg_vitals_ts_df[ireg_vitals_ts_df['timedelta'] >= 0]\n",
    "imputed_vitals = imputed_vitals[imputed_vitals['timedelta'] >= 0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if include_notes:\n",
    "    notes_df = pd.read_pickle(os.path.join(output_dir, \"icu_notes_text_embeddings.pkl\"))\n",
    "    notes_df = notes_df[notes_df['stay_id'].notnull()]\n",
    "\n",
    "    notes_df = notes_df[notes_df['icu_time_delta'] >= 0]\n",
    "\n",
    "if include_cxr:\n",
    "    cxr_df = pd.read_pickle(os.path.join(output_dir, \"cxr_embeddings_icu.pkl\"))\n",
    "    cxr_df = cxr_df[cxr_df['icu_time_delta'] >= 0]\n",
    "\n",
    "if include_ecg:\n",
    "    ecg_df = pd.read_pickle(os.path.join(output_dir, \"ecg_embeddings_icu.pkl\"))\n",
    "    ecg_df = ecg_df[ecg_df['icu_time_delta'] >= 0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "icustays_df = pd.read_csv(os.path.join(mimic_iv_path, \"icu\", \"icustays.csv\"), low_memory=False)\n",
    "icustays_df['intime'] = pd.to_datetime(icustays_df['intime'])\n",
    "icustays_df['outtime'] = pd.to_datetime(icustays_df['outtime'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "valid_stay_ids = icustays_df['stay_id'].unique()\n",
    "\n",
    "ireg_vitals_ts_df = ireg_vitals_ts_df[ireg_vitals_ts_df['stay_id'].isin(valid_stay_ids)]\n",
    "imputed_vitals = imputed_vitals[imputed_vitals['stay_id'].isin(valid_stay_ids)]\n",
    "\n",
    "if include_notes:\n",
    "    notes_df = notes_df[notes_df['stay_id'].isin(valid_stay_ids)]\n",
    "\n",
    "if include_cxr:\n",
    "    cxr_df = cxr_df[cxr_df['stay_id'].isin(valid_stay_ids)]\n",
    "\n",
    "if include_ecg:\n",
    "    ecg_df = ecg_df[ecg_df['stay_id'].isin(valid_stay_ids)]    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "if not include_missing:\n",
    "    unique_stays = ireg_vitals_ts_df['stay_id'].unique()\n",
    "    print(f\"Number of stays with vitals: {len(unique_stays)}\")\n",
    "\n",
    "    if include_notes:\n",
    "        unique_stays = np.intersect1d(unique_stays, notes_df['stay_id'].unique())\n",
    "        print(f\"Number of stays with notes: {len(unique_stays)}\")\n",
    "\n",
    "    if include_cxr:\n",
    "        unique_stays = np.intersect1d(unique_stays, cxr_df['stay_id'].unique())\n",
    "        print(f\"Number of stays with cxr: {len(unique_stays)}\")\n",
    "\n",
    "    if include_ecg:\n",
    "        unique_stays = np.intersect1d(unique_stays, ecg_df['stay_id'].unique())\n",
    "        print(f\"Number of stays with ecg: {len(unique_stays)}\")        \n",
    "else:\n",
    "    unique_stays = ireg_vitals_ts_df['stay_id'].unique()\n",
    "    print(f\"Number of stays with vitals: {len(unique_stays)}\")\n",
    "\n",
    "    if include_notes:\n",
    "        # Get stays with either TS or notes\n",
    "        unique_stays = np.union1d(unique_stays, notes_df['stay_id'].unique())\n",
    "        print(f\"Number of stays with either TS or notes: {len(unique_stays)}\")\n",
    "    \n",
    "    if include_cxr:\n",
    "        unique_stays = np.union1d(unique_stays, cxr_df['stay_id'].unique())\n",
    "        print(f\"Number of stays with either TS, notes, cxr: {len(unique_stays)}\")\n",
    "\n",
    "    if include_ecg:\n",
    "        unique_stays = np.union1d(unique_stays, ecg_df['stay_id'].unique())\n",
    "        print(f\"Number of stays with either TS, notes, cxr, ecg: {len(unique_stays)}\")        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create train, val, test splits\n",
    "np.random.seed(0)\n",
    "np.random.shuffle(unique_stays)\n",
    "train_stays = unique_stays[:int(0.7*len(unique_stays))]\n",
    "val_stays = unique_stays[int(0.7*len(unique_stays)):int(0.85*len(unique_stays))]\n",
    "test_stays = unique_stays[int(0.85*len(unique_stays)):]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_ireg_ts_df = ireg_vitals_ts_df[ireg_vitals_ts_df['stay_id'].isin(train_stays)].copy()\n",
    "train_imputed_df = imputed_vitals[imputed_vitals['stay_id'].isin(train_stays)].copy()\n",
    "\n",
    "cols = train_ireg_ts_df.columns.tolist()\n",
    "cols = [col for col in cols if col not in ['subject_id', 'hadm_id', 'stay_id', 'timedelta']]\n",
    "\n",
    "if standard_scale:\n",
    "    for col in cols:\n",
    "        scaler = StandardScaler()\n",
    "        scaler.fit(train_ireg_ts_df[[col]])\n",
    "        ireg_vitals_ts_df[col] = scaler.transform(ireg_vitals_ts_df[[col]])\n",
    "\n",
    "        scaler = StandardScaler()\n",
    "        scaler.fit(train_imputed_df[[col]])\n",
    "        imputed_vitals[col] = scaler.transform(imputed_vitals[[col]])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import yaml\n",
    "\n",
    "f_path = \"hcup_ccs_2015_definitions.yaml\"\n",
    "with open(f_path, 'r') as f:\n",
    "    hcup_ccs = yaml.safe_load(f)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "f_path = \"icd10cmtoicd9gem.csv\"\n",
    "icd10_to_icd9_df = pd.read_csv(f_path, low_memory=False)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "benchmark_diags = {}\n",
    "\n",
    "i = 0\n",
    "for key in hcup_ccs.keys():\n",
    "    curr_entry = hcup_ccs[key]\n",
    "\n",
    "    if not curr_entry['use_in_benchmark']:\n",
    "        continue\n",
    "    \n",
    "    curr_entry['icd9'] = curr_entry['codes']\n",
    "\n",
    "    icd10_codes = []\n",
    "    for code in curr_entry['codes']:\n",
    "        curr_icd10_codes = icd10_to_icd9_df[icd10_to_icd9_df['icd9cm'] == code]['icd10cm'].values\n",
    "\n",
    "        for icd10_code in curr_icd10_codes:\n",
    "            icd10_codes.append(icd10_code)\n",
    "\n",
    "    curr_entry['icd10'] = icd10_codes\n",
    "\n",
    "    # Drop codes from curr_entry\n",
    "    curr_entry.pop('codes')\n",
    "\n",
    "    curr_entry['id'] = i\n",
    "    benchmark_diags[key] = curr_entry\n",
    "    i += 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "admissions_df = pd.read_csv(os.path.join(mimic_iv_path, \"hosp\", \"admissions.csv\"))\n",
    "admissions_df = admissions_df.rename(columns={\"hospital_expire_flag\": \"died\"})\n",
    "admissions_df = admissions_df[[\"subject_id\", \"hadm_id\", \"died\"]]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "d_icd_diagnoses = pd.read_csv(os.path.join(mimic_iv_path, \"hosp\", \"d_icd_diagnoses.csv\"))\n",
    "diagnoses_df = pd.read_csv(os.path.join(mimic_iv_path, \"hosp\", \"diagnoses_icd.csv\"))\n",
    "\n",
    "diagnoses_df = diagnoses_df.merge(d_icd_diagnoses, on=[\"icd_code\", 'icd_version'], how=\"left\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_stay_list(stays):\n",
    "    stays_list = []\n",
    "\n",
    "    for curr_stay in tqdm(stays):\n",
    "        curr_stay_ireg = ireg_vitals_ts_df[ireg_vitals_ts_df['stay_id'] == curr_stay].copy()\n",
    "        curr_stay_imputed = imputed_vitals[imputed_vitals['stay_id'] == curr_stay].copy()\n",
    "\n",
    "        if len(curr_stay_ireg) == 0:\n",
    "            continue\n",
    "\n",
    "        if include_notes:\n",
    "            curr_stay_notes = notes_df[notes_df['stay_id'] == curr_stay].copy()\n",
    "\n",
    "        if include_cxr:\n",
    "            curr_stay_cxr = cxr_df[cxr_df['stay_id'] == curr_stay].copy()\n",
    "\n",
    "        curr_stay_dict = {}\n",
    "        curr_stay_dict['name'] = curr_stay_ireg['subject_id'].iloc[0]\n",
    "        curr_stay_dict['hadm_id'] = curr_stay_ireg['hadm_id'].iloc[0]\n",
    "        curr_stay_dict['stay_id'] = curr_stay\n",
    "        curr_stay_dict['ts_tt'] = curr_stay_ireg['timedelta'].values\n",
    "\n",
    "        curr_stay_ireg.drop(columns=['subject_id', 'hadm_id', 'stay_id', 'timedelta'], inplace=True)\n",
    "        ireg_ts_mask = curr_stay_ireg.notnull()\n",
    "        curr_stay_ireg.fillna(0, inplace=True)\n",
    "        curr_stay_dict['irg_ts'] = curr_stay_ireg.values\n",
    "        curr_stay_dict['irg_ts_mask'] = ireg_ts_mask.values.astype(int)\n",
    "\n",
    "        curr_stay_imputed.drop(columns=['subject_id', 'hadm_id', 'stay_id', 'timedelta'], inplace=True)\n",
    "        curr_stay_dict['reg_ts'] = curr_stay_imputed.values\n",
    "\n",
    "        if include_notes:\n",
    "            if len(curr_stay_notes) == 0:\n",
    "                curr_stay_dict['text_data'] = []\n",
    "                curr_stay_dict['text_time_to_end'] = []\n",
    "                curr_stay_dict['text_missing'] = 1\n",
    "                curr_stay_dict['text_embeddings'] = []\n",
    "            else:\n",
    "                curr_stay_dict['text_data'] = curr_stay_notes['text'].tolist()\n",
    "                curr_stay_dict['text_time_to_end'] = curr_stay_notes['icu_time_delta'].values\n",
    "                curr_stay_dict['text_embeddings'] = [emb[0][0] for emb in curr_stay_notes['biobert_embeddings']]\n",
    "                curr_stay_dict['text_missing'] = 0\n",
    "        \n",
    "        if include_cxr:\n",
    "            if len(curr_stay_cxr) == 0:\n",
    "                curr_stay_dict['cxr_feats'] = []\n",
    "                curr_stay_dict['cxr_time'] = []\n",
    "                curr_stay_dict['cxr_missing'] = 1\n",
    "            else:\n",
    "                curr_stay_dict['cxr_feats'] = curr_stay_cxr['densefeatures'].tolist()\n",
    "                curr_stay_dict['cxr_time'] = curr_stay_cxr['icu_time_delta'].values\n",
    "                curr_stay_dict['cxr_missing'] = 0\n",
    "\n",
    "        if include_ecg:\n",
    "            curr_stay_ecg = ecg_df[ecg_df['stay_id'] == curr_stay].copy()\n",
    "            if len(curr_stay_ecg) == 0:\n",
    "                curr_stay_dict['ecg_feats'] = []\n",
    "                curr_stay_dict['ecg_time'] = []\n",
    "                curr_stay_dict['ecg_missing'] = 1\n",
    "            else:\n",
    "                curr_stay_dict['ecg_feats'] = curr_stay_ecg['embeddings'].tolist()\n",
    "                curr_stay_dict['ecg_time'] = curr_stay_ecg['icu_time_delta'].values\n",
    "                curr_stay_dict['ecg_missing'] = 0            \n",
    "\n",
    "        curr_diagnoses = diagnoses_df[diagnoses_df['hadm_id'] == curr_stay_dict['hadm_id']]\n",
    "\n",
    "        curr_labels = np.zeros(len(benchmark_diags.keys()))\n",
    "\n",
    "        for index, row in curr_diagnoses.iterrows():\n",
    "            for key in benchmark_diags.keys():\n",
    "                curr_bench_diag = benchmark_diags[key]\n",
    "                if (row['icd_version'] == 9) and (row['icd_code'] in curr_bench_diag['icd9']):\n",
    "                    curr_labels[curr_bench_diag['id']] = 1\n",
    "                elif (row['icd_version'] == 10) and (row['icd_code'] in curr_bench_diag['icd10']):\n",
    "                    curr_labels[curr_bench_diag['id']] = 1\n",
    "\n",
    "        curr_stay_dict['label'] = curr_labels\n",
    "\n",
    "        stays_list.append(curr_stay_dict)\n",
    "\n",
    "    return stays_list\n",
    "\n",
    "train_stays_list = get_stay_list(train_stays)\n",
    "val_stays_list = get_stay_list(val_stays)\n",
    "test_stays_list = get_stay_list(test_stays)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Save the data\n",
    "import pickle\n",
    "\n",
    "base_name = \"pheno\"\n",
    "\n",
    "base_name += \"-all\"\n",
    "\n",
    "if include_cxr:\n",
    "    if include_notes:\n",
    "        base_name += \"-cxr-notes\"\n",
    "    else:\n",
    "        base_name += \"-cxr\"\n",
    "\n",
    "if include_ecg:\n",
    "    base_name += \"-ecg\"\n",
    "    \n",
    "if include_missing:\n",
    "    base_name += \"-missingInd\"\n",
    "\n",
    "f_path = os.path.join(output_dir, f\"train_{base_name}_stays.pkl\")\n",
    "with open(f_path, 'wb') as f:\n",
    "    print(f\"Saving train stays to {f_path}\")\n",
    "    pickle.dump(train_stays_list, f)\n",
    "\n",
    "f_path = os.path.join(output_dir, f\"val_{base_name}_stays.pkl\")\n",
    "with open(f_path, 'wb') as f:\n",
    "    print(f\"Saving val stays to {f_path}\")\n",
    "    pickle.dump(val_stays_list, f)\n",
    "\n",
    "f_path = os.path.join(output_dir, f\"test_{base_name}_stays.pkl\")\n",
    "with open(f_path, 'wb') as f:\n",
    "    print(f\"Saving test stays to {f_path}\")\n",
    "    pickle.dump(test_stays_list, f)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": ".venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
