{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Learning Sepsis Environment MDP\n",
    "\n",
    "Because the sepsis simulator is governed by fairly complex state transition dynamics, in this notebook we will approximate it by sampling. The approach is similar to one utilized in M. Oberst et al., “[Counterfactual Off-Policy Evaluation with Gumbel-Max Structural Causal Models](https://arxiv.org/abs/1905.05824)” paper.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tqdm\n",
    "import pickle\n",
    "import itertools\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pathlib import Path\n",
    "from typing import Tuple"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "from ced.actors.sepsis import SepsisAction\n",
    "from ced.envs.sepsis import Sepsis, State"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To configure our approximation, we introduce a couple of variables:\n",
    "- `SEED`: random seed used for reproducibility\n",
    "- `NUM_ITERATIONS`: number of samples to draw from the simulator to approximate a single transition\n",
    "- `NUM_ACTIONS`: total number of agent's actions\n",
    "- `NUM_STATES`: total number of environment states\n",
    "- `SAVE_PATH`: path to directory where resulting matrices will be saved"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "SEED = 5586\n",
    "NUM_SAMPLES = 10000\n",
    "NUM_ACTIONS = SepsisAction.NUM_TOTAL\n",
    "NUM_STATES = State.NUM_TOTAL\n",
    "SAVE_PATH = Path(\"./results/sepsis\")\n",
    "SAVE_PATH_ORIGINAL = SAVE_PATH / \"mdp_original.pkl\"\n",
    "SAVE_PATH_AI = SAVE_PATH / \"mdp_ai.pkl\"\n",
    "TRANSITIONS_ORIGINAL = Path(\"./assets/sepsis/sepsis_transition_probs_original.json\")\n",
    "TRANSITIONS_AI = Path(\"./assets/sepsis/sepsis_transition_probs_ai.json\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "SAVE_PATH.mkdir(exist_ok=True, parents=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "rng = np.random.default_rng(SEED)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "def learn_mdp(env: Sepsis) -> Tuple[np.ndarray, ...]:\n",
    "    transition_matrix = np.zeros((NUM_ACTIONS, NUM_STATES, NUM_STATES))\n",
    "    reward_matrix = np.zeros((NUM_ACTIONS, NUM_STATES, NUM_STATES))\n",
    "    initial_state_distribution = np.zeros((NUM_STATES, ))\n",
    "\n",
    "    states = range(NUM_STATES)\n",
    "    actions = range(NUM_ACTIONS)\n",
    "    iterations = range(NUM_SAMPLES)\n",
    "\n",
    "    # learn transition matrix\n",
    "    for s_t, a_t, _ in tqdm.tqdm(itertools.product(states, actions, iterations), total=NUM_STATES * NUM_ACTIONS * NUM_SAMPLES, desc=\"Learning MDP transition matrix\"):\n",
    "        s_curr = State.from_index(s_t)\n",
    "        s_next = env.step(state=s_curr, actions=[a_t], rng=rng)\n",
    "        transition_matrix[a_t, s_curr.index, s_next.index] += 1\n",
    "\n",
    "    # normalize transition matrix\n",
    "    transition_matrix /= NUM_SAMPLES\n",
    "    transition_matrix /= transition_matrix.sum(axis=-1, keepdims=True)\n",
    "        \n",
    "    # learn reward matrix\n",
    "    for s_t in tqdm.tqdm(states, desc=\"Learning MDP reward matrix\"):\n",
    "        reward_matrix[:, :, s_t] = State.from_index(s_t).reward\n",
    "\n",
    "    # learn initial state distribution\n",
    "    for _ in tqdm.tqdm(range(NUM_SAMPLES), desc=\"Learning MDP initial state distribution\"):\n",
    "        state = env.reset(rng=rng)\n",
    "        initial_state_distribution[state.index] += 1\n",
    "    \n",
    "    # normalize initial state distribution\n",
    "    initial_state_distribution /= NUM_SAMPLES\n",
    "    initial_state_distribution /= initial_state_distribution.sum(axis=-1, keepdims=True)\n",
    "\n",
    "    return transition_matrix, reward_matrix, initial_state_distribution"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Ground-Truth Transition & Reward Matrices\n",
    "\n",
    "We start our approximation with the transition and reward matrices, using the ground-truth underlying transition probabilities. Throughout this project, we will rely on [PyMDP Toolbox package](https://pymdptoolbox.readthedocs.io/en/latest/index.html) for tasks such as efficiently running policy iteration algorithm. The [format of the transition and reward matrices](https://pymdptoolbox.readthedocs.io/en/latest/api/mdp.html#mdptoolbox.mdp.MDP) expected by the PyMDP is `(A, S, S)`, where `A` is the number of actions and `S` is the number of states."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "if not SAVE_PATH_ORIGINAL.exists():\n",
    "    env = Sepsis(transition_probabilities=TRANSITIONS_ORIGINAL)\n",
    "    transition_matrix, reward_matrix, initial_state_distribution = learn_mdp(env)\n",
    "\n",
    "    with open(SAVE_PATH_ORIGINAL, \"wb\") as f:\n",
    "        pickle.dump({\n",
    "            \"transition_matrix\": transition_matrix,\n",
    "            \"reward_matrix\": reward_matrix,\n",
    "            \"initial_state_distribution\": initial_state_distribution}, f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "if SAVE_PATH_ORIGINAL.exists():\n",
    "    with open(SAVE_PATH_ORIGINAL, \"rb\") as f: data = pickle.load(f)\n",
    "    transition_matrix, reward_matrix = data[\"transition_matrix\"], data[\"reward_matrix\"]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We perform some sanity checks to ensure learned matrices make sense."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "states = [State.from_index(i) for i in range(State.NUM_TOTAL)]\n",
    "s_diabetic = [s for s in states if s.diabetes == 1]\n",
    "s_non_diabetic = [s for s in states if s.diabetes == 0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "for s1, s2 in itertools.product(s_diabetic, s_non_diabetic):\n",
    "    # ensures we cannot transition between diabetic and non-diabetic states\n",
    "    assert (transition_matrix[:, s1.index, s2.index] == 0).all()\n",
    "    assert (transition_matrix[:, s2.index, s1.index] == 0).all()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "for s_index in range(State.NUM_TOTAL):\n",
    "    # ensures states are correctly encoded and decoded\n",
    "    assert s_index == State.from_index(s_index).index"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "# ensures we have proper probabilities\n",
    "assert np.allclose(transition_matrix.sum(axis=-1), 1.0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "# ensures rewards are in the expected range\n",
    "assert {-1.0, 0.0, 1.0} == set(np.unique(reward_matrix).tolist())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## AI Agent MDP\n",
    "\n",
    "In our experiments, it is crucial that policy of the AI is (at least partially) different than the policy of the human. To achieve this, we will modify the underlying probabilities of the environment to obtain an updated MDP, which will then be used to learn the AI policy.\n",
    "\n",
    "### Modifying Probabilities\n",
    "\n",
    "The goal of this approach is to make AI policy **generally give higher doses of medications** than a clinician policy. To achieve this, we will increase the following probabilities:\n",
    "\n",
    "- Probability of successful medication effect. This implies that, whenever AI policy prescribes a medication, that medication is expected to work with higher probability. \n",
    "- Probability of diverting from normal when removing medication. When taking patient off a medication, we increase the probability that the patient's state diverts from normal (i.e., becomes lower or higher).\n",
    "\n",
    "Jointly, these two points should incentivise AI policy to overall prescribe more medications than the human policy. The probabilities are stored in `assets/sepsis_transition_probs.ai.json` file."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "if not SAVE_PATH_AI.exists():\n",
    "    env = Sepsis(transition_probabilities=TRANSITIONS_AI)\n",
    "    transition_matrix, reward_matrix, initial_state_distribution = learn_mdp(env)\n",
    "\n",
    "    with open(SAVE_PATH_AI, \"wb\") as f:\n",
    "        pickle.dump({\n",
    "            \"transition_matrix\": transition_matrix,\n",
    "            \"reward_matrix\": reward_matrix,\n",
    "            \"initial_state_distribution\": initial_state_distribution}, f)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "mpi-ase",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.13"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
