{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "# autoreload\n",
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import sys\n",
    "import numpy as np\n",
    "import pandas as pd"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "mimic_iv_path = \"/cis/home/charr165/Documents/physionet.org/mimiciv/2.2\"\n",
    "mm_dir = \"/cis/home/charr165/Documents/multimodal\"\n",
    "\n",
    "output_dir = os.path.join(mm_dir, \"preprocessing\")\n",
    "os.makedirs(output_dir, exist_ok=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "f_path = os.path.join(mimic_iv_path, \"hosp\", \"admissions.csv\")\n",
    "admissions_df = pd.read_csv(f_path, low_memory=False)\n",
    "admissions_df['admittime'] = pd.to_datetime(admissions_df['admittime'])\n",
    "admissions_df['dischtime'] = pd.to_datetime(admissions_df['dischtime'])\n",
    "\n",
    "icustays_df = pd.read_csv(os.path.join(mimic_iv_path, \"icu\", \"icustays.csv\"), low_memory=False)\n",
    "icustays_df['intime'] = pd.to_datetime(icustays_df['intime'])\n",
    "icustays_df['outtime'] = pd.to_datetime(icustays_df['outtime'])\n",
    "\n",
    "procedureevents_df = pd.read_csv(os.path.join(mimic_iv_path, \"icu\", \"procedureevents.csv\"), low_memory=False)\n",
    "procedureevents_df['starttime'] = pd.to_datetime(procedureevents_df['starttime'])\n",
    "procedureevents_df['endtime'] = pd.to_datetime(procedureevents_df['endtime'])\n",
    "procedureevents_df['storetime'] = pd.to_datetime(procedureevents_df['storetime'], format='mixed')\n",
    "\n",
    "chartevents_df = pd.read_csv(os.path.join(mimic_iv_path, \"icu\", \"chartevents.csv\"), low_memory=False)\n",
    "chartevents_df['charttime'] = pd.to_datetime(chartevents_df['charttime'])\n",
    "chartevents_df['storetime'] = pd.to_datetime(chartevents_df['storetime'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "hosp_lab_events = pd.read_csv(os.path.join(mimic_iv_path, \"hosp\", \"labevents.csv\"), low_memory=False)\n",
    "hosp_lab_events['charttime'] = pd.to_datetime(hosp_lab_events['charttime'])\n",
    "hosp_lab_events['storetime'] = pd.to_datetime(hosp_lab_events['storetime'])\n",
    "\n",
    "# Drop hosp_lab_events where hadm_id is nan\n",
    "hosp_lab_events = hosp_lab_events.dropna(subset=['hadm_id'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "      itemid                 label                fluid    category\n",
      "7      50809               Glucose                Blood   Blood Gas\n",
      "40     50842      Glucose, Ascites              Ascites   Chemistry\n",
      "129    50931               Glucose                Blood   Chemistry\n",
      "210    51022  Glucose, Joint Fluid          Joint Fluid   Chemistry\n",
      "222    51034   Glucose, Body Fluid     Other Body Fluid   Chemistry\n",
      "241    51053      Glucose, Pleural              Pleural   Chemistry\n",
      "272    51084        Glucose, Urine                Urine   Chemistry\n",
      "638    51478               Glucose                Urine  Hematology\n",
      "908    51790          Glucose, CSF  Cerebrospinal Fluid   Chemistry\n",
      "1034   51941        Glucose, Stool                Stool   Chemistry\n",
      "1074   51981               Glucose                Urine   Chemistry\n",
      "1120   52027  Glucose, Whole Blood                Blood   Blood Gas\n",
      "1528   52569               Glucose                Blood   Chemistry\n"
     ]
    }
   ],
   "source": [
    "d_lab_items_df = pd.read_csv(os.path.join(mimic_iv_path, \"hosp\", \"d_labitems.csv\"), low_memory=False)\n",
    "\n",
    "# Drop rows with missing values\n",
    "d_lab_items_df = d_lab_items_df.dropna()\n",
    "\n",
    "# Search labels for something that looks like ph\n",
    "ph_labels = d_lab_items_df[d_lab_items_df['label'].str.contains('Glucose', case=False)]\n",
    "print(ph_labels)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "d_items_df = pd.read_csv(os.path.join(mimic_iv_path, \"icu\", \"d_items.csv\"), low_memory=False)\n",
    "# d_items_df = d_items_df[d_items_df['category'] == \"Labs\"]\n",
    "# ph_labels = d_items_df[d_items_df['label'].str.contains('pressure', case=False)]\n",
    "# print(ph_labels)\n",
    "\n",
    "def get_procedures_of_interest(df):\n",
    "    df = df.copy()\n",
    "\n",
    "    event_list = ['Foley Catheter', 'PICC Line', 'Intubation', 'Peritoneal Dialysis', \n",
    "                            'Bronchoscopy', 'EEG', 'Dialysis - CRRT', 'Dialysis Catheter', \n",
    "                            'Chest Tube Removed', 'Hemodialysis']\n",
    "    event_links_df = pd.DataFrame()\n",
    "    for event in event_list:\n",
    "        curr_event_item_id = d_items_df[d_items_df[\"label\"] == event][\"itemid\"].values[0]\n",
    "\n",
    "        tmp_dict = {\"event\": event, \"itemid\": curr_event_item_id}\n",
    "        event_links_df = pd.concat([event_links_df, pd.DataFrame(tmp_dict, index=[0])], axis=0, ignore_index=True)\n",
    "\n",
    "    df = df[df[\"itemid\"].isin(event_links_df['itemid'])]\n",
    "    df = df.merge(event_links_df, on=\"itemid\", how=\"left\")\n",
    "    df.drop(columns=[\"itemid\"], inplace=True)\n",
    "    return df\n",
    "\n",
    "def get_labs_of_interest(df):\n",
    "    df = df.copy()\n",
    "\n",
    "    event_list = ['Glucose', 'Potassium', 'Sodium', 'Chloride', 'Creatinine',\n",
    "           'Urea Nitrogen', 'Bicarbonate', 'Anion Gap', 'Hemoglobin', 'Hematocrit',\n",
    "           'Magnesium', 'Platelet Count', 'Phosphate', 'White Blood Cells',\n",
    "           'Calcium, Total', 'MCH', 'Red Blood Cells', 'MCHC', 'MCV', 'RDW', \n",
    "                      'Platelet Count', 'Neutrophils', 'Vancomycin'\n",
    "                  ]\n",
    "\n",
    "    event_links_df = pd.DataFrame()\n",
    "    for event in event_list:\n",
    "        # print(event)\n",
    "        curr_event_item_id = d_lab_items_df[d_lab_items_df[\"label\"] == event][\"itemid\"].values[0]\n",
    "\n",
    "        tmp_dict = {\"event\": event, \"itemid\": curr_event_item_id}\n",
    "        event_links_df = pd.concat([event_links_df, pd.DataFrame(tmp_dict, index=[0])], axis=0, ignore_index=True)\n",
    "\n",
    "    df = df[df[\"itemid\"].isin(event_links_df['itemid'])]\n",
    "    df = df.merge(event_links_df, on=\"itemid\", how=\"left\")\n",
    "    df.drop(columns=[\"itemid\"], inplace=True)\n",
    "\n",
    "    return df\n",
    "\n",
    "def get_vitals_of_interest(df):\n",
    "    df = df.copy()\n",
    "\n",
    "    event_list = [ #CHART EVENTS\n",
    "                  'Heart Rate','Non Invasive Blood Pressure systolic',\n",
    "                    'Non Invasive Blood Pressure diastolic', 'Non Invasive Blood Pressure mean', \n",
    "                    'Respiratory Rate','O2 saturation pulseoxymetry', \n",
    "                    'GCS - Verbal Response', 'GCS - Eye Opening', 'GCS - Motor Response']\n",
    "\n",
    "    event_links_df = pd.DataFrame()\n",
    "    for event in event_list:\n",
    "        # print(event)\n",
    "        curr_event_item_id = d_items_df[d_items_df[\"label\"] == event][\"itemid\"].values[0]\n",
    "\n",
    "        tmp_dict = {\"event\": event, \"itemid\": curr_event_item_id}\n",
    "        event_links_df = pd.concat([event_links_df, pd.DataFrame(tmp_dict, index=[0])], axis=0, ignore_index=True)\n",
    "\n",
    "    df = df[df[\"itemid\"].isin(event_links_df['itemid'])]\n",
    "    df = df.merge(event_links_df, on=\"itemid\", how=\"left\")\n",
    "    df.drop(columns=[\"itemid\"], inplace=True)\n",
    "\n",
    "    rename_dict = {\n",
    "        'Non Invasive Blood Pressure systolic': 'Systolic BP',\n",
    "        'Non Invasive Blood Pressure diastolic': 'Diastolic BP',\n",
    "        'Non Invasive Blood Pressure mean': 'Mean BP',\n",
    "        'O2 saturation pulseoxymetry': 'O2 Saturation'\n",
    "    }\n",
    "\n",
    "    df['event'] = df['event'].replace(rename_dict)\n",
    "    \n",
    "    return df\n",
    "\n",
    "\n",
    "# procedureevents_df = get_procedures_of_interest(procedureevents_df)\n",
    "labevents_df = get_labs_of_interest(chartevents_df)\n",
    "vitals_df = get_vitals_of_interest(chartevents_df)\n",
    "labevents_df = labevents_df[['subject_id', 'hadm_id', 'stay_id', 'charttime', 'event', 'valuenum']]\n",
    "vitals_df = vitals_df[['subject_id', 'hadm_id', 'stay_id', 'charttime', 'event', 'valuenum']]\n",
    "# procedureevents_df = procedureevents_df[['subject_id', 'hadm_id', 'stay_id', 'starttime', 'endtime', 'storetime', 'value', 'event']]\n",
    "\n",
    "# labs_df = get_labs_of_interest(hosp_lab_events)\n",
    "# vitals_df = get_vitals_of_interest(chartevents_df)\n",
    "# labs_vitals_df = labs_vitals_df[['subject_id', 'hadm_id', 'stay_id', 'charttime', 'event', 'valuenum']]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "del chartevents_df, hosp_lab_events"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# d_items_df = pd.read_csv(os.path.join(mimic_iv_path, \"icu\", \"d_items.csv\"), low_memory=False)\n",
    "# # d_items_df = d_items_df[d_items_df['category'] == \"Labs\"]\n",
    "\n",
    "# def get_procedures_of_interest(df):\n",
    "#     df = df.copy()\n",
    "\n",
    "#     event_list = ['Foley Catheter', 'PICC Line', 'Intubation', 'Peritoneal Dialysis', \n",
    "#                             'Bronchoscopy', 'EEG', 'Dialysis - CRRT', 'Dialysis Catheter', \n",
    "#                             'Chest Tube Removed', 'Hemodialysis']\n",
    "#     event_links_df = pd.DataFrame()\n",
    "#     for event in event_list:\n",
    "#         curr_event_item_id = d_items_df[d_items_df[\"label\"] == event][\"itemid\"].values[0]\n",
    "\n",
    "#         tmp_dict = {\"event\": event, \"itemid\": curr_event_item_id}\n",
    "#         event_links_df = pd.concat([event_links_df, pd.DataFrame(tmp_dict, index=[0])], axis=0, ignore_index=True)\n",
    "\n",
    "#     df = df[df[\"itemid\"].isin(event_links_df['itemid'])]\n",
    "#     df = df.merge(event_links_df, on=\"itemid\", how=\"left\")\n",
    "#     df.drop(columns=[\"itemid\"], inplace=True)\n",
    "#     return df\n",
    "\n",
    "# def get_labs_of_interest(df):\n",
    "#     df = df.copy()\n",
    "\n",
    "#     event_list = [  #LAB EVENTS\n",
    "#                   'Glucose (serum)', 'Glucose (whole blood)',\n",
    "#                   'Potassium (serum)', 'Potassium (whole blood)', \n",
    "#                   'Sodium (serum)', 'Sodium (whole blood)',\n",
    "#                   'Chloride (serum)', 'Chloride (whole blood)',\n",
    "#                   'Creatinine (serum)', 'Creatinine (whole blood)',\n",
    "#                   'BUN', #   'Urea Nitrogen', \n",
    "#                   'HCO3 (serum)', #   'Bicarbonate', \n",
    "#                   'Anion gap', \n",
    "#                   'Hemoglobin', \n",
    "#                   'Hematocrit (serum)', 'Hematocrit (whole blood - calc)',\n",
    "#                   'Magnesium', \n",
    "#                   'Platelet Count', \n",
    "#                   'Alkaline Phosphate', \n",
    "#                   'WBC', #'White Blood Cells',\n",
    "#                   'Calcium non-ionized', 'Ionized Calcium', #'Calcium, Total', \n",
    "#                 #   'MCH', \n",
    "#                 #   'Red Blood Cells', \n",
    "#                 #   'MCHC', \n",
    "#                 #   'MCV', \n",
    "#                 #   'RDW', \n",
    "#                   'Absolute Neutrophil Count', #  'Neutrophils', \n",
    "#                   'Vancomycin (Peak)', 'Vancomycin (Random)', 'Vancomycin (Trough)',\n",
    "#                   # NEW\n",
    "#                   'PH (Arterial)', 'PH (dipstick)', 'PH (SOFT)', 'PH (Venous)',\n",
    "#                   'Capillary Refill R', 'Capillary Refill L',\n",
    "#                   'Temperature Celsius',\n",
    "#                   'Daily Weight', 'Admission Weight (Kg)',\n",
    "#                   'Inspired O2 Fraction'\n",
    "#                   ]\n",
    "\n",
    "#     event_links_df = pd.DataFrame()\n",
    "#     for event in event_list:\n",
    "#         # print(event)\n",
    "#         curr_event_item_id = d_items_df[d_items_df[\"label\"] == event][\"itemid\"].values[0]\n",
    "\n",
    "#         tmp_dict = {\"event\": event, \"itemid\": curr_event_item_id}\n",
    "#         event_links_df = pd.concat([event_links_df, pd.DataFrame(tmp_dict, index=[0])], axis=0, ignore_index=True)\n",
    "\n",
    "#     df = df[df[\"itemid\"].isin(event_links_df['itemid'])]\n",
    "#     df = df.merge(event_links_df, on=\"itemid\", how=\"left\")\n",
    "#     df.drop(columns=[\"itemid\"], inplace=True)\n",
    "\n",
    "#     rename_dict = {\n",
    "#         'Glucose (serum)': 'Glucose',\n",
    "#         'Glucose (whole blood)': 'Glucose',\n",
    "#         'Potassium (serum)': 'Potassium',\n",
    "#         'Potassium (whole blood)': 'Potassium',\n",
    "#         'Sodium (serum)': 'Sodium',\n",
    "#         'Sodium (whole blood)': 'Sodium',\n",
    "#         'Chloride (serum)': 'Chloride',\n",
    "#         'Chloride (whole blood)': 'Chloride',\n",
    "#         'Creatinine (serum)': 'Creatinine',\n",
    "#         'Creatinine (whole blood)': 'Creatinine',\n",
    "#         'BUN': 'Urea Nitrogen',\n",
    "#         'HCO3 (serum)': 'Bicarbonate',\n",
    "#         'Hematocrit (serum)': 'Hematocrit',\n",
    "#         'Hematocrit (whole blood - calc)': 'Hematocrit',\n",
    "#         'Calcium non-ionized': 'Calcium',\n",
    "#         'Ionized Calcium': 'Calcium',\n",
    "#         'Vancomycin (Peak)': 'Vancomycin',\n",
    "#         'Vancomycin (Random)': 'Vancomycin',\n",
    "#         'Vancomycin (Trough)': 'Vancomycin',\n",
    "#         'PH (Arterial)': 'PH',\n",
    "#         'PH (dipstick)': 'PH',\n",
    "#         'PH (SOFT)': 'PH',\n",
    "#         'PH (Venous)': 'PH',\n",
    "#         'Capillary Refill R': 'Capillary Refill',\n",
    "#         'Capillary Refill L': 'Capillary Refill',\n",
    "#         'Temperature Celsius': 'Temperature',\n",
    "#         'Daily Weight': 'Weight',\n",
    "#         'Admission Weight (Kg)': 'Weight',\n",
    "#         'Inspired O2 Fraction': 'Inspired O2 Fraction'\n",
    "#     }\n",
    "\n",
    "#     df['event'] = df['event'].replace(rename_dict)\n",
    "\n",
    "#     return df\n",
    "\n",
    "# def get_vitals_of_interest(df):\n",
    "#     df = df.copy()\n",
    "\n",
    "#     event_list = [ #CHART EVENTS\n",
    "#                   'Heart Rate','Non Invasive Blood Pressure systolic',\n",
    "#                     'Non Invasive Blood Pressure diastolic', 'Non Invasive Blood Pressure mean', \n",
    "#                     'Respiratory Rate','O2 saturation pulseoxymetry', \n",
    "#                     'GCS - Verbal Response', 'GCS - Eye Opening', 'GCS - Motor Response']\n",
    "\n",
    "#     event_links_df = pd.DataFrame()\n",
    "#     for event in event_list:\n",
    "#         # print(event)\n",
    "#         curr_event_item_id = d_items_df[d_items_df[\"label\"] == event][\"itemid\"].values[0]\n",
    "\n",
    "#         tmp_dict = {\"event\": event, \"itemid\": curr_event_item_id}\n",
    "#         event_links_df = pd.concat([event_links_df, pd.DataFrame(tmp_dict, index=[0])], axis=0, ignore_index=True)\n",
    "\n",
    "#     df = df[df[\"itemid\"].isin(event_links_df['itemid'])]\n",
    "#     df = df.merge(event_links_df, on=\"itemid\", how=\"left\")\n",
    "#     df.drop(columns=[\"itemid\"], inplace=True)\n",
    "\n",
    "#     rename_dict = {\n",
    "#         'Non Invasive Blood Pressure systolic': 'Systolic BP',\n",
    "#         'Non Invasive Blood Pressure diastolic': 'Diastolic BP',\n",
    "#         'Non Invasive Blood Pressure mean': 'Mean BP',\n",
    "#         'O2 saturation pulseoxymetry': 'O2 Saturation'\n",
    "#     }\n",
    "\n",
    "#     df['event'] = df['event'].replace(rename_dict)\n",
    "    \n",
    "#     return df\n",
    "\n",
    "# def get_labs_vitals(df):\n",
    "#     df = df.copy()\n",
    "\n",
    "#     event_list = [ #LAB EVENTS\n",
    "#                   'Glucose (serum)', 'Glucose (whole blood)',\n",
    "#                   'Potassium (serum)', 'Potassium (whole blood)', \n",
    "#                   'Sodium (serum)', 'Sodium (whole blood)',\n",
    "#                   'Chloride (serum)', 'Chloride (whole blood)',\n",
    "#                   'Creatinine (serum)', 'Creatinine (whole blood)',\n",
    "#                   'BUN', #   'Urea Nitrogen', \n",
    "#                   'HCO3 (serum)', #   'Bicarbonate', \n",
    "#                   'Anion gap', \n",
    "#                   'Hemoglobin', \n",
    "#                   'Hematocrit (serum)', 'Hematocrit (whole blood - calc)',\n",
    "#                   'Magnesium', \n",
    "#                   'Platelet Count', \n",
    "#                   'Alkaline Phosphate', \n",
    "#                   'WBC', #'White Blood Cells',\n",
    "#                   'Calcium non-ionized', 'Ionized Calcium', #'Calcium, Total', \n",
    "#                 #   'MCH', \n",
    "#                 #   'Red Blood Cells', \n",
    "#                 #   'MCHC', \n",
    "#                 #   'MCV', \n",
    "#                 #   'RDW', \n",
    "#                   'Absolute Neutrophil Count', #  'Neutrophils', \n",
    "#                   'Vancomycin (Peak)', 'Vancomycin (Random)', 'Vancomycin (Trough)',\n",
    "                  \n",
    "#                   # NEW\n",
    "#                   'PH (Arterial)', 'PH (dipstick)', 'PH (SOFT)', 'PH (Venous)',\n",
    "#                   'Capillary Refill R', 'Capillary Refill L',\n",
    "#                   'Temperature Celsius',\n",
    "#                   'Daily Weight', 'Admission Weight (Kg)',\n",
    "#                   'Inspired O2 Fraction',\n",
    "\n",
    "#                   #CHART EVENTS\n",
    "#                   'Heart Rate','Non Invasive Blood Pressure systolic',\n",
    "#                     'Non Invasive Blood Pressure diastolic', 'Non Invasive Blood Pressure mean', \n",
    "#                     'Respiratory Rate','O2 saturation pulseoxymetry', \n",
    "#                     'GCS - Verbal Response', 'GCS - Eye Opening', 'GCS - Motor Response'\n",
    "#                     ]\n",
    "\n",
    "#     event_links_df = pd.DataFrame()\n",
    "#     for event in event_list:\n",
    "#         curr_event_item_id = d_items_df[d_items_df[\"label\"] == event][\"itemid\"].values[0]\n",
    "\n",
    "#         tmp_dict = {\"event\": event, \"itemid\": curr_event_item_id}\n",
    "#         event_links_df = pd.concat([event_links_df, pd.DataFrame(tmp_dict, index=[0])], axis=0, ignore_index=True)\n",
    "\n",
    "#     df = df[df[\"itemid\"].isin(event_links_df['itemid'])]\n",
    "#     df = df.merge(event_links_df, on=\"itemid\", how=\"left\")\n",
    "#     df.drop(columns=[\"itemid\"], inplace=True)\n",
    "\n",
    "#     rename_dict = {\n",
    "#         'Glucose (serum)': 'Glucose',\n",
    "#         'Glucose (whole blood)': 'Glucose',\n",
    "#         'Potassium (serum)': 'Potassium',\n",
    "#         'Potassium (whole blood)': 'Potassium',\n",
    "#         'Sodium (serum)': 'Sodium',\n",
    "#         'Sodium (whole blood)': 'Sodium',\n",
    "#         'Chloride (serum)': 'Chloride',\n",
    "#         'Chloride (whole blood)': 'Chloride',\n",
    "#         'Creatinine (serum)': 'Creatinine',\n",
    "#         'Creatinine (whole blood)': 'Creatinine',\n",
    "#         'BUN': 'Urea Nitrogen',\n",
    "#         'HCO3 (serum)': 'Bicarbonate',\n",
    "#         'Hematocrit (serum)': 'Hematocrit',\n",
    "#         'Hematocrit (whole blood - calc)': 'Hematocrit',\n",
    "#         'Calcium non-ionized': 'Calcium',\n",
    "#         'Ionized Calcium': 'Calcium',\n",
    "#         'Vancomycin (Peak)': 'Vancomycin',\n",
    "#         'Vancomycin (Random)': 'Vancomycin',\n",
    "#         'Vancomycin (Trough)': 'Vancomycin',\n",
    "#         'PH (Arterial)': 'PH',\n",
    "#         'PH (dipstick)': 'PH',\n",
    "#         'PH (SOFT)': 'PH',\n",
    "#         'PH (Venous)': 'PH',\n",
    "#         'Capillary Refill R': 'Capillary Refill',\n",
    "#         'Capillary Refill L': 'Capillary Refill',\n",
    "#         'Temperature Celsius': 'Temperature',\n",
    "#         'Daily Weight': 'Weight',\n",
    "#         'Admission Weight (Kg)': 'Weight',\n",
    "#         'Inspired O2 Fraction': 'Inspired O2 Fraction',\n",
    "\n",
    "#         'Non Invasive Blood Pressure systolic': 'Systolic BP',\n",
    "#         'Non Invasive Blood Pressure diastolic': 'Diastolic BP',\n",
    "#         'Non Invasive Blood Pressure mean': 'Mean BP',\n",
    "#         'O2 saturation pulseoxymetry': 'O2 Saturation'\n",
    "#     }\n",
    "\n",
    "#     df['event'] = df['event'].replace(rename_dict)\n",
    "\n",
    "#     return df\n",
    "\n",
    "\n",
    "\n",
    "# # procedureevents_df = get_procedures_of_interest(procedureevents_df)\n",
    "# labevents_df = get_labs_of_interest(chartevents_df)\n",
    "# vitals_df = get_vitals_of_interest(chartevents_df)\n",
    "# # labevents_df = labevents_df[['subject_id', 'hadm_id', 'stay_id', 'charttime', 'event', 'valuenum']]\n",
    "# # vitals_df = vitals_df[['subject_id', 'hadm_id', 'stay_id', 'charttime', 'event', 'valuenum']]\n",
    "# # procedureevents_df = procedureevents_df[['subject_id', 'hadm_id', 'stay_id', 'starttime', 'endtime', 'storetime', 'value', 'event']]\n",
    "\n",
    "\n",
    "\n",
    "# # labs_vitals_df = get_labs_vitals(chartevents_df)\n",
    "# # labs_vitals_df = labs_vitals_df[['subject_id', 'hadm_id', 'stay_id', 'charttime', 'event', 'valuenum']]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 33094639/33094639 [12:17:31<00:00, 747.87it/s]  \n",
      "100%|██████████| 36259441/36259441 [11:34:27<00:00, 870.21it/s]  \n"
     ]
    }
   ],
   "source": [
    "from tqdm import tqdm\n",
    "\n",
    "def calc_time_delta_hrs(icu_intime, charttime):\n",
    "    return (charttime - icu_intime).total_seconds() / 3600\n",
    "\n",
    "\n",
    "\n",
    "def add_time_delta(df):\n",
    "    df = df.copy()\n",
    "\n",
    "    if 'stay_id' in df.columns:\n",
    "        stay_id_in_cols = True\n",
    "    else:\n",
    "        stay_id_in_cols = False\n",
    "        df['stay_id'] = None\n",
    "        \n",
    "    df['icu_time_delta'] = None\n",
    "    df['hosp_time_delta'] = None\n",
    "\n",
    "    for index, row in tqdm(df.iterrows(), total=df.shape[0]):\n",
    "        if 'charttime' in row:\n",
    "            ref_time = row['charttime']\n",
    "        elif 'storetime' in row:\n",
    "            ref_time = row['storetime']\n",
    "\n",
    "        curr_admission = admissions_df[(admissions_df['subject_id'] == row['subject_id']) & (admissions_df['hadm_id'] == row['hadm_id'])]\n",
    "\n",
    "        df.loc[index, 'hosp_time_delta'] = calc_time_delta_hrs(curr_admission['admittime'].iloc[0], ref_time)\n",
    "\n",
    "        if stay_id_in_cols:\n",
    "            curr_icu_stay = icustays_df[(icustays_df['subject_id'] == row['subject_id']) & (icustays_df['stay_id'] == row['stay_id'])]\n",
    "            df.loc[index, 'icu_time_delta'] = calc_time_delta_hrs(curr_icu_stay['intime'].iloc[0], ref_time)\n",
    "        else:\n",
    "            curr_pts_icustays = icustays_df[icustays_df['subject_id'] == row['subject_id']]\n",
    "\n",
    "            for icu_index, icu_row in curr_pts_icustays.iterrows():\n",
    "                if icu_row['intime'] <= ref_time <= icu_row['outtime']:\n",
    "                    df.loc[index, 'stay_id'] = icu_row['stay_id']\n",
    "                    df.loc[index, 'icu_time_delta'] = calc_time_delta_hrs(icu_row['intime'], ref_time)\n",
    "            \n",
    "\n",
    "    df = df.sort_values(by=['subject_id', 'hadm_id', 'stay_id', 'hosp_time_delta'])\n",
    "    return df\n",
    "\n",
    "\n",
    "# procedureevents_df = add_time_delta(icustays_df, procedureevents_df)\n",
    "labevents_df = add_time_delta(icustays_df, labevents_df)\n",
    "vitals_df = add_time_delta(icustays_df, vitals_df)\n",
    "\n",
    "# labs_df = add_time_delta(labs_df)\n",
    "# labs_df = labs_df[['subject_id', 'hadm_id', 'stay_id', 'hosp_time_delta', 'icu_time_delta', 'charttime', 'storetime', 'event', 'valuenum']]\n",
    "# labs_df.sort_values(by=['subject_id', 'hadm_id', 'stay_id', 'hosp_time_delta'], inplace=True)\n",
    "# vitals_df = add_time_delta(vitals_df)\n",
    "# vitals_df = vitals_df[['subject_id', 'hadm_id', 'stay_id', 'hosp_time_delta', 'icu_time_delta', 'charttime', 'storetime', 'event', 'valuenum']]\n",
    "# vitals_df.sort_values(by=['subject_id', 'hadm_id', 'stay_id', 'hosp_time_delta'], inplace=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "concat_df = pd.concat([labevents_df, vitals_df], axis=0, ignore_index=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def convert_events_table_to_ts_array(df):\n",
    "    # Ensure 'valuenum' or 'value' columns exist\n",
    "    value_column = 'valuenum' if 'valuenum' in df.columns else 'value'\n",
    "\n",
    "    # Create a pivot table\n",
    "    pivot_df = df.pivot_table(index=['hadm_id', 'hosp_time_delta'], \n",
    "                              columns='event', \n",
    "                              values=value_column, \n",
    "                              aggfunc='first').reset_index()\n",
    "\n",
    "    # Join with the original DataFrame to get other required columns\n",
    "    keys = ['subject_id', 'hadm_id', 'stay_id', 'hosp_time_delta', 'icu_time_delta']\n",
    "    merged_df = pd.merge(df[keys].drop_duplicates(), pivot_df, on=['hadm_id', 'hosp_time_delta'])\n",
    "\n",
    "    # Reorder the columns\n",
    "    cols = merged_df.columns.tolist()\n",
    "    cols = [col for col in keys if col in cols] + [col for col in cols if col not in keys]\n",
    "    merged_df = merged_df[cols]\n",
    "\n",
    "    # Sort the DataFrame\n",
    "    merged_df.sort_values(by=['subject_id', 'hadm_id', 'stay_id', 'hosp_time_delta'], inplace=True)\n",
    "\n",
    "    return merged_df\n",
    "\n",
    "# procedureevents_ts_df = convert_events_table_to_ts_array(procedureevents_df)\n",
    "labevents_ts_df = convert_events_table_to_ts_array(labevents_df)\n",
    "vitals_ts_df = convert_events_table_to_ts_array(vitals_df)\n",
    "\n",
    "concat_df = convert_events_table_to_ts_array(concat_df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mm_dir = \"/cis/home/charr165/Documents/multimodal\"\n",
    "output_dir = os.path.join(mm_dir, \"preprocessing\")\n",
    "\n",
    "# procedureevents_ts_df.to_pickle(os.path.join(output_dir, \"ts_procedureevents_icu.pkl\"))\n",
    "labevents_ts_df.to_pickle(os.path.join(output_dir, \"ts_labs_icu.pkl\"))\n",
    "vitals_ts_df.to_pickle(os.path.join(output_dir, \"ts_vitals_icu.pkl\"))\n",
    "\n",
    "concat_df.to_pickle(os.path.join(output_dir, \"ts_labs_vitals.pkl\"))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": ".venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
