{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Process the data for the sepsis treatment experiment\n",
    "\n",
    "Prerequisite:\n",
    " - Apply for the access of MIMIC III dataset here: https://mimic.mit.edu/iii/gettingstarted/\n",
    " - Follow the instruction in the AI Clinician repo to extract MIMICtable.mat (from file  AIClinician_mimic3_dataset_160219.m) https://gitlab.doc.ic.ac.uk/AIClinician/AIClinician\n",
    " - Run fixMIMICTable.m and extract_csv_data.m to preprocess the data to get MIMIC_outcome.csv table\n",
    " \n",
    "What this file does:\n",
    " - generate the s45da_mimic_train_episodes, s45da_mimic_valid_episodes, s45da_mimic_test_episodes in the project_folder/data for the use of RL codes."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "from tqdm.notebook import tqdm\n",
    "import pickle\n",
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data_folder = \"../../data\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data = pd.read_csv('sepsis_matlab/MIMIC_outcome.csv')\n",
    "for i in tqdm(range(data.shape[0])):\n",
    "    if i != data.shape[0]-1 and (data['id'][i] == data['id'][i+1]):\n",
    "        data.iloc[i,-1] = 0\n",
    "    else:\n",
    "        data.iloc[i,-1] = 100*(1-data.iloc[i,-1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "patient_ids = data['id'].unique()\n",
    "n_patients = patient_ids.shape[0]\n",
    "print(\"Total number of patients : {}\".format(n_patients))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('sepsis_matlab/train_val_test_split.pkl', 'rb') as f:\n",
    "    split_data = pickle.load(f)\n",
    "    val_patient_id = split_data['val_patient_id']\n",
    "    train_patient_id = split_data['train_patient_id']\n",
    "    test_patient_id = split_data['test_patient_id']\n",
    "n_train_patients = len(train_patient_id)\n",
    "n_val_patients = len(val_patient_id)\n",
    "n_test_patients = len(test_patient_id)\n",
    "print(n_train_patients,n_val_patients,n_test_patients)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class KNN_policy():\n",
    "    def __init__(self, df, columns, k=100, weight=None):\n",
    "        ''' Learning KNN policy with KNN \n",
    "        input\n",
    "        -----\n",
    "        df : dataframe\n",
    "        k : int number of neighbours\n",
    "        weight : weight vector \n",
    "        '''\n",
    "        # Columns in df that are covaraites\n",
    "        self.columns = columns\n",
    "        if weight is None:\n",
    "            weight = np.ones((len(self.columns), 1))\n",
    "        self.weight = pd.DataFrame(weight).T\n",
    "        self.weight.columns = self.columns\n",
    "        self.weight = self.weight.iloc[0, :]\n",
    "        \n",
    "        self.X = df.loc[:, self.columns] # feature data frame\n",
    "        self.A = df.loc[:, 'action'].to_numpy() # action data frame\n",
    "        self.nA = 26 # number of actions : actions are indexed from 1 (Matalb :X)\n",
    "        self.K = k\n",
    "        self.whichA = np.ones((self.A.shape[0],self.nA))*np.inf\n",
    "        self.whichA[np.arange(self.A.shape[0]),self.A] = 1\n",
    "        \n",
    "    def get_action_probability(self, x):\n",
    "        if isinstance(x, pd.DataFrame):\n",
    "            assert x.shape[0] == len(self.columns), 'Shape of iuput doesn\\'t match, got {}'.format(x.shape)\n",
    "        if isinstance(x, np.ndarray):\n",
    "            assert x.shape[0] == len(self.columns), 'Shape of iuput doesn\\'t match, got {}'.format(x.shape)\n",
    "            x = pd.DataFrame(x).T\n",
    "            x.columns = self.columns\n",
    "\n",
    "        diff = (self.X - x.iloc[0, :])**2\n",
    "        distance = (diff*self.weight).mean(axis=1)\n",
    "        # indexes of 100 closes\n",
    "        idxs = np.argsort(distance.to_numpy())[:self.K]\n",
    "        actions = self.A[idxs]\n",
    "        # action prob:\n",
    "        action_prob = np.zeros(self.nA)\n",
    "        for a in actions:\n",
    "            action_prob[int(a)] += 1\n",
    "        return action_prob/np.sum(action_prob)\n",
    "    \n",
    "    def get_nearest_same_action(self, x):\n",
    "        if isinstance(x, pd.DataFrame):\n",
    "            assert x.shape[0] == len(self.columns), 'Shape of iuput doesn\\'t match, got {}'.format(x.shape)\n",
    "        if isinstance(x, np.ndarray):\n",
    "            assert x.shape[0] == len(self.columns), 'Shape of iuput doesn\\'t match, got {}'.format(x.shape)\n",
    "            x = pd.DataFrame(x).T\n",
    "            x.columns = self.columns\n",
    "        \n",
    "        diff = (self.X - x.iloc[0, :])**2\n",
    "        distance = (diff*self.weight).mean(axis=1)\n",
    "        # indexes of 100 closes\n",
    "        distance = distance.to_numpy()\n",
    "        distance = distance*(self.whichA.transpose())\n",
    "        action_dist = np.nanmin(distance,axis=-1)\n",
    "        return action_dist"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "COLUMNS = ['gender', 're_admission', 'mechvent', 'age', 'Weight_kg',\n",
    "                   'GCS', 'HR', 'SysBP', 'MeanBP', 'DiaBP', 'RR', 'Temp_C', 'FiO2_1',\n",
    "                   'Potassium', 'Sodium', 'Chloride', 'Glucose', 'Magnesium', 'Calcium',\n",
    "                   'Hb', 'WBC_count', 'Platelets_count', 'PTT', 'PT', 'Arterial_pH',\n",
    "                   'paO2', 'paCO2', 'Arterial_BE', 'Arterial_lactate', 'HCO3',\n",
    "                   'Shock_Index', 'PaO2_FiO2', 'cumulated_balance', 'SOFA', 'SIRS', 'SpO2',\n",
    "                   'BUN', 'Creatinine', 'SGOT', 'SGPT', 'Total_bili', 'INR',\n",
    "                   'output_total', 'output_4hourly']\n",
    "\n",
    "train_df = data.loc[data['id'].isin(train_patient_id)]\n",
    "val_df = data.loc[data['id'].isin(val_patient_id)]\n",
    "test_df = data.loc[data['id'].isin(test_patient_id)]\n",
    "\n",
    "train_policy = KNN_policy(train_df, columns=COLUMNS, k=100)\n",
    "val_policy = KNN_policy(val_df, columns=COLUMNS, k=100)\n",
    "test_policy = KNN_policy(test_df, columns=COLUMNS, k=100)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_da = train_df.iloc[:,2:46].to_numpy()\n",
    "val_da = val_df.iloc[:,2:46].to_numpy()\n",
    "test_da = test_df.iloc[:,2:46].to_numpy()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Get nearest neighbor distance with same action"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_nearest_same_action = np.zeros((len(train_da), 25))\n",
    "\n",
    "for i in tqdm(range(len(train_da))):\n",
    "    train_nearest_same_action[i,:] = train_policy.get_nearest_same_action(train_da[i])[1:]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "np.save(f\"{data_folder}/s45da_mimic_train_cur_nn_action_dist.npy\", train_nearest_same_action)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "val_nearest_same_action = np.zeros((len(val_da), 25))\n",
    "\n",
    "for i in tqdm(range(len(val_da))):\n",
    "    val_nearest_same_action[i,:] = train_policy.get_nearest_same_action(val_da[i])[1:]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "np.save(f\"{data_folder}/s45da_mimic_val_cur_nn_action_dist.npy\", val_nearest_same_action)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_nearest_same_action = np.zeros((len(test_da), 25))\n",
    "\n",
    "for i in tqdm(range(len(test_da))):\n",
    "    test_nearest_same_action[i,:] = train_policy.get_nearest_same_action(test_da[i])[1:]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "np.save(f\"{data_folder}/s45da_mimic_test_cur_nn_action_dist.npy\", test_nearest_same_action)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Get kNN probability"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Computer (vectors of) probabilities for all states in dataset\n",
    "def compute_pibs(df, knn):\n",
    "    da = np.array(df)\n",
    "    pibs = np.zeros((len(da), 25))\n",
    "    \n",
    "    for i in tqdm(range(len(da))):\n",
    "        pibs[i,:] = knn.get_action_probability(da[i][2:46])[1:] \n",
    "        # kNN policy return values indexed from 1 !!!!!\n",
    "    return pibs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_current_pibs = compute_pibs(train_df, train_policy)\n",
    "np.save(f\"{data_folder}/s45da_mimic_train_cur_pibs.npy\", train_current_pibs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "val_current_pibs = compute_pibs(val_df, val_policy)\n",
    "val_estm_current_pibs = compute_pibs(val_df, train_policy)\n",
    "np.save(f\"{data_folder}/s45da_mimic_val_estm_cur_pibs.npy\", val_estm_current_pibs)\n",
    "np.save(f\"{data_folder}/s45da_mimic_val_cur_pibs.npy\", val_current_pibs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_current_pibs = compute_pibs(test_df, test_policy)\n",
    "test_estm_current_pibs = compute_pibs(test_df, train_policy)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "np.save(f\"{data_folder}/s45da_mimic_test_cur_pibs.npy\", test_current_pibs)\n",
    "np.save(f\"{data_folder}/s45da_mimic_test_estm_cur_pibs.npy\", test_estm_current_pibs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Construct dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def episodic_dataset(train_data, pibs, estm_pibs, nn_action_dist, obs_dim, n_patient, include_step=True):\n",
    "    # Formatted the training set to be the replay buffer in BCQ/PQL code\n",
    "    formatted_dataset = {'observations': np.zeros((n_patient,20,obs_dim-1)),\n",
    "                         'actions': np.zeros((n_patient,20,1), np.int),\n",
    "                         'rewards': np.zeros((n_patient,20,1)),\n",
    "                         'not_done': np.zeros((n_patient,20,1)),\n",
    "                         'pibs': np.zeros((n_patient,20,25)),\n",
    "                         'estm_pibs': np.zeros((n_patient,20,25)),\n",
    "                         'nn_action_dist': np.ones((n_patient,20,25))*1e9,\n",
    "                        }\n",
    "    \n",
    "    n_eps = 0\n",
    "    t = 0\n",
    "    for i in range(0, len(train_data)):\n",
    "        if include_step:\n",
    "            formatted_dataset['observations'][n_eps,t,:] = np.array(train_data[i][1:obs_dim])\n",
    "        else:\n",
    "            formatted_dataset['observations'][n_eps,t,:] = np.array(train_data[i][2:obs_dim])\n",
    "        formatted_dataset['actions'][n_eps,t,:] = train_data[i][obs_dim] - 1\n",
    "        formatted_dataset['rewards'][n_eps,t,:] = train_data[i][obs_dim+1]\n",
    "        formatted_dataset['pibs'][n_eps,t,:] = pibs[i,:]\n",
    "        formatted_dataset['estm_pibs'][n_eps,t,:] = estm_pibs[i,:]\n",
    "        formatted_dataset['nn_action_dist'][n_eps,t,:] = nn_action_dist[i,:]\n",
    "        if i != len(train_data)-1 and train_data[i+1][0] == train_data[i][0]:\n",
    "            # haven't moved to the next patient, same episode\n",
    "            formatted_dataset['not_done'][n_eps,t,:] = 1\n",
    "            t += 1\n",
    "        else:\n",
    "            formatted_dataset['not_done'][n_eps,t:,:] = 0\n",
    "            formatted_dataset['pibs'][n_eps,t+1:,:] = 1.0\n",
    "            formatted_dataset['estm_pibs'][n_eps,t+1:,:] = 1.0\n",
    "            n_eps += 1\n",
    "            t = 0\n",
    "    print(\"Dataset with\",n_eps,\"episodes\")\n",
    "    return formatted_dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "obs_dimensions = 46"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_dataset = episodic_dataset(train_df.to_numpy(), train_current_pibs, train_current_pibs, \n",
    "                                 train_nearest_same_action, obs_dimensions, len(train_patient_id))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "val_dataset = episodic_dataset(val_df.to_numpy(), val_current_pibs, val_estm_current_pibs, \n",
    "                               val_nearest_same_action, obs_dimensions, len(val_patient_id))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_dataset = episodic_dataset(test_df.to_numpy(), test_current_pibs, test_estm_current_pibs, \n",
    "                                test_nearest_same_action, obs_dimensions, len(test_patient_id))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(f\"{data_folder}/s45da_mimic_train_episodes\", 'wb') as f:\n",
    "    pickle.dump(train_dataset, f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(f\"{data_folder}/s45da_mimic_val_episodes\", 'wb') as f:\n",
    "    pickle.dump(val_dataset, f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(f\"{data_folder}/s45da_mimic_test_episodes\", 'wb') as f:\n",
    "    pickle.dump(test_dataset, f)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
