{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Learning Parameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pickle\n",
    "import tqdm as tqdm\n",
    "from scipy.linalg import block_diag\n",
    "\n",
    "from utils.utils import MatrixMDP\n",
    "from core.sepsisSimDiabetes.State import State\n",
    "from core.sepsisSimDiabetes.Action import Action\n",
    "import core.sepsisSimDiabetes.MDP as simulator"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The following `config` dictionary contains the necessary information for the rest of the notebook:\n",
    "\n",
    "- `prob_diab`: probability of having diabetes, default = 0.2\n",
    "- `nS`: number of states\n",
    "- `nA`: number of actions\n",
    "- `discount`: MDP discount factor ($\\gamma$)\n",
    "- `epsilon` : probability used for making the soft optimal policy ($\\epsilon$)\n",
    "- `mixture_prob`: percentage of the optimal policy in the mixture of the mixed policy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "config = {'prob_diab': 0.2, 'nS': State.NUM_FULL_STATES, \n",
    "          'nA': Action.NUM_ACTIONS_TOTAL, 'discount': 0.99,\n",
    "         'epsilon': 0.05, 'mixture_prob': 0.85}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "action_idx = 7\n",
    "mod_idx = action_idx\n",
    "term_base = 8/2\n",
    "antibiotic = np.floor(mod_idx/term_base).astype(int)\n",
    "mod_idx %= term_base\n",
    "term_base /= 2\n",
    "ventilation = np.floor(mod_idx/term_base).astype(int)\n",
    "mod_idx %= term_base\n",
    "term_base /= 2\n",
    "vasopressors = np.floor(mod_idx/term_base).astype(int)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Preprocessing"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(\"data/diab_txr_mats-replication.pkl\", \"rb\") as f:\n",
    "    mdict = pickle.load(f)\n",
    "\n",
    "tx_mat = mdict[\"tx_mat\"]\n",
    "r_mat = mdict[\"r_mat\"]\n",
    "p_mixture = np.array([1 - config['prob_diab'], config['prob_diab']])\n",
    "\n",
    "tx_mat_full = np.zeros((config['nA'], config['nS'], config['nS']))\n",
    "r_mat_full = np.zeros((config['nA'], config['nS'], config['nS']))\n",
    "\n",
    "for a in range(config['nA']):\n",
    "    tx_mat_full[a, ...] = block_diag(tx_mat[0, a, ...], tx_mat[1, a,...])\n",
    "    r_mat_full[a, ...] = block_diag(r_mat[0, a, ...], r_mat[1, a, ...])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Add one extra dimension for the new state\n",
    "tx = np.zeros((config['nA'], config['nS'] + 1, config['nS'] + 1))\n",
    "tx[:, :config['nS'], :config['nS']] = np.copy(tx_mat_full)\n",
    "\n",
    "tr = np.zeros((config['nA'], config['nS'] + 1, config['nS'] + 1))\n",
    "tr[:, :config['nS'], :config['nS']] = np.copy(r_mat_full)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for s0 in range(config['nS'] + 1):\n",
    "    for a in range(config['nA']):\n",
    "        for s1 in range(config['nS'] + 1):\n",
    "            #  any where R(a, s0, s1) == 1/ -1\n",
    "            if tr[a, s0, s1] == 1 or tr[a, s0, s1] == -1:\n",
    "                tx[:, s1, :] = 0 # transition probability to any othre state it's zero\n",
    "                tx[:, s1, config['nS']] = 1 # Transition to terminal\n",
    "# Reward is 0 at the terminal state\n",
    "# Terminal state transitions to itself with prob 1.0 \n",
    "tx[:,  config['nS'],  config['nS']] = 1\n",
    "tr[:,  config['nS'],  config['nS']] = 0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('data/tx_tr.pkl', 'wb') as f:\n",
    "    pickle.dump((tx, tr), f)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Optimal Policy and Value function"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "MDP = MatrixMDP(tx, tr)\n",
    "# Policy Iteration for the optimal policy\n",
    "optimal_policy = MDP.policyIteration(discount=config['discount'], \n",
    "                                     eval_type=1).argmax(axis=1)\n",
    "# Value Iteration for the value function\n",
    "V = MDP.valueIteration(discount=config['discount'], epsilon=0.001,\n",
    "                       max_iter=5000)\n",
    "value_function = np.array(V)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Make the soft optimal policy `optimal_policy_st` by assigning `1-epsilon` to the optimal action, and `epsilon` equally distributed among other actions."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "optimal_policy_st = np.zeros((optimal_policy.size, optimal_policy.max()+1))\n",
    "optimal_policy_st[np.arange(optimal_policy.size),optimal_policy] = 1\n",
    "\n",
    "# adjust the following code for varied mixture policies\n",
    "# optimal_policy_st[optimal_policy_st == 1] = 1 - 0.8\n",
    "# optimal_policy_st[optimal_policy_st == 0] = 0.8 / (config['nA'] - 1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('data/optimal_policy_20_st.pkl', 'wb') as f:\n",
    "    pickle.dump(optimal_policy_st, f)\n",
    "# with open('data/value_function.pkl', 'wb') as f:\n",
    "#     pickle.dump(value_function, f)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Mixed Policy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mapping = {1:0, 0:1, 2:3, 3:2, 4:5, 5:4, 6:7, 7:6}\n",
    "# change the optimal action\n",
    "mod_policy = np.copy(optimal_policy)\n",
    "for s in range(mod_policy.shape[0]):\n",
    "    mod_policy[s] = mapping[mod_policy[s]]\n",
    "    \n",
    "# make the mod_policy soft\n",
    "mod_policy_st = np.zeros((mod_policy.size, mod_policy.max()+1))\n",
    "mod_policy_st[np.arange(mod_policy.size), mod_policy] = 1\n",
    "\n",
    "mod_policy_st[mod_policy_st == 1] = 1 - config['epsilon']\n",
    "mod_policy_st[mod_policy_st == 0] = config['epsilon'] / (config['nA'] - 1)\n",
    "\n",
    "# mix two policies\n",
    "mixed_policy = config['mixture_prob'] * optimal_policy_st +\\\n",
    "              (1-config['mixture_prob']) * mod_policy_st"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('data/mixed_policy.pkl', 'wb') as f:\n",
    "    pickle.dump(mixed_policy, f)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### First time step policy (t0_policy)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "t0_policy = np.zeros((2, config['nS'] + 1, config['nA']))\n",
    "# note that, nS+1 is for one extra terminal states\n",
    "\n",
    "t0_policy[0, :, :] = np.copy(optimal_policy_st)\n",
    "t0_policy[1, :, :] = np.copy(optimal_policy_st)\n",
    "\n",
    "# With antibotics\n",
    "t0_policy[0, :, [7,6,5,4]] = t0_policy[0, :, [7,6,5,4]] +\\\n",
    "                             t0_policy[0, :, [3,2,1,0]]\n",
    "t0_policy[0, :, [3,2,1,0]] = 0\n",
    "# Without antibiotics\n",
    "t0_policy[1, :, [3,2,1,0]] = t0_policy[1, :, [7,6,5,4]] +\\\n",
    "                             t0_policy[1, :, [3,2,1,0]]\n",
    "t0_policy[1, :, [7,6,5,4]] = 0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('data/t0_policy.pkl', 'wb') as f:\n",
    "    pickle.dump(t0_policy, f)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
