{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# autoreload\n",
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import sys\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "from sklearn.preprocessing import StandardScaler"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "restrict_48_hours = True\n",
    "include_text_embeddings = False\n",
    "include_cxr = True\n",
    "include_notes = True"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mimic_iv_path = \"/cis/home/charr165/Documents/physionet.org/mimiciv/2.2\"\n",
    "mm_dir = \"/cis/home/charr165/Documents/multimodal\"\n",
    "\n",
    "output_dir = os.path.join(mm_dir, \"preprocessing\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# ireg_vitals_ts_df = pd.read_pickle(os.path.join(output_dir, \"ts_vitals_icu.pkl\"))\n",
    "# imputed_vitals = pd.read_pickle(os.path.join(output_dir, \"imputed_ts_vitals_icu.pkl\"))\n",
    "\n",
    "ireg_vitals_ts_df = pd.read_pickle(os.path.join(output_dir, \"ts_labs_vitals_icu.pkl\"))\n",
    "\n",
    "if restrict_48_hours:\n",
    "    ireg_vitals_ts_df = ireg_vitals_ts_df[ireg_vitals_ts_df['timedelta'] <= 48]\n",
    "    ireg_vitals_ts_df = ireg_vitals_ts_df[ireg_vitals_ts_df['timedelta'] >= 0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ireg_vitals_ts_df.columns"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ireg_vitals_ts_df = pd.read_pickle(os.path.join(output_dir, \"ts_labs_vitals.pkl\"))\n",
    "\n",
    "ireg_vitals_ts_df = ireg_vitals_ts_df[ireg_vitals_ts_df['icu_time_delta'] >= 0]\n",
    "ireg_vitals_ts_df.rename(columns={'icu_time_delta': 'timedelta'}, inplace=True)\n",
    "ireg_vitals_ts_df.drop(columns=['hosp_time_delta'], inplace=True)\n",
    "\n",
    "if restrict_48_hours:\n",
    "    ireg_vitals_ts_df = ireg_vitals_ts_df[ireg_vitals_ts_df['timedelta'] <= 48]\n",
    "    ireg_vitals_ts_df = ireg_vitals_ts_df[ireg_vitals_ts_df['timedelta'] >= 0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# notes_df = pd.read_pickle(os.path.join(output_dir, \"notes_text.pkl\"))\n",
    "notes_df = pd.read_pickle(os.path.join(output_dir, \"icu_notes_text_embeddings.pkl\"))\n",
    "notes_df = notes_df[notes_df['stay_id'].notnull()]\n",
    "notes_df = notes_df[notes_df['icu_time_delta'] >= 0]\n",
    "\n",
    "if restrict_48_hours:\n",
    "    notes_df = notes_df[notes_df['icu_time_delta'] <= 48]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "cxr_df = pd.read_pickle(os.path.join(output_dir, \"cxr_embeddings_icu.pkl\"))\n",
    "cxr_df = cxr_df[cxr_df['icu_time_delta'] >= 0]\n",
    "\n",
    "if restrict_48_hours:\n",
    "    cxr_df = cxr_df[cxr_df['icu_time_delta'] <= 48]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "icustays_df = pd.read_csv(os.path.join(mimic_iv_path, \"icu\", \"icustays.csv\"), low_memory=False)\n",
    "icustays_df['intime'] = pd.to_datetime(icustays_df['intime'])\n",
    "icustays_df['outtime'] = pd.to_datetime(icustays_df['outtime'])\n",
    "\n",
    "if restrict_48_hours:\n",
    "    icustays_df = icustays_df[icustays_df['los'] >= 2]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "valid_stay_ids = icustays_df['stay_id'].unique()\n",
    "\n",
    "ireg_vitals_ts_df = ireg_vitals_ts_df[ireg_vitals_ts_df['stay_id'].isin(valid_stay_ids)]\n",
    "notes_df = notes_df[notes_df['stay_id'].isin(valid_stay_ids)]\n",
    "\n",
    "if include_cxr:\n",
    "    cxr_df = cxr_df[cxr_df['stay_id'].isin(valid_stay_ids)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "def calculate_avg_embedding(notes_df):\n",
    "    # For each stay, calculate the average embedding of all notes\n",
    "    avg_embedding = notes_df.groupby('stay_id')['biobert_embeddings'].apply(lambda x: np.mean(np.vstack(x), axis=0))\n",
    "    avg_embedding = np.stack(avg_embedding.values, axis=0)\n",
    "\n",
    "    # Reshape the array to 2 dimensions\n",
    "    avg_embedding = avg_embedding.reshape(avg_embedding.shape[0], -1)\n",
    "\n",
    "    embedding_cols = ['emb_{}'.format(i) for i in range(avg_embedding.shape[1])]\n",
    "    avg_embedding_df = pd.DataFrame(avg_embedding, columns=embedding_cols)\n",
    "\n",
    "    # notes_df = pd.concat([notes_df.drop('biobert_embeddings', axis=1), avg_embedding_df], axis=1)\n",
    "    stays = notes_df['stay_id'].unique()\n",
    "    notes_df = pd.DataFrame({'stay_id': stays})\n",
    "\n",
    "    notes_df = pd.merge(notes_df, avg_embedding_df, left_index=True, right_index=True)\n",
    "    return notes_df, embedding_cols\n",
    "\n",
    "notes_df, embedding_cols = calculate_avg_embedding(notes_df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "admissions_df = pd.read_csv(os.path.join(mimic_iv_path, \"hosp\", \"admissions.csv\"))\n",
    "admissions_df = admissions_df.rename(columns={\"hospital_expire_flag\": \"died\"})\n",
    "admissions_df = admissions_df[[\"subject_id\", \"hadm_id\", \"died\"]]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "unique_stays = ireg_vitals_ts_df['stay_id'].unique()\n",
    "print(f\"Number of stays with vitals: {len(unique_stays)}\")\n",
    "\n",
    "include_notes = True\n",
    "if include_notes:\n",
    "    unique_stays = np.intersect1d(unique_stays, notes_df['stay_id'].unique())\n",
    "    print(f\"Number of stays with notes: {len(unique_stays)}\")\n",
    "\n",
    "if include_cxr:\n",
    "    unique_stays = np.intersect1d(unique_stays, cxr_df['stay_id'].unique())\n",
    "    print(f\"Number of stays with cxr: {len(unique_stays)}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ireg_vitals_ts_df = ireg_vitals_ts_df[ireg_vitals_ts_df['stay_id'].isin(unique_stays)].copy()\n",
    "notes_df = notes_df[notes_df['stay_id'].isin(unique_stays)].copy()\n",
    "\n",
    "event_list = ireg_vitals_ts_df.columns\n",
    "event_list = event_list.drop(['subject_id', 'hadm_id', 'stay_id', 'timedelta'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.model_selection import train_test_split\n",
    "\n",
    "pkl_list = ireg_vitals_ts_df['stay_id'].unique().tolist()\n",
    "\n",
    "seed = 0\n",
    "train_id, test_id = train_test_split(pkl_list, test_size=0.3, random_state=seed)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_ireg_ts_df = ireg_vitals_ts_df[ireg_vitals_ts_df['stay_id'].isin(train_id)].copy()\n",
    "\n",
    "cols = train_ireg_ts_df.columns.tolist()\n",
    "cols = [col for col in cols if col not in ['subject_id', 'hadm_id', 'stay_id', 'timedelta']]\n",
    "\n",
    "for col in cols:\n",
    "    scaler = StandardScaler()\n",
    "    scaler.fit(train_ireg_ts_df[[col]])\n",
    "    ireg_vitals_ts_df[col] = scaler.transform(ireg_vitals_ts_df[[col]])\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from scipy.signal import find_peaks\n",
    "from tqdm import tqdm\n",
    "\n",
    "unique_stays = ireg_vitals_ts_df['stay_id'].unique()\n",
    "\n",
    "curr_stay = unique_stays[0]\n",
    "\n",
    "haim_ts_embeddings = pd.DataFrame(columns=['subject_id', 'hadm_id', 'stay_id'])\n",
    "\n",
    "died = np.zeros(len(unique_stays))\n",
    "for i in tqdm(range(len(unique_stays)), desc=\"Embedding TS\", total=len(unique_stays)):\n",
    "    stay = unique_stays[i]\n",
    "    curr_ts = ireg_vitals_ts_df[ireg_vitals_ts_df['stay_id'] == stay]\n",
    "\n",
    "    curr_subject = curr_ts['subject_id'].iloc[0]\n",
    "    curr_hadm = curr_ts['hadm_id'].iloc[0]\n",
    "\n",
    "    row_data = {'subject_id': curr_subject, 'hadm_id': curr_hadm, 'stay_id': stay}\n",
    "\n",
    "    for event in event_list:\n",
    "        series = curr_ts[event].dropna() #dropna rows\n",
    "        if len(series) >0: #if there is any event\n",
    "            row_data[event+'_max'] = series.max()\n",
    "            row_data[event+'_min'] = series.min()\n",
    "            row_data[event+'_mean'] = series.mean(skipna=True)\n",
    "            row_data[event+'_variance'] = series.var(skipna=True)\n",
    "            row_data[event+'_meandiff'] = series.diff().mean() #average change\n",
    "            row_data[event+'_meanabsdiff'] =series.diff().abs().mean()\n",
    "            row_data[event+'_maxdiff'] = series.diff().abs().max()\n",
    "            row_data[event+'_sumabsdiff'] =series.diff().abs().sum()\n",
    "            row_data[event+'_diff'] = series.iloc[-1]-series.iloc[0]\n",
    "            \n",
    "            #Compute the n_peaks\n",
    "            peaks,_ = find_peaks(series)\n",
    "            row_data[event+'_npeaks'] = len(peaks)\n",
    "            \n",
    "            #Compute the trend (linear slope)\n",
    "            if len(series)>1:\n",
    "                row_data[event+'_trend']= np.polyfit(np.arange(len(series)), series, 1)[0] #fit deg-1 poly\n",
    "            else:\n",
    "                row_data[event+'_trend'] = 0\n",
    "\n",
    "    haim_ts_embeddings = pd.concat([haim_ts_embeddings, pd.DataFrame(row_data, index=[0])], ignore_index=True)\n",
    "\n",
    "    died[i] = admissions_df[admissions_df['hadm_id'] == curr_hadm]['died'].iloc[0]\n",
    "\n",
    "haim_ts_embeddings.fillna(0, inplace=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if include_text_embeddings:\n",
    "    haim_ts_embeddings = haim_ts_embeddings.merge(notes_df[['stay_id'] + embedding_cols], on='stay_id', how='left')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.model_selection import train_test_split\n",
    "\n",
    "pkl_list = haim_ts_embeddings['stay_id'].unique().tolist()\n",
    "\n",
    "df = haim_ts_embeddings.copy()\n",
    "df['died'] = died\n",
    "\n",
    "seed = 0\n",
    "df = df[~df.isna().any(axis=1)]\n",
    "\n",
    "# non_numeric_cols = df.select_dtypes(exclude=[np.number]).columns.tolist()\n",
    "# df.drop(non_numeric_cols, axis=1, inplace=True)\n",
    "\n",
    "train_idx = df[df['stay_id'].isin(train_id)]['stay_id'].tolist()\n",
    "test_idx = df[df['stay_id'].isin(test_id)]['stay_id'].tolist()\n",
    "\n",
    "y_train = df[df['stay_id'].isin(train_idx)]['died']\n",
    "y_test = df[df['stay_id'].isin(test_idx)]['died']\n",
    "\n",
    "x_train = df[df['stay_id'].isin(train_idx)]\n",
    "x_test = df[df['stay_id'].isin(test_idx)]\n",
    "\n",
    "x_train.drop(columns=['subject_id', 'hadm_id', 'stay_id', 'died'], inplace=True)\n",
    "x_test.drop(columns=['subject_id', 'hadm_id', 'stay_id', 'died'], inplace=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from prediction_util import run_xgb\n",
    "\n",
    "y_pred_test, y_pred_prob_test, y_pred_train, y_pred_prob_train, xgb = run_xgb(x_train, y_train, x_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn import metrics\n",
    "\n",
    "f1_train = metrics.f1_score(y_train, y_pred_train)\n",
    "accu_train = metrics.accuracy_score(y_train, y_pred_train)\n",
    "accu_bl_train = metrics.balanced_accuracy_score(y_train, y_pred_train)\n",
    "auc_train =  metrics.roc_auc_score(y_train, y_pred_prob_train)\n",
    "conf_matrix_train = metrics.confusion_matrix(y_train, y_pred_train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(f'F1 Score for Training Set is: {f1_train}')\n",
    "print(f'Accuracy for Training Set is: {accu_train}')\n",
    "print(f'Balanced Accuracy for Training Set is: {accu_bl_train}')\n",
    "print(f'AUC for Training Set is: {auc_train}')\n",
    "print(f'Confusion Matrix for Training Set is: {conf_matrix_train}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "f1_test = metrics.f1_score(y_test, y_pred_test)\n",
    "accu_test = metrics.accuracy_score(y_test, y_pred_test)\n",
    "accu_bl_test = metrics.balanced_accuracy_score(y_test, y_pred_test)\n",
    "auc_test =  metrics.roc_auc_score(y_test, y_pred_prob_test)\n",
    "auprc_test = metrics.average_precision_score(y_test, y_pred_prob_test)\n",
    "conf_matrix_test = metrics.confusion_matrix(y_test, y_pred_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(f'F1 Score for Testing Set is: {f1_test}')\n",
    "print(f'Accuracy for Testing Set is: {accu_test}')\n",
    "print(f'Balanced Accuracy for Testing Set is: {accu_bl_test}')\n",
    "print(f'AUC for Testing Set is: {auc_test}')\n",
    "print(f'AUPRC for Testing Set is: {auprc_test}')\n",
    "print(f'Confusion Matrix for Testing Set is: {conf_matrix_test}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Find the most important features\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "\n",
    "feature_importance = pd.DataFrame({'feature': x_train.columns, 'importance': xgb.feature_importances_})\n",
    "feature_importance = feature_importance.sort_values(by='importance', ascending=False)\n",
    "feature_importance = feature_importance.reset_index(drop=True)\n",
    "\n",
    "plt.figure(figsize=(10, 10))\n",
    "sns.barplot(x=\"importance\", y=\"feature\", data=feature_importance)\n",
    "plt.title('Feature Importance')\n",
    "plt.tight_layout()\n",
    "plt.savefig(os.path.join(output_dir, \"feature_importance.png\"), dpi=300)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "feature_importance"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": ".venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
