{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "f407150e-d967-4dfc-943c-65a9506a71dd",
   "metadata": {},
   "source": [
    "## Setup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "6bacb236-523c-44ab-8a32-d663f1d5d747",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import pickle\n",
    "from typing import List, Dict"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "697ff8d3-ec8e-4ebe-aeda-f9eabda3f915",
   "metadata": {},
   "outputs": [],
   "source": [
    "# These pickles are the outputs of the medbert repo: https://github.com/ZhiGroup/Med-BERT/blob/master/Pretraining%20Code/Readme.md\n",
    "\n",
    "MEDBERT_CODES_DICT_PATH = 'eicu_crd.types'\n",
    "PRETRAINED_TRAIN_PICKLE_PATH = 'eicu_crd.bencs.train'\n",
    "PRETRAINED_VALIDATION_PICKLE_PATH = 'eicu_crd.bencs.valid'\n",
    "PRETRAINED_TEST_PICKLE_PATH = 'eicu_crd.bencs.test'\n",
    "\n",
    "MEDBERT_OUTPUT_PICKLES_DIR = 'eicu_crd'\n",
    "TARGET_DISEASE_IDS = ['51881', '51882', '51883', '51884', '7991'] # icd9 codes (according CCS category) for Adlt resp that found in eicu-crd data.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "de36c57e-abbe-41cb-a96a-ae95ec19cc5d",
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(PRETRAINED_TRAIN_PICKLE_PATH, 'rb') as f:\n",
    "    medbert_train_data = pickle.load(f)\n",
    "with open(PRETRAINED_VALIDATION_PICKLE_PATH, 'rb') as f:\n",
    "    medbert_validation_data = pickle.load(f)\n",
    "with open(PRETRAINED_TEST_PICKLE_PATH, 'rb') as f:\n",
    "    medbert_test_data = pickle.load(f)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f3658eb5-2fe8-4c02-bb56-af03a117313f",
   "metadata": {},
   "outputs": [],
   "source": [
    "len(medbert_train_data), len(medbert_validation_data), len(medbert_test_data)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "568b598a-6aff-4ccd-9c72-2d4bdc61e056",
   "metadata": {},
   "source": [
    "### Convert pickle to df\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dd28ae15-a397-43d0-9066-bd718e66ba13",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_df = pd.DataFrame(medbert_train_data, columns= ['person_id', 'los', 'time_not_used', 'code', 'visits'])\n",
    "validation_df = pd.DataFrame(medbert_validation_data, columns= ['person_id', 'los', 'time_not_used', 'code', 'visits'])\n",
    "test_df = pd.DataFrame(medbert_test_data, columns= ['person_id', 'los', 'time_not_used', 'code', 'visits'])\n",
    "for x in (train_df, validation_df, test_df):\n",
    "    x.drop(columns=['los', 'time_not_used'], inplace=True)\n",
    "    x.drop(x[x['code'].apply(lambda x: len(x) <= 1)].index, inplace=True) # remove patients with only one diagnosis\n",
    "train_df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3d6f96c9-30e7-4db9-bd7e-f17451474c7a",
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(MEDBERT_CODES_DICT_PATH, 'rb') as f:\n",
    "          code_to_id_dict = pickle.load(f)\n",
    "print(code_to_id_dict)\n",
    "\n",
    "def convert_codes_to_ids(codes: List[str], code_to_id_dict: Dict[str, int]):\n",
    "    converted_codes = []\n",
    "    for code in codes: \n",
    "        converted_codes.append(code_to_id_dict[str(code)])\n",
    "    return converted_codes\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a8c02ffe-2a59-4afd-92b6-2eb8eba80003",
   "metadata": {},
   "outputs": [],
   "source": [
    "id_to_code_dict = {v:k for k, v in code_to_id_dict.items()}\n",
    "file_path = 'eicu_crd_id_to_code.types'\n",
    "with open(file_path, 'wb') as file:\n",
    "    pickle.dump(id_to_code_dict, file)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "12b37a75-8620-4644-bc74-3bdf2c2a7a41",
   "metadata": {},
   "outputs": [],
   "source": [
    "def convert_vocab_to_with_icd_dot():\n",
    "    id_to_icd9_dot = {v:k for k, v in code_to_id_dict.items()}\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "e8142cff-2794-4309-a1cf-96ff1fd73c03",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[14, 10, 131, 239, 122]"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "TARGET_DISEASE_IDS = convert_codes_to_ids(TARGET_DISEASE_IDS, code_to_id_dict)\n",
    "TARGET_DISEASE_IDS"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3b66e1d3-9eb7-499b-b495-5b7ef1209002",
   "metadata": {},
   "source": [
    "\n",
    "## Convert to medbert format"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8a626213-1465-4e21-9118-4f0e5c8959c1",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_df\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "6e5b5c4f-cf87-4094-9155-4be277267ae2",
   "metadata": {},
   "outputs": [],
   "source": [
    "def add_sep_between_visits(row):\n",
    "    codes, visits = row.code, row.visits\n",
    "    new_codes = []\n",
    "    new_visits = []\n",
    "    \n",
    "    for i in range(len(codes)):\n",
    "        new_codes.append(codes[i])\n",
    "        new_visits.append(visits[i])\n",
    "        if i < len(codes) - 1 and visits[i] != visits[i + 1]:\n",
    "            new_codes.append('SEP')\n",
    "            new_visits.append('SEP')\n",
    "    new_codes.append('SEP')\n",
    "    new_visits.append('SEP')\n",
    "    assert len(new_codes) == len(new_visits)\n",
    "    return new_codes, new_visits\n",
    "\n",
    "for x in (train_df, validation_df, test_df):\n",
    "    x[['code', 'visits']] = x.apply(add_sep_between_visits, axis=1, result_type='expand')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0764fa47-0c05-415c-9452-aa43b23493d4",
   "metadata": {},
   "outputs": [],
   "source": [
    "def has_first_dignosis_target(inner_list):\n",
    "    first_diag = inner_list[0]\n",
    "    return first_diag in TARGET_DISEASE_IDS\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "595fb128-f933-434f-9fd3-eacbeef978f4",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "False"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "has_first_dignosis_target([19, 29, 31, 102, 238, 8, 'SEP'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "eaa9f7ec-80db-4f56-a79b-f05a2da56b63",
   "metadata": {},
   "outputs": [],
   "source": [
    "def target_disease_in_first_diagnosis(row):\n",
    "    codes = row.code\n",
    "    return codes[0] in TARGET_DISEASE_IDS\n",
    "    \n",
    "for x in (train_df, validation_df, test_df):\n",
    "    mask = x.apply(target_disease_in_first_diagnosis, axis=1)\n",
    "    x.drop(index=x[mask].index, inplace=True)\n",
    "\n",
    "train_df.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b8943795-30c0-450a-84d8-91c3cbe88d1f",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_df.shape, validation_df.shape, test_df.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "61563d8d-7b6f-4100-8bf4-44e5c2c83b0b",
   "metadata": {},
   "source": [
    "### To medbert pickle format for fine-tuning"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "cee913f8-a220-4f21-80cc-dd643aa7b210",
   "metadata": {},
   "outputs": [],
   "source": [
    "from typing import List\n",
    "import random\n",
    "\n",
    "def has_target_disease(target_disease_ids: List[str], codes: List[str]):\n",
    "    # return True if at least one from target_disease_ids can be found in codes and its index in the codes.\n",
    "    for index, code in enumerate(codes):\n",
    "        if code in target_disease_ids:\n",
    "            if index == 0:\n",
    "                print('Error!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!')\n",
    "                print(codes)\n",
    "            return True\n",
    "    return False\n",
    "\n",
    "def random_negative_index(codes: List[str]):\n",
    "    # get a random index according the size of codes, and return the list until this index (included this index). \n",
    "    if not codes:\n",
    "        print('Empty list of codes!!!!!!!!!!!!!!!!!!!!!')\n",
    "        return []\n",
    "    \n",
    "    random_index = random.randint(0, len(codes) - 1)\n",
    "    return codes[:random_index+1]\n",
    "\n",
    "def get_positive_index(codes: List[str]): \n",
    "    if not codes:\n",
    "        print('Empty list of codes!!!!!!!!!!!!!!!!!!!!!')\n",
    "        return []\n",
    "    for index, code in enumerate(codes):\n",
    "        if code in TARGET_DISEASE_IDS:\n",
    "            return codes[:index]\n",
    "    print('Index was not found!')\n",
    "    return []\n",
    "\n",
    "def preprocess_patient_records(patient_list_data: List[str], was_target_found: bool):\n",
    "    # patient_list_data can be codes, ages, years. \n",
    "    if was_target_found:\n",
    "        return get_positive_index(patient_list_data)\n",
    "    return random_negative_index(patient_list_data)\n",
    "\n",
    "def preprocess_patient_data(row):\n",
    "    person_id = row['person_id']\n",
    "    codes = row['code']\n",
    "    visits = row['visits']\n",
    "    assert len(codes) == len(visits)\n",
    "    classification_binary_label = has_target_disease(TARGET_DISEASE_IDS, codes)\n",
    "    codes = preprocess_patient_records(codes, classification_binary_label)\n",
    "    visits = visits[:len(codes)]\n",
    "    assert len(codes) == len(visits)\n",
    "\n",
    "    return codes, visits, 1 if classification_binary_label else 0\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "427a9bf1-3c5d-4d07-8523-639834309a6e",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a6607c8f-4fb8-4308-a07d-d518efa52ff3",
   "metadata": {},
   "outputs": [],
   "source": [
    "for x in (train_df, validation_df, test_df):\n",
    "    x[['code', 'visits', 'label']] = x.apply(preprocess_patient_data, axis=1, result_type='expand')\n",
    "train_df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "420a8430-4bdf-4f87-914e-03052af38840",
   "metadata": {},
   "outputs": [],
   "source": [
    "def filter_sep(row):\n",
    "    codes = row.code\n",
    "    visits_num = row.visits\n",
    "    indexes_to_remove = []\n",
    "    assert len(codes) == len(visits_num)\n",
    "    for index, code in enumerate(codes):\n",
    "        if code == 'SEP':\n",
    "            indexes_to_remove.append(index)\n",
    "    assert len(codes) == len(visits_num)\n",
    "    codes = [code for i, code in enumerate(codes) if i not in indexes_to_remove]\n",
    "    visits_num = [num for i, num in enumerate(visits_num) if i not in indexes_to_remove]\n",
    "    assert len(codes) == len(visits_num)\n",
    "    return codes, visits_num\n",
    "\n",
    "for x in (train_df, validation_df, test_df):\n",
    "    x[['code', 'visits']] = x.apply(filter_sep, axis=1, result_type='expand')\n",
    "\n",
    "train_df.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cab0ea7a-b9af-4e63-a8ff-ae3c45030224",
   "metadata": {},
   "source": [
    "### To pickles"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "id": "c286362e-f43d-4ab0-b4e8-d2afcedd4e0e",
   "metadata": {},
   "outputs": [],
   "source": [
    "def write_df_to_pickle(df: pd.DataFrame, pickle_output_dir: str, df_type: str, disease_name: str):\n",
    "    # df with columns: person_id, code\n",
    "    # Create a list to store patient records\n",
    "    patient_records = []\n",
    "\n",
    "    # Iterate over each row in the DataFrame\n",
    "    for index, row in df.iterrows():\n",
    "        # Extract the necessary information from the row\n",
    "        pt_id = row['person_id']\n",
    "        label = row['label']\n",
    "        seq_list = row['code']\n",
    "        segment_list = row['visits']\n",
    "        assert len(seq_list) == len(segment_list)\n",
    "        \n",
    "        # Create a patient record as a sublist\n",
    "        patient_record = [pt_id, label, seq_list, segment_list]\n",
    "        # Append the patient record to the list of patient records\n",
    "        patient_records.append(patient_record)\n",
    "\n",
    "    # Write the list of patient records to a pickle file\n",
    "    output_pickle_path = f'{pickle_output_dir}/{disease_name}_{df_type}.pickle'\n",
    "    with open(output_pickle_path, 'wb') as file:\n",
    "        pickle.dump(patient_records, file)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "id": "88fa0c77-637c-417c-94ef-bb0d59e7fcc7",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'/sise/home/benshoho/projects/Med-BERT/Fine-Tunning-Tutorials/data/eicu_crd'"
      ]
     },
     "execution_count": 29,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "MEDBERT_OUTPUT_PICKLES_DIR"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "id": "590eaca9-8caa-4046-a8f6-8ae030314d79",
   "metadata": {},
   "outputs": [],
   "source": [
    "for current_df, current_df_type in zip([train_df, validation_df, test_df], ['train', 'validation', 'test']):\n",
    "    write_df_to_pickle(current_df, MEDBERT_OUTPUT_PICKLES_DIR, current_df_type, disease_name='eicu_crd_adult_respiratory_failure')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "id": "ab7202dc-efc4-4988-b10b-7da481827005",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'/sise/home/benshoho/projects/Med-BERT/Fine-Tunning-Tutorials/data/eicu_crd'"
      ]
     },
     "execution_count": 31,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "MEDBERT_OUTPUT_PICKLES_DIR"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}