{
 "cells": [
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "A notebook to produce hf datasets for OOD."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from collections import Counter\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "import pickle\n",
    "import random\n",
    "pd.set_option('display.max_colwidth', None)\n",
    "random.seed(42)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### SET LOCAL PATH TO MIMIC_IV BELOW"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "PATH_TO_MIMIC = '/root/data/mimic-iv_data/hosp'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(\"./data/finetune_disease_test_ids.pkl\", \"rb\") as f:\n",
    "    finetune_d_t_ids = pickle.load(f)\n",
    "\n",
    "with open(\"./data/finetune_disease_train_ids.pkl\", \"rb\") as f:\n",
    "    finetune_d_tr_ids = pickle.load(f)\n",
    "\n",
    "with open(\"./data/finetune_mortality_test_ids.pkl\", \"rb\") as f:\n",
    "    finetune_m_t_ids = pickle.load(f)\n",
    "\n",
    "with open(\"./data/finetune_mortality_train_ids.pkl\", \"rb\") as f:\n",
    "    finetune_m_tr_ids = pickle.load(f)\n",
    "\n",
    "with open(\"./data/pretrain_ids.pkl\", \"rb\") as f:\n",
    "    pretrain_ids = pickle.load(f)\n",
    "\n",
    "with open(\"./data/finetune_testset_ids.pkl\", \"rb\") as f:\n",
    "    finetune_t_ids = pickle.load(f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "path_to_data = PATH_TO_MIMIC\n",
    "dfd = pd.read_csv(f'{path_to_data}/diagnoses_icd.csv')\n",
    "dfd['icd_code'] = dfd['icd_code'] + '-' + dfd['icd_version'].astype(str)\n",
    "dfa = pd.read_csv(f'{path_to_data}/admissions.csv')\n",
    "dfp = pd.read_csv(f'{path_to_data}/patients.csv')\n",
    "dfd = dfd.drop('icd_version', axis=1)\n",
    "dfa = dfa[['subject_id', 'hadm_id', 'admittime']]\n",
    "dfp = dfp[['subject_id', \"dod\"]]\n",
    "df = pd.merge(dfd, dfa, how='inner')\n",
    "df = pd.merge(df, dfp, how='inner')\n",
    "df['admittime'] = pd.to_datetime(df['admittime'])\n",
    "df['dod'] = pd.to_datetime(df['dod'])\n",
    "df['days_until_death'] = df['dod'] - df['admittime']\n",
    "df['days_until_death'] = df['days_until_death'].apply(lambda x: x.days)\n",
    "df = df.sort_values(by=['subject_id', 'admittime', 'seq_num'], ascending=True)\n",
    "with open(\"./data/icd_to_vsa_data.pkl\", \"rb\") as f:\n",
    "    icd_to_vsa_data = pickle.load(f)\n",
    "clean_codes = np.array(list(icd_to_vsa_data.keys()))\n",
    "df_clean = df[df.icd_code.isin(clean_codes)]\n",
    "print(df_clean['icd_code'].unique().shape[0])\n",
    "df = df_clean\n",
    "subjects_admissions = {}\n",
    "for item in df[['subject_id', 'hadm_id']].values.tolist():\n",
    "    subjects_admissions.setdefault(item[0], []).append(item[1])\n",
    "admission_enum = {}\n",
    "for subject in subjects_admissions:\n",
    "    admissions = list(dict.fromkeys(subjects_admissions[subject]))\n",
    "    for i, hadm_id in enumerate(admissions, start=1):\n",
    "        admission_enum.setdefault('hadm_id', []).append(hadm_id)\n",
    "        admission_enum.setdefault('visit_order', []).append(i)\n",
    "dfv = pd.DataFrame(admission_enum)\n",
    "dfv\n",
    "df = pd.merge(df, dfv)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df = df[['subject_id', 'visit_order', 'seq_num', 'icd_code', 'days_until_death']]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tqdm import tqdm\n",
    "\n",
    "def createDict(df, savename):\n",
    "    dd = {}\n",
    "    for (index, subject_id, visit_order, seq_num, icd_code, days) in tqdm(df.itertuples()):\n",
    "        dd.setdefault(subject_id, {}).setdefault(\"subject_id\", []).append(subject_id)\n",
    "        dd.setdefault(subject_id, {}).setdefault(\"icd_code\", []).append(icd_code)\n",
    "        dd.setdefault(subject_id, {}).setdefault(\"visit_order\", []).append(visit_order)\n",
    "        dd.setdefault(subject_id, {}).setdefault(\"seq_num\", []).append(seq_num)\n",
    "        dd.setdefault(subject_id, {}).setdefault(\"days_until_death\", []).append(days)\n",
    "    with open(savename, 'wb') as f:\n",
    "        pickle.dump(dd, f)\n",
    "\n",
    "\n",
    "createDict(df[df['subject_id'].isin(finetune_t_ids)], './data/ft_dict_with_subject_id.pkl')\n",
    "createDict(df[df['subject_id'].isin(pretrain_ids)], './data/pretrain_dict_with_subject_id.pkl')\n",
    "createDict(df[df['subject_id'].isin(finetune_d_t_ids)], './data/fd_t_dict_with_subject_id.pkl')\n",
    "createDict(df[df['subject_id'].isin(finetune_d_tr_ids)], './data/fd_tr_dict_with_subject_id.pkl')\n",
    "createDict(df[df['subject_id'].isin(finetune_m_t_ids)], './data/fm_t_dict_with_subject_id.pkl')\n",
    "createDict(df[df['subject_id'].isin(finetune_m_tr_ids)], './data/fm_tr_dict_with_subject_id.pkl')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from datasets import Dataset\n",
    "from transformers import AutoTokenizer\n",
    "import pickle\n",
    "\n",
    "tokenizer = AutoTokenizer.from_pretrained('./tokenizer-mimic-iv-icd-final/')\n",
    "\n",
    "def createData(filename, savelocation, testset_split=False, test_size=0.1):\n",
    "    # Open from local\n",
    "    with open(filename, 'rb') as f:\n",
    "        datadict = pickle.load(f)\n",
    "\n",
    "    tokenization_params = {\n",
    "        'max_length': 128,\n",
    "        'truncation': True,\n",
    "        'padding': 'max_length',\n",
    "        'is_split_into_words': True,\n",
    "        'return_special_tokens_mask': True\n",
    "    }\n",
    "\n",
    "    dd = {}\n",
    "    for d in datadict.values():\n",
    "        for k, v in d.items():\n",
    "            dd.setdefault(k, []).append(v)\n",
    "\n",
    "    ds = Dataset.from_dict(dd)\n",
    "\n",
    "    def process_function(entry, tokenization_params):\n",
    "        codes = entry['icd_code']\n",
    "        # Get coding embedding ids\n",
    "        encoding = tokenizer(codes, **tokenization_params)\n",
    "        # Add token type ids for visit sequence\n",
    "        num_codes_to_keep = min(tokenization_params['max_length'] - 2, len(codes))\n",
    "        encoding['token_type_ids'][1:num_codes_to_keep+1] = entry['visit_order'][:num_codes_to_keep]\n",
    "        entry.update(encoding)\n",
    "        return entry\n",
    "\n",
    "    ds = ds.map(lambda x: process_function(x, tokenization_params), num_proc=4, remove_columns=['icd_code', 'visit_order', 'seq_num'])\n",
    "    print(ds)\n",
    "    if testset_split:\n",
    "        ds = ds.train_test_split(test_size=test_size)\n",
    "    ds.save_to_disk(savelocation)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "createData('./data/ft_dict_with_subject_id.pkl', './data/Finetuning_Testset/')\n",
    "createData('./data/fd_t_dict_with_subject_id.pkl', './data/Finetuning_Disease_Prediction/train/')\n",
    "createData('./data/fd_tr_dict_with_subject_id.pkl', './data/Finetuning_Disease_Prediction/test/')\n",
    "createData('./data/fm_t_dict_with_subject_id.pkl', './data/Finetuning_Mortality_Prediction/train/')\n",
    "createData('./data/fm_tr_dict_with_subject_id.pkl', './data/Finetuning_Mortality_Prediction/test/')\n",
    "createData('./data/pretrain_dict_with_subject_id.pkl', './data/Pretraining/', testset_split=True)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
