{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import os\n",
    "import matplotlib.pyplot as plt\n",
    "np.seterr(divide = 'ignore') \n",
    "from scipy.linalg import block_diag\n",
    "import warnings\n",
    "warnings.simplefilter(action='ignore', category=FutureWarning)\n",
    "import matplotlib\n",
    "matplotlib.rcParams['pdf.fonttype'] = 42\n",
    "matplotlib.rcParams['ps.fonttype'] = 42\n",
    "import cf.counterfactual as cf\n",
    "import networkx as nx\n",
    "import copy\n",
    "import pickle"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Action(object):\n",
    "    NUM_ACTIONS_TOTAL = 8\n",
    "    ANTIBIOTIC_STRING = \"antibiotic\"\n",
    "    VENT_STRING = \"ventilation\"\n",
    "    VASO_STRING = \"vasopressors\"\n",
    "    ACTION_VEC_SIZE = 3\n",
    "\n",
    "    def __init__(self, selected_actions = None, action_idx = None):\n",
    "        # This method sets up the action object. \n",
    "        # Actions can be specified in two ways: by providing a list of selected actions (as strings) or by an action index.\n",
    "        \n",
    "        assert (selected_actions is not None and action_idx is None) \\\n",
    "            or (selected_actions is None and action_idx is not None), \\\n",
    "            \"must specify either set of action strings or action index\"\n",
    "            \n",
    "        if selected_actions is not None:\n",
    "            # For each of the three treatments (ANTIBIOTIC_STRING, VENT_STRING, VASO_STRING), the code checks if its corresponding string is present in the selected_actions. \n",
    "            # If it is, the relevant attribute (e.g., self.antibiotic) is set to 1, indicating that treatment is selected. Otherwise, it's set to 0, indicating the treatment is not selected.\n",
    "            if Action.ANTIBIOTIC_STRING in selected_actions:\n",
    "                self.antibiotic = 1\n",
    "            else:\n",
    "                self.antibiotic = 0\n",
    "            if Action.VENT_STRING in selected_actions:\n",
    "                self.ventilation = 1\n",
    "            else:\n",
    "                self.ventilation = 0\n",
    "            if Action.VASO_STRING in selected_actions:\n",
    "                self.vasopressors = 1\n",
    "            else:\n",
    "                self.vasopressors = 0\n",
    "                \n",
    "        else:\n",
    "            # This block decomposes the action_idx (from 0 to 7) into the three binary treatment values (0 or 1). \n",
    "            # This process assumes a specific order and numbering scheme for the action index and treatments.\n",
    "            mod_idx = action_idx\n",
    "            term_base = Action.NUM_ACTIONS_TOTAL/2\n",
    "            self.antibiotic = np.floor(mod_idx/term_base).astype(int)\n",
    "            mod_idx %= term_base\n",
    "            term_base /= 2\n",
    "            self.ventilation = np.floor(mod_idx/term_base).astype(int)\n",
    "            mod_idx %= term_base\n",
    "            term_base /= 2\n",
    "            self.vasopressors = np.floor(mod_idx/term_base).astype(int)\n",
    "            \n",
    "            '''\n",
    "            There are three treatments (A, E, V) and thus 2^3 = 8 possible action combinations. \n",
    "            The binary representation of action_idx from 0 to 7 can be thought of as the action combinations:\n",
    "\n",
    "                000 -> No treatments\n",
    "                001 -> V\n",
    "                010 -> E\n",
    "                011 -> E, V\n",
    "                100 -> A\n",
    "                101 -> A, V\n",
    "                110 -> A, E\n",
    "                111 -> A, E, V\n",
    "                \n",
    "            The code block breaks down action_idx to understand which treatments are being used and initializes the three attributes (self.antibiotic, self.ventilation, self.vasopressors) accordingly.\n",
    "            '''\n",
    "            \n",
    "    # Equality and Inequality (__eq__ and __ne__ methods): These are to check the equality or inequality of two Action objects.\n",
    "\n",
    "    def __eq__(self, other):\n",
    "        return isinstance(other, self.__class__) and \\\n",
    "            self.antibiotic == other.antibiotic and \\\n",
    "            self.ventilation == other.ventilation and \\\n",
    "            self.vasopressors == other.vasopressors\n",
    "\n",
    "    def __ne__(self, other):\n",
    "        return not self.__eq__(other)\n",
    "\n",
    "    # Get Action Index (get_action_idx method): This method converts the selected actions into an integer index.\n",
    "    \n",
    "    def get_action_idx(self):\n",
    "        assert self.antibiotic in (0, 1)\n",
    "        assert self.ventilation in (0, 1)\n",
    "        assert self.vasopressors in (0, 1)\n",
    "        return 4*self.antibiotic + 2*self.ventilation + self.vasopressors\n",
    "    '''\n",
    "    The weighted sum effectively encodes the three binary values into a single integer (form 0 to 7; NUM_ACTIONS_TOTAL = 8 in total). \n",
    "    The weights (4, 2, and 1) were chosen to uniquely identify each combination of the three treatments.\n",
    "    \n",
    "    For example:\n",
    "\n",
    "        If only antibiotic is used: action_idx = 4*1 + 2*0 + 0*1 = 4.\n",
    "        If only ventilation is used: action_idx = 4*0 + 2*1 + 0*1 = 2.\n",
    "        If antibiotic and ventilation are used: action_idx = 4*1 + 2*1 + 0*1 = 6.\n",
    "        If all three are used: action_idx = 4*1 + 2*1 + 1*1 = 7.\n",
    "    '''\n",
    "\n",
    "    # Hash (__hash__ method): Provides a unique hash for the action object. This is important if you want to use Action objects as keys in a dictionary.\n",
    "\n",
    "    def __hash__(self):\n",
    "        return self.get_action_idx()\n",
    "    \n",
    "    # Get Selected Actions (get_selected_actions method): Returns a set of selected actions for the object.\n",
    "\n",
    "    def get_selected_actions(self):\n",
    "        selected_actions = set()\n",
    "        if self.antibiotic == 1:\n",
    "            selected_actions.add(Action.ANTIBIOTIC_STRING)\n",
    "        if self.ventilation == 1:\n",
    "            selected_actions.add(Action.VENT_STRING)\n",
    "        if self.vasopressors == 1:\n",
    "            selected_actions.add(Action.VASO_STRING)\n",
    "        return selected_actions\n",
    "    \n",
    "    # Abbreviated String (get_abbrev_string method): Returns a short string representation of the actions (A for antibiotic, E for ventilation, V for vasopressors).\n",
    "\n",
    "    def get_abbrev_string(self):\n",
    "        '''\n",
    "        AEV: antibiotics, ventilation, vasopressors\n",
    "        '''\n",
    "        output_str = ''\n",
    "        if self.antibiotic == 1:\n",
    "            output_str += 'A'\n",
    "        if self.ventilation == 1:\n",
    "            output_str += 'E'\n",
    "        if self.vasopressors == 1:\n",
    "            output_str += 'V'\n",
    "        return output_str\n",
    "\n",
    "    # Action Vector (get_action_vec method): Returns a numpy array representation of the action, with a shape of (3,1).\n",
    "    \n",
    "    def get_action_vec(self):\n",
    "        return np.array([[self.antibiotic], [self.ventilation], [self.vasopressors]])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class State(object):\n",
    "\n",
    "    NUM_OBS_STATES = 720\n",
    "    NUM_HID_STATES = 2  # Binary value of diabetes\n",
    "    NUM_PROJ_OBS_STATES = int(720 / 5)  # Marginalizing over glucose\n",
    "    NUM_FULL_STATES = int(NUM_OBS_STATES * NUM_HID_STATES)\n",
    "\n",
    "    def __init__(self, state_idx = None, idx_type = 'obs', diabetic_idx = None, state_categs = None):\n",
    "    # __init__: Constructor method to initialize the state either by its index or by passing specific categories for each state variable.\n",
    "\n",
    "        assert state_idx is not None or state_categs is not None\n",
    "        assert ((diabetic_idx is not None and diabetic_idx in [0, 1]) or\n",
    "                (state_idx is not None and idx_type == 'full'))\n",
    "\n",
    "        assert idx_type in ['obs', 'full', 'proj_obs']\n",
    "\n",
    "        if state_idx is not None:\n",
    "            self.set_state_by_idx(\n",
    "                    state_idx, idx_type=idx_type, diabetic_idx=diabetic_idx)\n",
    "        elif state_categs is not None:\n",
    "            assert len(state_categs) == 7, \"must specify 7 state variables\"\n",
    "            self.hr_state = state_categs[0]\n",
    "            self.sysbp_state = state_categs[1]\n",
    "            self.percoxyg_state = state_categs[2]\n",
    "            self.glucose_state = state_categs[3]\n",
    "            self.antibiotic_state = state_categs[4]\n",
    "            self.vaso_state = state_categs[5]\n",
    "            self.vent_state = state_categs[6]\n",
    "            self.diabetic_idx = diabetic_idx\n",
    "\n",
    "    def check_absorbing_state(self):\n",
    "        # check_absorbing_state: Checks if the state is \"absorbing\" which means it has a certain \n",
    "        # number of abnormal conditions or it is a normal state with no ongoing treatment.\n",
    "        num_abnormal = self.get_num_abnormal()\n",
    "        if num_abnormal >= 3:\n",
    "            return True\n",
    "        elif num_abnormal == 0 and not self.on_treatment():\n",
    "            return True\n",
    "        return False\n",
    "    \n",
    "    def state_rewards(self):\n",
    "        # check_absorbing_state: Checks if the state is \"absorbing\" which means it has a certain \n",
    "        # number of abnormal conditions or it is a normal state with no ongoing treatment.\n",
    "        num_abnormal = self.get_num_abnormal()\n",
    "        if num_abnormal >= 3:\n",
    "            return (-1000)\n",
    "        elif num_abnormal == 2:\n",
    "            return (-50)\n",
    "        elif num_abnormal == 1:\n",
    "            return (+50)\n",
    "        elif num_abnormal == 0 and self.on_treatment():\n",
    "            return (+70)\n",
    "        elif num_abnormal == 0 and not self.on_treatment():\n",
    "            return (+1000)\n",
    "\n",
    "    def set_state_by_idx(self, state_idx, idx_type, diabetic_idx=None):\n",
    "        \n",
    "        # set_state_by_idx: interprets the state index into its respective categorical variables. \n",
    "        # Depending on the index type (observable, full, or projected observable), the function decodes the index and sets the member variables. \n",
    "        # This method employs a form of \"bit\" arithmetic, even though not all states are binary.\n",
    "        \"\"\"set_state_by_idx\n",
    "\n",
    "        The state index is determined by using \"bit\" arithmetic, with the\n",
    "        complication that not every state is binary\n",
    "\n",
    "        :param state_idx: Given index\n",
    "        :param idx_type: Index type, either observed (720), projected (144) or\n",
    "        full (1440)\n",
    "        :param diabetic_idx: If full state index not given, this is required\n",
    "        \"\"\"\n",
    "        \n",
    "        # Determine Base for Arithmetic: Depending on the idx_type, the method calculates the term_base. \n",
    "        # This base will be used for extracting individual state information from the given index. \n",
    "        # The choice of this base reflects the number of categories available for the primary state variables.\n",
    "        \n",
    "        if idx_type == 'obs':\n",
    "            term_base = State.NUM_OBS_STATES/3 # Starts with heart rate\n",
    "        elif idx_type == 'proj_obs':\n",
    "            term_base = State.NUM_PROJ_OBS_STATES/3\n",
    "        elif idx_type == 'full':\n",
    "            term_base = State.NUM_FULL_STATES/2 # Starts with diab\n",
    "        \n",
    "\n",
    "        # Start with the given state index\n",
    "        mod_idx = state_idx\n",
    "\n",
    "        if idx_type == 'full':\n",
    "            # If the idx_type is 'full', the function first extracts the diabetes status (diabetic_idx) \n",
    "            # and then adjusts the base for the next state variable (heart rate).\n",
    "            \n",
    "            self.diabetic_idx = np.floor(mod_idx/term_base).astype(int)\n",
    "            mod_idx %= term_base\n",
    "            term_base /= 3 # This is for heart rate, the next item\n",
    "        else:\n",
    "            assert diabetic_idx is not None\n",
    "            self.diabetic_idx = diabetic_idx\n",
    "\n",
    "        self.hr_state = np.floor(mod_idx/term_base).astype(int)\n",
    "\n",
    "        mod_idx %= term_base\n",
    "        term_base /= 3\n",
    "        self.sysbp_state = np.floor(mod_idx/term_base).astype(int)\n",
    "\n",
    "        mod_idx %= term_base\n",
    "        term_base /= 2\n",
    "        self.percoxyg_state = np.floor(mod_idx/term_base).astype(int)\n",
    "\n",
    "        if idx_type == 'proj_obs':\n",
    "            self.glucose_state = 2\n",
    "        else:\n",
    "            mod_idx %= term_base\n",
    "            term_base /= 5\n",
    "            self.glucose_state = np.floor(mod_idx/term_base).astype(int)\n",
    "\n",
    "        mod_idx %= term_base\n",
    "        term_base /= 2\n",
    "        self.antibiotic_state = np.floor(mod_idx/term_base).astype(int)\n",
    "\n",
    "        mod_idx %= term_base\n",
    "        term_base /= 2\n",
    "        self.vaso_state = np.floor(mod_idx/term_base).astype(int)\n",
    "\n",
    "        mod_idx %= term_base\n",
    "        term_base /= 2\n",
    "        self.vent_state = np.floor(mod_idx/term_base).astype(int)\n",
    "\n",
    "\n",
    "    def get_state_idx(self, idx_type='obs'):\n",
    "        # Opposite of set_state_by_idx. It takes the categorical variables of the state and returns its index. \n",
    "        # It constructs the index using the \"bit\" arithmetic approach.\n",
    "        '''\n",
    "        returns integer index of state: significance order as in categorical array\n",
    "        '''\n",
    "        \n",
    "        if idx_type == 'obs':\n",
    "            categ_num = np.array([3,3,2,5,2,2,2])\n",
    "            state_categs = [\n",
    "                    self.hr_state,\n",
    "                    self.sysbp_state,\n",
    "                    self.percoxyg_state,\n",
    "                    self.glucose_state,\n",
    "                    self.antibiotic_state,\n",
    "                    self.vaso_state,\n",
    "                    self.vent_state]\n",
    "        elif idx_type == 'proj_obs':\n",
    "            categ_num = np.array([3,3,2,2,2,2])\n",
    "            state_categs = [\n",
    "                    self.hr_state,\n",
    "                    self.sysbp_state,\n",
    "                    self.percoxyg_state,\n",
    "                    self.antibiotic_state,\n",
    "                    self.vaso_state,\n",
    "                    self.vent_state]\n",
    "        elif idx_type == 'full':\n",
    "            categ_num = np.array([2,3,3,2,5,2,2,2])\n",
    "            state_categs = [\n",
    "                    self.diabetic_idx,\n",
    "                    self.hr_state,\n",
    "                    self.sysbp_state,\n",
    "                    self.percoxyg_state,\n",
    "                    self.glucose_state,\n",
    "                    self.antibiotic_state,\n",
    "                    self.vaso_state,\n",
    "                    self.vent_state]\n",
    "\n",
    "        sum_idx = 0\n",
    "        prev_base = 1\n",
    "        for i in range(len(state_categs)):\n",
    "            idx = len(state_categs) - 1 - i\n",
    "            sum_idx += prev_base*state_categs[idx]\n",
    "            prev_base *= categ_num[idx]\n",
    "        return sum_idx\n",
    "    \n",
    "    # __eq__, __ne__, __hash__: Overridden methods to check for equality, inequality, and to generate a hash value respectively.\n",
    "\n",
    "    def __eq__(self, other):\n",
    "        '''\n",
    "        override equals: two states equal if all internal states same\n",
    "        '''\n",
    "        return isinstance(other, self.__class__) and \\\n",
    "            self.hr_state == other.hr_state and \\\n",
    "            self.sysbp_state == other.sysbp_state and \\\n",
    "            self.percoxyg_state == other.percoxyg_state and \\\n",
    "            self.glucose_state == other.glucose_state and \\\n",
    "            self.antibiotic_state == other.antibiotic_state and \\\n",
    "            self.vaso_state == other.vaso_state and \\\n",
    "            self.vent_state == other.vent_state\n",
    "\n",
    "    def __ne__(self, other):\n",
    "        return not self.__eq__(other)\n",
    "\n",
    "    def __hash__(self):\n",
    "        return self.get_state_idx()\n",
    "\n",
    "    def get_num_abnormal(self):\n",
    "        # get_num_abnormal: Counts and returns the number of abnormal conditions present in the current state.\n",
    "        '''\n",
    "        returns number of abnormal conditions\n",
    "        '''\n",
    "        num_abnormal = 0\n",
    "        if self.hr_state != 1:\n",
    "            num_abnormal += 1\n",
    "        if self.sysbp_state != 1:\n",
    "            num_abnormal += 1\n",
    "        if self.percoxyg_state != 1:\n",
    "            num_abnormal += 1\n",
    "        if self.glucose_state != 2:\n",
    "            num_abnormal += 1\n",
    "        return num_abnormal\n",
    "\n",
    "    # on_treatment, on_antibiotics, on_vasopressors, on_ventilation: These methods check if certain treatments are active.\n",
    "    \n",
    "    def on_treatment(self):\n",
    "        '''\n",
    "        returns True iff any of 3 treatments active\n",
    "        '''\n",
    "        if self.antibiotic_state == 0 and \\\n",
    "            self.vaso_state == 0 and self.vent_state == 0:\n",
    "            return False\n",
    "        return True\n",
    "\n",
    "    def on_antibiotics(self):\n",
    "        '''\n",
    "        returns True iff antibiotics active\n",
    "        '''\n",
    "        return self.antibiotic_state == 1\n",
    "\n",
    "    def on_vasopressors(self):\n",
    "        '''\n",
    "        returns True iff vasopressors active\n",
    "        '''\n",
    "        return self.vaso_state == 1\n",
    "\n",
    "    def on_ventilation(self):\n",
    "        '''\n",
    "        returns True iff ventilation active\n",
    "        '''\n",
    "        return self.vent_state == 1\n",
    "\n",
    "    def copy_state(self):\n",
    "        return State(state_categs = [\n",
    "            self.hr_state,\n",
    "            self.sysbp_state,\n",
    "            self.percoxyg_state,\n",
    "            self.glucose_state,\n",
    "            self.antibiotic_state,\n",
    "            self.vaso_state,\n",
    "            self.vent_state],\n",
    "            diabetic_idx=self.diabetic_idx)\n",
    "\n",
    "    def get_state_vector(self):\n",
    "        # get_state_vector: Returns the state as a vector (numpy array).\n",
    "        return np.array([self.hr_state,\n",
    "            self.sysbp_state,\n",
    "            self.percoxyg_state,\n",
    "            self.glucose_state,\n",
    "            self.antibiotic_state,\n",
    "            self.vaso_state,\n",
    "            self.vent_state]).astype(int)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class MDP(object):\n",
    "\n",
    "    def __init__(self, init_state_idx=None, init_state_idx_type='obs', policy_array=None, policy_idx_type='obs', p_diabetes=0.2):\n",
    "\n",
    "        assert p_diabetes >= 0 and p_diabetes <= 1, \\\n",
    "                \"Invalid p_diabetes: {}\".format(p_diabetes)\n",
    "        assert policy_idx_type in ['obs', 'full', 'proj_obs']\n",
    "\n",
    "        # Check the policy dimensions (states x actions)\n",
    "        if policy_array is not None:\n",
    "            assert policy_array.shape[1] == Action.NUM_ACTIONS_TOTAL\n",
    "            if policy_idx_type == 'obs':\n",
    "                assert policy_array.shape[0] == State.NUM_OBS_STATES\n",
    "            elif policy_idx_type == 'full':\n",
    "                assert policy_array.shape[0] == \\\n",
    "                        State.NUM_HID_STATES * State.NUM_OBS_STATES\n",
    "            elif policy_idx_type == 'proj_obs':\n",
    "                assert policy_array.shape[0] == State.NUM_PROJ_OBS_STATES\n",
    "\n",
    "        # p_diabetes is used to generate random state if init_state is None\n",
    "        self.p_diabetes = p_diabetes\n",
    "        self.state = None\n",
    "\n",
    "        # Only need to use init_state_idx_type if you are providing a state_idx!\n",
    "        self.state = self.get_new_state(init_state_idx, init_state_idx_type)\n",
    "\n",
    "        self.policy_array = policy_array\n",
    "        self.policy_idx_type = policy_idx_type  # Used for mapping the policy to actions\n",
    "        \n",
    "\n",
    "    def get_new_state(self, state_idx = None, idx_type = 'obs', diabetic_idx = None):\n",
    "\n",
    "        assert idx_type in ['obs', 'full', 'proj_obs']\n",
    "        option = None\n",
    "        if state_idx is not None:\n",
    "            if idx_type == 'obs' and diabetic_idx is not None:\n",
    "                option = 'spec_obs'\n",
    "            elif idx_type == 'obs' and diabetic_idx is None:\n",
    "                option = 'spec_obs_no_diab'\n",
    "                diabetic_idx = np.random.binomial(1, self.p_diabetes)\n",
    "            elif idx_type == 'full':\n",
    "                option = 'spec_full'\n",
    "            elif idx_type == 'proj_obs' and diabetic_idx is not None:\n",
    "                option = 'spec_proj_obs'\n",
    "        elif state_idx is None and diabetic_idx is None:\n",
    "            option = 'random'\n",
    "        elif state_idx is None and diabetic_idx is not None:\n",
    "            option = 'random_cond_diab'\n",
    "\n",
    "        assert option is not None, \"Invalid specification of new state\"\n",
    "\n",
    "        if option in ['random', 'random_cond_diab'] :\n",
    "            init_state = self.generate_random_state(diabetic_idx)\n",
    "            # Do not start in death or discharge state\n",
    "            while init_state.check_absorbing_state():\n",
    "                init_state = self.generate_random_state(diabetic_idx)\n",
    "        else:\n",
    "            # Note that diabetic_idx will be ignored if idx_type = 'full'\n",
    "            init_state = State(\n",
    "                    state_idx=state_idx, idx_type=idx_type,\n",
    "                    diabetic_idx=diabetic_idx)\n",
    "\n",
    "        return init_state\n",
    "\n",
    "    def generate_random_state(self, diabetic_idx=None):\n",
    "        # generate_random_state: Randomly generates a patient's health state.\n",
    "        \n",
    "        # Note that we will condition on diabetic idx if provided\n",
    "        if diabetic_idx is None:\n",
    "            diabetic_idx = np.random.binomial(1, self.p_diabetes)\n",
    "\n",
    "        # hr and sys_bp w.p. [.25, .5, .25]\n",
    "        hr_state = np.random.choice(np.arange(3), p=np.array([.25, .5, .25]))\n",
    "        sysbp_state = np.random.choice(np.arange(3), p=np.array([.25, .5, .25]))\n",
    "        # percoxyg w.p. [.2, .8]\n",
    "        percoxyg_state = np.random.choice(np.arange(2), p=np.array([.2, .8]))\n",
    "\n",
    "        if diabetic_idx == 0:\n",
    "            glucose_state = np.random.choice(np.arange(5), \\\n",
    "                p=np.array([.05, .15, .6, .15, .05]))\n",
    "        else:\n",
    "            glucose_state = np.random.choice(np.arange(5), \\\n",
    "                p=np.array([.01, .05, .15, .6, .19]))\n",
    "        antibiotic_state = 0\n",
    "        vaso_state = 0\n",
    "        vent_state = 0\n",
    "\n",
    "        state_categs = [hr_state, sysbp_state, percoxyg_state,\n",
    "                glucose_state, antibiotic_state, vaso_state, vent_state]\n",
    "\n",
    "        return State(state_categs=state_categs, diabetic_idx=diabetic_idx)\n",
    "\n",
    "    # transition_antibiotics_on/off: Models the effect of turning antibiotics on/off.\n",
    "    \n",
    "    def transition_antibiotics_on(self):\n",
    "        \n",
    "        '''\n",
    "        antibiotics state on\n",
    "        heart rate, sys bp: hi -> normal w.p. .5\n",
    "        '''\n",
    "        self.state.antibiotic_state = 1\n",
    "        if self.state.hr_state == 2 and np.random.uniform(0,1) < 0.5:\n",
    "            self.state.hr_state = 1\n",
    "        if self.state.sysbp_state == 2 and np.random.uniform(0,1) < 0.5:\n",
    "            self.state.sysbp_state = 1\n",
    "\n",
    "    def transition_antibiotics_off(self):\n",
    "        '''\n",
    "        antibiotics state off\n",
    "        if antibiotics was on: heart rate, sys bp: normal -> hi w.p. .1\n",
    "        '''\n",
    "        if self.state.antibiotic_state == 1:\n",
    "            if self.state.hr_state == 1 and np.random.uniform(0,1) < 0.1:\n",
    "                self.state.hr_state = 2\n",
    "            if self.state.sysbp_state == 1 and np.random.uniform(0,1) < 0.1:\n",
    "                self.state.sysbp_state = 2\n",
    "            self.state.antibiotic_state = 0\n",
    "\n",
    "    # transition_vent_on/off: Models the effect of turning ventilation on/off.\n",
    "\n",
    "    def transition_vent_on(self):\n",
    "        '''\n",
    "        ventilation state on\n",
    "        percent oxygen: low -> normal w.p. .7\n",
    "        '''\n",
    "        self.state.vent_state = 1\n",
    "        if self.state.percoxyg_state == 0 and np.random.uniform(0,1) < 0.7:\n",
    "            self.state.percoxyg_state = 1\n",
    "\n",
    "    def transition_vent_off(self):\n",
    "        '''\n",
    "        ventilation state off\n",
    "        if ventilation was on: percent oxygen: normal -> lo w.p. .1\n",
    "        '''\n",
    "        if self.state.vent_state == 1:\n",
    "            if self.state.percoxyg_state == 1 and np.random.uniform(0,1) < 0.1:\n",
    "                self.state.percoxyg_state = 0\n",
    "            self.state.vent_state = 0\n",
    "    \n",
    "    # transition_vaso_on/off: Models the effect of turning vasopressors on/off, considering if the patient is diabetic.\n",
    "\n",
    "    def transition_vaso_on(self):\n",
    "        '''\n",
    "        vasopressor state on\n",
    "        for non-diabetic:\n",
    "            sys bp: low -> normal, normal -> hi w.p. .7\n",
    "        for diabetic:\n",
    "            raise blood pressure: normal -> hi w.p. .9,\n",
    "                lo -> normal w.p. .5, lo -> hi w.p. .4\n",
    "            raise blood glucose by 1 w.p. .5\n",
    "        '''\n",
    "        self.state.vaso_state = 1\n",
    "        if self.state.diabetic_idx == 0:\n",
    "            if np.random.uniform(0,1) < 0.7:\n",
    "                if self.state.sysbp_state == 0:\n",
    "                    self.state.sysbp_state = 1\n",
    "                elif self.state.sysbp_state == 1:\n",
    "                    self.state.sysbp_state = 2\n",
    "        else:\n",
    "            if self.state.sysbp_state == 1:\n",
    "                if np.random.uniform(0,1) < 0.9:\n",
    "                    self.state.sysbp_state = 2\n",
    "            elif self.state.sysbp_state == 0:\n",
    "                up_prob = np.random.uniform(0,1)\n",
    "                if up_prob < 0.5:\n",
    "                    self.state.sysbp_state = 1\n",
    "                elif up_prob < 0.9:\n",
    "                    self.state.sysbp_state = 2\n",
    "            if np.random.uniform(0,1) < 0.5:\n",
    "                self.state.glucose_state = min(4, self.state.glucose_state + 1)\n",
    "\n",
    "    def transition_vaso_off(self):\n",
    "        '''\n",
    "        vasopressor state off\n",
    "        if vasopressor was on:\n",
    "            for non-diabetics, sys bp: normal -> low, hi -> normal w.p. .1\n",
    "            for diabetics, blood pressure falls by 1 w.p. .05 instead of .1\n",
    "        '''\n",
    "        if self.state.vaso_state == 1:\n",
    "            if self.state.diabetic_idx == 0:\n",
    "                if np.random.uniform(0,1) < 0.1:\n",
    "                    self.state.sysbp_state = max(0, self.state.sysbp_state - 1)\n",
    "            else:\n",
    "                if np.random.uniform(0,1) < 0.05:\n",
    "                    self.state.sysbp_state = max(0, self.state.sysbp_state - 1)\n",
    "            self.state.vaso_state = 0\n",
    "\n",
    "    def transition_fluctuate(self, hr_fluctuate, sysbp_fluctuate, percoxyg_fluctuate, glucose_fluctuate):\n",
    "        \n",
    "        # transition_fluctuate: Captures the random fluctuations in the patient's state variables.\n",
    "        \n",
    "        '''\n",
    "        all (non-treatment) states fluctuate +/- 1 w.p. .1\n",
    "        exception: glucose flucuates +/- 1 w.p. .3 if diabetic\n",
    "        '''\n",
    "        if hr_fluctuate:\n",
    "            hr_prob = np.random.uniform(0,1)\n",
    "            if hr_prob < 0.1:\n",
    "                self.state.hr_state = max(0, self.state.hr_state - 1)\n",
    "            elif hr_prob < 0.2:\n",
    "                self.state.hr_state = min(2, self.state.hr_state + 1)\n",
    "        if sysbp_fluctuate:\n",
    "            sysbp_prob = np.random.uniform(0,1)\n",
    "            if sysbp_prob < 0.1:\n",
    "                self.state.sysbp_state = max(0, self.state.sysbp_state - 1)\n",
    "            elif sysbp_prob < 0.2:\n",
    "                self.state.sysbp_state = min(2, self.state.sysbp_state + 1)\n",
    "        if percoxyg_fluctuate:\n",
    "            percoxyg_prob = np.random.uniform(0,1)\n",
    "            if percoxyg_prob < 0.1:\n",
    "                self.state.percoxyg_state = max(0, self.state.percoxyg_state - 1)\n",
    "            elif percoxyg_prob < 0.2:\n",
    "                self.state.percoxyg_state = min(1, self.state.percoxyg_state + 1)\n",
    "        if glucose_fluctuate:\n",
    "            glucose_prob = np.random.uniform(0,1)\n",
    "            if self.state.diabetic_idx == 0:\n",
    "                if glucose_prob < 0.1:\n",
    "                    self.state.glucose_state = max(0, self.state.glucose_state - 1)\n",
    "                elif glucose_prob < 0.2:\n",
    "                    self.state.glucose_state = min(4, self.state.glucose_state + 1)\n",
    "            else:\n",
    "                if glucose_prob < 0.3:\n",
    "                    self.state.glucose_state = max(0, self.state.glucose_state - 1)\n",
    "                elif glucose_prob < 0.6:\n",
    "                    self.state.glucose_state = min(4, self.state.glucose_state + 1)\n",
    "\n",
    "    def calculateReward(self):\n",
    "        \n",
    "        # calculateReward: Calculates a reward based on the patient's state. The system rewards a healthy state and penalizes an unhealthy state.\n",
    "        num_abnormal = self.state.get_num_abnormal()\n",
    "        if num_abnormal >= 3:\n",
    "            return -1\n",
    "        elif num_abnormal == 0 and not self.state.on_treatment():\n",
    "            return 1\n",
    "        return 0\n",
    "\n",
    "    def transition(self, action):\n",
    "        self.state = self.state.copy_state()\n",
    "\n",
    "        if action.antibiotic == 1:\n",
    "            self.transition_antibiotics_on()\n",
    "            hr_fluctuate = False\n",
    "            sysbp_fluctuate = False\n",
    "        elif self.state.antibiotic_state == 1:\n",
    "            self.transition_antibiotics_off()\n",
    "            hr_fluctuate = False\n",
    "            sysbp_fluctuate = False\n",
    "        else:\n",
    "            hr_fluctuate = True\n",
    "            sysbp_fluctuate = True\n",
    "\n",
    "        if action.ventilation == 1:\n",
    "            self.transition_vent_on()\n",
    "            percoxyg_fluctuate = False\n",
    "        elif self.state.vent_state == 1:\n",
    "            self.transition_vent_off()\n",
    "            percoxyg_fluctuate = False\n",
    "        else:\n",
    "            percoxyg_fluctuate = True\n",
    "\n",
    "        glucose_fluctuate = True\n",
    "\n",
    "        if action.vasopressors == 1:\n",
    "            self.transition_vaso_on()\n",
    "            sysbp_fluctuate = False\n",
    "            glucose_fluctuate = False\n",
    "        elif self.state.vaso_state == 1:\n",
    "            self.transition_vaso_off()\n",
    "            sysbp_fluctuate = False\n",
    "\n",
    "        self.transition_fluctuate(hr_fluctuate, sysbp_fluctuate, percoxyg_fluctuate, \\\n",
    "            glucose_fluctuate)\n",
    "\n",
    "        return self.calculateReward()\n",
    "\n",
    "    def select_actions(self):\n",
    "        assert self.policy_array is not None\n",
    "        probs = self.policy_array[\n",
    "                    self.state.get_state_idx(self.policy_idx_type)\n",
    "                ]\n",
    "        aev_idx = np.random.choice(np.arange(Action.NUM_ACTIONS_TOTAL), p=probs)\n",
    "        return Action(action_idx = aev_idx)\n",
    "\n",
    "    def action_idx(self, state_idx):\n",
    "        assert self.policy_array is not None\n",
    "        #print(f'state is {state_idx}')\n",
    "        probs = self.policy_array[state_idx]\n",
    "        #print(f'probs is {probs}')\n",
    "        aev_idx = np.random.choice(np.arange(Action.NUM_ACTIONS_TOTAL), p=probs)\n",
    "        #print(f'aev_idx is {aev_idx}')\n",
    "        return aev_idx\n",
    "        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "NSIMSAMPS = 1  # Samples to draw from the simulator (they did a 1000)\n",
    "NSTEPS = 10  # Max length of each trajectory\n",
    "NCFSAMPS = 5  # Counterfactual Samples per observed sample (do i need this? probably not, i just need the model)\n",
    "DISCOUNT_Pol = 0.99 # Used for computing optimal policies\n",
    "DISCOUNT = 1 # Used for computing actual reward\n",
    "PHYS_EPSILON = 0.05 # Used for sampling using physician pol as eps greedy\n",
    "PROB_DIAB = 0.2\n",
    "n_actions = Action.NUM_ACTIONS_TOTAL\n",
    "\n",
    "with open(\"../data/diab_txr_mats-replication.pkl\", \"rb\") as f:\n",
    "    mdict = pickle.load(f)\n",
    "\n",
    "tx_mat = mdict[\"tx_mat\"]\n",
    "r_mat = mdict[\"r_mat\"]\n",
    "p_mixture = np.array([1 - PROB_DIAB, PROB_DIAB])\n",
    "\n",
    "tx_mat_full = np.zeros((n_actions, State.NUM_FULL_STATES, State.NUM_FULL_STATES))\n",
    "r_mat_full = np.zeros((n_actions, State.NUM_FULL_STATES, State.NUM_FULL_STATES))\n",
    "# tx_mat_full is of the shape (actions, state, state)\n",
    "for a in range(n_actions):\n",
    "    tx_mat_full[a, ...] = block_diag(tx_mat[0, a, ...], tx_mat[1, a,...])\n",
    "    r_mat_full[a, ...] = block_diag(r_mat[0, a, ...], r_mat[1, a, ...])\n",
    "\n",
    "print(tx_mat_full)\n",
    "print(r_mat_full)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "all_absorbing_states = []\n",
    "all_absorbing_rewards = []\n",
    "non_absorbing_states = []\n",
    "all_rewards = []\n",
    "\n",
    "for s in range(1440):\n",
    "    get_states = State(state_idx=s, idx_type = 'full')\n",
    "    abs = get_states.check_absorbing_state()\n",
    "    if abs == True: \n",
    "        all_absorbing_states.append(s)\n",
    "        rew = get_states.state_rewards()\n",
    "        all_absorbing_rewards.append(rew)\n",
    "        \n",
    "    if abs == False:\n",
    "        non_absorbing_states.append(s)\n",
    "\n",
    "    rew = get_states.state_rewards()\n",
    "    all_rewards.append(rew)\n",
    "\n",
    "print(f'winning states are {all_absorbing_states[208]} and {all_absorbing_states[625]}')\n",
    "\n",
    "for s in range(1440): # for each state\n",
    "    for a in range(8): # for each action in this state\n",
    "        if s in all_absorbing_states: # if this state is absorbing \n",
    "            tx_mat_full[a, s, :] = np.zeros(1440) # tx_mat_full is of the shape (actions, state, state)\n",
    "            tx_mat_full[a, s, s] = 1 \n",
    "\n",
    "for s in range(1440): # for each state\n",
    "    for a in range(8): # for each action in this state\n",
    "        if s in all_absorbing_states: # if this state is absorbing \n",
    "             # tx_mat_full is of the shape (actions, state, state)\n",
    "             \n",
    "            reward_idx = all_absorbing_states.index(s)\n",
    "            r_mat_full[a, s, :] = np.full((1440,), (all_absorbing_rewards[reward_idx]))\n",
    "        else:\n",
    "            for s_p in np.where(tx_mat_full[a, s, :]!=0)[0]:\n",
    "                r_mat_full[a, s, s_p] = all_rewards[s_p]\n",
    "\n",
    "\n",
    "rewards_pi = np.zeros((1440, 8)) \n",
    "\n",
    "for s in range(1440):\n",
    "    for a in range(8):\n",
    "        # Take the action, new state is property of the MDP\n",
    "        s_p = (np.where(tx_mat_full[a, s, :] == (np.max(tx_mat_full[a, s, :]))))[0][0]\n",
    "        rewards_pi[s, a] = r_mat_full[a, s, s_p]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fullMDP = cf.MatrixMDP(tx_mat_full, r_mat_full)\n",
    "fullPol = fullMDP.policyIteration(discount=DISCOUNT_Pol, eval_type=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class DataGen(object):\n",
    "    def __init__(self):\n",
    "        \n",
    "        mdp = MDP(init_state_idx=None, policy_array=fullPol, policy_idx_type='full', p_diabetes=PROB_DIAB)\n",
    "        self.mdp = mdp\n",
    "\n",
    "    def mdp_sample(self, policy=fullPol, n_obs=1, n_steps=NSTEPS): # trajectory sample for a given policy and MDP\n",
    "        n_state = 4 # Get the number of states [current state, next state, action, reward]\n",
    "\n",
    "        # Initialize the trajectories\n",
    "        trajectories = np.zeros((n_obs, n_steps, n_state))\n",
    "        # observations, trajectories, states\n",
    "\n",
    "        # Loop over the observations\n",
    "        for obs_idx in range(n_obs): # to generate the desired amount of trajectories\n",
    "            current_state = np.random.choice(non_absorbing_states) # initial state can not be absorbant \n",
    "                        \n",
    "            # Go over time steps\n",
    "            for time_idx in range(n_steps): # for each time step in the currently generating (\"observed\") trajectory \n",
    "                \n",
    "                # Get the action\n",
    "                action = self.mdp.action_idx(state_idx=current_state) # pick a action for that initial state according to the given policy\n",
    "                \n",
    "                next_state = np.random.choice(\n",
    "                    1440, size=1, p=tx_mat_full[action, current_state, :])[0]  # tx_mat_full is of the shape (actions, state, state)\n",
    "                \n",
    "                reward = r_mat_full[action, current_state, next_state] # matters what state you **actually** get to\n",
    "\n",
    "                trajectories[obs_idx, time_idx, :] = np.array([current_state, next_state, action, reward])\n",
    "            \n",
    "                current_state = next_state\n",
    "\n",
    "        return trajectories "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dgen = DataGen()\n",
    "MDP_samp = dgen.mdp_sample().astype(int)\n",
    "\n",
    "# Suboptimal Path\n",
    "# MDP_samp = np.array([[[777., 939.,   3., -50.],\n",
    "#   [939., 869.,   6., -50.],\n",
    "#   [869., 869.,   6., -50.],\n",
    "#   [869., 861.,   6.,  50.],\n",
    "#   [861., 861.,   6.,  50.],\n",
    "#   [861., 869.,   6., -50.],\n",
    "#   [869., 861.,   6.,  50.],\n",
    "#   [861., 861.,   6.,  50.],\n",
    "#   [861., 853.,   6., -50.],\n",
    "#   [853., 853.,   6., -50.]]]).astype(int)\n",
    "\n",
    "\n",
    "# Catostrophic Path\n",
    "MDP_samp = np.array([[[  777,   939,     3,   -50],\n",
    "  [  939,   941,     6,   -50],\n",
    "  [  941,   949,     6, -1000],\n",
    "  [  949,   949,     0, -1000],\n",
    "  [  949,   949,     0, -1000],\n",
    "  [  949,   949,     0, -1000],\n",
    "  [  949,   949,     0, -1000],\n",
    "  [  949,   949,     0, -1000],\n",
    "  [  949,   949,     0, -1000],\n",
    "  [  949,   949,     0, -1000]]])\n",
    "\n",
    "print(MDP_samp)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def truncated_gumbel(logit, truncation):\n",
    "    assert not np.isneginf(logit)\n",
    "\n",
    "    gumbel = np.random.gumbel(size=(truncation.shape[0])) + logit\n",
    "    trunc_g = -np.log(np.exp(-gumbel) + np.exp(-truncation))\n",
    "    return trunc_g\n",
    "\n",
    "def topdown_tracking_influenced_states(obs_logits, obs_state, nsamp=1): # is there only 1 sample? \n",
    "    poss_next_states = obs_logits.shape[0]\n",
    "    gumbels = np.zeros((nsamp, poss_next_states))\n",
    "    influenced_states = np.full(shape=poss_next_states, fill_value=False)\n",
    "\n",
    "    # Sample top gumbels\n",
    "    topgumbel = np.random.gumbel(size=(nsamp))\n",
    "\n",
    "    for next_state in range(poss_next_states):\n",
    "        # This is the observed outcome\n",
    "        if (next_state == obs_state) and not(np.isneginf(obs_logits[next_state])):\n",
    "            gumbels[:, obs_state] = topgumbel - obs_logits[next_state]\n",
    "            influenced_states[obs_state] = True\n",
    "        # These were the other feasible options (p > 0)\n",
    "        elif not(np.isneginf(obs_logits[next_state])):\n",
    "            gumbels[:, next_state] = truncated_gumbel(obs_logits[next_state], topgumbel) - obs_logits[next_state]\n",
    "            influenced_states[next_state] = True\n",
    "        # These have zero probability to start with, so are unconstrained\n",
    "        else:\n",
    "            gumbels[:, next_state] = np.random.gumbel(size=nsamp)\n",
    "\n",
    "    return gumbels, influenced_states # list of gumbel noise values derived from the observed trajectory"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from multiprocessing import Process, Manager\n",
    "\n",
    "class CounterfactualSampler(object):\n",
    "\n",
    "    def __init__(self, mdp):\n",
    "        self.mdp = mdp\n",
    "        self.sprtb_theta = 0.9\n",
    "        self.sprtb_delta = 0.05\n",
    "        self.sprtb_r = 0.9\n",
    "\n",
    "    def cf_posterior_tracking_influenced_states(self, obs_prob, intrv_prob, state, n_mc):\n",
    "        obs_logits = np.log(obs_prob);\n",
    "        next_state = state\n",
    "        intrv_logits = np.log(intrv_prob);\n",
    "\n",
    "        # Sample from the gumbel posterior\n",
    "        gumbels, influenced_states = topdown_tracking_influenced_states(obs_logits, next_state, n_mc);\n",
    "\n",
    "        # Get the posterior\n",
    "        posterior = intrv_logits + gumbels\n",
    "        intrv_posterior = np.argmax(posterior, axis=1)\n",
    "\n",
    "        # create the counterfactual transition probabilities\n",
    "        posterior_prob = np.zeros(np.size(intrv_prob, 0))\n",
    "        for i in range(np.size(intrv_prob, 0)):\n",
    "            posterior_prob[i] = np.sum(intrv_posterior == i) / n_mc\n",
    "\n",
    "        return posterior_prob, intrv_posterior, influenced_states\n",
    "    \n",
    "##########################################################################################################################################################\n",
    "\n",
    "    def cf_sample_prob_tracking_influenced_transitions(self, trajectories, a, time_idx, P_cf_save, influenced_transitions_save, n_cf_samps=1): \n",
    "        n_obs = trajectories.shape[0] \n",
    "        n_mc = 1000\n",
    "        \n",
    "        for obs_idx in range(n_obs): # for each given \"observed\" trajectory\n",
    "            P_cf = {}\n",
    "            influenced_transitions = {}\n",
    "\n",
    "            for _ in range(n_cf_samps): # get the desired number of CF trajectories for each given \"observed\" trajectory \n",
    "                    obs_state = trajectories[obs_idx, time_idx, :]\n",
    "                    obs_current_state = int(obs_state[0]) # same as s_real\n",
    "                    obs_next_state = int(obs_state[1]) # same as s_p_real\n",
    "                    obs_action = int(obs_state[2]) # same as a_real\n",
    "\n",
    "                    P_cf[a, time_idx] = np.zeros((int(1440),int(1440)))\n",
    "                    influenced_transitions[a, time_idx] = np.full(shape=(int(1440),int(1440)), fill_value=False)\n",
    "\n",
    "                    # A matrix is initialized to zeros to store transition counts.\n",
    "                    for s in range(1440):\n",
    "                    \n",
    "                        obs_intrv =  tx_mat_full[obs_action, obs_current_state, :]\n",
    "                        # Get the transition probabilities for the counterfactual state and action:\n",
    "                        cf_intrv =  tx_mat_full[a, s, :]\n",
    "                        \n",
    "                        cf_prob, s_p, influenced_states = self.cf_posterior_tracking_influenced_states(obs_intrv, cf_intrv, obs_next_state, n_mc)\n",
    "                \n",
    "                        for s_p in range(len(cf_prob)):\n",
    "                            P_cf[a,time_idx][s,s_p] = cf_prob[s_p]\n",
    "                            influenced_transitions[a,time_idx][s, s_p] = influenced_states[s_p]\n",
    "\n",
    "        P_cf_save[(a,time_idx)] = P_cf\n",
    "        influenced_transitions_save[(a, time_idx)] = influenced_transitions\n",
    "\n",
    "    def run_sample_tracking_influenced_transitions(self, inp, trajectories, P_cf, influenced_transitions):\n",
    "        P_cf_save = {}\n",
    "        influenced_transitions_save = {}\n",
    "\n",
    "        for i in inp:\n",
    "            self.cf_sample_prob_tracking_influenced_transitions(trajectories, i[0], i[1], P_cf_save, influenced_transitions_save)\n",
    "\n",
    "        for i in inp:\n",
    "            P_cf.update(P_cf_save[i])\n",
    "            influenced_transitions.update(influenced_transitions_save[i])\n",
    "\n",
    "    def run_parallel_sampling_tracking_influenced_transitions(self, trajectories):\n",
    "        n_steps = trajectories.shape[1]\n",
    "        n_actions = 8\n",
    "        \n",
    "        inp = [(a, time_idx) for time_idx in range(n_steps) for a in range(n_actions)]\n",
    "\n",
    "        # Run with n threads\n",
    "        def split(a, n):\n",
    "            k, m = divmod(len(a), n)\n",
    "            return (a[i*k+min(i, m):(i+1)*k+min(i+1, m)] for i in range(n))\n",
    "        \n",
    "        split_work = split(inp, 32)\n",
    "        processes = []\n",
    "\n",
    "        with Manager() as manager:\n",
    "            P_cf = manager.dict()\n",
    "            influenced_transitions = manager.dict()\n",
    "            \n",
    "            for chunk in split_work:\n",
    "                process = Process(target=self.run_sample_tracking_influenced_transitions, args=(chunk, trajectories, P_cf, influenced_transitions))\n",
    "                processes.append(process)\n",
    "                process.start()\n",
    "\n",
    "            for process in processes:\n",
    "                process.join()\n",
    "\n",
    "            return P_cf.copy(), influenced_transitions.copy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "NUM_ITERATIONS = 10000\n",
    "from collections import deque, defaultdict\n",
    "\n",
    "print(MDP_samp)\n",
    "\n",
    "sampler = CounterfactualSampler(dgen)\n",
    "P_cf, influenced_transitions = sampler.run_parallel_sampling_tracking_influenced_transitions(MDP_samp)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class InfluenceMDPPruner:\n",
    "    def __init__(self, mdp, look_ahead_k=11):\n",
    "        self.mdp = mdp\n",
    "        self.optimal_policy = fullPol\n",
    "        self.rewards_pi = rewards_pi\n",
    "        self.sampler = sampler\n",
    "        self.mdp_sample = MDP_samp\n",
    "        self.initial_state = self.mdp_sample[0, 0, 0]\n",
    "        self.states = range(1440)\n",
    "        self.actions = range(8)\n",
    "        self.look_ahead_k = look_ahead_k\n",
    "        self.T = len(self.mdp_sample[0])\n",
    "        \n",
    "        # Generate the counterfacutal transition probabilities, keeping track\n",
    "        # of which transitionals have been influenced by the observed path.\n",
    "        self.P_cf = P_cf\n",
    "        self.influenced_transitions = influenced_transitions\n",
    "\n",
    "    def build_graph(self, transition_probs, T, all_states, all_actions):\n",
    "        G = nx.MultiDiGraph()\n",
    "        pos = {}\n",
    "\n",
    "        print(transition_probs)\n",
    "        \n",
    "        for t in range(T):\n",
    "            for s in all_states:\n",
    "                G.add_node((t, s))\n",
    "                pos[(t, s)] = (s, -t)\n",
    "\n",
    "                for a in all_actions:\n",
    "                    for s_prime in all_states:\n",
    "                        if transition_probs[a, s, s_prime] > 0:\n",
    "                            G.add_node(((t+1), s_prime))\n",
    "                            pos[(t+1, s_prime)] = (s_prime, -(t+1)) \n",
    "                            G.add_edge((t, s), ((t+1), s_prime), key=a, label=f\"({a}, {transition_probs[a, s, s_prime]})\")\n",
    "\n",
    "        return G\n",
    "\n",
    "    def build_cf_graph(self, transition_probs, T, all_states, all_actions, k):\n",
    "        G = nx.MultiDiGraph()\n",
    "        pos = {}\n",
    "        \n",
    "        for t in range(T):\n",
    "            for s in all_states:\n",
    "                G.add_node((t, s))\n",
    "                pos[(t, s)] = (s, -t)\n",
    "\n",
    "                for a in all_actions:\n",
    "                    for s_prime in all_states:\n",
    "                        if transition_probs[a, t][s, s_prime] > 0:\n",
    "                            G.add_node(((t+1), s_prime))\n",
    "                            pos[(t+1, s_prime)] = (s_prime, -(t+1)) \n",
    "                            G.add_edge((t, s), ((t+1), s_prime), key=a, label=f\"({a}, {transition_probs[a, t][s, s_prime]})\")\n",
    "\n",
    "                # If node has no outgoing edges, remove it from the graph\n",
    "                if G.has_node((t, s)) and G.out_degree((t, s)) == 0 and t < T-1:\n",
    "                    G.remove_node((t, s))\n",
    "\n",
    "        # Remove unreachable nodes at t>0 with in-degree = 0\n",
    "        unreachable_nodes = {n for n in G if G.in_degree(n) == 0 and n[0]>0}\n",
    "\n",
    "        while len(unreachable_nodes) > 0:\n",
    "            G.remove_nodes_from(unreachable_nodes)\n",
    "            unreachable_nodes = {n for n in G if G.in_degree(n) == 0 and n[0]>0}\n",
    "\n",
    "        return G\n",
    "    \n",
    "    def get_counterfactual_transition_probabilities(self, P_cf, original_G, new_mdp_G, all_states, all_actions, A_real, S_real, T, k):\n",
    "        print(f\"Calculating counterfactual transition probabilities for k={k}\")\n",
    "        # Update the transition probabilities P_cf with the pruned mdp new_mdp_G.\n",
    "        # Remove actions entirely to ensure that the probabilities for each action in\n",
    "        # each state add up to 1. Keep track of which actions are valid choices in\n",
    "        # which states.        \n",
    "        valid_action = np.full((T, len(all_states), len(all_actions)), False)\n",
    "\n",
    "        T = self.T\n",
    "\n",
    "        for t in range(T-1, -1, -1):\n",
    "            for s in range(1440):\n",
    "                for a in range(8):\n",
    "                    for s_prime in range(1440):\n",
    "                        if new_mdp_G.has_node((t, s)) and P_cf[a, t][s, s_prime] > 0.0:\n",
    "                            if not new_mdp_G.has_edge((t, s), (t+1, s_prime), key=a):\n",
    "                                imm_descendants = nx.descendants_at_distance(new_mdp_G, (t, s), 1)\n",
    "\n",
    "                                for imm_descendant in imm_descendants:\n",
    "                                    if new_mdp_G.has_edge((t, s), imm_descendant, key=a):\n",
    "                                        new_mdp_G.remove_edge((t, s), imm_descendant, key=a)\n",
    "\n",
    "                # If node has no outgoing edges, remove it from the graph\n",
    "                if new_mdp_G.has_node((t, s)) and new_mdp_G.out_degree((t, s)) == 0 and t < T-1:\n",
    "                    new_mdp_G.remove_node((t, s))\n",
    "\n",
    "            # Remove unreachable nodes at t>0 with in-degree = 0\n",
    "            unreachable_nodes = {n for n in new_mdp_G if new_mdp_G.in_degree(n) == 0 and n[0]>0}\n",
    "\n",
    "            while len(unreachable_nodes) > 0:\n",
    "                new_mdp_G.remove_nodes_from(unreachable_nodes)\n",
    "                unreachable_nodes = {n for n in new_mdp_G if new_mdp_G.in_degree(n) == 0 and n[0]>0}\n",
    "\n",
    "        for t in range(T-1, -1, -1):\n",
    "            for s in range(1440):\n",
    "                for a in range(8):\n",
    "                    for s_prime in range(1440):\n",
    "                        if P_cf[a, t][s, s_prime] > 0.0:\n",
    "                            if not new_mdp_G.has_edge((t, s), (t+1, s_prime), key=a):\n",
    "                               P_cf[a, t][s, :] = 0.0\n",
    "                        else:\n",
    "                            assert(P_cf[a, t][s, s_prime] == 0.0)\n",
    "                    \n",
    "                    if sum(P_cf[a, t][s, :]) == 1.0:\n",
    "                        valid_action[t, s, a] = True\n",
    "\n",
    "                    if a == A_real[t] and s == S_real[t]:\n",
    "                        assert(valid_action[t, s, a])\n",
    "\n",
    "        return P_cf, valid_action\n",
    "\n",
    "    def find_wcc(self, G): \n",
    "        s_0 = self.initial_state\n",
    "        t_0 = 0\n",
    "        target_node = (t_0, s_0)\n",
    "        print(f'the initial state is {s_0}')\n",
    "\n",
    "        # Create a subgraph containing only the weakly connected component of the target node\n",
    "        for i, component in enumerate(nx.weakly_connected_components(G)):\n",
    "            if target_node in component:\n",
    "                G_sub = G.subgraph(component).copy()\n",
    "                break\n",
    "        \n",
    "        return G_sub\n",
    "\n",
    "    def get_influence_graph(self, G, k, influenced_transitions):\n",
    "        print(f\"Generating influence graph for k={k}\")\n",
    "\n",
    "        def reverse_bfs(G, start_nodes, k):\n",
    "            distance = defaultdict(lambda: float('inf'))\n",
    "            nodes_to_visit = deque([(node, 0) for node in start_nodes])\n",
    "            within_k_steps = set()\n",
    "\n",
    "            while nodes_to_visit:\n",
    "                curr_node, curr_dist = nodes_to_visit.popleft()\n",
    "\n",
    "                if curr_dist <= k:\n",
    "                    within_k_steps.add(curr_node)\n",
    "\n",
    "                    if distance[curr_node] > curr_dist:\n",
    "                        distance[curr_node] = curr_dist\n",
    "\n",
    "                        for predecessor in G.predecessors(curr_node):\n",
    "                            nodes_to_visit.append((predecessor, curr_dist+1))\n",
    "\n",
    "            return within_k_steps\n",
    "\n",
    "        directly_influenced_nodes = set()\n",
    "\n",
    "        for s in range(1440):\n",
    "            for a in range(8):\n",
    "                for s_prime in range(1440):\n",
    "                    for t in range(self.T):\n",
    "                        if influenced_transitions[a,t][s, s_prime]:\n",
    "                            directly_influenced_nodes.add((t+1, s_prime))\n",
    "\n",
    "        reachable_nodes = reverse_bfs(G, directly_influenced_nodes, k)\n",
    "\n",
    "        influence_graph = G.subgraph(reachable_nodes).copy()\n",
    "\n",
    "        # If we are between T-k+1 and T, then we want to add all the paths between these layers, as they are all treated as influenced.\n",
    "        for timestep in range(self.T-k+1, self.T):\n",
    "            for s in range(1440):\n",
    "                for a in range(8):\n",
    "                    for s_prime in range(1440):\n",
    "                        if not influence_graph.has_edge((timestep, s), (timestep+1, s_prime), key=a) and G.has_edge((timestep, s), (timestep+1, s_prime), key=a):\n",
    "                            influence_graph.add_edge((timestep, s), (timestep+1, s_prime), key=a)\n",
    "\n",
    "        # Remove nodes with in-degree = 0 or out-degree = 0\n",
    "        unreachable_nodes = {n for n in influence_graph if (influence_graph.in_degree(n) == 0 and n[0]>0) or (influence_graph.out_degree(n) == 0 and n[0] < self.T)}\n",
    "\n",
    "        while len(unreachable_nodes) > 0:\n",
    "            influence_graph.remove_nodes_from(unreachable_nodes)\n",
    "            unreachable_nodes = {n for n in influence_graph if (influence_graph.in_degree(n) == 0 and n[0]>0) or (influence_graph.out_degree(n) == 0 and n[0] < self.T)}\n",
    "\n",
    "        return influence_graph\n",
    "\n",
    "    def prune_mdp(self): \n",
    "        # Build graph using the original MDP transition probabilities.\n",
    "        G = self.build_graph(tx_mat_full, self.T, self.states, self.actions)\n",
    "\n",
    "        # Build the influence graph for each look-ahead k\n",
    "        influence_graphs = []\n",
    "\n",
    "        # Generate graphs for the pruned MDP.\n",
    "        for k in range(1, self.look_ahead_k+1):\n",
    "            influence_graph = self.get_influence_graph(copy.deepcopy(G), k, self.influenced_transitions)\n",
    "            influence_graphs.append(influence_graph)\n",
    "\n",
    "        cf_transition_probs = []\n",
    "        valid_actions = []\n",
    "\n",
    "        print(self.mdp_sample)\n",
    "\n",
    "        A_real = self.mdp_sample[0, :, 2]\n",
    "        S_real = self.mdp_sample[0, :, 0]\n",
    "\n",
    "        for look_ahead_k in range(1, self.look_ahead_k+1):\n",
    "            new_P_cf, valid_action = self.get_counterfactual_transition_probabilities(copy.deepcopy(self.P_cf), G, influence_graphs[look_ahead_k-1], self.states, self.actions, A_real, S_real, self.T, look_ahead_k)\n",
    "            cf_transition_probs.append(new_P_cf)\n",
    "            valid_actions.append(valid_action)\n",
    "\n",
    "        # Generate graphs for the pruned counterfactual MDP.\n",
    "        cf_graphs = []\n",
    "\n",
    "        for k in range(1, self.look_ahead_k+1):\n",
    "            G = self.build_cf_graph(cf_transition_probs[k-1], self.T, self.states, self.actions, k)\n",
    "            cf_graphs.append(G)\n",
    "\n",
    "        return cf_transition_probs, valid_actions, cf_graphs\n",
    "\n",
    "    def build_graphs(self, cf_transition_probs):\n",
    "        # Generate graphs for the pruned counterfactual MDP.\n",
    "        cf_graphs = []\n",
    "\n",
    "        for k in range(1, self.look_ahead_k+1):\n",
    "            G = self.build_cf_graph(cf_transition_probs[k-1], self.T, self.states, self.actions, k)\n",
    "            cf_graphs.append(G)\n",
    "\n",
    "        return cf_graphs\n",
    "\n",
    "    def get_optimal_policy(self, max_num_actions_changed, transition_probs, valid_action, new_mdp_G, all_states, all_actions, A_real, T, rewards_pi):\n",
    "        h_fun = np.zeros((1440, T+1, max_num_actions_changed+1)) \n",
    "        pi = np.zeros((1440, max_num_actions_changed+1, T+1), dtype=int) \n",
    "    \n",
    "        for r in range(1, T+1): \n",
    "            for s in range(1440): \n",
    "                h_fun[s, r, 0] = rewards_pi[s][(A_real[T-r])]  \n",
    "                for s_p in range(1440): # for every singe next state (s') for each state s\n",
    "                    h_fun[s, r, 0] += transition_probs[A_real[T-r], T-r][s, s_p] * h_fun[s_p, r-1, 0]\n",
    "                pi[s, max_num_actions_changed, T-r] = A_real[T-r]\n",
    "\n",
    "        for c in range(1, max_num_actions_changed+1): \n",
    "            for r in range(1, T+1): \n",
    "                for s in range(1440):\n",
    "                    pi[s, max_num_actions_changed-c, T-r] = A_real[T-r] # instead let it be the real action\n",
    "                    best_act = A_real[T-r]\n",
    "                    max_val = -np.inf\n",
    "                    \n",
    "                    for a in range(8):\n",
    "                        if valid_action[T-r, s, a]:\n",
    "                            val = rewards_pi[s][a]\n",
    "                            if a != A_real[T-r]: \n",
    "                                for s_p in range(1440):\n",
    "                                    if transition_probs[a, T-r][s, s_p] != 0:\n",
    "                                        val += transition_probs[a, T-r][s, s_p] * h_fun[s_p, r-1, c-1] \n",
    "                            elif a == A_real[T-r]:\n",
    "                                for s_p in range(1440):\n",
    "                                    if transition_probs[a, T-r][s, s_p] != 0:\n",
    "                                        val += transition_probs[a, T-r][s, s_p] * h_fun[s_p, r-1, c]\n",
    "                                        \n",
    "                            if val > max_val:\n",
    "                                max_val = val\n",
    "                                best_act = a\n",
    "\n",
    "                    h_fun[s, r, c] = max_val\n",
    "                    pi[s, max_num_actions_changed-c, T-r] = best_act\n",
    "\n",
    "        return pi, h_fun\n",
    "\n",
    "    def generate_policies(self, cf_transition_probs, valid_actions, cf_graphs):\n",
    "        # Generate policies for each of the pruned counterfactual MDPs.\n",
    "        policies = []\n",
    "        h_funs = []\n",
    "        k_vals = range(1, self.look_ahead_k+1)\n",
    "        A_real = self.mdp_sample[0, :, 2]\n",
    "\n",
    "        new_all_rewards = np.zeros((self.T, len(self.states), len(self.actions))) # for now it is the same as the old one - old has rewards_pi[current_state, action]\n",
    "\n",
    "        for t in range(self.T):\n",
    "            for s in self.states:\n",
    "                for a in self.actions:                \n",
    "                    new_all_rewards[t, s, a] = self.rewards_pi[s, a] # the OG reward\n",
    "\n",
    "        for look_ahead_k in k_vals:\n",
    "            print(f\"Estimating policy with k={look_ahead_k}\")\n",
    "            policies_k = []\n",
    "            h_funs_k = []\n",
    "\n",
    "            for max_num_actions_changed in k_vals:\n",
    "                # Get the optimal policy\n",
    "                pi, h_fun = self.get_optimal_policy(\n",
    "                    max_num_actions_changed, \n",
    "                    cf_transition_probs[look_ahead_k-1],\n",
    "                    valid_actions[look_ahead_k-1],\n",
    "                    cf_graphs[look_ahead_k-1],\n",
    "                    self.states,\n",
    "                    self.actions,\n",
    "                    A_real,\n",
    "                    self.T,\n",
    "                    self.rewards_pi\n",
    "                )\n",
    "\n",
    "                policies_k.append(pi)\n",
    "                h_funs_k.append(h_fun)\n",
    "\n",
    "            policies.append(policies_k)\n",
    "            h_funs.append(h_funs_k)\n",
    "\n",
    "        return policies, new_all_rewards, h_funs\n",
    "\n",
    "    def generate_random_trajectory(self, MDP_samp, transition_probs, pi, s_0, A_real, all_states, rewards_pi, T):\n",
    "        n_obs=MDP_samp.shape[0]\n",
    "        n_state=MDP_samp.shape[2]\n",
    "        CF_trajectory = np.zeros((n_obs, T, n_state))\n",
    "        \n",
    "        rng = np.random.default_rng()\n",
    "        s = np.zeros(T+1, dtype=int)\n",
    "        s[0] = s_0   # Initial state the same\n",
    "        l = np.zeros(T+1, dtype=int)\n",
    "        l[0] = 0    # Start with 0 changes\n",
    "        a = np.zeros(T, dtype=int)\n",
    "        \n",
    "        for t in range(T):\n",
    "            # Pick actions according to the given policy\n",
    "            a[t] = pi[s[t], l[t], t]\n",
    "\n",
    "            # Sample the next state\n",
    "            s[t+1] = (rng.choice(a=all_states, size=1,  p=transition_probs[a[t], t][s[t]]))[0]\n",
    "\n",
    "            # Adjust the number of changes so far\n",
    "            if a[t] != A_real[t]:\n",
    "                l[t+1] = l[t] + 1\n",
    "            else:\n",
    "                l[t+1] = l[t]\n",
    "            \n",
    "            CF_trajectory[0, t, :] = np.array([s[t], s[t+1], a[t], rewards_pi[t, s[t], a[t]]])\n",
    "                    \n",
    "        return CF_trajectory\n",
    "\n",
    "    def generate_cf_trajectories(self, cf_transition_probs, policies, new_all_rewards):\n",
    "        print(f\"Generating CF trajectories\")\n",
    "        all_obs = []\n",
    "        all_cf = []\n",
    "        k_vals = range(1, self.look_ahead_k+1)\n",
    "        A_real = self.mdp_sample[0, :, 2]\n",
    "        s_0 = self.mdp_sample[0, 0, 0]\n",
    "\n",
    "        for _ in range(1000):\n",
    "            obs = np.zeros(shape=(self.look_ahead_k, self.look_ahead_k))\n",
    "            cf = np.zeros(shape=(self.look_ahead_k, self.look_ahead_k))\n",
    "\n",
    "            for look_ahead_k in k_vals:\n",
    "                for max_num_actions_changed in k_vals:\n",
    "                    CF_trajectory = self.generate_random_trajectory(\n",
    "                        self.mdp_sample,\n",
    "                        cf_transition_probs[look_ahead_k-1],\n",
    "                        policies[look_ahead_k-1][max_num_actions_changed-1],\n",
    "                        s_0,\n",
    "                        A_real,\n",
    "                        self.states,\n",
    "                        new_all_rewards,\n",
    "                        self.T\n",
    "                    )\n",
    "\n",
    "                    obs[look_ahead_k-1][max_num_actions_changed-1] = self.mdp_sample[0, self.T-1, 3] # Immediate reward for obs path at time T\n",
    "                    cf[look_ahead_k-1][max_num_actions_changed-1] = CF_trajectory[0, self.T-1, 3] # Immediate reward for cf path at time T\n",
    "            \n",
    "            all_obs.append(obs)\n",
    "            all_cf.append(cf)\n",
    "\n",
    "        all_obs = np.array(all_obs)\n",
    "        all_cf = np.array(all_cf)\n",
    "\n",
    "        mean_obs = all_obs.mean(axis=0)\n",
    "        mean_cf = all_cf.mean(axis=0)\n",
    "\n",
    "        return mean_obs, mean_cf, k_vals"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Prune MDP"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mdp = MDP(init_state_idx=None, policy_array=fullPol, policy_idx_type='full', p_diabetes=PROB_DIAB)\n",
    "\n",
    "influence_pruner = InfluenceMDPPruner(mdp, look_ahead_k=11)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "cf_transition_probs, valid_actions, cf_graphs = influence_pruner.prune_mdp()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Generating Policies"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "policies, new_all_rewards, h_funs = influence_pruner.generate_policies(cf_transition_probs, valid_actions, cf_graphs)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Print Value Function of S_0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(MDP_samp)\n",
    "\n",
    "values = []\n",
    "k_vals = range(1, 12)\n",
    "obs_values = None\n",
    "\n",
    "for look_ahead_k in k_vals:\n",
    "    k_values = []\n",
    "    obs_values = []\n",
    "\n",
    "    for max_num_actions_changed in k_vals:\n",
    "        h_fun = h_funs[look_ahead_k-1][max_num_actions_changed-1]\n",
    "\n",
    "        # s_0 = 777\n",
    "\n",
    "        k_values.append(h_fun[777, -1, max_num_actions_changed])\n",
    "        obs_values.append(h_fun[777, -1, 0])\n",
    "\n",
    "    values.append(k_values)\n",
    "\n",
    "fig = plt.figure(figsize=(4, 6));\n",
    "ax = fig.add_subplot()\n",
    "\n",
    "plt.title(f'Value of Initial State Given Influence')\n",
    "plt.xlabel('Maximum Number of Actions Changed')\n",
    "plt.ylabel('V(S0)'); \n",
    "plt.grid(which='both')\n",
    "\n",
    "ax.scatter(k_vals, obs_values, color='lightpink', label='Observed reward', marker=\"o\", s=100);\n",
    "colors = ['orange', 'red', 'aqua', 'yellow', 'darkblue', 'deeppink', 'darkviolet', 'silver', 'teal', 'blue']\n",
    "\n",
    "for look_ahead_k in range(1, 11):\n",
    "    print(values[look_ahead_k-1])\n",
    "\n",
    "ax.scatter(k_vals, values[-3], color='blue', label='CF reward', marker=\"d\", s=50)\n",
    "ax.scatter(k_vals, values[-2], color='deeppink', label='CF reward', marker=\"d\", s=50)\n",
    "ax.scatter(k_vals, values[-1], color = 'green', label='CF reward', marker=\"d\", s=50)\n",
    "\n",
    "\n",
    "plt.legend([\"Observed Path\", \"Look-Ahead K=1 to 9\", \"Look-Ahead K=10 (T)\", \"Look-Ahead K=∞\"], loc=0, frameon=True)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Generating CF Trajectories"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mean_obs, mean_cf, k_vals = influence_pruner.generate_cf_trajectories(cf_transition_probs, policies, new_all_rewards)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig = plt.figure(figsize=(20, 20))\n",
    "NUM_ITERATIONS = 1000\n",
    "ax = fig.add_subplot()\n",
    "\n",
    "plt.title(f'Final State Reward of Observed vs Counterfactual Paths After Pruning MDP, Averaged Over {NUM_ITERATIONS} Iterations')\n",
    "plt.xlabel('Maximum Number of Actions Changed');\n",
    "plt.ylabel('Final State Reward'); \n",
    "plt.grid(which='both')\n",
    "\n",
    "print(mean_obs)\n",
    "print(mean_cf)\n",
    "\n",
    "ax.scatter(k_vals, mean_obs[0], color='lightpink', label='Observed reward', marker=\"o\", s=100);\n",
    "colors = ['orange', 'red', 'aqua', 'yellow', 'darkblue', 'deeppink', 'darkviolet', 'silver', 'teal', 'blue']\n",
    "\n",
    "for look_ahead_k in k_vals[:-1]:\n",
    "    ax.scatter(k_vals, mean_cf[look_ahead_k-1], color=colors[look_ahead_k-1], label='CF reward', marker=\"d\", s=50)\n",
    "\n",
    "ax.scatter(k_vals, mean_cf[-1], color = 'yellow', label='CF reward', marker=\"d\", s=30)\n",
    "\n",
    "plt.legend([\"Observed Path\", \"Look-Ahead K=1\", \"Look-Ahead K=2\", \"Look-Ahead K=3\", \"Look-Ahead K=4\", \"Look-Ahead K=5\", \"Look-Ahead K=6\", \"Look-Ahead K=7\", \"Look-Ahead K=8\", \"Look-Ahead K=9\", \"Look-Ahead K=10\"], loc=0, frameon=True)\n",
    "plt.show()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
