{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#libraries\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "from sqlalchemy import create_engine\n",
    "\n",
    "import collections\n",
    "# import getpass\n",
    "from datetime import datetime as dt\n",
    "import os,sys,re\n",
    "from collections import Counter\n",
    "#import seaborn as sns\n",
    "# import random\n",
    "from datetime import timedelta\n",
    "from pathlib import Path\n",
    "import importlib\n",
    "\n",
    "from tqdm import tqdm\n",
    "\n",
    "import multiprocessing as mp\n",
    "\n",
    "import time\n",
    "import pickle\n",
    "\n",
    "import json\n",
    "\n",
    "from sklearn.model_selection import train_test_split"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import datetime\n",
    "root = '../DataProcessing'\n",
    "current_date = datetime.date.today()\n",
    "date = current_date.strftime('%y%m%d')\n",
    "main_dir = os.path.join(root, 'benchmark_data', date)\n",
    "chartlab_dir = os.path.join(main_dir, 'chartlab')\n",
    "treatments_dir = os.path.join(main_dir, 'treatments')\n",
    "merged_dir = os.path.join(main_dir, 'merged')\n",
    "\n",
    "\n",
    "os.makedirs(main_dir, exist_ok=True)\n",
    "os.makedirs(chartlab_dir, exist_ok=True)\n",
    "os.makedirs(treatments_dir, exist_ok=True)\n",
    "os.makedirs(merged_dir, exist_ok=True)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# pd.set_option('display.max_rows', 100)\n",
    "# # engine = create_engine('postgresql://postgres:jWkpJAFbgcafiq6afHW3o0naRYP4M70w@localhost:24983/mimiciv22')\n",
    "# engine = create_engine('postgresql://postgres:jWkpJAFbgcafiq6afHW3o0naRYP4M70w@localhost:5432/mimiciv22')\n",
    "\n",
    "\n",
    "DBNAME = os.environ.get('DBNAME', 'mimiciv22')\n",
    "DBUSER = os.environ.get('DBUSER', 'postgres')\n",
    "DBPASS = os.environ.get('DBPASS', 'jWkpJAFbgcafiq6afHW3o0naRYP4M70w')\n",
    "DBHOST = os.environ.get('DBHOST', 'localhost')\n",
    "DBPORT = os.environ.get('DBPORT', '5432')\n",
    "\n",
    "db_uri = (f\"postgresql://{DBUSER}:{DBPASS}@{DBHOST}:{DBPORT}/{DBNAME}\")\n",
    "engine = create_engine(db_uri)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sqlalchemy import create_engine, MetaData\n",
    "from sqlalchemy.engine import reflection\n",
    "metadata = MetaData()\n",
    "insp = reflection.Inspector.from_engine(engine)\n",
    "# Get list of all schemas\n",
    "schemas = insp.get_schema_names()\n",
    "\n",
    "for schema in schemas:\n",
    "    print(f\"Schema: {schema}\")\n",
    "\n",
    "    # Get list of all tables in each schema\n",
    "    tables = insp.get_table_names(schema=schema)\n",
    "    for table in tables:\n",
    "        print(f\"  Table: {table}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Base patien\n",
    "\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sql = \"\"\"     \n",
    "with hosp_base as (select subject_id, hadm_id, admittime, dischtime, deathtime, hospital_expire_flag, race, admission_type from mimiciv_hosp.admissions),\n",
    "     icu_base as (select subject_id, hadm_id, stay_id, intime, outtime, los from mimiciv_icu.icustays)\n",
    "\n",
    "select hosp_base.*, icu_base.stay_id, icu_base.intime, icu_base.outtime, icu_base.los\n",
    "     from hosp_base left join icu_base on hosp_base.hadm_id = icu_base.hadm_id\n",
    "     where stay_id is not null;\n",
    "\"\"\"\n",
    "patient_base = pd.read_sql(sql, engine)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "patient_base"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sql = \"\"\"\n",
    "select * from mimiciv_derived.height;\n",
    "\"\"\"\n",
    "height = pd.read_sql(sql, engine)\n",
    "height = height[height['stay_id'].isin(patient_base['stay_id'])]\n",
    "average_height = height.groupby('stay_id')['height'].mean().reset_index()\n",
    "average_height.columns = ['stay_id', 'average_height']\n",
    "patient_base = pd.merge(patient_base, average_height, on='stay_id', how='left')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Filter data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sql = \"\"\"\n",
    "select * from mimiciv_hosp.patients;\n",
    "\"\"\"\n",
    "patient_meta = pd.read_sql(sql, engine)\n",
    "patient_meta = patient_meta[patient_meta['subject_id'].isin(patient_base['subject_id'])]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sql = \"\"\"\n",
    "select * from mimiciv_derived.age;\n",
    "\"\"\"\n",
    "age = pd.read_sql(sql, engine)\n",
    "age = age[age['hadm_id'].isin(patient_base['hadm_id'])]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "patient_base = patient_base.merge(patient_meta[['subject_id', 'gender']], on='subject_id', how='left')\n",
    "patient_base = patient_base.merge(age[['hadm_id', 'age']], on='hadm_id', how='left')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "patient_base"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sql = \"\"\"     \n",
    "select subject_id, hadm_id, admittime, dischtime, deathtime, hospital_expire_flag from mimiciv_hosp.admissions;\n",
    "\n",
    "\"\"\"\n",
    "admissions_base = pd.read_sql(sql, engine)\n",
    "\n",
    "sql = \"\"\"     \n",
    "select subject_id, hadm_id, stay_id, intime, outtime, los from mimiciv_icu.icustays;\n",
    "\n",
    "\"\"\"\n",
    "icu_base = pd.read_sql(sql, engine)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "icu_base"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sql = \"\"\"\n",
    "select diagnoses_icd.* , long_title\n",
    "    from mimiciv_hosp.diagnoses_icd\n",
    "    left join mimiciv_hosp.d_icd_diagnoses on (diagnoses_icd.icd_code=d_icd_diagnoses.icd_code and diagnoses_icd.icd_version=d_icd_diagnoses.icd_version)\n",
    "    where mimiciv_hosp.diagnoses_icd.hadm_id in (select hadm_id from mimiciv_icu.icustays)\n",
    "\"\"\"\n",
    "diagonases = pd.read_sql(sql, engine)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sql = \"\"\"\n",
    "select * \n",
    "    from mimiciv_note.discharge\n",
    "    where mimiciv_note.discharge.hadm_id in (select hadm_id from mimiciv_icu.icustays);\n",
    "\"\"\"\n",
    "discharge = pd.read_sql(sql, engine)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "patient_base"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Unified Basic Check - Check the number changes in all filtering steps\n",
    "print(\"=\" * 60)\n",
    "print(\"MIMIC-IV Data Filtering Process Unified Check\")\n",
    "print(\"=\" * 60)\n",
    "\n",
    "# Store filtering step information\n",
    "filter_steps = []\n",
    "initial_count = len(patient_base)\n",
    "filter_steps.append((\"Original Data\", initial_count, 0))\n",
    "\n",
    "print(f\"1. Original Data: {initial_count:,} samples\")\n",
    "\n",
    "# Data quality check\n",
    "print(\"\\n--- Data Quality Check ---\")\n",
    "invalid_age = patient_base[patient_base['age'] < 18]\n",
    "print(f\"   Age < 18 years: {len(invalid_age)} cases\")\n",
    "\n",
    "patient_base['admittime'] = pd.to_datetime(patient_base['admittime'])\n",
    "patient_base['dischtime'] = pd.to_datetime(patient_base['dischtime'])\n",
    "invalid_admissions = patient_base[patient_base['admittime'] > patient_base['dischtime']]\n",
    "print(f\"   Admission time > Discharge time: {len(invalid_admissions)} cases\")\n",
    "\n",
    "patient_base['intime'] = pd.to_datetime(patient_base['intime'])\n",
    "patient_base['outtime'] = pd.to_datetime(patient_base['outtime'])\n",
    "invalid_stay = patient_base[patient_base['intime'] > patient_base['outtime']]\n",
    "print(f\"   ICU admission time > Discharge time: {len(invalid_stay)} cases\")\n",
    "\n",
    "invalid_stay2 = patient_base[patient_base['admittime'] > patient_base['intime']]\n",
    "print(f\"   Admission time > ICU admission time: {len(invalid_stay2)} cases\")\n",
    "\n",
    "invalid_los = patient_base[patient_base['los'] <= 0.25]\n",
    "print(f\"   Length of stay ≤ 0.25 days: {len(invalid_los)} cases\")\n",
    "\n",
    "missing_diagonases = patient_base[~(patient_base['hadm_id'].isin(diagonases['hadm_id']))]\n",
    "print(f\"   Missing diagnosis information: {len(missing_diagonases)} cases\")\n",
    "\n",
    "missing_discharge = patient_base[~(patient_base['hadm_id'].isin(discharge['hadm_id']))]\n",
    "death_patient_missing_discharge = patient_base[~(patient_base['hadm_id'].isin(discharge['hadm_id'])) & (patient_base['hospital_expire_flag']==1)]\n",
    "print(f\"   Missing discharge information: {len(missing_discharge)} cases (including {len(death_patient_missing_discharge)} deceased patients)\")\n",
    "\n",
    "print(f\"   Duplicate hadm_id: {admissions_base['hadm_id'].duplicated().any()}\")\n",
    "print(f\"   Duplicate stay_id: {icu_base['stay_id'].duplicated().any()}\")\n",
    "print(f\"   Duplicate hadm_id in ICU: {icu_base['hadm_id'].duplicated().any()}\")\n",
    "\n",
    "# Execute filtering steps and record quantity changes\n",
    "print(\"\\n--- Execute Filtering Steps ---\")\n",
    "\n",
    "# Filter 1: Has diagnosis information\n",
    "before_count = len(patient_base)\n",
    "patient_base = patient_base[patient_base['hadm_id'].isin(diagonases['hadm_id'])]\n",
    "after_count = len(patient_base)\n",
    "filtered = before_count - after_count\n",
    "filter_steps.append((\"Record without diagnoses code\", after_count, filtered))\n",
    "print(f\"2. Has Diagnosis Information: {after_count:,} samples (filtered out {filtered:,})\")\n",
    "\n",
    "# Filter 2: Has discharge information\n",
    "before_count = len(patient_base)\n",
    "patient_base = patient_base[patient_base['hadm_id'].isin(discharge['hadm_id'])]\n",
    "after_count = len(patient_base)\n",
    "filtered = before_count - after_count\n",
    "filter_steps.append((\"Record without discharge note\", after_count, filtered))\n",
    "print(f\"3. Has Discharge Information: {after_count:,} samples (filtered out {filtered:,})\")\n",
    "\n",
    "# Filter 3: Length of stay > 0.25 days\n",
    "before_count = len(patient_base)\n",
    "patient_base = patient_base[patient_base['los'] > 0.25]\n",
    "after_count = len(patient_base)\n",
    "filtered = before_count - after_count\n",
    "filter_steps.append((\"Length of Stay > 0.25 Days\", after_count, filtered))\n",
    "print(f\"4. Length of Stay > 0.25 Days: {after_count:,} samples (filtered out {filtered:,})\")\n",
    "\n",
    "# Filter 4: Time logic correct (Admission ≤ ICU admission)\n",
    "before_count = len(patient_base)\n",
    "patient_base = patient_base[~(patient_base['admittime'] > patient_base['intime'])]\n",
    "after_count = len(patient_base)\n",
    "filtered = before_count - after_count\n",
    "filter_steps.append((\"Admission time > ICU admission time\", after_count, filtered))\n",
    "print(f\"5. Time Logic Correct (Admission ≤ ICU Admission): {after_count:,} samples (filtered out {filtered:,})\")\n",
    "\n",
    "# Filter 5: Time logic correct (Admission ≤ Discharge)\n",
    "before_count = len(patient_base)\n",
    "patient_base = patient_base[~(patient_base['admittime'] > patient_base['dischtime'])]\n",
    "after_count = len(patient_base)\n",
    "filtered = before_count - after_count\n",
    "filter_steps.append((\"Admission time > Discharge time\", after_count, filtered))\n",
    "print(f\"6. Time Logic Correct (Admission ≤ Discharge): {after_count:,} samples (filtered out {filtered:,})\")\n",
    "\n",
    "# Filter 6: Remove duplicates\n",
    "before_count = len(patient_base)\n",
    "patient_base = patient_base.drop_duplicates()\n",
    "after_count = len(patient_base)\n",
    "filtered = before_count - after_count\n",
    "filter_steps.append((\"Remove Duplicates\", after_count, filtered))\n",
    "print(f\"7. Remove Duplicates: {after_count:,} samples (filtered out {filtered:,})\")\n",
    "\n",
    "# Filter 7: Specific admission types\n",
    "before_count = len(patient_base)\n",
    "patient_base = patient_base[patient_base['admission_type'].isin(['URGENT', 'DIRECT EMER.', 'EW EMER.', 'OBSERVATION ADMIT'])]\n",
    "after_count = len(patient_base)\n",
    "filtered = before_count - after_count\n",
    "filter_steps.append((\"Specific Admission Types\", after_count, filtered))\n",
    "print(f\"8. Specific Admission Types: {after_count:,} samples (filtered out {filtered:,})\")\n",
    "\n",
    "# Filter 8: First ICU admission\n",
    "patient_base = patient_base.sort_values(by=['subject_id', 'intime'])\n",
    "patient_base['episode_number'] = patient_base.groupby('hadm_id').cumcount()+1\n",
    "\n",
    "before_count = len(patient_base)\n",
    "patient_base = patient_base[patient_base['episode_number']==1]\n",
    "after_count = len(patient_base)\n",
    "filtered = before_count - after_count\n",
    "filter_steps.append((\"First ICU Admission\", after_count, filtered))\n",
    "print(f\"9. First ICU Admission: {after_count:,} samples (filtered out {filtered:,})\")\n",
    "\n",
    "# Filter 9: ICU stay ≥ 48 hours (Final filter)\n",
    "before_count = len(patient_base)\n",
    "cohort = patient_base[(patient_base['outtime']-patient_base['intime']).dt.total_seconds()/3600 > 48]\n",
    "after_count = len(cohort)\n",
    "filtered = before_count - after_count\n",
    "filter_steps.append((\"ICU Stay ≥ 48 Hours\", after_count, filtered))\n",
    "print(f\"10. ICU Stay ≥ 48 Hours: {after_count:,} samples (filtered out {filtered:,})\")\n",
    "\n",
    "# Summary statistics\n",
    "print(\"\\n\" + \"=\" * 60)\n",
    "print(\"Filtering Process Summary\")\n",
    "print(\"=\" * 60)\n",
    "print(f\"Original sample count: {initial_count:,}\")\n",
    "print(f\"Final sample count: {after_count:,}\")\n",
    "print(f\"Total filtered count: {initial_count - after_count:,}\")\n",
    "print(f\"Total filtering rate: {(initial_count - after_count) / initial_count * 100:.1f}%\")\n",
    "print(f\"Final retention rate: {after_count / initial_count * 100:.1f}%\")\n",
    "\n",
    "# Create filtering process table\n",
    "print(\"\\n--- Detailed Filtering Process Table ---\")\n",
    "print(f\"{'Step':<30} {'Sample Count':<15} {'Filtered Count':<15} {'Retention Rate':<15}\")\n",
    "print(\"-\" * 80)\n",
    "for i, (step, count, filtered) in enumerate(filter_steps):\n",
    "    if i == 0:\n",
    "        retention_rate = 100.0\n",
    "    else:\n",
    "        retention_rate = count / filter_steps[i-1][1] * 100\n",
    "    print(f\"{step:<30} {count:<15,} {filtered:<15,} {retention_rate:<15.1f}%\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Basic check\n",
    "print(\"Base record amount: \", len(patient_base))\n",
    "print(\"Duplicate hadm_id in admission:\", admissions_base['hadm_id'].duplicated().any())\n",
    "print(\"Duplicate stay_id in icu:\", icu_base['stay_id'].duplicated().any())\n",
    "print(\"Duplicate hadm_id in icu:\", icu_base['hadm_id'].duplicated().any())\n",
    "\n",
    "invalid_age = patient_base[patient_base['age'] < 18]\n",
    "print(\"Invalid rows where age is smaller than 18:\", len(invalid_age))\n",
    "\n",
    "patient_base['admittime'] = pd.to_datetime(patient_base['admittime'])\n",
    "patient_base['dischtime'] = pd.to_datetime(patient_base['dischtime'])\n",
    "invalid_admissions = patient_base[patient_base['admittime'] > patient_base['dischtime']]\n",
    "print(\"Invalid rows where admittime is greater than dischtime:\", len(invalid_admissions))\n",
    "\n",
    "patient_base['intime'] = pd.to_datetime(patient_base['intime'])\n",
    "patient_base['outtime'] = pd.to_datetime(patient_base['outtime'])\n",
    "invalid_stay = patient_base[patient_base['intime'] > patient_base['outtime']]\n",
    "print(\"Invalid rows where intime is greater than outtime:\", len(invalid_stay))\n",
    "\n",
    "patient_base['admittime'] = pd.to_datetime(patient_base['admittime'])\n",
    "patient_base['intime'] = pd.to_datetime(patient_base['intime'])\n",
    "invalid_stay = patient_base[patient_base['admittime'] > patient_base['intime']]\n",
    "print(\"Invalid rows where admittime is greater than intime:\", len(invalid_stay))\n",
    "\n",
    "invalid_los = patient_base[patient_base['los'] <= 0.25] \n",
    "print(\"Invalid rows where length of stay is less than 0.25: \", len(invalid_los))\n",
    "\n",
    "missing_diagonases = patient_base[~(patient_base['hadm_id'].isin(diagonases['hadm_id']))]\n",
    "print(\"Missing diagonases: \", len(missing_diagonases))\n",
    "\n",
    "missing_discharge = patient_base[~(patient_base['hadm_id'].isin(discharge['hadm_id']))]\n",
    "death_patient_missing_discharge = patient_base[~(patient_base['hadm_id'].isin(discharge['hadm_id'])) & (patient_base['hospital_expire_flag']==1)]\n",
    "print(\"Missing discharge: \", len(missing_discharge), len(death_patient_missing_discharge))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(\"Original record amount: \", len(patient_base))\n",
    "patient_base = patient_base[patient_base['hadm_id'].isin(diagonases['hadm_id'])]\n",
    "print(\"New record amount: \", len(patient_base))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(\"Original record amount: \", len(patient_base))\n",
    "patient_base = patient_base[patient_base['hadm_id'].isin(discharge['hadm_id'])]\n",
    "print(\"New record amount: \", len(patient_base))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(\"Original record amount: \", len(patient_base))\n",
    "patient_base = patient_base[patient_base['los'] > 0.25]\n",
    "print(\"New record amount: \", len(patient_base))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(\"Original record amount: \", len(patient_base))\n",
    "patient_base = patient_base[~(patient_base['admittime'] > patient_base['intime'])]\n",
    "print(\"New record amount: \", len(patient_base))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(\"Original record amount: \", len(patient_base))\n",
    "patient_base = patient_base[~(patient_base['admittime'] > patient_base['dischtime'])]\n",
    "print(\"New record amount: \", len(patient_base))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "patient_base"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "patient_base = patient_base.drop_duplicates()\n",
    "print(\"New record amount: \", len(patient_base))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "patient_base['subject_id'].nunique()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "icu_base"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Extract metadata of ICU stays and CXR availabilities"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "patient_base = patient_base.sort_values(by=['subject_id', 'intime'])\n",
    "patient_base['episode_number'] = patient_base.groupby('hadm_id').cumcount()+1\n",
    "\n",
    "patient_base['icu_mortality'] = 0\n",
    "patient_base.loc[(patient_base['deathtime'] > patient_base['intime']) & (patient_base['deathtime'] <= patient_base['outtime']), 'icu_mortality'] = 1\n",
    "\n",
    "patient_base['hadm_mortality'] = 0\n",
    "patient_base.loc[(patient_base['deathtime'] > patient_base['admittime']) & (patient_base['deathtime'] <= patient_base['dischtime']), 'hadm_mortality'] = 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "patient_base"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "patient_base.groupby('admission_type').agg({\n",
    "    'icu_mortality': 'mean',\n",
    "    'hadm_mortality': 'mean',\n",
    "    'hadm_id': 'nunique'\n",
    "}).sort_values('hadm_mortality', ascending=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "patient_base = patient_base[patient_base['admission_type'].isin(['URGENT', 'DIRECT EMER.', 'EW EMER.', 'OBSERVATION ADMIT'])]\n",
    "patient_base['hadm_los'] = (patient_base.dischtime - patient_base.admittime).dt.total_seconds()/3600/24"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "patient_base"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "patient_base[patient_base['stay_id']==33555982][['intime', 'outtime', 'deathtime', 'icu_mortality']]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "patient_base['icu_mortality'].sum()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Get CXRs\n",
    "cxr_meta = pd.read_csv('/hdd/datasets/mimic-cxr-jpg/2.0.0/mimic-cxr-2.0.0-metadata.csv',\n",
    "                       dtype={'StudyDate': str, 'StudyTime': str})\n",
    "\n",
    "study_time = cxr_meta['StudyDate'] + ' ' + cxr_meta['StudyTime'].apply(lambda x: x.split('.')[0]).str.zfill(6)\n",
    "cxr_meta['CXRTime'] = pd.to_datetime(study_time, format='%Y%m%d %H%M%S')\n",
    "\n",
    "# Use AP view only\n",
    "cxr_meta = cxr_meta[cxr_meta['ViewPosition']=='AP'].sort_values(['subject_id', 'CXRTime'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_cxr_dicom_ids(row):\n",
    "    cxr_meta_pt = cxr_meta[cxr_meta['subject_id']==row.subject_id]\n",
    "    \n",
    "    df = cxr_meta_pt.loc[(cxr_meta_pt['CXRTime']>row.intime) & (cxr_meta_pt['CXRTime']<=row.outtime), \n",
    "                         ['dicom_id', 'CXRTime']]\n",
    "    cxr_within_stays = df.apply(lambda x: (x.dicom_id, x.CXRTime), axis=1).tolist()\n",
    "\n",
    "    df = cxr_meta_pt.loc[(cxr_meta_pt['CXRTime']>row.admittime) & (cxr_meta_pt['CXRTime']<=row.dischtime),\n",
    "                         ['dicom_id', 'CXRTime']]\n",
    "    cxr_within_hadm = df.apply(lambda x: (x.dicom_id, x.CXRTime), axis=1).tolist()\n",
    "    \n",
    "    return cxr_within_stays, cxr_within_hadm\n",
    "\n",
    "cxr_dicom_ids = patient_base.apply(get_cxr_dicom_ids, axis=1)\n",
    "patient_base['cxr_within_icu'], patient_base['cxr_within_hadm'] = zip(*cxr_dicom_ids)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "patient_base['n_cxr_within_icu'] = patient_base['cxr_within_icu'].apply(len)\n",
    "patient_base['n_cxr_within_hadm'] = patient_base['cxr_within_hadm'].apply(len)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# use the first ICU stay for each hospital admission\n",
    "patient_base = patient_base[patient_base['episode_number']==1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "patient_base"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "patient_base['hadm_id'].unique().shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "stayid_intime = patient_base[['stay_id', 'intime']].set_index('stay_id').to_dict()['intime']\n",
    "hadmid_intime = patient_base[['hadm_id', 'intime']].set_index('hadm_id').to_dict()['intime']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "stayid_admittime = patient_base[['stay_id', 'admittime']].set_index('stay_id').to_dict()['admittime']\n",
    "hadmid_admittime = patient_base[['hadm_id', 'admittime']].set_index('hadm_id').to_dict()['admittime']"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Extract Phenotype labels"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import yaml\n",
    "\n",
    "with open('icd_9_10_definitions.yaml', 'r') as f:\n",
    "    phenotype_definitions = yaml.safe_load(f)\n",
    "\n",
    "code_phenotype_mapping = {\n",
    "    code.strip(): (dx, phenotype_definitions[dx]['use_in_benchmark'])\n",
    "    for dx, dx_def in phenotype_definitions.items()\n",
    "    for code in dx_def['codes']\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# extract diagnosis\n",
    "sql = \"\"\"SELECT subject_id, hadm_id, seq_num, trim(diagnoses_icd.icd_code) AS icd_code,\n",
    "                diagnoses_icd.icd_version, d_icd_diagnoses.long_title\n",
    "         FROM mimiciv_hosp.diagnoses_icd\n",
    "         LEFT JOIN mimiciv_hosp.d_icd_diagnoses \n",
    "         ON diagnoses_icd.icd_code = d_icd_diagnoses.icd_code\"\"\"\n",
    "diagnoses = pd.read_sql(sql, con=engine)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "diagnoses['hcup_ccs_class'] = diagnoses['icd_code'].map(lambda x: code_phenotype_mapping.get(x, [None, _])[0])\n",
    "diagnoses['use_in_benchmark'] = diagnoses['icd_code'].map(lambda x: code_phenotype_mapping.get(x, [_, None])[1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "phenotype_labels = diagnoses[diagnoses['use_in_benchmark']==True][['hadm_id', 'hcup_ccs_class']].drop_duplicates()\n",
    "phenotype_labels['label'] = 1\n",
    "phenotype_labels = phenotype_labels.pivot(index='hadm_id', columns='hcup_ccs_class', values='label').fillna(0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "phenotype_classes = phenotype_labels.columns.tolist()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "phenotype_classes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "patient_base = patient_base.merge(phenotype_labels, how='left', on='hadm_id')\n",
    "patient_base[phenotype_classes] = patient_base[phenotype_classes].fillna(0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "patient_base"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Support table"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import psutil\n",
    "import os\n",
    "import gc \n",
    "process = psutil.Process(os.getpid())\n",
    "print(f\"Memory Usage: {process.memory_info().rss / 1024 ** 2:.2f} MB\")\n",
    "gc.collect()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sql = \"\"\"\n",
    "select * from mimiciv_hosp.d_labitems\n",
    "\"\"\"\n",
    "labitem_table = pd.read_sql(sql, engine)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "labitem_table[labitem_table['itemid']==51221]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sql = \"\"\"\n",
    "select * from mimiciv_icu.d_items\n",
    "\"\"\"\n",
    "d_items_table = pd.read_sql(sql, engine)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "d_items_table[d_items_table['itemid']==50816]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sql = \"\"\"\n",
    "select subject_id, hadm_id, specimen_id, itemid, charttime, storetime, value as specimen, valuenum, valueuom from mimiciv_hosp.labevents where itemid in(52033)\n",
    "\"\"\"\n",
    "labevents = pd.read_sql(sql, engine)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "labevents[labevents['specimen_id']==68252610]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Support function"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def melt_pivot(old_table, id_name, time_name, output_var_name='item', output_value_name='value'):\n",
    "    '''\n",
    "    Break down the pivot version of table to a seperate row of id, time, item, value; \n",
    "    '''\n",
    "    # breakdown the table\n",
    "    new_table = pd.melt(old_table.reset_index(drop=True), id_vars=combine_to_list(id_name, time_name), \n",
    "                      var_name=output_var_name, value_name=output_value_name)\n",
    "    # clean the table\n",
    "    new_table = new_table[~(new_table[output_value_name].isna())].sort_values(by=combine_to_list(id_name, time_name, output_var_name)).reset_index(drop=True)\n",
    "    return clean_table(new_table, sort_method=combine_to_list(id_name, time_name, output_var_name))\n",
    "\n",
    "def clean_table(table, sort_method=None):\n",
    "    '''\n",
    "    sort and remove repeat value\n",
    "    '''\n",
    "    if sort_method is not None:\n",
    "        table = table.sort_values(by=sort_method)\n",
    "        table = table.drop_duplicates(subset=sort_method, keep='first')\n",
    "        return table\n",
    "    return table\n",
    "\n",
    "def generate_stat(table, group_name, save_path, base=None, merge_id=None, drop=0):\n",
    "\n",
    "    # add the subject id back\n",
    "    if base is not None and merge_id is not None and 'subject_id' not in table.columns:\n",
    "        table = table.merge(base[[merge_id, 'subject_id']], on=merge_id, how='left')\n",
    "\n",
    "    # generate statistic of each feature\n",
    "    stats = table.groupby(group_name).agg(\n",
    "        pid_count=('subject_id', 'nunique'), \n",
    "        Min=('value', 'min'),\n",
    "        Max=('value', 'max'),\n",
    "        Mean=('value', 'mean'),\n",
    "        Std=('value', 'std'),\n",
    "        Median=('value', 'median'),\n",
    "        Q1=('value', lambda x: x.quantile(0.25)),\n",
    "        Q3=('value', lambda x: x.quantile(0.75)),  \n",
    "        IQR=('value', lambda x: x.quantile(0.75) - x.quantile(0.25)),\n",
    "        lowest_1=('value', lambda x: x.quantile(0.01)),\n",
    "        highest_1=('value', lambda x: x.quantile(0.99))\n",
    "    ).reset_index().sort_values(by='pid_count', ascending=False)\n",
    "\n",
    "    # save the result\n",
    "    stats.to_csv(save_path, index=False)\n",
    "\n",
    "    'column name item in table is the pid count of group_name'\n",
    "    if drop != 0:\n",
    "        top_items = stats.nlargest(drop, 'pid_count')[group_name].tolist()\n",
    "        table = table[table[group_name].isin(top_items)]\n",
    "        \n",
    "    return table, stats  \n",
    "\n",
    "def generate_stat_cate(table, group_name, save_path, base=None, merge_id=None, drop=0):\n",
    "    # add the subject id back\n",
    "    if base is not None and merge_id is not None and 'subject_id' not in table.columns:\n",
    "        table = table.merge(base[[merge_id, 'subject_id']], on=merge_id, how='left')\n",
    "\n",
    "    # generate statistic of each feature\n",
    "    stats = table.groupby(group_name).agg(\n",
    "        pid_count=('subject_id', 'nunique'), \n",
    "        value_count=('value', 'nunique')\n",
    "    ).reset_index()\n",
    "    \n",
    "    # save the result\n",
    "    stats.sort_values(by='pid_count', ascending=False).to_csv(save_path, index=False)\n",
    "    \n",
    "    if drop != 0:\n",
    "        top_items = stats.nlargest(drop, 'pid_count')[group_name].tolist()\n",
    "        table = table[table[group_name].isin(top_items)]\n",
    "      \n",
    "    return table \n",
    "    \n",
    "def generate_dict(table, group_name, key_name, save_path):\n",
    "    table = table[group_name].drop_duplicates()\n",
    "    event_values = table.groupby(key_name).agg({'value': list})['value'].to_dict()\n",
    "    event_dict = save_path\n",
    "    with open(event_dict, 'w') as f:\n",
    "        json.dump(event_values, f, indent=2)\n",
    "\n",
    "\n",
    "def combine_to_list(*args):\n",
    "    return [item for var in args for item in ([var] if isinstance(var, str) else var)]\n",
    "\n",
    "def turn_binary(table):\n",
    "    table = table[~table['value'].isna()]\n",
    "    table['item'] = table['item'] + '_' + table['value']\n",
    "    table['value'] = 1\n",
    "    return table\n",
    "\n",
    "def repeat_rows(df):\n",
    "\n",
    "    counts = (df['end_time_step'] - df['start_time_step']).astype(int)\n",
    "    print(\"Finished generating count!\")\n",
    "    \n",
    "    repeated_df = df.loc[df.index.repeat(counts)]\n",
    "    print(\"Finished repeating rows!\")\n",
    "    \n",
    "    repeated_df['timestep'] = (\n",
    "        repeated_df.groupby(level=0).cumcount() + repeated_df['start_time_step']\n",
    "    ).astype(float)\n",
    "    print(\"Finished calculating timestep!\")\n",
    "\n",
    "    repeated_df.reset_index(drop=True, inplace=True)\n",
    "    \n",
    "    new_table = repeated_df.drop(['starttime', 'endtime'], axis=1)\n",
    "    \n",
    "    return new_table\n",
    "\n",
    "\n",
    "def add_timestep(table, base, base_timename, timename, timestep_name='time_step', merge_id='hadm_id', sort_list=['subject_id', 'charttime']):\n",
    "    '''\n",
    "    Adding a timestep value for the time series\n",
    "    '''\n",
    "    if base_timename not in table.columns:\n",
    "        table = table.sort_values(by=sort_list).merge(base[[merge_id, base_timename]], on=[merge_id], how='left')\n",
    "    \n",
    "    # # keep data within 48 hours of ICU intime\n",
    "    table = table[table[timename] >= table[base_timename]]\n",
    "    table = table[table[timename] <= table[base_timename]+pd.Timedelta('2 days')]\n",
    "    # table = table[table[timename] <= table[base_timename]+pd.Timedelta('12 hours')]\n",
    "    table[timestep_name] = (table[timename] - table[base_timename]).dt.total_seconds()//3600\n",
    "    return table\n",
    "\n",
    "def count_duplicated(table, subset):\n",
    "    duplicates = table.duplicated(subset=subset, keep=False)\n",
    "    duplicate_count = duplicates.sum()\n",
    "    print(f\"Number of duplicate records: {duplicate_count}\")\n",
    "    print('Total row', len(table))\n",
    "\n",
    "def remove_outliner(table, stats, ignore_feature=None):\n",
    "    for feature in stats['item']:\n",
    "        if ignore_feature is not None and feature not in ignore_feature:\n",
    "            lower = stats[stats['item']==feature]['lowest_1'].item()\n",
    "            higher = stats[stats['item']==feature]['highest_1'].item()\n",
    "            q1 = stats[stats['item']==feature]['Q1'].item()\n",
    "            q3 = stats[stats['item']==feature]['Q3'].item()\n",
    "            if lower != higher and lower != q1 and higher != q3:\n",
    "                # filter out the row where table['item']==feature and outside the lower bound and higher bound\n",
    "                table = table[~((table['item']==feature) & ((table['value'] < lower) | (table['value'] > higher)))]\n",
    "    return table\n",
    "\n",
    "def rename_feature(table, name):\n",
    "    table['item'] = name +'_' + table['item']\n",
    "    return table\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Table 1: bg\n",
    "hadm_id<br>\n",
    "use feature 25"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sql = \"\"\"\n",
    "select * from mimiciv_derived.bg\n",
    "where mimiciv_derived.bg.hadm_id in (select hadm_id from mimiciv_icu.icustays)\n",
    "\"\"\"\n",
    "bg = pd.read_sql(sql, engine)\n",
    "bg = bg[bg['hadm_id'].isin(patient_base['hadm_id'])]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "bg.columns"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "count_duplicated(bg, subset=['subject_id', 'hadm_id', 'charttime'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# drop row without specimen and turn feature into specimen specific feature\n",
    "bg = bg.pivot_table(\n",
    "    index=['subject_id','hadm_id', 'charttime'],\n",
    "    columns='specimen',\n",
    "    values=['so2', 'po2', 'pco2',\n",
    "    'fio2_chartevents', 'fio2', 'aado2', 'aado2_calc', 'pao2fio2ratio',\n",
    "    'ph', 'baseexcess', 'bicarbonate', 'totalco2', 'hematocrit',\n",
    "    'hemoglobin', 'carboxyhemoglobin', 'methemoglobin', 'chloride',\n",
    "    'calcium', 'temperature', 'potassium', 'sodium', 'lactate', 'glucose'],\n",
    "    aggfunc='first'\n",
    ")\n",
    "bg.columns = ['_'.join(col).strip() for col in bg.columns.values]\n",
    "bg = bg.reset_index()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "bg = melt_pivot(bg.drop(columns=['subject_id']), id_name='hadm_id', time_name='charttime')\n",
    "bg, bg_stat = generate_stat(bg, group_name='item', save_path='../DataProcessing/benchmark_stat/bg.csv', base=patient_base, merge_id='hadm_id', drop=25)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "bg = add_timestep(table=bg, base=patient_base, base_timename='intime',\n",
    "                      timename='charttime', timestep_name='timestep',\n",
    "                      merge_id='hadm_id', sort_list=['subject_id', 'hadm_id', 'charttime'])\n",
    "\n",
    "bg.reset_index(drop=True, inplace=True)\n",
    "bg.drop(['charttime', 'intime'], axis=1, inplace=True)\n",
    "\n",
    "bg = bg.groupby(['subject_id', 'hadm_id', 'item', 'timestep'], as_index=False).mean()\n",
    "bg = remove_outliner(bg, bg_stat, ['glucose_ART.', 'hematocrit_ART.', 'lactate_ART.', 'lactate_VEN.',\n",
    "                                   'fio2_ART.', 'fio2_chartevents_ART.'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "bg = rename_feature(bg, 'bg')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "bg"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Table 2: urine_output\n",
    "stay_id"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sql = \"\"\"\n",
    "select * from mimiciv_derived.urine_output\n",
    "where mimiciv_derived.urine_output.stay_id in (select stay_id from mimiciv_icu.icustays)\n",
    "\"\"\"\n",
    "urine_output = pd.read_sql(sql, engine)\n",
    "urine_output = urine_output[urine_output['stay_id'].isin(patient_base['stay_id'])]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "urine_output = urine_output[~urine_output['urineoutput'].isna()]\n",
    "urine_output.drop_duplicates()\n",
    "urine_output['item']='urine_output'\n",
    "urine_output = urine_output.rename(columns={'urineoutput': 'value'})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "count_duplicated(urine_output, subset=['stay_id', 'charttime', 'item'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "urine_output = add_timestep(table=urine_output, base=patient_base, base_timename='intime',\n",
    "                      timename='charttime', timestep_name='timestep',\n",
    "                      merge_id='stay_id', sort_list=['stay_id', 'charttime'])\n",
    "\n",
    "urine_output.reset_index(drop=True, inplace=True)\n",
    "urine_output.drop(['charttime', 'intime'], axis=1, inplace=True)\n",
    "\n",
    "urine_output, urine_output_stat = generate_stat(urine_output, group_name='item', save_path='../DataProcessing/benchmark_stat/urine_output.csv', base=patient_base, merge_id='stay_id')\n",
    "\n",
    "urine_output = urine_output.groupby(['subject_id', 'stay_id', 'item', 'timestep'], as_index=False).mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "urine_output = remove_outliner(urine_output, urine_output_stat)\n",
    "urine_output = rename_feature(urine_output, 'urine_output')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "urine_output"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Table 3: kdigo_uo\n",
    "stay_id"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#cumulative_urine_output\n",
    "sql = \"\"\"\n",
    "select * from mimiciv_derived.kdigo_uo\n",
    "where mimiciv_derived.kdigo_uo.stay_id in (select stay_id from mimiciv_icu.icustays)\n",
    "\"\"\"\n",
    "kdigo_uo = pd.read_sql(sql, engine)\n",
    "kdigo_uo = kdigo_uo[kdigo_uo['stay_id'].isin(patient_base['stay_id'])]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "kdigo_uo.columns"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "count_duplicated(kdigo_uo, subset=['stay_id', 'charttime'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "kdigo_uo = melt_pivot(kdigo_uo, id_name='stay_id', time_name='charttime')\n",
    "kdigo_uo, kdigo_uo_stat = generate_stat(kdigo_uo, group_name='item', save_path='../DataProcessing/benchmark_stat/kdigo_uo.csv', base=patient_base, merge_id='stay_id')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "kdigo_uo = add_timestep(table=kdigo_uo, base=patient_base, base_timename='intime',\n",
    "                      timename='charttime', timestep_name='timestep',\n",
    "                      merge_id='stay_id', sort_list=['subject_id', 'stay_id', 'charttime'])\n",
    "\n",
    "kdigo_uo.reset_index(drop=True, inplace=True)\n",
    "kdigo_uo.drop(['charttime', 'intime'], axis=1, inplace=True)\n",
    "\n",
    "kdigo_uo = kdigo_uo.groupby(['subject_id', 'stay_id', 'timestep', 'item'], as_index=False).mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# filter out feature\n",
    "kdigo_uo = kdigo_uo[~kdigo_uo['item'].isin(['urineoutput_12hr', 'urineoutput_24hr', 'urineoutput_6hr', 'weight'])]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "kdigo_uo = remove_outliner(kdigo_uo, kdigo_uo_stat, ['uo_tm_12hr', 'uo_tm_24hr', 'uo_tm_6hr'])\n",
    "kdigo_uo = rename_feature(kdigo_uo, 'kdigo_uo')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "kdigo_uo"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Table 4: blood_differential\n",
    "hadm_id"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sql = \"\"\"\n",
    "select * from mimiciv_derived.blood_differential\n",
    "where mimiciv_derived.blood_differential.hadm_id in (select hadm_id from mimiciv_icu.icustays)\n",
    "\"\"\"\n",
    "blood_differential = pd.read_sql(sql, engine)\n",
    "blood_differential = blood_differential[blood_differential['hadm_id'].isin(patient_base['hadm_id'])]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "count_duplicated(blood_differential, subset=['subject_id', 'hadm_id', 'charttime'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "blood_differential = melt_pivot(blood_differential.drop(columns=['subject_id', 'specimen_id']), id_name='hadm_id', time_name='charttime')\n",
    "blood_differential, blood_differential_stat = generate_stat(blood_differential, group_name='item', save_path='../DataProcessing/benchmark_stat/blood_differential.csv', base=patient_base, merge_id='hadm_id')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "blood_differential = add_timestep(table=blood_differential, base=patient_base, base_timename='intime',\n",
    "                      timename='charttime', timestep_name='timestep',\n",
    "                      merge_id='hadm_id', sort_list=['subject_id', 'hadm_id', 'charttime'])\n",
    "\n",
    "blood_differential.reset_index(drop=True, inplace=True)\n",
    "blood_differential.drop(['charttime', 'intime'], axis=1, inplace=True)\n",
    "\n",
    "blood_differential = blood_differential.groupby(['subject_id', 'hadm_id', 'timestep', 'item'], as_index=False).mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "blood_differential = remove_outliner(blood_differential, blood_differential_stat)\n",
    "blood_differential = rename_feature(blood_differential, 'blood_differential')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "blood_differential"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Table 5: cardiac_marker\n",
    "hadm_id<br>\n",
    "remark: table in the database is different from the table direct extract using script from source code"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sql = \"\"\"\n",
    "select * from mimiciv_derived.cardiac_marker\n",
    "where mimiciv_derived.cardiac_marker.hadm_id in (select hadm_id from mimiciv_icu.icustays)\n",
    "\"\"\"\n",
    "cardiac_marker2 = pd.read_sql(sql, engine)\n",
    "cardiac_marker2 = cardiac_marker2[cardiac_marker2['hadm_id'].isin(patient_base['hadm_id'])]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "cardiac_marker2['troponin_t'].unique()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sql = \"\"\"\n",
    "SELECT\n",
    "  MAX(subject_id) AS subject_id,\n",
    "  MAX(hadm_id) AS hadm_id,\n",
    "  MAX(charttime) AS charttime,\n",
    "  le.specimen_id, /* convert from itemid into a meaningful column */\n",
    "  MAX(CASE WHEN itemid = 51003 THEN valuenum ELSE NULL END) AS troponin_t,\n",
    "  MAX(CASE WHEN itemid = 50911 THEN valuenum ELSE NULL END) AS ck_mb,\n",
    "  MAX(CASE WHEN itemid = 50963 THEN valuenum ELSE NULL END) AS ntprobnp\n",
    "FROM mimiciv_hosp.labevents AS le\n",
    "WHERE\n",
    "  le.itemid IN (51003 /* 51002, -- Troponin I (troponin-I is not measured in MIMIC-IV) */ /* 52598, -- Troponin I, point of care, rare/poor quality */ /* Troponin T */, 50911 /* Creatinine Kinase, MB isoenzyme */, 50963 /* N-terminal (NT)-pro hormone BNP (NT-proBNP) */)\n",
    "  AND NOT valuenum IS NULL\n",
    "GROUP BY\n",
    "  le.specimen_id\n",
    "\"\"\"\n",
    "cardiac_marker = pd.read_sql(sql, engine)\n",
    "cardiac_marker = cardiac_marker[cardiac_marker['hadm_id'].isin(patient_base['hadm_id'])]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "cardiac_marker"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "count_duplicated(cardiac_marker, subset=['subject_id', 'hadm_id', 'charttime'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "cardiac_marker = melt_pivot(cardiac_marker.drop(columns=['subject_id', 'specimen_id']), id_name='hadm_id', time_name='charttime')\n",
    "cardiac_marker, cardiac_marker_stat = generate_stat(cardiac_marker, group_name='item', save_path='../DataProcessing/benchmark_stat/cardiac_marker.csv', base=patient_base, merge_id='hadm_id')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "cardiac_marker = add_timestep(table=cardiac_marker, base=patient_base, base_timename='intime',\n",
    "                      timename='charttime', timestep_name='timestep',\n",
    "                      merge_id='hadm_id', sort_list=['subject_id', 'hadm_id', 'charttime'])\n",
    "\n",
    "cardiac_marker.reset_index(drop=True, inplace=True)\n",
    "cardiac_marker.drop(['charttime', 'intime'], axis=1, inplace=True)\n",
    "\n",
    "cardiac_marker = cardiac_marker.groupby(['subject_id', 'hadm_id', 'timestep', 'item'], as_index=False).mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "cardiac_marker = cardiac_marker[~cardiac_marker['item'].isin(['ck_mb'])]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "cardiac_marker = remove_outliner(cardiac_marker, cardiac_marker_stat)\n",
    "cardiac_marker = rename_feature(cardiac_marker, 'cardiac_marker')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "cardiac_marker"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Table 6: chemistry\n",
    "hadm_id"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sql = \"\"\"\n",
    "select * from mimiciv_derived.chemistry\n",
    "where mimiciv_derived.chemistry.hadm_id in (select hadm_id from mimiciv_icu.icustays)\n",
    "\"\"\"\n",
    "chemistry = pd.read_sql(sql, engine)\n",
    "chemistry = chemistry[chemistry['hadm_id'].isin(patient_base['hadm_id'])]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "count_duplicated(chemistry, subset=['subject_id', 'hadm_id', 'charttime'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "chemistry = melt_pivot(chemistry.drop(columns=['subject_id', 'specimen_id']), id_name='hadm_id', time_name='charttime')\n",
    "chemistry, chemistry_stat = generate_stat(chemistry, group_name='item', save_path='../DataProcessing/benchmark_stat/chemistry.csv', base=patient_base, merge_id='hadm_id')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "chemistry = add_timestep(table=chemistry, base=patient_base, base_timename='intime',\n",
    "                      timename='charttime', timestep_name='timestep',\n",
    "                      merge_id='hadm_id', sort_list=['subject_id', 'hadm_id', 'charttime'])\n",
    "\n",
    "chemistry.reset_index(drop=True, inplace=True)\n",
    "chemistry.drop(['charttime', 'intime'], axis=1, inplace=True)\n",
    "\n",
    "chemistry = chemistry.groupby(['subject_id', 'hadm_id', 'timestep', 'item'], as_index=False).mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "chemistry = chemistry[~chemistry['item'].isin(['globulin', 'total_protein'])]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "chemistry = rename_feature(chemistry, 'chemistry')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "chemistry"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Table 7: coagulation\n",
    "hadm_id"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sql = \"\"\"\n",
    "select * from mimiciv_derived.coagulation\n",
    "where mimiciv_derived.coagulation.hadm_id in (select hadm_id from mimiciv_icu.icustays)\n",
    "\"\"\"\n",
    "coagulation = pd.read_sql(sql, engine)\n",
    "coagulation = coagulation[coagulation['hadm_id'].isin(patient_base['hadm_id'])]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "count_duplicated(coagulation, subset=['subject_id', 'hadm_id', 'charttime'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "coagulation = melt_pivot(coagulation.drop(columns=['subject_id', 'specimen_id']), id_name='hadm_id', time_name='charttime')\n",
    "coagulation, coagulation_stat = generate_stat(coagulation, group_name='item', save_path='../DataProcessing/benchmark_stat/coagulation.csv', base=patient_base, merge_id='hadm_id')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "coagulation = add_timestep(table=coagulation, base=patient_base, base_timename='intime',\n",
    "                      timename='charttime', timestep_name='timestep',\n",
    "                      merge_id='hadm_id', sort_list=['subject_id', 'hadm_id', 'charttime'])\n",
    "\n",
    "coagulation.reset_index(drop=True, inplace=True)\n",
    "coagulation.drop(['charttime', 'intime'], axis=1, inplace=True)\n",
    "\n",
    "coagulation = coagulation.groupby(['subject_id', 'hadm_id', 'timestep', 'item'], as_index=False).mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "coagulation = coagulation[~coagulation['item'].isin(['d_dimer', 'thrombin'])]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "coagulation = remove_outliner(coagulation, coagulation_stat)\n",
    "coagulation = rename_feature(coagulation, 'coagulation')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "coagulation"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Table 8: complete_blood_count\n",
    "hadm_id"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sql = \"\"\"\n",
    "select * from mimiciv_derived.complete_blood_count\n",
    "where mimiciv_derived.complete_blood_count.hadm_id in (select hadm_id from mimiciv_icu.icustays)\n",
    "\"\"\"\n",
    "complete_blood_count = pd.read_sql(sql, engine)\n",
    "complete_blood_count = complete_blood_count[complete_blood_count['hadm_id'].isin(patient_base['hadm_id'])]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "count_duplicated(complete_blood_count, subset=['subject_id', 'hadm_id', 'charttime'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "complete_blood_count = melt_pivot(complete_blood_count.drop(columns=['subject_id', 'specimen_id']), id_name='hadm_id', time_name='charttime')\n",
    "complete_blood_count, complete_blood_count_stat = generate_stat(complete_blood_count, group_name='item', save_path='../DataProcessing/benchmark_stat/complete_blood_count.csv', base=patient_base, merge_id='hadm_id')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "complete_blood_count = add_timestep(table=complete_blood_count, base=patient_base, base_timename='intime',\n",
    "                      timename='charttime', timestep_name='timestep',\n",
    "                      merge_id='hadm_id', sort_list=['subject_id', 'hadm_id', 'charttime'])\n",
    "\n",
    "complete_blood_count = complete_blood_count.reset_index(drop=True)\n",
    "complete_blood_count = complete_blood_count.drop(['charttime', 'intime'], axis=1)\n",
    "\n",
    "complete_blood_count = complete_blood_count.groupby(['subject_id', 'hadm_id', 'timestep', 'item'], as_index=False).mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "complete_blood_count = complete_blood_count[~complete_blood_count['item'].isin(['wbc'])]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "complete_blood_count = remove_outliner(complete_blood_count, complete_blood_count_stat)\n",
    "complete_blood_count = rename_feature(complete_blood_count, 'complete_blood_count')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "complete_blood_count"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Table 9: enzyme\n",
    "hadm_id"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sql = \"\"\"\n",
    "select * from mimiciv_derived.enzyme\n",
    "where mimiciv_derived.enzyme.hadm_id in (select hadm_id from mimiciv_icu.icustays)\n",
    "\"\"\"\n",
    "enzyme = pd.read_sql(sql, engine)\n",
    "enzyme = enzyme[enzyme['hadm_id'].isin(patient_base['hadm_id'])]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "count_duplicated(enzyme, subset=['subject_id', 'hadm_id', 'charttime'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "enzyme = melt_pivot(enzyme.drop(columns=['subject_id', 'specimen_id']), id_name='hadm_id', time_name='charttime')\n",
    "enzyme, enzyme_stat = generate_stat(enzyme, group_name='item', save_path='../DataProcessing/benchmark_stat/enzyme.csv', base=patient_base, merge_id='hadm_id')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "enzyme = add_timestep(table=enzyme, base=patient_base, base_timename='intime',\n",
    "                      timename='charttime', timestep_name='timestep',\n",
    "                      merge_id='hadm_id', sort_list=['subject_id', 'hadm_id', 'charttime'])\n",
    "\n",
    "enzyme.reset_index(drop=True, inplace=True)\n",
    "enzyme.drop(['charttime', 'intime'], axis=1, inplace=True)\n",
    "\n",
    "enzyme = enzyme.groupby(['subject_id', 'hadm_id', 'timestep', 'item'], as_index=False).mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "enzyme = enzyme[~enzyme['item'].isin(['ggt'])]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "enzyme = remove_outliner(enzyme, enzyme_stat)\n",
    "enzyme = rename_feature(enzyme, 'enzyme')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "enzyme"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Table 10: gcs\n",
    "stay_id"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sql = \"\"\"\n",
    "select * from mimiciv_derived.gcs\n",
    "where mimiciv_derived.gcs.stay_id in (select stay_id from mimiciv_icu.icustays)\n",
    "\"\"\"\n",
    "gcs = pd.read_sql(sql, engine)\n",
    "gcs = gcs[gcs['stay_id'].isin(patient_base['stay_id'])]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "count_duplicated(gcs, subset=['subject_id', 'stay_id', 'charttime'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "gcs = melt_pivot(gcs.drop(columns=['subject_id']), id_name='stay_id', time_name='charttime')\n",
    "gcs, gcs_stat = generate_stat(gcs, group_name='item', save_path='../DataProcessing/benchmark_stat/gcs.csv', base=patient_base, merge_id='stay_id')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "gcs = add_timestep(table=gcs, base=patient_base, base_timename='intime',\n",
    "                      timename='charttime', timestep_name='timestep',\n",
    "                      merge_id='stay_id', sort_list=['subject_id', 'stay_id', 'charttime'])\n",
    "\n",
    "gcs.reset_index(drop=True, inplace=True)\n",
    "gcs.drop(['charttime', 'intime'], axis=1, inplace=True)\n",
    "\n",
    "gcs = gcs.groupby(['subject_id', 'stay_id', 'timestep', 'item'], as_index=False).mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "gcs = gcs[~gcs['item'].isin(['gcs_unable'])]\n",
    "gcs = rename_feature(gcs, 'gcs')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "gcs"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Table 11: icp\n",
    "stay_id<br>\n",
    "remark: very rare"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sql = \"\"\"\n",
    "select * from mimiciv_derived.icp\n",
    "where mimiciv_derived.icp.stay_id in (select stay_id from mimiciv_icu.icustays)\n",
    "\"\"\"\n",
    "icp = pd.read_sql(sql, engine)\n",
    "icp = icp[icp['stay_id'].isin(patient_base['stay_id'])]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "icp = icp[~icp['icp'].isna()]\n",
    "icp = icp.drop_duplicates()\n",
    "icp['item']='icp'\n",
    "icp = icp.rename(columns={'icp': 'value'})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "icp['subject_id'].nunique()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "icp = add_timestep(table=icp, base=patient_base, base_timename='intime',\n",
    "                      timename='charttime', timestep_name='timestep',\n",
    "                      merge_id='stay_id', sort_list=['subject_id', 'stay_id', 'charttime'])\n",
    "\n",
    "icp.reset_index(drop=True, inplace=True)\n",
    "icp.drop(['charttime', 'intime'], axis=1, inplace=True)\n",
    "\n",
    "icp = icp.groupby(['subject_id', 'stay_id', 'timestep', 'item'], as_index=False).mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "icp, icp_stat = generate_stat(icp, group_name='item', save_path='../DataProcessing/benchmark_stat/icp.csv', base=patient_base, merge_id='stay_id')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "icp = rename_feature(icp, 'icp')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "icp"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Table 12: vasoactive_agent\n",
    "not going to use"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sql = \"\"\"\n",
    "select * from mimiciv_derived.vasoactive_agent\n",
    "where mimiciv_derived.vasoactive_agent.stay_id in (select stay_id from mimiciv_icu.icustays)\n",
    "\"\"\"\n",
    "vasoactive_agent = pd.read_sql(sql, engine)\n",
    "vasoactive_agent = vasoactive_agent[vasoactive_agent['stay_id'].isin(patient_base['stay_id'])]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "count_duplicated(vasoactive_agent, subset=['stay_id', 'starttime', 'endtime'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "vasoactive_agent = melt_pivot(vasoactive_agent, id_name='stay_id', time_name=['starttime', 'endtime'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "vasoactive_agent = add_timestep(table=vasoactive_agent, base=patient_base, base_timename='intime',\n",
    "                      timename='starttime', timestep_name='start_time_step',\n",
    "                      merge_id='stay_id', sort_list=['stay_id', 'starttime'])\n",
    "\n",
    "vasoactive_agent = add_timestep(table=vasoactive_agent, base=patient_base, base_timename='intime',\n",
    "                      timename='endtime', timestep_name='end_time_step',\n",
    "                      merge_id='stay_id', sort_list=['stay_id', 'starttime'])\n",
    "\n",
    "vasoactive_agent = repeat_rows(vasoactive_agent)\n",
    "\n",
    "vasoactive_agent.reset_index(drop=True, inplace=True)\n",
    "vasoactive_agent.drop(['start_time_step', 'end_time_step', 'intime'], axis=1, inplace=True)\n",
    "\n",
    "vasoactive_agent, vasoactive_agent_stat = generate_stat(vasoactive_agent, group_name='item', save_path='../DataProcessing/benchmark_stat/vasoactive_agent.csv', base=patient_base, merge_id='stay_id')\n",
    "\n",
    "vasoactive_agent = vasoactive_agent.groupby(['subject_id', 'stay_id', 'timestep', 'item'], as_index=False).mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "vasoactive_agent"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Table 13: inflammation\n",
    "hadm_id<br>\n",
    "remark: very rare"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sql = \"\"\"\n",
    "select * from mimiciv_derived.inflammation\n",
    "where mimiciv_derived.inflammation.hadm_id in (select hadm_id from mimiciv_icu.icustays)\n",
    "\"\"\"\n",
    "inflammation = pd.read_sql(sql, engine)\n",
    "inflammation = inflammation[inflammation['hadm_id'].isin(patient_base['hadm_id'])]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "count_duplicated(inflammation, subset=['hadm_id','charttime'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "inflammation = inflammation[~inflammation['crp'].isna()]\n",
    "inflammation.drop(['specimen_id'], axis=1, inplace=True)\n",
    "inflammation['item']='crp'\n",
    "inflammation = inflammation.rename(columns={'crp': 'value'})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "inflammation = add_timestep(table=inflammation, base=patient_base, base_timename='intime',\n",
    "                      timename='charttime', timestep_name='timestep',\n",
    "                      merge_id='hadm_id', sort_list=['subject_id', 'hadm_id', 'charttime'])\n",
    "\n",
    "inflammation.reset_index(drop=True, inplace=True)\n",
    "inflammation.drop(['charttime', 'intime'], axis=1, inplace=True)\n",
    "\n",
    "inflammation = inflammation.groupby(['subject_id', 'hadm_id', 'timestep', 'item'], as_index=False).mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "inflammation, inflammation_stat = generate_stat(inflammation, group_name='item', save_path='../DataProcessing/benchmark_stat/inflammation.csv', base=patient_base, merge_id='hadm_id')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "inflammation = remove_outliner(inflammation, inflammation_stat)\n",
    "inflammation = rename_feature(inflammation, 'inflammation')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "inflammation"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Table 14: oxygen_delivery\n",
    "stay_id<br>\n",
    "only use o2 flow"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sql = \"\"\"\n",
    "select * from mimiciv_derived.oxygen_delivery\n",
    "where mimiciv_derived.oxygen_delivery.stay_id in (select stay_id from mimiciv_icu.icustays)\n",
    "\"\"\"\n",
    "oxygen_delivery = pd.read_sql(sql, engine)\n",
    "oxygen_delivery = oxygen_delivery[oxygen_delivery['stay_id'].isin(patient_base['stay_id'])]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "oxygen_delivery = melt_pivot(oxygen_delivery, id_name=['subject_id', 'stay_id'], time_name='charttime')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "oxygen_delivery = add_timestep(table=oxygen_delivery, base=patient_base, base_timename='intime',\n",
    "                      timename='charttime', timestep_name='timestep',\n",
    "                      merge_id='stay_id', sort_list=['subject_id', 'stay_id', 'charttime'])\n",
    "\n",
    "oxygen_delivery.reset_index(drop=True, inplace=True)\n",
    "oxygen_delivery.drop(['charttime', 'intime'], axis=1, inplace=True)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "oxygen_delivery_o2_flow = oxygen_delivery[oxygen_delivery['item']=='o2_flow'].groupby(['subject_id', 'stay_id', 'timestep', 'item'], as_index=False).mean()\n",
    "oxygen_delivery_others = oxygen_delivery[oxygen_delivery['item']!='o2_flow'].drop_duplicates(subset=['subject_id', 'stay_id', 'timestep', 'item'], keep='first')\n",
    "\n",
    "oxygen_delivery = pd.concat([oxygen_delivery_o2_flow, oxygen_delivery_others], ignore_index=True)\n",
    "# oxygen_delivery = generate_stat(oxygen_delivery, group_name='item', save_path='../DataProcessing/benchmark_stat/oxygen_delivery.csv', base=patient_base, merge_id='stay_id')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "oxygen_delivery_o2_flow, oxygen_delivery_o2_flow_stat = generate_stat(oxygen_delivery_o2_flow, group_name='item', save_path='../DataProcessing/benchmark_stat/oxygen_delivery_o2_flow.csv', base=patient_base, merge_id='stay_id')\n",
    "oxygen_delivery_others = generate_stat_cate(oxygen_delivery_others, group_name='item', save_path='../DataProcessing/benchmark_stat/oxygen_delivery_others.csv', base=patient_base, merge_id='stay_id')\n",
    "generate_dict(oxygen_delivery_others, group_name=['item', 'value'], key_name='item', save_path='../DataProcessing/benchmark_stat/oxygen_delivery_others.json')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "oxygen_delivery = oxygen_delivery.sort_values(by=['subject_id', 'stay_id', 'timestep', 'item'])\n",
    "oxygen_delivery.reset_index(drop=True, inplace=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "oxygen_delivery_o2_flow = remove_outliner(oxygen_delivery_o2_flow, oxygen_delivery_o2_flow_stat)\n",
    "oxygen_delivery_o2_flow = rename_feature(oxygen_delivery_o2_flow, 'oxygen_delivery')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "oxygen_delivery_o2_flow"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Table 15: rhythm\n",
    "stay_id"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sql = \"\"\"\n",
    "SELECT\n",
    "  ce.subject_id,\n",
    "  ce.stay_id,\n",
    "  ce.charttime,\n",
    "  MAX(CASE WHEN itemid = 220048 THEN value ELSE NULL END) AS heart_rhythm,\n",
    "  MAX(CASE WHEN itemid = 224650 THEN value ELSE NULL END) AS ectopy_type,\n",
    "  MAX(CASE WHEN itemid = 224651 THEN value ELSE NULL END) AS ectopy_frequency,\n",
    "  MAX(CASE WHEN itemid = 226479 THEN value ELSE NULL END) AS ectopy_type_secondary,\n",
    "  MAX(CASE WHEN itemid = 226480 THEN value ELSE NULL END) AS ectopy_frequency_secondary\n",
    "FROM mimiciv_icu.chartevents AS ce\n",
    "WHERE\n",
    "  NOT ce.stay_id IS NULL\n",
    "  AND ce.itemid IN (220048 /* Heart Rhythm */, 224650 /* Ectopy Type 1 */, 224651 /* Ectopy Frequency 1 */, 226479 /* Ectopy Type 2 */, 226480 /* Ectopy Frequency 2 */)\n",
    "GROUP BY\n",
    "  ce.subject_id,\n",
    "  ce.stay_id,\n",
    "  ce.charttime\n",
    "\"\"\"\n",
    "rhythm = pd.read_sql(sql, engine)\n",
    "rhythm = rhythm[rhythm['stay_id'].isin(patient_base['stay_id'])]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "rhythm = melt_pivot(rhythm, id_name=['subject_id', 'stay_id'], time_name='charttime')\n",
    "rhythm = add_timestep(table=rhythm, base=patient_base, base_timename='intime',\n",
    "                      timename='charttime', timestep_name='timestep',\n",
    "                      merge_id='stay_id', sort_list=['subject_id', 'stay_id', 'charttime'])\n",
    "\n",
    "rhythm.reset_index(drop=True, inplace=True)\n",
    "rhythm.drop(['charttime', 'intime'], axis=1, inplace=True)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "rhythm = generate_stat_cate(rhythm, group_name='item', save_path='../DataProcessing/benchmark_stat/rhythm.csv', base=patient_base, merge_id='stay_id')\n",
    "generate_dict(rhythm, group_name=['item', 'value'], key_name='item', save_path='../DataProcessing/benchmark_stat/rhythm.json')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "rhythm = turn_binary(rhythm)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "rhythm = rhythm.sort_values(by=['subject_id', 'stay_id', 'timestep']).drop_duplicates()\n",
    "rhythm = rename_feature(rhythm, 'rhythm')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "rhythm"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Table 16: vitalsign\n",
    "stay_id"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sql = \"\"\"\n",
    "select subject_id, stay_id, charttime, heart_rate, sbp, dbp, mbp, resp_rate, temperature, spo2, glucose\n",
    "from mimiciv_derived.vitalsign\n",
    "where mimiciv_derived.vitalsign.stay_id in (select stay_id from mimiciv_icu.icustays)\n",
    "\"\"\"\n",
    "vitalsign = pd.read_sql(sql, engine)\n",
    "vitalsign = vitalsign[vitalsign['stay_id'].isin(patient_base['stay_id'])]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "count_duplicated(vitalsign, subset=['stay_id','charttime'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "vitalsign = melt_pivot(vitalsign, id_name=['subject_id', 'stay_id'], time_name='charttime')\n",
    "vitalsign, vitalsign_stat = generate_stat(vitalsign, group_name='item', save_path='../DataProcessing/benchmark_stat/vitalsign.csv', base=patient_base, merge_id='stay_id')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "vitalsign = add_timestep(table=vitalsign, base=patient_base, base_timename='intime',\n",
    "                      timename='charttime', timestep_name='timestep',\n",
    "                      merge_id='stay_id', sort_list=['subject_id', 'stay_id', 'charttime'])\n",
    "\n",
    "vitalsign.reset_index(drop=True, inplace=True)\n",
    "vitalsign.drop(['charttime', 'intime'], axis=1, inplace=True)\n",
    "\n",
    "vitalsign = vitalsign.groupby(['subject_id', 'stay_id', 'timestep', 'item'], as_index=False).mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "vitalsign = rename_feature(vitalsign, 'vitalsign')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "vitalsign"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Table 17: weight_durations\n",
    "stay_id"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sql = \"\"\"\n",
    "select * from mimiciv_derived.weight_durations\n",
    "where mimiciv_derived.weight_durations.stay_id in (select stay_id from mimiciv_icu.icustays)\n",
    "\"\"\"\n",
    "weight_durations = pd.read_sql(sql, engine)\n",
    "weight_durations = weight_durations[weight_durations['stay_id'].isin(patient_base['stay_id'])]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# drop row without specimen and turn feature into specimen specific feature\n",
    "weight_durations = weight_durations.pivot_table(\n",
    "    index=['stay_id','starttime', 'endtime'],\n",
    "    columns='weight_type',\n",
    "    values=['weight'],\n",
    "    aggfunc='first'\n",
    ")\n",
    "weight_durations.columns = ['_'.join(col).strip() for col in weight_durations.columns.values]\n",
    "weight_durations = weight_durations.reset_index()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "weight_durations = melt_pivot(weight_durations, id_name='stay_id', time_name=['starttime', 'endtime'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "weight = weight_durations[weight_durations['item']=='weight_daily']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "weight = weight[~(weight['starttime']>weight['endtime'])]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "weight = add_timestep(table=weight, base=patient_base, base_timename='intime',\n",
    "                      timename='starttime', timestep_name='start_time_step',\n",
    "                      merge_id='stay_id', sort_list=['stay_id', 'starttime'])\n",
    "\n",
    "weight = add_timestep(table=weight, base=patient_base, base_timename='intime',\n",
    "                      timename='endtime', timestep_name='end_time_step',\n",
    "                      merge_id='stay_id', sort_list=['stay_id', 'starttime'])\n",
    "\n",
    "weight = repeat_rows(weight)\n",
    "\n",
    "weight.reset_index(drop=True, inplace=True)\n",
    "weight.drop(['start_time_step', 'end_time_step', 'intime'], axis=1, inplace=True)\n",
    "\n",
    "weight, weight_stat = generate_stat(weight, group_name='item', save_path='../DataProcessing/benchmark_stat/weight.csv', base=patient_base, merge_id='stay_id')\n",
    "\n",
    "weight = weight.groupby(['subject_id', 'stay_id', 'timestep', 'item'], as_index=False).mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "weight = remove_outliner(weight, weight_stat)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "weight"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Table 18: kdigo_creatinine\n",
    "stay_id"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sql = \"\"\"\n",
    "select * from mimiciv_derived.kdigo_creatinine\n",
    "where mimiciv_derived.kdigo_creatinine.stay_id in (select stay_id from mimiciv_icu.icustays)\n",
    "\"\"\"\n",
    "kdigo_creatinine = pd.read_sql(sql, engine)\n",
    "kdigo_creatinine = kdigo_creatinine[kdigo_creatinine['stay_id'].isin(patient_base['stay_id'])]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "kdigo_creatinine = melt_pivot(kdigo_creatinine, id_name=['hadm_id', 'stay_id'], time_name='charttime')\n",
    "kdigo_creatinine, kdigo_creatinine_stat = generate_stat(kdigo_creatinine, group_name='item', save_path='../DataProcessing/benchmark_stat/kdigo_creatinine.csv', base=patient_base, merge_id='stay_id')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "kdigo_creatinine = add_timestep(table=kdigo_creatinine, base=patient_base, base_timename='intime',\n",
    "                      timename='charttime', timestep_name='timestep',\n",
    "                      merge_id='stay_id', sort_list=['subject_id', 'stay_id', 'hadm_id', 'charttime'])\n",
    "\n",
    "kdigo_creatinine.reset_index(drop=True, inplace=True)\n",
    "kdigo_creatinine.drop(['charttime', 'intime', 'hadm_id'], axis=1, inplace=True)\n",
    "\n",
    "kdigo_creatinine = kdigo_creatinine.groupby(['subject_id', 'stay_id', 'timestep', 'item'], as_index=False).mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "kdigo_creatinine = kdigo_creatinine[~kdigo_creatinine['item'].isin(['creat'])]\n",
    "kdigo_creatinine = rename_feature(kdigo_creatinine, 'kdigo_creatinine')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "kdigo_creatinine"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Table 19: norepinephrine_equivalent_dose\n",
    "stay_id"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sql = \"\"\"\n",
    "select * from mimiciv_derived.norepinephrine_equivalent_dose\n",
    "where mimiciv_derived.norepinephrine_equivalent_dose.stay_id in (select stay_id from mimiciv_icu.icustays)\n",
    "\"\"\"\n",
    "norepinephrine_equivalent_dose = pd.read_sql(sql, engine)\n",
    "norepinephrine_equivalent_dose = norepinephrine_equivalent_dose[norepinephrine_equivalent_dose['stay_id'].isin(patient_base['stay_id'])]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "norepinephrine_equivalent_dose = melt_pivot(norepinephrine_equivalent_dose, id_name='stay_id', time_name=['starttime', 'endtime'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "norepinephrine_equivalent_dose = add_timestep(table=norepinephrine_equivalent_dose, base=patient_base, base_timename='intime',\n",
    "                      timename='starttime', timestep_name='start_time_step',\n",
    "                      merge_id='stay_id', sort_list=['stay_id', 'starttime'])\n",
    "\n",
    "norepinephrine_equivalent_dose = add_timestep(table=norepinephrine_equivalent_dose, base=patient_base, base_timename='intime',\n",
    "                      timename='endtime', timestep_name='end_time_step',\n",
    "                      merge_id='stay_id', sort_list=['stay_id', 'starttime'])\n",
    "\n",
    "norepinephrine_equivalent_dose = repeat_rows(norepinephrine_equivalent_dose)\n",
    "\n",
    "norepinephrine_equivalent_dose.reset_index(drop=True, inplace=True)\n",
    "norepinephrine_equivalent_dose.drop(['start_time_step', 'end_time_step', 'intime'], axis=1, inplace=True)\n",
    "\n",
    "norepinephrine_equivalent_dose, norepinephrine_equivalent_dose_stat = generate_stat(norepinephrine_equivalent_dose, group_name='item', save_path='../DataProcessing/benchmark_stat/norepinephrine_equivalent_dose.csv', base=patient_base, merge_id='stay_id')\n",
    "\n",
    "norepinephrine_equivalent_dose = norepinephrine_equivalent_dose.groupby(['subject_id', 'stay_id', 'timestep', 'item'], as_index=False).mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "norepinephrine_equivalent_dose"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Table 20: ventilator_setting\n",
    "stay_id<br>\n",
    "only use cont feature"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sql = \"\"\"\n",
    "select * from mimiciv_derived.ventilator_setting\n",
    "where mimiciv_derived.ventilator_setting.stay_id in (select stay_id from mimiciv_icu.icustays)\n",
    "\"\"\"\n",
    "ventilator_setting = pd.read_sql(sql, engine)\n",
    "ventilator_setting = ventilator_setting[ventilator_setting['stay_id'].isin(patient_base['stay_id'])]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "count_duplicated(ventilator_setting, subset=['subject_id', 'stay_id', 'charttime'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ventilator_setting = melt_pivot(ventilator_setting, id_name=['subject_id', 'stay_id'], time_name='charttime')\n",
    "ventilator_setting = add_timestep(table=ventilator_setting, base=patient_base, base_timename='intime',\n",
    "                      timename='charttime', timestep_name='timestep',\n",
    "                      merge_id='stay_id', sort_list=['subject_id', 'stay_id', 'charttime'])\n",
    "\n",
    "ventilator_setting.reset_index(drop=True, inplace=True)\n",
    "ventilator_setting.drop(['charttime', 'intime'], axis=1, inplace=True)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ventilator_setting_cont = ventilator_setting[ventilator_setting['item'].isin(\n",
    "    ['respiratory_rate_set', \n",
    "     'respiratory_rate_total',\n",
    "     'respiratory_rate_spontaneous',\n",
    "     'minute_volume',\n",
    "     'tidal_volume_set',\n",
    "     'Tidal Volume (observed)',\n",
    "     'tidal_volume_spontaneous',\n",
    "     'plateau_pressure',\n",
    "     'peep',\n",
    "     'fio2',\n",
    "     'flow_rate'])].groupby(['subject_id', 'stay_id', 'timestep', 'item'], as_index=False).mean()\n",
    "\n",
    "\n",
    "ventilator_setting_cate = ventilator_setting[ventilator_setting['item'].isin(['ventilator_mode', 'ventilator_mode_hamilton','ventilator_type'])].drop_duplicates(subset=['subject_id', 'stay_id', 'timestep', 'item'], keep='first')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ventilator_setting_cont, ventilator_setting_cont_stat = generate_stat(ventilator_setting_cont, group_name='item', save_path='../DataProcessing/benchmark_stat/ventilator_setting_cont.csv', base=patient_base, merge_id='stay_id')\n",
    "ventilator_setting_cate = generate_stat_cate(ventilator_setting_cate, group_name='item', save_path='../DataProcessing/benchmark_stat/ventilator_setting_cate.csv', base=patient_base, merge_id='stay_id')\n",
    "generate_dict(ventilator_setting_cate, group_name=['item', 'value'], key_name='item', save_path='../DataProcessing/benchmark_stat/ventilator_setting_cate.json')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ventilator_setting_cont = remove_outliner(ventilator_setting_cont, ventilator_setting_cont_stat, ['fio2', 'peep'])\n",
    "ventilator_setting_cont = rename_feature(ventilator_setting_cont, 'ventilator_setting')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ventilator_setting_cont"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Table 21: crrt\n",
    "stay_id"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sql = \"\"\"\n",
    "select * from mimiciv_derived.crrt\n",
    "where mimiciv_derived.crrt.stay_id in (select stay_id from mimiciv_icu.icustays)\n",
    "\"\"\"\n",
    "crrt = pd.read_sql(sql, engine)\n",
    "crrt = crrt[crrt['stay_id'].isin(patient_base['stay_id'])]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# dialysate_fluid, heparin_concentration, replacement_fluid, system_active, clots, clots_increasing, clotted"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "crrt = melt_pivot(crrt, id_name=['stay_id'], time_name='charttime')\n",
    "crrt = add_timestep(table=crrt, base=patient_base, base_timename='intime',\n",
    "                      timename='charttime', timestep_name='timestep',\n",
    "                      merge_id='stay_id', sort_list=['stay_id', 'charttime'])\n",
    "\n",
    "crrt.reset_index(drop=True, inplace=True)\n",
    "crrt.drop(['charttime', 'intime'], axis=1, inplace=True)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "crrt_cont = crrt[crrt['item'].isin(\n",
    "    ['access_pressure',\n",
    "     'blood_flow',\n",
    "     'citrate',\n",
    "     'current_goal',\n",
    "     'filter_pressure',\n",
    "     'dialysate_rate',\n",
    "     'effluent_pressure',\n",
    "     'heparin_dose',\n",
    "     'hourly_patient_fluid_removal',\n",
    "     'prefilter_replacement_rate',\n",
    "     'postfilter_replacement_rate',\n",
    "     'replacement_rate',\n",
    "     'return_pressure',\n",
    "     'ultrafiltrate_output',\n",
    "     'clots',\n",
    "     'clots_increasing',\n",
    "     'clotted',\n",
    "     'system_active'])].groupby(['stay_id', 'timestep', 'item'], as_index=False).mean()\n",
    "\n",
    "\n",
    "crrt_cate = crrt[crrt['item'].isin(\n",
    "    ['crrt_mode',\n",
    "     'dialysate_fluid', \n",
    "     'heparin_concentration',\n",
    "     'replacement_fluid'\n",
    "     ])].drop_duplicates(subset=['stay_id', 'timestep', 'item'], keep='first')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "crrt_cont, crrt_cont_stat = generate_stat(crrt_cont, group_name='item', save_path='../DataProcessing/benchmark_stat/crrt_cont.csv', base=patient_base, merge_id='stay_id')\n",
    "crrt_cate = generate_stat_cate(crrt_cate, group_name='item', save_path='../DataProcessing/benchmark_stat/crrt_cate.csv', base=patient_base, merge_id='stay_id')\n",
    "generate_dict(crrt_cate, group_name=['item', 'value'], key_name='item', save_path='../DataProcessing/benchmark_stat/crrt_cate.json')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "crrt_cate = turn_binary(crrt_cate)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "crrt_cont = remove_outliner(crrt_cont, crrt_cont_stat, ['clots', 'clots_increasing', 'system_active', 'clotted'])\n",
    "crrt_cont = rename_feature(crrt_cont, 'crrt')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "crrt_cont"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "crrt_cate = crrt_cate.sort_values(by=['stay_id', 'subject_id', 'timestep']).drop_duplicates()\n",
    "crrt_cate = rename_feature(crrt_cate, 'crrt')\n",
    "crrt_cate"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Table 22: invasive_line\n",
    "not going to use"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sql = \"\"\"\n",
    "select * from mimiciv_derived.invasive_line\n",
    "where mimiciv_derived.invasive_line.stay_id in (select stay_id from mimiciv_icu.icustays)\n",
    "\"\"\"\n",
    "invasive_line = pd.read_sql(sql, engine)\n",
    "invasive_line = invasive_line[invasive_line['stay_id'].isin(patient_base['stay_id'])]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "count_duplicated(invasive_line, subset=['stay_id', 'line_site','line_type', 'starttime'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "invasive_line = melt_pivot(invasive_line, id_name='stay_id', time_name=['starttime', 'endtime'])\n",
    "invasive_line = add_timestep(table=invasive_line, base=patient_base, base_timename='intime',\n",
    "                      timename='starttime', timestep_name='start_time_step',\n",
    "                      merge_id='stay_id', sort_list=['stay_id', 'starttime'])\n",
    "\n",
    "invasive_line = add_timestep(table=invasive_line, base=patient_base, base_timename='intime',\n",
    "                      timename='endtime', timestep_name='end_time_step',\n",
    "                      merge_id='stay_id', sort_list=['stay_id', 'starttime'])\n",
    "\n",
    "invasive_line = repeat_rows(invasive_line)\n",
    "print('finished generate row')\n",
    "invasive_line.reset_index(drop=True, inplace=True)\n",
    "print('finish reset index')\n",
    "invasive_line = invasive_line.drop(['start_time_step', 'end_time_step', 'intime'], axis=1)\n",
    "print('finish dropping column')\n",
    "invasive_line = generate_stat_cate(invasive_line, group_name='item', save_path='../DataProcessing/benchmark_stat/invasive_line.csv', base=patient_base, merge_id='stay_id')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "generate_dict(invasive_line, group_name=['item', 'value'], key_name='item', save_path='../DataProcessing/benchmark_stat/invasive_line.json')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "invasive_line"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Table 23: rrt\n",
    "stay_id"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sql = \"\"\"\n",
    "select * from mimiciv_derived.rrt\n",
    "where mimiciv_derived.rrt.stay_id in (select stay_id from mimiciv_icu.icustays)\n",
    "\"\"\"\n",
    "rrt = pd.read_sql(sql, engine)\n",
    "rrt = rrt[rrt['stay_id'].isin(patient_base['stay_id'])]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "rrt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def rrt_agg(series):\n",
    "    if 1 in series.values:\n",
    "        return 1\n",
    "    elif 0 in series.values:\n",
    "        return 0\n",
    "    return np.nan"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# rrt = rrt.pivot_table(\n",
    "#     index=['stay_id','charttime'],\n",
    "#     columns='dialysis_type',\n",
    "#     values=['dialysis_present', 'dialysis_active'],\n",
    "#     aggfunc=rrt_agg\n",
    "# )\n",
    "# rrt.columns = ['_'.join(col).strip() for col in rrt.columns.values]\n",
    "# rrt = rrt.reset_index()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "rrt = melt_pivot(rrt, id_name=['stay_id'], time_name='charttime')\n",
    "rrt = add_timestep(table=rrt, base=patient_base, base_timename='intime',\n",
    "                      timename='charttime', timestep_name='timestep',\n",
    "                      merge_id='stay_id', sort_list=['stay_id', 'charttime'])\n",
    "\n",
    "rrt.reset_index(drop=True, inplace=True)\n",
    "rrt.drop(['charttime', 'intime'], axis=1, inplace=True)\n",
    "rrt = rrt.sort_values(by=['stay_id', 'timestep', 'item', 'value'], ascending=False)\n",
    "rrt = rrt.drop_duplicates(subset=['stay_id', 'timestep', 'item', 'value'], keep='first')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "rrt = generate_stat_cate(rrt, group_name='item', save_path='../DataProcessing/benchmark_stat/rrt.csv', base=patient_base, merge_id='stay_id')\n",
    "generate_dict(rrt, group_name=['item', 'value'], key_name='item', save_path='../DataProcessing/benchmark_stat/rrt.json')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "rrt = rrt[rrt['item'].isin(['dialysis_active'])]\n",
    "def modify_values(row):\n",
    "    if row['item'] == 'dialysis_active' and row['value']!= None:\n",
    "        return 'active' if row['value'] == 1 else 'inactive'\n",
    "    return row['value']\n",
    "\n",
    "rrt['value'] = rrt.apply(modify_values, axis=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "rrt['item']=rrt['item'] + '_' + rrt['value']\n",
    "rrt['value']=1\n",
    "rrt = rename_feature(rrt, 'rrt')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "rrt = rrt.sort_values(by=['subject_id', 'stay_id', 'timestep'])\n",
    "rrt.reset_index(drop=True, inplace=True)\n",
    "rrt"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Table 24: ventilation\n",
    "stay_id"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sql = \"\"\"\n",
    "select * from mimiciv_derived.ventilation\n",
    "where mimiciv_derived.ventilation.stay_id in (select stay_id from mimiciv_icu.icustays)\n",
    "\"\"\"\n",
    "ventilation = pd.read_sql(sql, engine)\n",
    "ventilation = ventilation[ventilation['stay_id'].isin(patient_base['stay_id'])]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ventilation['item']='ventilation_status'\n",
    "ventilation.rename(columns={'ventilation_status':'value'}, inplace=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ventilation = add_timestep(table=ventilation, base=patient_base, base_timename='intime',\n",
    "                      timename='starttime', timestep_name='start_time_step',\n",
    "                      merge_id='stay_id', sort_list=['stay_id', 'starttime'])\n",
    "\n",
    "ventilation = add_timestep(table=ventilation, base=patient_base, base_timename='intime',\n",
    "                      timename='endtime', timestep_name='end_time_step',\n",
    "                      merge_id='stay_id', sort_list=['stay_id', 'starttime'])\n",
    "\n",
    "ventilation = repeat_rows(ventilation)\n",
    "\n",
    "ventilation.reset_index(drop=True, inplace=True)\n",
    "ventilation.drop(['start_time_step', 'end_time_step', 'intime'], axis=1, inplace=True)\n",
    "\n",
    "ventilation = generate_stat_cate(ventilation, group_name='item', save_path='../DataProcessing/benchmark_stat/ventilation.csv', base=patient_base, merge_id='stay_id')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ventilation = turn_binary(ventilation)\n",
    "ventilation = rename_feature(ventilation, 'ventilation')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ventilation = ventilation.sort_values(by=['stay_id', 'subject_id', 'timestep'])\n",
    "ventilation"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Table 25: emar\n",
    "hadm_id"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import yaml\n",
    "\n",
    "with open('icd_9_10_definitions.yaml', 'r') as f:\n",
    "    phenotype_definitions = yaml.safe_load(f)\n",
    "\n",
    "code_phenotype_mapping = {\n",
    "    code.strip(): (dx, phenotype_definitions[dx]['use_in_benchmark'])\n",
    "    for dx, dx_def in phenotype_definitions.items()\n",
    "    for code in dx_def['codes']\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sql=\"\"\"\n",
    "select  emar.subject_id, emar.hadm_id, emar.charttime, emar.event_txt, emar.medication as emar_medication\n",
    "    from mimiciv_hosp.emar\n",
    "    where mimiciv_hosp.emar.hadm_id in (select hadm_id from mimiciv_icu.icustays)\n",
    "\"\"\"\n",
    "emar = pd.read_sql(sql, engine)\n",
    "emar = emar[emar['hadm_id'].isin(patient_base['hadm_id'])]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "emar = emar[(emar['event_txt']=='Administered') &\n",
    "            (~emar['emar_medication'].isna())] # Filter out the subject that is not our target"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "emar.rename(columns={'event_txt':'value', 'emar_medication':'item'}, inplace=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "emar = add_timestep(table=emar, base=patient_base, base_timename='intime',\n",
    "                      timename='charttime', timestep_name='timestep',\n",
    "                      merge_id='hadm_id', sort_list=['hadm_id', 'charttime'])\n",
    "\n",
    "emar.reset_index(drop=True, inplace=True)\n",
    "emar.drop(['charttime', 'intime'], axis=1, inplace=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "emar['item'] = emar['item'].str.lower().str.replace(' ', '', regex=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "emar = generate_stat_cate(emar, group_name='item', save_path='../DataProcessing/benchmark_stat/emar.csv', base=patient_base, merge_id='hadm_id', drop=100)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "emar['value'] = 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "emar"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Combine feature together"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def generate_feature_list(table_list):\n",
    "    feature_set = set()\n",
    "    for current_table in table_list:\n",
    "        current_feature = set(current_table['item'].unique())\n",
    "        feature_set.update(current_feature)\n",
    "\n",
    "    feature_list = sorted(feature_set)\n",
    "    return feature_list \n",
    "\n",
    "''' chartlab '''\n",
    "bg = bg.reset_index(drop=True) # chartlab\n",
    "urine_output = urine_output.reset_index(drop=True) # chartlab\n",
    "kdigo_uo = kdigo_uo.reset_index(drop=True) # chartlab\n",
    "blood_differential = blood_differential.reset_index(drop=True) # chartlab\n",
    "cardiac_marker = cardiac_marker.reset_index(drop=True) # chartlab\n",
    "chemistry = chemistry.reset_index(drop=True) # chartlab\n",
    "coagulation = coagulation.reset_index(drop=True) # chartlab\n",
    "complete_blood_count = complete_blood_count.reset_index(drop=True) # chartlab\n",
    "enzyme = enzyme.reset_index(drop=True) # chartlab\n",
    "gcs = gcs.reset_index(drop=True) # chartlab\n",
    "icp = icp.reset_index(drop=True) # chartlab\n",
    "inflammation = inflammation.reset_index(drop=True) # chartlab\n",
    "oxygen_delivery_o2_flow = oxygen_delivery_o2_flow.reset_index(drop=True) # chartlab\n",
    "rhythm = rhythm.reset_index(drop=True) # chartlab\n",
    "vitalsign = vitalsign.reset_index(drop=True) # chartlab\n",
    "weight = weight.reset_index(drop=True) # chartlab\n",
    "kdigo_creatinine = kdigo_creatinine.reset_index(drop=True) # chartlab\n",
    "ventilator_setting_cont = ventilator_setting_cont.reset_index(drop=True) # chartlab\n",
    "\n",
    "chartlab_table = [bg, urine_output, kdigo_uo, blood_differential, cardiac_marker, chemistry, \n",
    "                  coagulation, complete_blood_count,enzyme, gcs, icp, inflammation, \n",
    "                  oxygen_delivery_o2_flow, rhythm, vitalsign, weight, kdigo_creatinine, ventilator_setting_cont]\n",
    "\n",
    "chartlab_feature = generate_feature_list(chartlab_table)\n",
    "\n",
    "''' treatments '''\n",
    "norepinephrine_equivalent_dose = norepinephrine_equivalent_dose.reset_index(drop=True) # treatment\n",
    "crrt_cont = crrt_cont.reset_index(drop=True) # treatment\n",
    "crrt_cate = crrt_cate.reset_index(drop=True) # treatment\n",
    "rrt = rrt.reset_index(drop=True) # treatment\n",
    "ventilation = ventilation.reset_index(drop=True) # treatment\n",
    "emar = emar.reset_index(drop=True) # treatment\n",
    "\n",
    "treatment_table = [norepinephrine_equivalent_dose, crrt_cont, crrt_cate, rrt, ventilation, emar]\n",
    "treatment_feature = generate_feature_list(treatment_table)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "from tqdm import tqdm\n",
    "import warnings\n",
    "from concurrent.futures import ProcessPoolExecutor\n",
    "import numpy as np\n",
    "\n",
    "# Ignore FutureWarning\n",
    "warnings.filterwarnings('ignore', category=FutureWarning)\n",
    "\n",
    "# Define tables and their ID types\n",
    "chartlab_tables = {\n",
    "    'bg': 'hadm_id',\n",
    "    'urine_output': 'stay_id',\n",
    "    'blood_differential': 'hadm_id',\n",
    "    'cardiac_marker': 'hadm_id',\n",
    "    'chemistry': 'hadm_id',\n",
    "    'coagulation': 'hadm_id',\n",
    "    'complete_blood_count': 'hadm_id',\n",
    "    'enzyme': 'hadm_id',\n",
    "    'gcs': 'stay_id',\n",
    "    'icp': 'stay_id',\n",
    "    'inflammation': 'hadm_id',\n",
    "    'oxygen_delivery_o2_flow': 'stay_id',\n",
    "    'rhythm': 'stay_id',\n",
    "    'vitalsign': 'stay_id',\n",
    "    'weight': 'stay_id',\n",
    "    'kdigo_creatinine': 'stay_id',\n",
    "    'ventilator_setting_cont': 'stay_id',\n",
    "    'kdigo_uo': 'stay_id'\n",
    "}\n",
    "\n",
    "treatment_tables = {\n",
    "    'norepinephrine_equivalent_dose': 'stay_id',\n",
    "    'crrt_cont': 'stay_id',\n",
    "    'crrt_cate': 'stay_id',\n",
    "    'rrt': 'stay_id',\n",
    "    'ventilation': 'stay_id',\n",
    "    'emar': 'hadm_id'\n",
    "}\n",
    "\n",
    "# Extract time series data in batch\n",
    "def extract_time_series_batch(table, rows, id_name):\n",
    "    # Get all ID values that need to be queried\n",
    "    id_values = rows[id_name].unique()\n",
    "    \n",
    "    # Query all matching rows at once\n",
    "    mask = table[id_name].isin(id_values)\n",
    "    all_data = table[mask].copy()\n",
    "    \n",
    "    # Create ID to stay_id mapping\n",
    "    if id_name == 'hadm_id':\n",
    "        id_to_stay = dict(zip(rows[id_name], rows['stay_id']))\n",
    "        all_data['stay_id'] = all_data[id_name].map(id_to_stay)\n",
    "        all_data = all_data.drop(columns=id_name)\n",
    "    \n",
    "    return all_data\n",
    "\n",
    "# Process a batch of patient data\n",
    "def process_batch(patient_batch, all_tables):\n",
    "    batch_size = len(patient_batch)\n",
    "    batch_chartlab = {}\n",
    "    batch_treatments = {}\n",
    "    \n",
    "    # Extract data from each table for this batch\n",
    "    chartlab_batch_data = {}\n",
    "    treatment_batch_data = {}\n",
    "    \n",
    "    # Extract chartlab data\n",
    "    for table_name, id_name in chartlab_tables.items():\n",
    "        table = all_tables[table_name]\n",
    "        batch_data = extract_time_series_batch(table, patient_batch, id_name)\n",
    "        if not batch_data.empty:\n",
    "            chartlab_batch_data[table_name] = batch_data\n",
    "    \n",
    "    # # Extract treatment data\n",
    "    # for table_name, id_name in treatment_tables.items():\n",
    "    #     table = all_tables[table_name]\n",
    "    #     batch_data = extract_time_series_batch(table, patient_batch, id_name)\n",
    "    #     if not batch_data.empty:\n",
    "    #         treatment_batch_data[table_name] = batch_data\n",
    "    \n",
    "    # Organize data for each patient\n",
    "    for _, row in patient_batch.iterrows():\n",
    "        stay_id = row.stay_id\n",
    "        \n",
    "        # Filter data for this patient from each table\n",
    "        chartlab_data = []\n",
    "        for table_name, data in chartlab_batch_data.items():\n",
    "            patient_data = data[data['stay_id'] == stay_id]\n",
    "            if not patient_data.empty:\n",
    "                chartlab_data.append(patient_data)\n",
    "        \n",
    "        treatment_data = []\n",
    "        for table_name, data in treatment_batch_data.items():\n",
    "            patient_data = data[data['stay_id'] == stay_id]\n",
    "            if not patient_data.empty:\n",
    "                treatment_data.append(patient_data)\n",
    "        \n",
    "        # Process chartlab data\n",
    "        if chartlab_data:\n",
    "            try:\n",
    "                chartlab_df = pd.concat(chartlab_data, ignore_index=True)\n",
    "                chartlab_df = chartlab_df.sort_values(by=['timestep'])\n",
    "                chartlab_df = chartlab_df.drop_duplicates(subset=['stay_id', 'item', 'timestep'])\n",
    "                batch_chartlab[stay_id] = chartlab_df\n",
    "            except Exception as e:\n",
    "                print(f\"Error processing chartlab for stay_id {stay_id}: {e}\")\n",
    "        \n",
    "        # # Process treatment data\n",
    "        # if treatment_data:\n",
    "        #     try:\n",
    "        #         treatment_df = pd.concat(treatment_data, ignore_index=True)\n",
    "        #         treatment_df = treatment_df.sort_values(by=['timestep'])\n",
    "        #         treatment_df = treatment_df.drop_duplicates(subset=['stay_id', 'item', 'timestep'])\n",
    "        #         batch_treatments[stay_id] = treatment_df\n",
    "        #     except Exception as e:\n",
    "        #         print(f\"Error processing treatment for stay_id {stay_id}: {e}\")\n",
    "    \n",
    "    return batch_chartlab, batch_treatments\n",
    "\n",
    "# Main function: Process all patient data\n",
    "def process_all_patients(patient_base, all_tables, batch_size=500):\n",
    "    total_patients = len(patient_base)\n",
    "    chartlab = {}\n",
    "    treatments = {}\n",
    "    \n",
    "    for i in tqdm(range(0, total_patients, batch_size), desc=\"Processing patients\"):\n",
    "        # Extract a batch of patient data\n",
    "        batch = patient_base.iloc[i:min(i+batch_size, total_patients)]\n",
    "        \n",
    "        # Process this batch\n",
    "        batch_chartlab, batch_treatments = process_batch(batch, all_tables)\n",
    "        \n",
    "        # Update result dictionaries\n",
    "        chartlab.update(batch_chartlab)\n",
    "        treatments.update(batch_treatments)\n",
    "    \n",
    "    return chartlab, treatments\n",
    "\n",
    "# Main processing function\n",
    "def main():\n",
    "    # Prepare all tables\n",
    "    # all_tables = {\n",
    "    #     # Chart and lab data\n",
    "    #     'bg': bg,\n",
    "    #     'urine_output': urine_output,\n",
    "    #     'blood_differential': blood_differential,\n",
    "    #     'cardiac_marker': cardiac_marker,\n",
    "    #     'chemistry': chemistry,\n",
    "    #     'coagulation': coagulation,\n",
    "    #     'complete_blood_count': complete_blood_count,\n",
    "    #     'enzyme': enzyme,\n",
    "    #     'gcs': gcs,\n",
    "    #     'icp': icp,\n",
    "    #     'inflammation': inflammation,\n",
    "    #     'oxygen_delivery_o2_flow': oxygen_delivery_o2_flow,\n",
    "    #     'rhythm': rhythm,\n",
    "    #     'vitalsign': vitalsign,\n",
    "    #     'weight': weight,\n",
    "    #     'kdigo_creatinine': kdigo_creatinine,\n",
    "    #     'ventilator_setting_cont': ventilator_setting_cont,\n",
    "    #     'kdigo_uo': kdigo_uo,\n",
    "        \n",
    "    #     # Treatment data\n",
    "    #     'norepinephrine_equivalent_dose': norepinephrine_equivalent_dose,\n",
    "    #     'crrt_cont': crrt_cont,\n",
    "    #     'crrt_cate': crrt_cate,\n",
    "    #     'rrt': rrt,\n",
    "    #     'ventilation': ventilation,\n",
    "    #     'emar': emar\n",
    "    # }\n",
    "    all_tables = {\n",
    "    # Chart and lab data\n",
    "    'bg': bg,\n",
    "    'urine_output': urine_output,\n",
    "    'blood_differential': blood_differential,\n",
    "    'cardiac_marker': cardiac_marker,\n",
    "    'chemistry': chemistry,\n",
    "    'coagulation': coagulation,\n",
    "    'complete_blood_count': complete_blood_count,\n",
    "    'enzyme': enzyme,\n",
    "    'gcs': gcs,\n",
    "    'icp': icp,\n",
    "    'inflammation': inflammation,\n",
    "    'oxygen_delivery_o2_flow': oxygen_delivery_o2_flow,\n",
    "    'rhythm': rhythm,\n",
    "    'vitalsign': vitalsign,\n",
    "    'weight': weight,\n",
    "    'kdigo_creatinine': kdigo_creatinine,\n",
    "    'ventilator_setting_cont': ventilator_setting_cont,\n",
    "    'kdigo_uo': kdigo_uo,\n",
    "    }\n",
    "    \n",
    "    # Process all patient data\n",
    "    chartlab, treatments = process_all_patients(patient_base, all_tables, batch_size=500)\n",
    "    \n",
    "    return chartlab, treatments\n",
    "\n",
    "chartlab, treatments = main()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# frequent_procedures_icu.sort_values(by='counting', ascending=False).to_csv('./DataProcessing/benchmark_stat/procedures_icu.csv', index=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# frequent_procedures_icu.sort_values(by='counting', ascending=False).to_csv('./DataProcessing/benchmark_stat/procedures_icu.csv', index=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# def extract_time_series(table, row, id_name):\n",
    "#     data = table.loc[table[id_name]==row[id_name]].copy()\n",
    "#     if id_name=='hadm_id':\n",
    "#         data['stay_id']=row.stay_id\n",
    "#         data = data.drop(columns=id_name)\n",
    "#     return data\n",
    "\n",
    "# chartlab = {}\n",
    "# treatments = {}\n",
    "# for i, (_, row) in tqdm(enumerate(patient_base.iterrows()), total=patient_base.shape[0]):\n",
    "#     chartlab_data = []\n",
    "#     treatments_data = []\n",
    "#     \"\"\"======================================chartlab==============================================\"\"\"\n",
    "#     # 1. bg: hadm_id \n",
    "#     chartlab_data.append(extract_time_series(bg, row, 'hadm_id'))\n",
    "#     # 2. urine_output: stay_id\n",
    "#     chartlab_data.append(extract_time_series(urine_output, row, 'stay_id'))\n",
    "#     # 3. blood_differential: hadm_id\n",
    "#     chartlab_data.append(extract_time_series(blood_differential, row, 'hadm_id'))\n",
    "#     # 4. cardiac_marker: hadm_id\n",
    "#     chartlab_data.append(extract_time_series(cardiac_marker, row, 'hadm_id'))\n",
    "#     # 5. chemistry: hadm_id\n",
    "#     chartlab_data.append(extract_time_series(chemistry, row, 'hadm_id'))\n",
    "#     # 6. coagulation: hadm_id\n",
    "#     chartlab_data.append(extract_time_series(coagulation, row, 'hadm_id'))\n",
    "#     # 7. complete_blood_count: hadm_id\n",
    "#     chartlab_data.append(extract_time_series(complete_blood_count, row, 'hadm_id'))\n",
    "#     # 8. enzyme: hadm_id\n",
    "#     chartlab_data.append(extract_time_series(enzyme, row, 'hadm_id'))\n",
    "#     # 9. gcs: stay_id\n",
    "#     chartlab_data.append(extract_time_series(gcs, row, 'stay_id'))\n",
    "#     # 10. icp: stay_id\n",
    "#     chartlab_data.append(extract_time_series(icp, row, 'stay_id'))\n",
    "#     # 11. inflammation: hadm_id\n",
    "#     chartlab_data.append(extract_time_series(inflammation, row, 'hadm_id'))\n",
    "#     # 12. oxygen_delivery_o2_flow: stay_id\n",
    "#     chartlab_data.append(extract_time_series(oxygen_delivery_o2_flow, row, 'stay_id'))\n",
    "#     # 13. rhythm: stay_id\n",
    "#     chartlab_data.append(extract_time_series(rhythm, row, 'stay_id'))\n",
    "#     # 14. vitalsign: stay_id\n",
    "#     chartlab_data.append(extract_time_series(vitalsign, row, 'stay_id'))\n",
    "#     # 15: weight: stay_id\n",
    "#     chartlab_data.append(extract_time_series(weight, row, 'stay_id'))\n",
    "#     # 16. kdigo_creatinine: stay_id\n",
    "#     chartlab_data.append(extract_time_series(kdigo_creatinine, row, 'stay_id'))\n",
    "#     # 17. ventilator_setting_cont: stay_id\n",
    "#     chartlab_data.append(extract_time_series(ventilator_setting_cont, row, 'stay_id'))\n",
    "#     # 18. kdigo_uo: stay_id\n",
    "#     chartlab_data.append(extract_time_series(kdigo_uo, row, 'stay_id'))\n",
    "    \n",
    "#     # chartlab[row.stay_id] = pd.concat(chartlab_data, ignore_index=True).sort_values(by=['timestep']).drop_duplicates(subset=['stay_id', 'item', 'timestep'])\n",
    "#     # 合并chartlab数据并添加mask\n",
    "#     chartlab_df = pd.concat(chartlab_data, ignore_index=True).sort_values(by=['timestep']).drop_duplicates(subset=['stay_id', 'item', 'timestep'])\n",
    "#     # 添加mask列\n",
    "#     feat_cols = [x for x in chartlab_df.columns if x not in ['stay_id', 'item', 'timestep']]\n",
    "#     mask_dict = {f\"{feat}_mask\": (~chartlab_df[feat].isna()).astype(float) for feat in feat_cols}\n",
    "#     chartlab_df = pd.concat([chartlab_df, pd.DataFrame(mask_dict, index=chartlab_df.index)], axis=1)\n",
    "#     chartlab[row.stay_id] = chartlab_df\n",
    "\n",
    "#     \"\"\"======================================treatments==============================================\"\"\"\n",
    "#     # 1. norepinephrine_equivalent_dose: stay_id\n",
    "#     treatments_data.append(extract_time_series(norepinephrine_equivalent_dose, row, 'stay_id'))\n",
    "#     # 2. crrt_cont: stay_id\n",
    "#     treatments_data.append(extract_time_series(crrt_cont, row, 'stay_id'))\n",
    "#     # 3. crrt_cate: stay_id\n",
    "#     treatments_data.append(extract_time_series(crrt_cate, row, 'stay_id'))\n",
    "#     # 4. rrt: stay_id\n",
    "#     treatments_data.append(extract_time_series(rrt, row, 'stay_id'))\n",
    "#     # 5. ventilation: stay_id\n",
    "#     treatments_data.append(extract_time_series(ventilation, row, 'stay_id'))\n",
    "#     # 6. emar: hadm_id\n",
    "#     treatments_data.append(extract_time_series(emar, row, 'hadm_id'))\n",
    "\n",
    "#     treatments[row.stay_id] = pd.concat(treatments_data, ignore_index=True).sort_values(by=['timestep']).drop_duplicates(subset=['stay_id', 'item', 'timestep'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "age = patient_base[['subject_id', 'hadm_id', 'stay_id', 'age']]\n",
    "height = patient_base[['subject_id', 'hadm_id', 'stay_id', 'average_height']]\n",
    "\n",
    "gender = patient_base[['subject_id', 'hadm_id', 'stay_id', 'gender']]\n",
    "gender = gender.pivot_table(index=['subject_id', 'hadm_id', 'stay_id'], \n",
    "                        columns='gender', \n",
    "                        aggfunc=lambda x: 1).reset_index()\n",
    "gender = gender.reset_index(drop=True)\n",
    "gender.columns.name = None\n",
    "\n",
    "\n",
    "race = patient_base[['subject_id', 'hadm_id', 'stay_id', 'race']]\n",
    "race = race.pivot_table(index=['subject_id', 'hadm_id', 'stay_id'], \n",
    "                        columns='race', \n",
    "                        aggfunc=lambda x: 1).reset_index()\n",
    "race = race.reset_index(drop=True)\n",
    "race.columns.name = None\n",
    "\n",
    "from functools import reduce\n",
    "\n",
    "demographics_list = [age, height, gender, race]\n",
    "demographics = reduce(lambda left, right: pd.merge(left, right, on=['subject_id', 'hadm_id', 'stay_id']), demographics_list)\n",
    "\n",
    "# demographics = demographics.fillna(0)\n",
    "demographics"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# def pivot_table_custom(data, feature_list):\n",
    "#     table = data.pivot_table(index=['subject_id','stay_id','timestep'], columns='item', values='value', aggfunc='first')\n",
    "#     table = table.reindex(columns=feature_list)\n",
    "#     # table = table.fillna(0)\n",
    "#     table = table.reset_index()\n",
    "#     table.columns.name = None\n",
    "#     return table\n",
    "\n",
    "# # Function to add mask columns to DataFrame\n",
    "# def add_mask_columns(df):\n",
    "#     # Identify feature columns\n",
    "#     feat_cols = [x for x in df.columns if x not in ['subject_id', 'stay_id', 'timestep']]\n",
    "    \n",
    "#     # Create all mask columns at once\n",
    "#     mask_dict = {f\"{feat}_mask\": (~df[feat].isna()).astype(float) for feat in feat_cols}\n",
    "#     mask_df = pd.DataFrame(mask_dict, index=df.index)\n",
    "    \n",
    "#     return pd.concat([df, mask_df], axis=1)\n",
    "\n",
    "# # Fill missing timesteps with NaN\n",
    "# # def fill_timesteps(df, subject_id, stay_id):\n",
    "# #     if df.empty:\n",
    "# #         return df\n",
    "# #     min_ts = int(df['timestep'].min())\n",
    "# #     max_ts = int(df['timestep'].max())\n",
    "# #     all_timesteps = np.arange(min_ts, max_ts + 1)\n",
    "# #     full_index = pd.DataFrame({\n",
    "# #         'subject_id': subject_id,\n",
    "# #         'stay_id': stay_id,\n",
    "# #         'timestep': all_timesteps\n",
    "# #     })\n",
    "# #     df_filled = pd.merge(full_index, df, on=['subject_id', 'stay_id', 'timestep'], how='left')\n",
    "# #     return df_filled\n",
    "\n",
    "# def fill_timesteps(df, subject_id, stay_id, target_length=12):\n",
    "#     \"\"\"\n",
    "#     填充timesteps\n",
    "    \n",
    "#     Parameters:\n",
    "#     df: DataFrame for a single patient\n",
    "#     subject_id: subject ID\n",
    "#     stay_id: stay ID\n",
    "#     target_length: 目标timestep长度，如果为None则使用原始数据的最大长度\n",
    "    \n",
    "#     Returns:\n",
    "#     DataFrame with filled timesteps\n",
    "#     \"\"\"\n",
    "#     if df.empty:\n",
    "#         return df\n",
    "    \n",
    "#     if target_length is None:\n",
    "#         # 使用原始数据的最大长度\n",
    "#         min_ts = int(df['timestep'].min())\n",
    "#         max_ts = int(df['timestep'].max())\n",
    "#         all_timesteps = np.arange(min_ts, max_ts + 1)\n",
    "#     else:\n",
    "#         # 使用指定的目标长度\n",
    "#         all_timesteps = np.arange(target_length)\n",
    "    \n",
    "#     full_index = pd.DataFrame({\n",
    "#         'subject_id': subject_id,\n",
    "#         'stay_id': stay_id,\n",
    "#         'timestep': all_timesteps\n",
    "#     })\n",
    "    \n",
    "#     df_filled = pd.merge(full_index, df, on=['subject_id', 'stay_id', 'timestep'], how='left')\n",
    "#     return df_filled\n",
    "\n",
    "\n",
    "# all_patient_data = []\n",
    "# valid_indices = []\n",
    "\n",
    "# for i, (idx, row) in tqdm(enumerate(patient_base.iterrows()), total=patient_base.shape[0]):\n",
    "#     chartlab_data = chartlab.get(row.stay_id)\n",
    "#     # treatments_data = treatments.get(row.stay_id)  # 不再需要treatment数据\n",
    "\n",
    "#     if chartlab_data is None:\n",
    "#         continue  \n",
    "#     if len(chartlab_data) < 3:\n",
    "#         continue\n",
    "    \n",
    "#     # Record valid index\n",
    "#     valid_indices.append(idx)\n",
    "\n",
    "#     chartlab_table = pivot_table_custom(chartlab_data, chartlab_feature)\n",
    "#     chartlab_table = add_mask_columns(chartlab_table)\n",
    "#     chartlab_table = fill_timesteps(chartlab_table, row.subject_id, row.stay_id, target_length=12)\n",
    "#     chartlab_table = chartlab_table.sort_values('timestep')\n",
    "    \n",
    "#     # 现在只使用chartlab数据作为最终数据\n",
    "#     all_patient_data.append(chartlab_table)\n",
    "    \n",
    "#     csv_name = str(row.stay_id) + '.csv'\n",
    "#     # 只保存chartlab数据\n",
    "#     chartlab_table.to_csv(os.path.join(chartlab_dir, csv_name), index=False)\n",
    "#     # 如果您还想保存到merged目录，使用chartlab_table\n",
    "#     chartlab_table.to_csv(os.path.join(merged_dir, csv_name), index=False)\n",
    "#     # 删除treatment相关的保存操作\n",
    "\n",
    "# # Merge all patient data\n",
    "# final_table = pd.concat(all_patient_data, axis=0, ignore_index=True)\n",
    "# final_table = final_table.sort_values(['subject_id', 'stay_id', 'timestep'])\n",
    "# final_table.to_csv(os.path.join(main_dir, 'merged_all.csv'), index=False)\n",
    "\n",
    "# # Update patient_base and demographics to keep only valid entries\n",
    "# patient_base = patient_base.loc[valid_indices].reset_index(drop=True)\n",
    "# patient_base.to_csv(os.path.join(main_dir, 'stays_meta_with_labels.csv'), index=False)\n",
    "# demographics = demographics.loc[valid_indices].reset_index(drop=True)\n",
    "# demographics.to_csv(os.path.join(main_dir, 'demographics.csv'), index=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def pivot_table_custom(data, feature_list):\n",
    "    table = data.pivot_table(index=['subject_id','stay_id','timestep'], columns='item', values='value', aggfunc='first')\n",
    "    table = table.reindex(columns=feature_list)\n",
    "    # table = table.fillna(0)\n",
    "    table = table.reset_index()\n",
    "    table.columns.name = None\n",
    "    return table\n",
    "\n",
    "# Function to add mask columns to DataFrame\n",
    "def add_mask_columns(df):\n",
    "    # Identify feature columns\n",
    "    feat_cols = [x for x in df.columns if x not in ['subject_id', 'stay_id', 'timestep']]\n",
    "    \n",
    "    # Create all mask columns at once\n",
    "    mask_dict = {f\"{feat}_mask\": (~df[feat].isna()).astype(float) for feat in feat_cols}\n",
    "    mask_df = pd.DataFrame(mask_dict, index=df.index)\n",
    "    \n",
    "    return pd.concat([df, mask_df], axis=1)\n",
    "\n",
    "def fill_timesteps(df, subject_id, stay_id, target_length=48):\n",
    "    \"\"\"\n",
    "    Fill timesteps for a patient's data.\n",
    "    \n",
    "    Parameters:\n",
    "    df: DataFrame for a single patient\n",
    "    subject_id: subject ID\n",
    "    stay_id: stay ID\n",
    "    target_length: Target timestep length. If None, use the max length in the original data.\n",
    "    \n",
    "    Returns:\n",
    "    DataFrame with filled timesteps\n",
    "    \"\"\"\n",
    "    if df.empty:\n",
    "        return df\n",
    "    \n",
    "    if target_length is None:\n",
    "        # Use the max length in the original data\n",
    "        min_ts = int(df['timestep'].min())\n",
    "        max_ts = int(df['timestep'].max())\n",
    "        all_timesteps = np.arange(min_ts, max_ts + 1)\n",
    "    else:\n",
    "        # Use the specified target length\n",
    "        all_timesteps = np.arange(target_length)\n",
    "    \n",
    "    full_index = pd.DataFrame({\n",
    "        'subject_id': subject_id,\n",
    "        'stay_id': stay_id,\n",
    "        'timestep': all_timesteps\n",
    "    })\n",
    "    \n",
    "    df_filled = pd.merge(full_index, df, on=['subject_id', 'stay_id', 'timestep'], how='left')\n",
    "    return df_filled\n",
    "\n",
    "def filter_high_missing_features(df, threshold=90):\n",
    "    \"\"\"\n",
    "    Filter out features (and their mask columns) with missing ratio above the threshold.\n",
    "    \n",
    "    Parameters:\n",
    "    df: DataFrame with time series data\n",
    "    threshold: Missing ratio threshold, default 90%\n",
    "    \n",
    "    Returns:\n",
    "    DataFrame with filtered features and list of removed features\n",
    "    \"\"\"\n",
    "    # Get feature columns (exclude metadata and mask columns)\n",
    "    feature_cols = [col for col in df.columns \n",
    "                   if col not in ['subject_id', 'stay_id', 'timestep'] and not col.endswith('_mask')]\n",
    "    \n",
    "    # Calculate missing ratio for each feature\n",
    "    missing_ratios = df[feature_cols].isnull().sum() / len(df) * 100\n",
    "    \n",
    "    # Find features with missing ratio below the threshold\n",
    "    valid_features = missing_ratios[missing_ratios < threshold].index.tolist()\n",
    "    removed_features = missing_ratios[missing_ratios >= threshold].index.tolist()\n",
    "    \n",
    "    # Build the list of columns to keep\n",
    "    columns_to_keep = ['subject_id', 'stay_id', 'timestep']\n",
    "    \n",
    "    # Add valid features\n",
    "    columns_to_keep.extend(valid_features)\n",
    "    \n",
    "    # Add corresponding mask columns (only keep masks for valid features)\n",
    "    for feat in valid_features:\n",
    "        mask_col = f\"{feat}_mask\"\n",
    "        if mask_col in df.columns:\n",
    "            columns_to_keep.append(mask_col)\n",
    "    \n",
    "    # Filter DataFrame\n",
    "    filtered_df = df[columns_to_keep].copy()\n",
    "    \n",
    "    # Validate filtering result\n",
    "    remaining_features = [col for col in filtered_df.columns \n",
    "                         if col not in ['subject_id', 'stay_id', 'timestep'] and not col.endswith('_mask')]\n",
    "    remaining_masks = [col for col in filtered_df.columns if col.endswith('_mask')]\n",
    "    \n",
    "    print(f\"Number of features kept after filtering: {len(remaining_features)}\")\n",
    "    print(f\"Number of mask columns kept after filtering: {len(remaining_masks)}\")\n",
    "    \n",
    "    # Validate correspondence between mask columns and feature columns\n",
    "    expected_masks = [f\"{feat}_mask\" for feat in remaining_features]\n",
    "    actual_masks = [col for col in filtered_df.columns if col.endswith('_mask')]\n",
    "    \n",
    "    if set(expected_masks) == set(actual_masks):\n",
    "        print(\"Mask columns correspond correctly to feature columns\")\n",
    "    else:\n",
    "        print(\"Warning: There may be a mismatch between mask columns and feature columns\")\n",
    "        missing_masks = set(expected_masks) - set(actual_masks)\n",
    "        extra_masks = set(actual_masks) - set(expected_masks)\n",
    "        if missing_masks:\n",
    "            print(f\"   Missing mask columns: {missing_masks}\")\n",
    "        if extra_masks:\n",
    "            print(f\"   Extra mask columns: {extra_masks}\")\n",
    "    \n",
    "    return filtered_df, valid_features, removed_features\n",
    "\n",
    "# Process all patient data\n",
    "all_patient_data = []\n",
    "valid_indices = []\n",
    "\n",
    "for i, (idx, row) in tqdm(enumerate(patient_base.iterrows()), total=patient_base.shape[0]):\n",
    "    chartlab_data = chartlab.get(row.stay_id)\n",
    "\n",
    "    if chartlab_data is None:\n",
    "        continue  \n",
    "    if len(chartlab_data) < 3:\n",
    "        continue\n",
    "    \n",
    "    # Record valid index\n",
    "    valid_indices.append(idx)\n",
    "\n",
    "    chartlab_table = pivot_table_custom(chartlab_data, chartlab_feature)\n",
    "    chartlab_table = add_mask_columns(chartlab_table)\n",
    "    # 12 hours\n",
    "    # chartlab_table = fill_timesteps(chartlab_table, row.subject_id, row.stay_id, target_length=48)\n",
    "    chartlab_table = chartlab_table.sort_values('timestep')\n",
    "    \n",
    "    # Only use chartlab data as the final data\n",
    "    all_patient_data.append(chartlab_table)\n",
    "    \n",
    "    csv_name = str(row.stay_id) + '.csv'\n",
    "    # Only save chartlab data\n",
    "    chartlab_table.to_csv(os.path.join(chartlab_dir, csv_name), index=False)\n",
    "    # If you also want to save to the merged directory, use chartlab_table\n",
    "    chartlab_table.to_csv(os.path.join(merged_dir, csv_name), index=False)\n",
    "\n",
    "# Merge all patient data\n",
    "final_table = pd.concat(all_patient_data, axis=0, ignore_index=True)\n",
    "final_table = final_table.sort_values(['subject_id', 'stay_id', 'timestep'])\n",
    "\n",
    "# Filter features with missing ratio > 90%\n",
    "print(\"Filtering features with high missing ratio...\")\n",
    "final_table_filtered, valid_features, removed_features = filter_high_missing_features(final_table, threshold=90)\n",
    "\n",
    "print(f\"Original number of features: {len([col for col in final_table.columns if col not in ['subject_id', 'stay_id', 'timestep'] and not col.endswith('_mask')])}\")\n",
    "print(f\"Number of features after filtering: {len(valid_features)}\")\n",
    "print(f\"Number of removed features: {len(removed_features)}\")\n",
    "\n",
    "# Show removed features\n",
    "if removed_features:\n",
    "    print(f\"\\nRemoved features (missing ratio >= 90%):\")\n",
    "    for i, feat in enumerate(removed_features[:10]):  # Show only the first 10\n",
    "        print(f\"  {i+1}. {feat}\")\n",
    "    if len(removed_features) > 10:\n",
    "        print(f\"  ... {len(removed_features) - 10} more features\")\n",
    "\n",
    "# Update chartlab_feature list\n",
    "chartlab_feature_filtered = [feat for feat in chartlab_feature if feat in valid_features]\n",
    "print(f\"\\nchartlab_feature updated: {len(chartlab_feature)} -> {len(chartlab_feature_filtered)}\")\n",
    "chartlab_feature = chartlab_feature_filtered\n",
    "\n",
    "# Save filtered data (replace original merged_all.csv)\n",
    "final_table_filtered.to_csv(os.path.join(main_dir, 'merged_all.csv'), index=False)\n",
    "\n",
    "# # Update features.yaml file (replace original features.yaml)\n",
    "# features_dict_filtered = {\n",
    "#     'chartlab_feature': chartlab_feature_filtered,\n",
    "#     'treatment_feature': treatment_feature\n",
    "# }\n",
    "\n",
    "# with open(os.path.join(split_dir, 'features.yaml'), 'w') as f:\n",
    "#     yaml.dump(features_dict_filtered, f, default_flow_style=False)\n",
    "\n",
    "print(\"Filtering complete!\")\n",
    "print(\"Filtered data saved to: merged_all.csv\")\n",
    "print(\"Filtered features configuration saved to: features.yaml\")\n",
    "\n",
    "# Update patient_base and demographics to keep only valid entries\n",
    "patient_base = patient_base.loc[valid_indices].reset_index(drop=True)\n",
    "patient_base.to_csv(os.path.join(main_dir, 'stays_meta_with_labels.csv'), index=False)\n",
    "demographics = demographics.loc[valid_indices].reset_index(drop=True)\n",
    "demographics.to_csv(os.path.join(main_dir, 'demographics.csv'), index=False)\n",
    "\n",
    "# Optional: Show statistics after filtering\n",
    "print(f\"\\n=== Statistics after filtering ===\")\n",
    "print(f\"Total rows: {len(final_table_filtered)}\")\n",
    "print(f\"Number of patients: {final_table_filtered['stay_id'].nunique()}\")\n",
    "print(f\"Average missing ratio: {final_table_filtered[valid_features].isnull().sum().sum() / (len(final_table_filtered) * len(valid_features)) * 100:.2f}%\")\n",
    "\n",
    "# Count filtered features by category\n",
    "def categorize_features(features):\n",
    "    \"\"\"Group features by category\"\"\"\n",
    "    categories = {}\n",
    "    for feat in features:\n",
    "        if feat.startswith('bg_'):\n",
    "            categories.setdefault('bg', []).append(feat)\n",
    "        elif feat.startswith('blood_differential_'):\n",
    "            categories.setdefault('blood_differential', []).append(feat)\n",
    "        elif feat.startswith('cardiac_marker_'):\n",
    "            categories.setdefault('cardiac_marker', []).append(feat)\n",
    "        elif feat.startswith('chemistry_'):\n",
    "            categories.setdefault('chemistry', []).append(feat)\n",
    "        elif feat.startswith('coagulation_'):\n",
    "            categories.setdefault('coagulation', []).append(feat)\n",
    "        elif feat.startswith('complete_blood_count_'):\n",
    "            categories.setdefault('complete_blood_count', []).append(feat)\n",
    "        elif feat.startswith('enzyme_'):\n",
    "            categories.setdefault('enzyme', []).append(feat)\n",
    "        elif feat.startswith('gcs_'):\n",
    "            categories.setdefault('gcs', []).append(feat)\n",
    "        elif feat.startswith('icp_'):\n",
    "            categories.setdefault('icp', []).append(feat)\n",
    "        elif feat.startswith('inflammation_'):\n",
    "            categories.setdefault('inflammation', []).append(feat)\n",
    "        elif feat.startswith('kdigo_'):\n",
    "            categories.setdefault('kdigo', []).append(feat)\n",
    "        elif feat.startswith('oxygen_delivery_'):\n",
    "            categories.setdefault('oxygen_delivery', []).append(feat)\n",
    "        elif feat.startswith('rhythm_'):\n",
    "            categories.setdefault('rhythm', []).append(feat)\n",
    "        elif feat.startswith('urine_output_'):\n",
    "            categories.setdefault('urine_output', []).append(feat)\n",
    "        elif feat.startswith('ventilator_setting_'):\n",
    "            categories.setdefault('ventilator_setting', []).append(feat)\n",
    "        elif feat.startswith('vitalsign_'):\n",
    "            categories.setdefault('vitalsign', []).append(feat)\n",
    "        elif feat.startswith('weight_'):\n",
    "            categories.setdefault('weight', []).append(feat)\n",
    "        else:\n",
    "            categories.setdefault('other', []).append(feat)\n",
    "    return categories\n",
    "\n",
    "feature_categories = categorize_features(valid_features)\n",
    "print(f\"\\n=== Filtered features by category ===\")\n",
    "for category, features in feature_categories.items():\n",
    "    print(f\"{category}: {len(features)} features\")\n",
    "\n",
    "# Validate the integrity of the final data\n",
    "print(f\"\\n=== Final data validation ===\")\n",
    "print(f\"Number of columns in final data: {len(final_table_filtered.columns)}\")\n",
    "print(f\"Number of rows in final data: {len(final_table_filtered)}\")\n",
    "print(f\"Number of features in final data: {len([col for col in final_table_filtered.columns if col not in ['subject_id', 'stay_id', 'timestep'] and not col.endswith('_mask')])}\")\n",
    "print(f\"Number of mask columns in final data: {len([col for col in final_table_filtered.columns if col.endswith('_mask')])}\")\n",
    "\n",
    "# Check for isolated mask columns (without corresponding feature)\n",
    "isolated_masks = []\n",
    "for col in final_table_filtered.columns:\n",
    "    if col.endswith('_mask'):\n",
    "        feature_name = col[:-5]  # Remove '_mask' suffix\n",
    "        if feature_name not in final_table_filtered.columns:\n",
    "            isolated_masks.append(col)\n",
    "\n",
    "if isolated_masks:\n",
    "    print(f\"Warning: Found isolated mask columns: {isolated_masks}\")\n",
    "else:\n",
    "    print(\"No isolated mask columns found\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import yaml\n",
    "\n",
    "with open('icd_9_10_definitions.yaml', 'r') as f:\n",
    "    phenotype_definitions = yaml.safe_load(f)\n",
    "\n",
    "code_phenotype_mapping = {\n",
    "    code.strip(): (dx, phenotype_definitions[dx]['use_in_benchmark'])\n",
    "    for dx, dx_def in phenotype_definitions.items()\n",
    "    for code in dx_def['codes']\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "len(patient_base)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "final_table.head(5)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Data Split"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# cohort = patient_base[(patient_base['outtime']-patient_base['intime']).dt.total_seconds()/3600 > 2]\n",
    "cohort = patient_base[(patient_base['outtime']-patient_base['intime']).dt.total_seconds()/3600 > 48]\n",
    "\n",
    "print(f'Number of stays having LoS ≥ 48 hours:', cohort.shape[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def filter_cxrs(row):\n",
    "    start_time = row.intime - pd.Timedelta('24 hours')\n",
    "    # end_time = row.intime + pd.Timedelta('12 hours')\n",
    "    end_time = row.intime + pd.Timedelta('48 hours')\n",
    "    return [x for x in row.cxr_within_hadm if start_time <= x[1] <= end_time]\n",
    "            \n",
    "cohort['valid_cxrs'] = patient_base.apply(filter_cxrs, axis=1)\n",
    "print(f\"Number of stays with valid CXRs: {(cohort['valid_cxrs'].apply(len)>0).sum()}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "split_dir = os.path.join(main_dir, 'splits')\n",
    "os.makedirs(split_dir, exist_ok=True)\n",
    "\n",
    "cohort.to_csv(os.path.join(split_dir, 'data_cohort.csv'), index=False)\n",
    "\n",
    "features_dict = {\n",
    "    'chartlab_feature': chartlab_feature,\n",
    "    'treatment_feature': treatment_feature\n",
    "}\n",
    "\n",
    "with open(os.path.join(split_dir, 'features.yaml'), 'w') as f:\n",
    "    yaml.dump(features_dict, f, default_flow_style=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for seed in range(5):\n",
    "    print(f'Fold {seed+1}:')\n",
    "    df_train, df_test, y_train, y_test = train_test_split(cohort,\n",
    "                                                          cohort['icu_mortality'].values,\n",
    "                                                          test_size=0.2,\n",
    "                                                          random_state=seed)\n",
    "    df_train, df_val, y_train, y_val = train_test_split(df_train,\n",
    "                                                        df_train['icu_mortality'].values,\n",
    "                                                        test_size=1/8,\n",
    "                                                        random_state=10+seed)\n",
    "    \n",
    "    outpath = os.path.join(split_dir, f'fold{seed+1}')\n",
    "    os.makedirs(outpath, exist_ok=True)\n",
    "    \n",
    "    df_train.to_csv(os.path.join(outpath, 'stays_train.csv'), index=False)\n",
    "    df_val.to_csv(os.path.join(outpath, 'stays_val.csv'), index=False)\n",
    "    df_test.to_csv(os.path.join(outpath, 'stays_test.csv'), index=False)\n",
    "    \n",
    "    # 只获取训练集的stay_id对应的时间序列数据\n",
    "    train_stay_ids = df_train['stay_id'].unique().tolist()\n",
    "    features = chartlab_feature\n",
    "    data = final_table[final_table['stay_id'].isin(train_stay_ids)][features]\n",
    "\n",
    "    # get statistics from training data\n",
    "    stats = {\n",
    "        'median': data.median(),\n",
    "        'iqr': data.quantile(0.75) - data.quantile(0.25),\n",
    "        'mean': data.mean(),\n",
    "        'std': data.std(),\n",
    "        'min': data.min(),\n",
    "        'max': data.max(),\n",
    "    }\n",
    "    # stats = {\n",
    "    # 'median': data.median(),\n",
    "    # 'iqr': data.quantile(0.75) - data.quantile(0.25),\n",
    "    # 'mean': data.mean(),\n",
    "    # 'std': data.std(),\n",
    "    # 'min': data.min(),\n",
    "    # 'max': data.max(),\n",
    "    # 'mode': data.mode().iloc[0] if not data.empty else np.nan,  # 添加众数计算\n",
    "    # }\n",
    "\n",
    "    stats_per_feature = {}\n",
    "    for feature in features:\n",
    "        feat_stats = {}\n",
    "        for stat_var in stats:\n",
    "            value = stats[stat_var][feature]\n",
    "            # 处理None的情况\n",
    "            if value is None:\n",
    "                value = np.nan\n",
    "            feat_stats[stat_var] = float(value)\n",
    "        feat_stats['normalize'] = ('gcs' not in feature) and ('rhythm' not in feature)\n",
    "        stats_per_feature[feature] = feat_stats\n",
    "\n",
    "    with open(os.path.join(outpath, 'train_stats.yaml'), 'w') as f:\n",
    "        yaml.dump(stats_per_feature, f, default_flow_style=False)\n",
    "        \n",
    "    print(f\"Ratio of stays with valid CXRs in training subset: {(df_train['valid_cxrs'].apply(len)>0).sum()/df_train.shape[0]:.2%}\")\n",
    "    print(f\"Ratio of stays with valid CXRs in validation subset: {(df_val['valid_cxrs'].apply(len)>0).sum()/df_val.shape[0]:.2%}\")\n",
    "    print(f\"Ratio of stays with valid CXRs in test subset: {(df_test['valid_cxrs'].apply(len)>0).sum()/df_test.shape[0]:.2%}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "benchmark",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
