{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "metadata": {}
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import os\n",
    "import matplotlib.pyplot as plt\n",
    "np.seterr(divide = 'ignore') \n",
    "from scipy.linalg import block_diag\n",
    "import warnings\n",
    "warnings.simplefilter(action='ignore', category=FutureWarning)\n",
    "import matplotlib\n",
    "matplotlib.rcParams['pdf.fonttype'] = 42\n",
    "matplotlib.rcParams['ps.fonttype'] = 42\n",
    "import cf.counterfactual as cf\n",
    "import networkx as nx\n",
    "import copy\n",
    "import pickle"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "metadata": {}
   },
   "outputs": [],
   "source": [
    "class Action(object):\n",
    "    NUM_ACTIONS_TOTAL = 8\n",
    "    ANTIBIOTIC_STRING = \"antibiotic\"\n",
    "    VENT_STRING = \"ventilation\"\n",
    "    VASO_STRING = \"vasopressors\"\n",
    "    ACTION_VEC_SIZE = 3\n",
    "\n",
    "    def __init__(self, selected_actions = None, action_idx = None):\n",
    "        # This method sets up the action object. \n",
    "        # Actions can be specified in two ways: by providing a list of selected actions (as strings) or by an action index.\n",
    "        \n",
    "        assert (selected_actions is not None and action_idx is None) \\\n",
    "            or (selected_actions is None and action_idx is not None), \\\n",
    "            \"must specify either set of action strings or action index\"\n",
    "            \n",
    "        if selected_actions is not None:\n",
    "            # For each of the three treatments (ANTIBIOTIC_STRING, VENT_STRING, VASO_STRING), the code checks if its corresponding string is present in the selected_actions. \n",
    "            # If it is, the relevant attribute (e.g., self.antibiotic) is set to 1, indicating that treatment is selected. Otherwise, it's set to 0, indicating the treatment is not selected.\n",
    "            if Action.ANTIBIOTIC_STRING in selected_actions:\n",
    "                self.antibiotic = 1\n",
    "            else:\n",
    "                self.antibiotic = 0\n",
    "            if Action.VENT_STRING in selected_actions:\n",
    "                self.ventilation = 1\n",
    "            else:\n",
    "                self.ventilation = 0\n",
    "            if Action.VASO_STRING in selected_actions:\n",
    "                self.vasopressors = 1\n",
    "            else:\n",
    "                self.vasopressors = 0\n",
    "                \n",
    "        else:\n",
    "            # This block decomposes the action_idx (from 0 to 7) into the three binary treatment values (0 or 1). \n",
    "            # This process assumes a specific order and numbering scheme for the action index and treatments.\n",
    "            mod_idx = action_idx\n",
    "            term_base = Action.NUM_ACTIONS_TOTAL/2\n",
    "            self.antibiotic = np.floor(mod_idx/term_base).astype(int)\n",
    "            mod_idx %= term_base\n",
    "            term_base /= 2\n",
    "            self.ventilation = np.floor(mod_idx/term_base).astype(int)\n",
    "            mod_idx %= term_base\n",
    "            term_base /= 2\n",
    "            self.vasopressors = np.floor(mod_idx/term_base).astype(int)\n",
    "            \n",
    "            '''\n",
    "            There are three treatments (A, E, V) and thus 2^3 = 8 possible action combinations. \n",
    "            The binary representation of action_idx from 0 to 7 can be thought of as the action combinations:\n",
    "\n",
    "                000 -> No treatments\n",
    "                001 -> V\n",
    "                010 -> E\n",
    "                011 -> E, V\n",
    "                100 -> A\n",
    "                101 -> A, V\n",
    "                110 -> A, E\n",
    "                111 -> A, E, V\n",
    "                \n",
    "            The code block breaks down action_idx to understand which treatments are being used and initializes the three attributes (self.antibiotic, self.ventilation, self.vasopressors) accordingly.\n",
    "            '''\n",
    "            \n",
    "    # Equality and Inequality (__eq__ and __ne__ methods): These are to check the equality or inequality of two Action objects.\n",
    "\n",
    "    def __eq__(self, other):\n",
    "        return isinstance(other, self.__class__) and \\\n",
    "            self.antibiotic == other.antibiotic and \\\n",
    "            self.ventilation == other.ventilation and \\\n",
    "            self.vasopressors == other.vasopressors\n",
    "\n",
    "    def __ne__(self, other):\n",
    "        return not self.__eq__(other)\n",
    "\n",
    "    # Get Action Index (get_action_idx method): This method converts the selected actions into an integer index.\n",
    "    \n",
    "    def get_action_idx(self):\n",
    "        assert self.antibiotic in (0, 1)\n",
    "        assert self.ventilation in (0, 1)\n",
    "        assert self.vasopressors in (0, 1)\n",
    "        return 4*self.antibiotic + 2*self.ventilation + self.vasopressors\n",
    "    '''\n",
    "    The weighted sum effectively encodes the three binary values into a single integer (form 0 to 7; NUM_ACTIONS_TOTAL = 8 in total). \n",
    "    The weights (4, 2, and 1) were chosen to uniquely identify each combination of the three treatments.\n",
    "    \n",
    "    For example:\n",
    "\n",
    "        If only antibiotic is used: action_idx = 4*1 + 2*0 + 0*1 = 4.\n",
    "        If only ventilation is used: action_idx = 4*0 + 2*1 + 0*1 = 2.\n",
    "        If antibiotic and ventilation are used: action_idx = 4*1 + 2*1 + 0*1 = 6.\n",
    "        If all three are used: action_idx = 4*1 + 2*1 + 1*1 = 7.\n",
    "    '''\n",
    "\n",
    "    # Hash (__hash__ method): Provides a unique hash for the action object. This is important if you want to use Action objects as keys in a dictionary.\n",
    "\n",
    "    def __hash__(self):\n",
    "        return self.get_action_idx()\n",
    "    \n",
    "    # Get Selected Actions (get_selected_actions method): Returns a set of selected actions for the object.\n",
    "\n",
    "    def get_selected_actions(self):\n",
    "        selected_actions = set()\n",
    "        if self.antibiotic == 1:\n",
    "            selected_actions.add(Action.ANTIBIOTIC_STRING)\n",
    "        if self.ventilation == 1:\n",
    "            selected_actions.add(Action.VENT_STRING)\n",
    "        if self.vasopressors == 1:\n",
    "            selected_actions.add(Action.VASO_STRING)\n",
    "        return selected_actions\n",
    "    \n",
    "    # Abbreviated String (get_abbrev_string method): Returns a short string representation of the actions (A for antibiotic, E for ventilation, V for vasopressors).\n",
    "\n",
    "    def get_abbrev_string(self):\n",
    "        '''\n",
    "        AEV: antibiotics, ventilation, vasopressors\n",
    "        '''\n",
    "        output_str = ''\n",
    "        if self.antibiotic == 1:\n",
    "            output_str += 'A'\n",
    "        if self.ventilation == 1:\n",
    "            output_str += 'E'\n",
    "        if self.vasopressors == 1:\n",
    "            output_str += 'V'\n",
    "        return output_str\n",
    "\n",
    "    # Action Vector (get_action_vec method): Returns a numpy array representation of the action, with a shape of (3,1).\n",
    "    \n",
    "    def get_action_vec(self):\n",
    "        return np.array([[self.antibiotic], [self.ventilation], [self.vasopressors]])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "metadata": {}
   },
   "outputs": [],
   "source": [
    "class State(object):\n",
    "    NUM_OBS_STATES = 720\n",
    "    NUM_HID_STATES = 2  # Binary value of diabetes\n",
    "    NUM_PROJ_OBS_STATES = int(720 / 5)  # Marginalizing over glucose\n",
    "    NUM_FULL_STATES = int(NUM_OBS_STATES * NUM_HID_STATES)\n",
    "\n",
    "    def __init__(self, state_idx = None, idx_type = 'obs', diabetic_idx = None, state_categs = None):\n",
    "    # __init__: Constructor method to initialize the state either by its index or by passing specific categories for each state variable.\n",
    "\n",
    "        assert state_idx is not None or state_categs is not None\n",
    "        # assert ((diabetic_idx is not None and diabetic_idx in [0, 1]) or\n",
    "        #         (state_idx is not None and idx_type == 'full'))\n",
    "\n",
    "        assert idx_type in ['obs', 'full', 'proj_obs']\n",
    "\n",
    "        if state_idx is not None:\n",
    "            self.set_state_by_idx(\n",
    "                    state_idx, idx_type=idx_type, diabetic_idx=diabetic_idx)\n",
    "        elif state_categs is not None:\n",
    "            assert len(state_categs) == 7, \"must specify 7 state variables\"\n",
    "            self.hr_state = state_categs[0]\n",
    "            self.sysbp_state = state_categs[1]\n",
    "            self.percoxyg_state = state_categs[2]\n",
    "            self.glucose_state = state_categs[3]\n",
    "            self.antibiotic_state = state_categs[4]\n",
    "            self.vaso_state = state_categs[5]\n",
    "            self.vent_state = state_categs[6]\n",
    "            self.diabetic_idx = diabetic_idx\n",
    "\n",
    "    def check_absorbing_state(self):\n",
    "        # check_absorbing_state: Checks if the state is \"absorbing\" which means it has a certain \n",
    "        # number of abnormal conditions or it is a normal state with no ongoing treatment.\n",
    "        num_abnormal = self.get_num_abnormal()\n",
    "        if num_abnormal >= 3:\n",
    "            return True\n",
    "        elif num_abnormal == 0 and not self.on_treatment():\n",
    "            return True\n",
    "        return False\n",
    "    \n",
    "    def state_rewards(self):\n",
    "        # check_absorbing_state: Checks if the state is \"absorbing\" which means it has a certain \n",
    "        # number of abnormal conditions or it is a normal state with no ongoing treatment.\n",
    "        num_abnormal = self.get_num_abnormal()\n",
    "        if num_abnormal >= 3:\n",
    "            return (-1000)\n",
    "        elif num_abnormal == 2:\n",
    "            return (-50)\n",
    "        elif num_abnormal == 1:\n",
    "            return (+50)\n",
    "        elif num_abnormal == 0 and self.on_treatment():\n",
    "            return (+70)\n",
    "        elif num_abnormal == 0 and not self.on_treatment():\n",
    "            return (+1000)\n",
    "\n",
    "    def set_state_by_idx(self, state_idx, idx_type, diabetic_idx=None):\n",
    "        \n",
    "        # set_state_by_idx: interprets the state index into its respective categorical variables. \n",
    "        # Depending on the index type (observable, full, or projected observable), the function decodes the index and sets the member variables. \n",
    "        # This method employs a form of \"bit\" arithmetic, even though not all states are binary.\n",
    "        \"\"\"set_state_by_idx\n",
    "\n",
    "        The state index is determined by using \"bit\" arithmetic, with the\n",
    "        complication that not every state is binary\n",
    "\n",
    "        :param state_idx: Given index\n",
    "        :param idx_type: Index type, either observed (720), projected (144) or\n",
    "        full (720)\n",
    "        :param diabetic_idx: If full state index not given, this is required\n",
    "        \"\"\"\n",
    "        \n",
    "        # Determine Base for Arithmetic: Depending on the idx_type, the method calculates the term_base. \n",
    "        # This base will be used for extracting individual state information from the given index. \n",
    "        # The choice of this base reflects the number of categories available for the primary state variables.\n",
    "        \n",
    "        if idx_type == 'obs':\n",
    "            term_base = State.NUM_OBS_STATES/3 # Starts with heart rate\n",
    "        elif idx_type == 'proj_obs':\n",
    "            term_base = State.NUM_PROJ_OBS_STATES/3\n",
    "        elif idx_type == 'full':\n",
    "            term_base = State.NUM_FULL_STATES/2 # Starts with diab\n",
    "        \n",
    "\n",
    "        # Start with the given state index\n",
    "        mod_idx = state_idx\n",
    "\n",
    "        if idx_type == 'full':\n",
    "            # If the idx_type is 'full', the function first extracts the diabetes status (diabetic_idx) \n",
    "            # and then adjusts the base for the next state variable (heart rate).\n",
    "            \n",
    "            self.diabetic_idx = np.floor(mod_idx/term_base).astype(int)\n",
    "            mod_idx %= term_base\n",
    "            term_base /= 3 # This is for heart rate, the next item\n",
    "        else:\n",
    "            assert diabetic_idx is not None\n",
    "            self.diabetic_idx = diabetic_idx\n",
    "\n",
    "        self.hr_state = np.floor(mod_idx/term_base).astype(int)\n",
    "\n",
    "        mod_idx %= term_base\n",
    "        term_base /= 3\n",
    "        self.sysbp_state = np.floor(mod_idx/term_base).astype(int)\n",
    "\n",
    "        mod_idx %= term_base\n",
    "        term_base /= 2\n",
    "        self.percoxyg_state = np.floor(mod_idx/term_base).astype(int)\n",
    "\n",
    "        if idx_type == 'proj_obs':\n",
    "            self.glucose_state = 2\n",
    "        else:\n",
    "            mod_idx %= term_base\n",
    "            term_base /= 5\n",
    "            self.glucose_state = np.floor(mod_idx/term_base).astype(int)\n",
    "\n",
    "        mod_idx %= term_base\n",
    "        term_base /= 2\n",
    "        self.antibiotic_state = np.floor(mod_idx/term_base).astype(int)\n",
    "\n",
    "        mod_idx %= term_base\n",
    "        term_base /= 2\n",
    "        self.vaso_state = np.floor(mod_idx/term_base).astype(int)\n",
    "\n",
    "        mod_idx %= term_base\n",
    "        term_base /= 2\n",
    "        self.vent_state = np.floor(mod_idx/term_base).astype(int)\n",
    "\n",
    "\n",
    "    def get_state_idx(self, idx_type='obs'):\n",
    "        # Opposite of set_state_by_idx. It takes the categorical variables of the state and returns its index. \n",
    "        # It constructs the index using the \"bit\" arithmetic approach.\n",
    "        '''\n",
    "        returns integer index of state: significance order as in categorical array\n",
    "        '''\n",
    "        \n",
    "        if idx_type == 'obs':\n",
    "            categ_num = np.array([3,3,2,5,2,2,2])\n",
    "            state_categs = [\n",
    "                    self.hr_state,\n",
    "                    self.sysbp_state,\n",
    "                    self.percoxyg_state,\n",
    "                    self.glucose_state,\n",
    "                    self.antibiotic_state,\n",
    "                    self.vaso_state,\n",
    "                    self.vent_state]\n",
    "        elif idx_type == 'proj_obs':\n",
    "            categ_num = np.array([3,3,2,2,2,2])\n",
    "            state_categs = [\n",
    "                    self.hr_state,\n",
    "                    self.sysbp_state,\n",
    "                    self.percoxyg_state,\n",
    "                    self.antibiotic_state,\n",
    "                    self.vaso_state,\n",
    "                    self.vent_state]\n",
    "        elif idx_type == 'full':\n",
    "            categ_num = np.array([2,3,3,2,5,2,2,2])\n",
    "            state_categs = [\n",
    "                    self.diabetic_idx,\n",
    "                    self.hr_state,\n",
    "                    self.sysbp_state,\n",
    "                    self.percoxyg_state,\n",
    "                    self.glucose_state,\n",
    "                    self.antibiotic_state,\n",
    "                    self.vaso_state,\n",
    "                    self.vent_state]\n",
    "\n",
    "        sum_idx = 0\n",
    "        prev_base = 1\n",
    "        for i in range(len(state_categs)):\n",
    "            idx = len(state_categs) - 1 - i\n",
    "            sum_idx += prev_base*state_categs[idx]\n",
    "            prev_base *= categ_num[idx]\n",
    "        return sum_idx\n",
    "    \n",
    "    # __eq__, __ne__, __hash__: Overridden methods to check for equality, inequality, and to generate a hash value respectively.\n",
    "\n",
    "    def __eq__(self, other):\n",
    "        '''\n",
    "        override equals: two states equal if all internal states same\n",
    "        '''\n",
    "        return isinstance(other, self.__class__) and \\\n",
    "            self.hr_state == other.hr_state and \\\n",
    "            self.sysbp_state == other.sysbp_state and \\\n",
    "            self.percoxyg_state == other.percoxyg_state and \\\n",
    "            self.glucose_state == other.glucose_state and \\\n",
    "            self.antibiotic_state == other.antibiotic_state and \\\n",
    "            self.vaso_state == other.vaso_state and \\\n",
    "            self.vent_state == other.vent_state\n",
    "\n",
    "    def __ne__(self, other):\n",
    "        return not self.__eq__(other)\n",
    "\n",
    "    def __hash__(self):\n",
    "        return self.get_state_idx()\n",
    "\n",
    "    def get_num_abnormal(self):\n",
    "        # get_num_abnormal: Counts and returns the number of abnormal conditions present in the current state.\n",
    "        '''\n",
    "        returns number of abnormal conditions\n",
    "        '''\n",
    "        num_abnormal = 0\n",
    "        if self.hr_state != 1:\n",
    "            num_abnormal += 1\n",
    "        if self.sysbp_state != 1:\n",
    "            num_abnormal += 1\n",
    "        if self.percoxyg_state != 1:\n",
    "            num_abnormal += 1\n",
    "        if self.glucose_state != 2:\n",
    "            num_abnormal += 1\n",
    "        return num_abnormal\n",
    "\n",
    "    # on_treatment, on_antibiotics, on_vasopressors, on_ventilation: These methods check if certain treatments are active.\n",
    "    \n",
    "    def on_treatment(self):\n",
    "        '''\n",
    "        returns True iff any of 3 treatments active\n",
    "        '''\n",
    "        if self.antibiotic_state == 0 and \\\n",
    "            self.vaso_state == 0 and self.vent_state == 0:\n",
    "            return False\n",
    "        return True\n",
    "\n",
    "    def on_antibiotics(self):\n",
    "        '''\n",
    "        returns True iff antibiotics active\n",
    "        '''\n",
    "        return self.antibiotic_state == 1\n",
    "\n",
    "    def on_vasopressors(self):\n",
    "        '''\n",
    "        returns True iff vasopressors active\n",
    "        '''\n",
    "        return self.vaso_state == 1\n",
    "\n",
    "    def on_ventilation(self):\n",
    "        '''\n",
    "        returns True iff ventilation active\n",
    "        '''\n",
    "        return self.vent_state == 1\n",
    "\n",
    "    def copy_state(self):\n",
    "        return State(state_categs = [\n",
    "            self.hr_state,\n",
    "            self.sysbp_state,\n",
    "            self.percoxyg_state,\n",
    "            self.glucose_state,\n",
    "            self.antibiotic_state,\n",
    "            self.vaso_state,\n",
    "            self.vent_state],\n",
    "            diabetic_idx=self.diabetic_idx)\n",
    "\n",
    "    def get_state_vector(self):\n",
    "        # get_state_vector: Returns the state as a vector (numpy array).\n",
    "        return np.array([self.hr_state,\n",
    "            self.sysbp_state,\n",
    "            self.percoxyg_state,\n",
    "            self.glucose_state,\n",
    "            self.antibiotic_state,\n",
    "            self.vaso_state,\n",
    "            self.vent_state]).astype(int)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "metadata": {}
   },
   "outputs": [],
   "source": [
    "class MDP(object):\n",
    "\n",
    "    def __init__(self, init_state_idx=None, init_state_idx_type='obs', policy_array=None, policy_idx_type='obs', p_diabetes=0.2):\n",
    "\n",
    "        assert p_diabetes >= 0 and p_diabetes <= 1, \\\n",
    "                \"Invalid p_diabetes: {}\".format(p_diabetes)\n",
    "        assert policy_idx_type in ['obs', 'full', 'proj_obs']\n",
    "\n",
    "        # Check the policy dimensions (states x actions)\n",
    "        if policy_array is not None:\n",
    "            assert policy_array.shape[1] == Action.NUM_ACTIONS_TOTAL\n",
    "            if policy_idx_type == 'obs':\n",
    "                assert policy_array.shape[0] == State.NUM_OBS_STATES\n",
    "            elif policy_idx_type == 'full':\n",
    "                assert policy_array.shape[0] == \\\n",
    "                        State.NUM_HID_STATES * State.NUM_OBS_STATES\n",
    "            elif policy_idx_type == 'proj_obs':\n",
    "                assert policy_array.shape[0] == State.NUM_PROJ_OBS_STATES\n",
    "\n",
    "        # p_diabetes is used to generate random state if init_state is None\n",
    "        self.p_diabetes = p_diabetes\n",
    "        self.state = None\n",
    "\n",
    "        # Only need to use init_state_idx_type if you are providing a state_idx!\n",
    "        self.state = self.get_new_state(init_state_idx, init_state_idx_type)\n",
    "\n",
    "        self.policy_array = policy_array\n",
    "        self.policy_idx_type = policy_idx_type  # Used for mapping the policy to actions\n",
    "        \n",
    "\n",
    "    def get_new_state(self, state_idx = None, idx_type = 'obs', diabetic_idx = None):\n",
    "\n",
    "        assert idx_type in ['obs', 'full', 'proj_obs']\n",
    "        option = None\n",
    "        if state_idx is not None:\n",
    "            if idx_type == 'obs' and diabetic_idx is not None:\n",
    "                option = 'spec_obs'\n",
    "            elif idx_type == 'obs' and diabetic_idx is None:\n",
    "                option = 'spec_obs_no_diab'\n",
    "                diabetic_idx = np.random.binomial(1, self.p_diabetes)\n",
    "            elif idx_type == 'full':\n",
    "                option = 'spec_full'\n",
    "            elif idx_type == 'proj_obs' and diabetic_idx is not None:\n",
    "                option = 'spec_proj_obs'\n",
    "        elif state_idx is None and diabetic_idx is None:\n",
    "            option = 'random'\n",
    "        elif state_idx is None and diabetic_idx is not None:\n",
    "            option = 'random_cond_diab'\n",
    "\n",
    "        assert option is not None, \"Invalid specification of new state\"\n",
    "\n",
    "        if option in ['random', 'random_cond_diab'] :\n",
    "            init_state = self.generate_random_state(diabetic_idx)\n",
    "            # Do not start in death or discharge state\n",
    "            while init_state.check_absorbing_state():\n",
    "                init_state = self.generate_random_state(diabetic_idx)\n",
    "        else:\n",
    "            # Note that diabetic_idx will be ignored if idx_type = 'full'\n",
    "            init_state = State(\n",
    "                    state_idx=state_idx, idx_type=idx_type,\n",
    "                    diabetic_idx=diabetic_idx)\n",
    "\n",
    "        return init_state\n",
    "\n",
    "    def generate_random_state(self, diabetic_idx=None):\n",
    "        # generate_random_state: Randomly generates a patient's health state.\n",
    "        \n",
    "        # Note that we will condition on diabetic idx if provided\n",
    "        if diabetic_idx is None:\n",
    "            diabetic_idx = np.random.binomial(1, self.p_diabetes)\n",
    "\n",
    "        # hr and sys_bp w.p. [.25, .5, .25]\n",
    "        hr_state = np.random.choice(np.arange(3), p=np.array([.25, .5, .25]))\n",
    "        sysbp_state = np.random.choice(np.arange(3), p=np.array([.25, .5, .25]))\n",
    "        # percoxyg w.p. [.2, .8]\n",
    "        percoxyg_state = np.random.choice(np.arange(2), p=np.array([.2, .8]))\n",
    "\n",
    "        if diabetic_idx == 0:\n",
    "            glucose_state = np.random.choice(np.arange(5), \\\n",
    "                p=np.array([.05, .15, .6, .15, .05]))\n",
    "        else:\n",
    "            glucose_state = np.random.choice(np.arange(5), \\\n",
    "                p=np.array([.01, .05, .15, .6, .19]))\n",
    "        antibiotic_state = 0\n",
    "        vaso_state = 0\n",
    "        vent_state = 0\n",
    "\n",
    "        state_categs = [hr_state, sysbp_state, percoxyg_state,\n",
    "                glucose_state, antibiotic_state, vaso_state, vent_state]\n",
    "\n",
    "        return State(state_categs=state_categs, diabetic_idx=diabetic_idx)\n",
    "\n",
    "    # transition_antibiotics_on/off: Models the effect of turning antibiotics on/off.\n",
    "    \n",
    "    def transition_antibiotics_on(self):\n",
    "        \n",
    "        '''\n",
    "        antibiotics state on\n",
    "        heart rate, sys bp: hi -> normal w.p. .5\n",
    "        '''\n",
    "        self.state.antibiotic_state = 1\n",
    "        if self.state.hr_state == 2 and np.random.uniform(0,1) < 0.5:\n",
    "            self.state.hr_state = 1\n",
    "        if self.state.sysbp_state == 2 and np.random.uniform(0,1) < 0.5:\n",
    "            self.state.sysbp_state = 1\n",
    "\n",
    "    def transition_antibiotics_off(self):\n",
    "        '''\n",
    "        antibiotics state off\n",
    "        if antibiotics was on: heart rate, sys bp: normal -> hi w.p. .1\n",
    "        '''\n",
    "        if self.state.antibiotic_state == 1:\n",
    "            if self.state.hr_state == 1 and np.random.uniform(0,1) < 0.1:\n",
    "                self.state.hr_state = 2\n",
    "            if self.state.sysbp_state == 1 and np.random.uniform(0,1) < 0.1:\n",
    "                self.state.sysbp_state = 2\n",
    "            self.state.antibiotic_state = 0\n",
    "\n",
    "    # transition_vent_on/off: Models the effect of turning ventilation on/off.\n",
    "\n",
    "    def transition_vent_on(self):\n",
    "        '''\n",
    "        ventilation state on\n",
    "        percent oxygen: low -> normal w.p. .7\n",
    "        '''\n",
    "        self.state.vent_state = 1\n",
    "        if self.state.percoxyg_state == 0 and np.random.uniform(0,1) < 0.7:\n",
    "            self.state.percoxyg_state = 1\n",
    "\n",
    "    def transition_vent_off(self):\n",
    "        '''\n",
    "        ventilation state off\n",
    "        if ventilation was on: percent oxygen: normal -> lo w.p. .1\n",
    "        '''\n",
    "        if self.state.vent_state == 1:\n",
    "            if self.state.percoxyg_state == 1 and np.random.uniform(0,1) < 0.1:\n",
    "                self.state.percoxyg_state = 0\n",
    "            self.state.vent_state = 0\n",
    "    \n",
    "    # transition_vaso_on/off: Models the effect of turning vasopressors on/off, considering if the patient is diabetic.\n",
    "\n",
    "    def transition_vaso_on(self):\n",
    "        '''\n",
    "        vasopressor state on\n",
    "        for non-diabetic:\n",
    "            sys bp: low -> normal, normal -> hi w.p. .7\n",
    "        for diabetic:\n",
    "            raise blood pressure: normal -> hi w.p. .9,\n",
    "                lo -> normal w.p. .5, lo -> hi w.p. .4\n",
    "            raise blood glucose by 1 w.p. .5\n",
    "        '''\n",
    "        self.state.vaso_state = 1\n",
    "        if self.state.diabetic_idx == 0:\n",
    "            if np.random.uniform(0,1) < 0.7:\n",
    "                if self.state.sysbp_state == 0:\n",
    "                    self.state.sysbp_state = 1\n",
    "                elif self.state.sysbp_state == 1:\n",
    "                    self.state.sysbp_state = 2\n",
    "        else:\n",
    "            if self.state.sysbp_state == 1:\n",
    "                if np.random.uniform(0,1) < 0.9:\n",
    "                    self.state.sysbp_state = 2\n",
    "            elif self.state.sysbp_state == 0:\n",
    "                up_prob = np.random.uniform(0,1)\n",
    "                if up_prob < 0.5:\n",
    "                    self.state.sysbp_state = 1\n",
    "                elif up_prob < 0.9:\n",
    "                    self.state.sysbp_state = 2\n",
    "            if np.random.uniform(0,1) < 0.5:\n",
    "                self.state.glucose_state = min(4, self.state.glucose_state + 1)\n",
    "\n",
    "    def transition_vaso_off(self):\n",
    "        '''\n",
    "        vasopressor state off\n",
    "        if vasopressor was on:\n",
    "            for non-diabetics, sys bp: normal -> low, hi -> normal w.p. .1\n",
    "            for diabetics, blood pressure falls by 1 w.p. .05 instead of .1\n",
    "        '''\n",
    "        if self.state.vaso_state == 1:\n",
    "            if self.state.diabetic_idx == 0:\n",
    "                if np.random.uniform(0,1) < 0.1:\n",
    "                    self.state.sysbp_state = max(0, self.state.sysbp_state - 1)\n",
    "            else:\n",
    "                if np.random.uniform(0,1) < 0.05:\n",
    "                    self.state.sysbp_state = max(0, self.state.sysbp_state - 1)\n",
    "            self.state.vaso_state = 0\n",
    "\n",
    "    def transition_fluctuate(self, hr_fluctuate, sysbp_fluctuate, percoxyg_fluctuate, glucose_fluctuate):\n",
    "        \n",
    "        # transition_fluctuate: Captures the random fluctuations in the patient's state variables.\n",
    "        \n",
    "        '''\n",
    "        all (non-treatment) states fluctuate +/- 1 w.p. .1\n",
    "        exception: glucose flucuates +/- 1 w.p. .3 if diabetic\n",
    "        '''\n",
    "        if hr_fluctuate:\n",
    "            hr_prob = np.random.uniform(0,1)\n",
    "            if hr_prob < 0.1:\n",
    "                self.state.hr_state = max(0, self.state.hr_state - 1)\n",
    "            elif hr_prob < 0.2:\n",
    "                self.state.hr_state = min(2, self.state.hr_state + 1)\n",
    "        if sysbp_fluctuate:\n",
    "            sysbp_prob = np.random.uniform(0,1)\n",
    "            if sysbp_prob < 0.1:\n",
    "                self.state.sysbp_state = max(0, self.state.sysbp_state - 1)\n",
    "            elif sysbp_prob < 0.2:\n",
    "                self.state.sysbp_state = min(2, self.state.sysbp_state + 1)\n",
    "        if percoxyg_fluctuate:\n",
    "            percoxyg_prob = np.random.uniform(0,1)\n",
    "            if percoxyg_prob < 0.1:\n",
    "                self.state.percoxyg_state = max(0, self.state.percoxyg_state - 1)\n",
    "            elif percoxyg_prob < 0.2:\n",
    "                self.state.percoxyg_state = min(1, self.state.percoxyg_state + 1)\n",
    "        if glucose_fluctuate:\n",
    "            glucose_prob = np.random.uniform(0,1)\n",
    "            if self.state.diabetic_idx == 0:\n",
    "                if glucose_prob < 0.1:\n",
    "                    self.state.glucose_state = max(0, self.state.glucose_state - 1)\n",
    "                elif glucose_prob < 0.2:\n",
    "                    self.state.glucose_state = min(4, self.state.glucose_state + 1)\n",
    "            else:\n",
    "                if glucose_prob < 0.3:\n",
    "                    self.state.glucose_state = max(0, self.state.glucose_state - 1)\n",
    "                elif glucose_prob < 0.6:\n",
    "                    self.state.glucose_state = min(4, self.state.glucose_state + 1)\n",
    "\n",
    "    def calculateReward(self):\n",
    "        \n",
    "        # calculateReward: Calculates a reward based on the patient's state. The system rewards a healthy state and penalizes an unhealthy state.\n",
    "        num_abnormal = self.state.get_num_abnormal()\n",
    "        if num_abnormal >= 3:\n",
    "            return -1\n",
    "        elif num_abnormal == 0 and not self.state.on_treatment():\n",
    "            return 1\n",
    "        return 0\n",
    "\n",
    "    def transition(self, action):\n",
    "        self.state = self.state.copy_state()\n",
    "\n",
    "        if action.antibiotic == 1:\n",
    "            self.transition_antibiotics_on()\n",
    "            hr_fluctuate = False\n",
    "            sysbp_fluctuate = False\n",
    "        elif self.state.antibiotic_state == 1:\n",
    "            self.transition_antibiotics_off()\n",
    "            hr_fluctuate = False\n",
    "            sysbp_fluctuate = False\n",
    "        else:\n",
    "            hr_fluctuate = True\n",
    "            sysbp_fluctuate = True\n",
    "\n",
    "        if action.ventilation == 1:\n",
    "            self.transition_vent_on()\n",
    "            percoxyg_fluctuate = False\n",
    "        elif self.state.vent_state == 1:\n",
    "            self.transition_vent_off()\n",
    "            percoxyg_fluctuate = False\n",
    "        else:\n",
    "            percoxyg_fluctuate = True\n",
    "\n",
    "        glucose_fluctuate = True\n",
    "\n",
    "        if action.vasopressors == 1:\n",
    "            self.transition_vaso_on()\n",
    "            sysbp_fluctuate = False\n",
    "            glucose_fluctuate = False\n",
    "        elif self.state.vaso_state == 1:\n",
    "            self.transition_vaso_off()\n",
    "            sysbp_fluctuate = False\n",
    "\n",
    "        self.transition_fluctuate(hr_fluctuate, sysbp_fluctuate, percoxyg_fluctuate, \\\n",
    "            glucose_fluctuate)\n",
    "\n",
    "        return self.calculateReward()\n",
    "\n",
    "    def select_actions(self):\n",
    "        assert self.policy_array is not None\n",
    "        probs = self.policy_array[\n",
    "                    self.state.get_state_idx(self.policy_idx_type)\n",
    "                ]\n",
    "        aev_idx = np.random.choice(np.arange(Action.NUM_ACTIONS_TOTAL), p=probs)\n",
    "        return Action(action_idx = aev_idx)\n",
    "\n",
    "    def action_idx(self, state_idx):\n",
    "        assert self.policy_array is not None\n",
    "        #print(f'state is {state_idx}')\n",
    "        probs = self.policy_array[state_idx]\n",
    "        #print(f'probs is {probs}')\n",
    "        aev_idx = np.random.choice(np.arange(Action.NUM_ACTIONS_TOTAL), p=probs)\n",
    "        #print(f'aev_idx is {aev_idx}')\n",
    "        return aev_idx\n",
    "        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# First, learn the MDP parameters where the diabetes state is hidden.\n",
    "import itertools as it\n",
    "import numpy as np\n",
    "from tqdm import tqdm_notebook as tqdm\n",
    "\n",
    "# Samples per component/state/action pair\n",
    "np.random.seed(1)\n",
    "n_iter = 10000\n",
    "n_actions = Action.NUM_ACTIONS_TOTAL\n",
    "n_states = State.NUM_OBS_STATES\n",
    "n_components = 2\n",
    "\n",
    "states = range(n_states)\n",
    "actions = range(n_actions)\n",
    "print(states)\n",
    "components = [0, 1]\n",
    "\n",
    "## TRANSITION MATRIX\n",
    "tx_mat = np.zeros((n_components, n_actions, n_states, n_states))\n",
    "\n",
    "# Not used, but a required argument\n",
    "dummy_pol = np.ones((n_states, n_actions)) / n_actions\n",
    "\n",
    "# Removed diabetes component - 20% of population have diabetes.\n",
    "# WARNING: This takes about 2 hours to run on my laptop\n",
    "tx_mat = np.zeros((n_actions, n_states, n_states))\n",
    "for (s0, a, _) in tqdm(it.product(states, actions, range(n_iter)), total=n_actions*n_states*n_iter):\n",
    "    this_mdp = MDP(init_state_idx=s0, policy_array=dummy_pol, p_diabetes=0.2)\n",
    "    r = this_mdp.transition(Action(action_idx=a))\n",
    "    s1 = this_mdp.state.get_state_idx()\n",
    "    tx_mat[a, s0, s1] += 1\n",
    "\n",
    "est_tx_mat = tx_mat / n_iter\n",
    "# Extra normalization\n",
    "est_tx_mat /= est_tx_mat.sum(axis=-1, keepdims=True)\n",
    "\n",
    "## REWARD MATRIX\n",
    "np.random.seed(1)\n",
    "\n",
    "# Calculate the reward matrix explicitly, only based on state\n",
    "est_r_mat = np.zeros_like(est_tx_mat)\n",
    "for s1 in states:\n",
    "    this_mdp = MDP(init_state_idx=s1, policy_array=dummy_pol, p_diabetes=1)\n",
    "    r = this_mdp.calculateReward()\n",
    "    est_r_mat[:, :, s1] = r\n",
    "\n",
    "## PRIOR ON INITIAL STATE\n",
    "np.random.seed(1)\n",
    "prior_initial_state = np.zeros((n_components, n_states))\n",
    "\n",
    "for c in components:\n",
    "    this_mdp = MDP(p_diabetes=c)\n",
    "    for _ in range(n_iter):\n",
    "        s = this_mdp.get_new_state().get_state_idx()\n",
    "        prior_initial_state[c, s] += 1\n",
    "    \n",
    "prior_initial_state = prior_initial_state / n_iter\n",
    "# Extra normalization\n",
    "prior_initial_state /= prior_initial_state.sum(axis=-1, keepdims=True)\n",
    "\n",
    "prior_mx_components = np.array([0.8, 0.2]) # Population has 80% non-diabetic and 20% diabetic patients\n",
    "\n",
    "mat_dict = {\"tx_mat\": est_tx_mat,\n",
    "            \"r_mat\": est_r_mat,\n",
    "            \"p_initial_state\": prior_initial_state,\n",
    "            \"p_mixture\": prior_mx_components}\n",
    "\n",
    "with open('../data/diab_txr_mats-replication-unobserved-diabetes.pkl', 'wb') as f:\n",
    "    pickle.dump(mat_dict, f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "metadata": {}
   },
   "outputs": [],
   "source": [
    "NSIMSAMPS = 1  # Samples to draw from the simulator (they did a 1000)\n",
    "NSTEPS = 10  # Max length of each trajectory\n",
    "NCFSAMPS = 5  # Counterfactual Samples per observed sample (do i need this? probably not, i just need the model)\n",
    "DISCOUNT_Pol = 0.99 # Used for computing optimal policies\n",
    "DISCOUNT = 1 # Used for computing actual reward\n",
    "PHYS_EPSILON = 0.05 # Used for sampling using physician pol as eps greedy\n",
    "PROB_DIAB = 0.2\n",
    "n_actions = Action.NUM_ACTIONS_TOTAL\n",
    "\n",
    "with open(\"../data/diab_txr_mats-replication-unobserved-diabetes.pkl\", \"rb\") as f:\n",
    "    mdict = pickle.load(f)\n",
    "\n",
    "tx_mat = mdict[\"tx_mat\"]\n",
    "r_mat = mdict[\"r_mat\"]\n",
    "p_mixture = np.array([1 - PROB_DIAB, PROB_DIAB])\n",
    "\n",
    "tx_mat_full = tx_mat\n",
    "r_mat_full = r_mat\n",
    "\n",
    "print(tx_mat_full.shape)\n",
    "print(r_mat_full.shape)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "metadata": {}
   },
   "outputs": [],
   "source": [
    "all_absorbing_states = []\n",
    "all_absorbing_rewards = []\n",
    "non_absorbing_states = []\n",
    "all_rewards = []\n",
    "\n",
    "for s in range(720):\n",
    "    get_states = State(state_idx=s, idx_type = 'obs', diabetic_idx=1)\n",
    "    abs = get_states.check_absorbing_state()\n",
    "    if abs == True: \n",
    "        all_absorbing_states.append(s)\n",
    "        rew = get_states.state_rewards()\n",
    "        all_absorbing_rewards.append(rew)\n",
    "        \n",
    "    if abs == False:\n",
    "        non_absorbing_states.append(s)\n",
    "\n",
    "    rew = get_states.state_rewards()\n",
    "    all_rewards.append(rew)\n",
    "\n",
    "rewards_pi = np.zeros((720, 8)) \n",
    "\n",
    "for s in range(720):\n",
    "    for a in range(8):\n",
    "        # Take the action, new state is property of the MDP\n",
    "        s_p = (np.where(tx_mat_full[a, s, :] == (np.max(tx_mat_full[a, s, :]))))[0][0]\n",
    "        rewards_pi[s, a] = r_mat_full[a, s, s_p]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "metadata": {}
   },
   "outputs": [],
   "source": [
    "fullMDP = cf.MatrixMDP(tx_mat_full, r_mat_full)\n",
    "fullPol = fullMDP.policyIteration(discount=DISCOUNT_Pol, eval_type=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "metadata": {}
   },
   "outputs": [],
   "source": [
    "class DataGen(object):\n",
    "    def __init__(self):\n",
    "        \n",
    "        mdp = MDP(init_state_idx=None, policy_array=fullPol, policy_idx_type='obs', p_diabetes=PROB_DIAB)\n",
    "        self.mdp = mdp\n",
    "\n",
    "    def mdp_sample(self, policy=fullPol, n_obs=1, n_steps=NSTEPS): # trajectory sample for a given policy and MDP\n",
    "        n_state = 4 # Get the number of states [current state, next state, action, reward]\n",
    "\n",
    "        # Initialize the trajectories\n",
    "        trajectories = np.zeros((n_obs, n_steps, n_state))\n",
    "        # observations, trajectories, states\n",
    "\n",
    "        # Loop over the observations\n",
    "        for obs_idx in range(n_obs): # to generate the desired amount of trajectories\n",
    "            current_state = np.random.choice(non_absorbing_states) # initial state can not be absorbant \n",
    "                        \n",
    "            # Go over time steps\n",
    "            for time_idx in range(n_steps): # for each time step in the currently generating (\"observed\") trajectory \n",
    "                \n",
    "                # Get the action\n",
    "                action = self.mdp.action_idx(state_idx=current_state) # pick a action for that initial state according to the given policy\n",
    "                \n",
    "                next_state = np.random.choice(\n",
    "                    720, size=1, p=tx_mat_full[action, current_state, :])[0]  # tx_mat_full is of the shape (actions, state, state)\n",
    "                \n",
    "                reward = r_mat_full[action, current_state, next_state] # matters what state you **actually** get to\n",
    "\n",
    "                trajectories[obs_idx, time_idx, :] = np.array([current_state, next_state, action, reward])\n",
    "            \n",
    "                current_state = next_state\n",
    "\n",
    "        return trajectories "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "metadata": {}
   },
   "outputs": [],
   "source": [
    "def generate_rand_MDP_samples(init_state, n_obs=1, n_steps=NSTEPS, diabetic=0): # trajectory sample for a given policy and MDP\n",
    "    n_state = 4 # Get the number of states [current state, next state, action, reward]\n",
    "\n",
    "    # Initialize the trajectories\n",
    "    trajectories = np.zeros((n_obs, n_steps, n_state))\n",
    "    glucose_trajectories = np.zeros((n_obs, n_steps, 2))\n",
    "    bp_trajectories = np.zeros((n_obs, n_steps, 2))\n",
    "\n",
    "    for obs_idx in range(n_obs):\n",
    "        this_mdp = MDP(init_state_idx=init_state, p_diabetes=diabetic)\n",
    "\n",
    "        for time_idx in range(n_steps):\n",
    "            s = this_mdp.state.get_state_idx()\n",
    "            g = this_mdp.state.glucose_state\n",
    "            bp = this_mdp.state.sysbp_state\n",
    "            a = np.random.choice(n_actions)\n",
    "            r = this_mdp.transition(Action(action_idx=a))\n",
    "            s_prime = this_mdp.state.get_state_idx()\n",
    "            g_prime = this_mdp.state.glucose_state\n",
    "            bp_prime = this_mdp.state.sysbp_state\n",
    "\n",
    "            trajectories[obs_idx, time_idx, :] = np.array([s, s_prime, a, r])\n",
    "            glucose_trajectories[obs_idx, time_idx] = np.array([g, g_prime])\n",
    "            bp_trajectories[obs_idx, time_idx] = np.array([bp, bp_prime])\n",
    "\n",
    "    return trajectories, glucose_trajectories, bp_trajectories"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "metadata": {}
   },
   "outputs": [],
   "source": [
    "import pydot\n",
    "\n",
    "def generate_trajectories(n_obs=10):\n",
    "    # Get an initial state\n",
    "    this_mdp = MDP(p_diabetes=0)\n",
    "    initial_state_idx = this_mdp.state.get_state_idx()\n",
    "\n",
    "    whole_population_MDP_samp, whole_population_glucose, whole_population_bp = generate_rand_MDP_samples(initial_state_idx, n_obs=n_obs, diabetic=0.2)\n",
    "    whole_population_MDP_samp = whole_population_MDP_samp.astype(int)\n",
    "    whole_population_glucose = whole_population_glucose.astype(int)\n",
    "    whole_population_bp = whole_population_bp.astype(int)\n",
    "    \n",
    "    diabetes_MDP_samp, diabetes_glucose, diabetes_bp = generate_rand_MDP_samples(initial_state_idx, n_obs=n_obs, diabetic=1)\n",
    "    diabetes_MDP_samp = diabetes_MDP_samp.astype(int)\n",
    "    diabetes_glucose = diabetes_glucose.astype(int)\n",
    "    diabetes_bp = diabetes_bp.astype(int)\n",
    "\n",
    "    return whole_population_MDP_samp, whole_population_glucose, whole_population_bp, diabetes_MDP_samp, diabetes_glucose, diabetes_bp\n",
    "\n",
    "def calculate_usage_frequencies(whole_population_trajectories, diabetes_trajectories):\n",
    "    node_usage = {}\n",
    "    arc_usage = {}\n",
    "\n",
    "    for trajectory in whole_population_trajectories:\n",
    "        for t in range(len(trajectory)):\n",
    "            # Update node usage\n",
    "            transition = trajectory[t]\n",
    "            s, s_prime, a, _ = transition\n",
    "\n",
    "            if (t, s) in node_usage:\n",
    "                node_usage[(t, s)] = (node_usage[(t, s)][0]+1, node_usage[(t, s)][1])\n",
    "            else:\n",
    "                node_usage[(t, s)] = (1, 0)\n",
    "\n",
    "            # Update arc usage\n",
    "            arc = ((t, s), (t+1, s_prime), a)\n",
    "\n",
    "            if arc in arc_usage:\n",
    "                arc_usage[arc] = (arc_usage[arc][0]+1, arc_usage[arc][1])\n",
    "            else:\n",
    "                arc_usage[arc] = (1, 0)\n",
    "            \n",
    "            if t == len(trajectory) - 1:\n",
    "                if (t+1, s_prime) in node_usage:\n",
    "                    node_usage[(t+1, s_prime)] = (node_usage[(t+1, s_prime)][0] + 1, node_usage[(t+1, s_prime)][1])\n",
    "                else:\n",
    "                    node_usage[(t+1, s_prime)] = (1, 0)\n",
    "\n",
    "    \n",
    "    for trajectory in diabetes_trajectories:\n",
    "        for t in range(len(trajectory)):\n",
    "            # Update node usage\n",
    "            transition = trajectory[t]\n",
    "            s, s_prime, a, _ = transition\n",
    "\n",
    "            if (t, s) in node_usage:\n",
    "                node_usage[(t, s)] = (node_usage[(t, s)][0], node_usage[(t, s)][1]+1)\n",
    "            else:\n",
    "                node_usage[(t, s)] = (0, 1)\n",
    "\n",
    "            # Update arc usage\n",
    "            arc = ((t, s), (t+1, s_prime), a)\n",
    "\n",
    "            if arc in arc_usage:\n",
    "                arc_usage[arc] = (arc_usage[arc][0], arc_usage[arc][1]+1)\n",
    "            else:\n",
    "                arc_usage[arc] = (0, 1)\n",
    "            \n",
    "            if t == len(trajectory) - 1:\n",
    "                if (t+1, s_prime) in node_usage:\n",
    "                    node_usage[(t+1, s_prime)] = (node_usage[(t+1, s_prime)][0], node_usage[(t+1, s_prime)][1]+1)\n",
    "                else:\n",
    "                    node_usage[(t+1, s_prime)] = (0, 1)\n",
    "\n",
    "    return node_usage, arc_usage\n",
    "\n",
    "n_obs = 1000\n",
    "whole_population_trajectories, whole_population_glucose, whole_population_bp, diabetes_trajectories, diabetes_glucose, diabetes_bp = generate_trajectories(n_obs=n_obs)\n",
    "\n",
    "def generate_heat_map(whole_population_trajectories, diabetes_trajectories, observation, example_influenced_trajectory, example_optimal_trajectory):\n",
    "    print(observation)\n",
    "    print(example_influenced_trajectory)\n",
    "    print(example_optimal_trajectory)\n",
    "    G = nx.MultiDiGraph()\n",
    "    \n",
    "    node_usage, arc_usage = calculate_usage_frequencies(whole_population_trajectories, diabetes_trajectories)\n",
    "    \n",
    "    # Add nodes with colors based on usage frequency\n",
    "    for node, usage in node_usage.items():\n",
    "        G.add_node(node, usage=usage)\n",
    "\n",
    "    # Add edges with colors based on usage frequency\n",
    "    for arc, usage in arc_usage.items():\n",
    "        G.add_edge(arc[0], arc[1], usage=usage, label=arc[2])\n",
    "\n",
    "    # Create a pydot.Dot object\n",
    "    dot_graph = pydot.Dot(graph_type='digraph', ranksep=0.0, nodesep=0.0, overlap=\"compress\")\n",
    "\n",
    "    # Add nodes to pydot graph\n",
    "    for node, attr in G.nodes(data=True):\n",
    "        node_name = str(node)\n",
    "        usage = attr.get('usage', (0, 0))\n",
    "        red_green_intensity = max(0, 255 - usage[0] * 10)\n",
    "        green_blue_intensity = max(0, 255 - usage[1] * 10)\n",
    "        color_top = \"#{:02x}{:02x}{:02x}\".format(int((255 + red_green_intensity) / (2)), int((green_blue_intensity + red_green_intensity) / (2)), int((green_blue_intensity + 255) / (2)))\n",
    "        dot_node = pydot.Node(node_name, label='', width=0.5, height=0.5, style=\"filled\", color=\"none\", fillcolor=color_top, fontcolor='black')\n",
    "        dot_graph.add_node(dot_node)\n",
    "\n",
    "    # Add edges to pydot graph\n",
    "    for u, v, attr in G.edges(data=True):\n",
    "        dot_edge = pydot.Edge(str(u), str(v), color='black', style=\"invisible\")\n",
    "        dot_edge.set_arrowhead(\"none\")\n",
    "        dot_graph.add_edge(dot_edge)\n",
    "\n",
    "    for t, transition in enumerate(example_influenced_trajectory):\n",
    "        for node in dot_graph.get_nodes():\n",
    "            if node.get_name() == f\"\\\"{(t, transition[0])}\\\"\":\n",
    "                node.set_color(\"red\")\n",
    "\n",
    "            if t == len(example_influenced_trajectory) - 1:\n",
    "                if node.get_name() == f\"\\\"{(t+1, transition[1])}\\\"\":\n",
    "                    node.set_color(\"red\")\n",
    "\n",
    "        edge_found = False\n",
    "\n",
    "        for edge in dot_graph.get_edges():\n",
    "            if edge_found:\n",
    "                break\n",
    "\n",
    "            if edge.get_source() == f\"\\\"{(t, transition[0])}\\\"\" and edge.get_destination() == f\"\\\"{(t+1, transition[1])}\\\"\":\n",
    "                edge_found = True\n",
    "                edge.set_style(\"bold\")\n",
    "                edge.set_arrowhead(\"normal\")\n",
    "                edge.set_color(\"red\")\n",
    "        \n",
    "        if not edge_found:\n",
    "            dot_edge = pydot.Edge(str((t, transition[0])), str((t+1, transition[1])), color='red', style=\"bold\")\n",
    "            dot_edge.set_color(\"red\")\n",
    "            dot_edge.set_arrowhead(\"normal\")\n",
    "            dot_graph.add_edge(dot_edge)\n",
    "\n",
    "    for t, transition in enumerate(example_optimal_trajectory):\n",
    "        for node in dot_graph.get_nodes():\n",
    "            if node.get_name() == f\"\\\"{(t, transition[0])}\\\"\":\n",
    "                node.set_color(\"blue\")\n",
    "\n",
    "            if t == len(example_optimal_trajectory) - 1:\n",
    "                if node.get_name() == f\"\\\"{(t+1, transition[1])}\\\"\":\n",
    "                    node.set_color(\"blue\")\n",
    "\n",
    "        edge_found = False\n",
    "        for edge in dot_graph.get_edges():\n",
    "            if edge_found:\n",
    "                break\n",
    "\n",
    "            if edge.get_source() == f\"\\\"{(t, transition[0])}\\\"\" and edge.get_destination() == f\"\\\"{(t+1, transition[1])}\\\"\":\n",
    "                edge.set_style(\"bold\")\n",
    "                edge.set_arrowhead(\"normal\")\n",
    "                edge.set_color(\"blue\")\n",
    "                edge_found = True\n",
    "\n",
    "        if not edge_found:\n",
    "            dot_edge = pydot.Edge(str((t, transition[0])), str((t+1, transition[1])), color='blue', style=\"bold\")\n",
    "            dot_edge.set_color(\"blue\")\n",
    "            dot_edge.set_arrowhead(\"normal\")\n",
    "            dot_graph.add_edge(dot_edge)\n",
    "\n",
    "\n",
    "    for t, transition in enumerate(observation):\n",
    "        for node in dot_graph.get_nodes():\n",
    "            if node.get_name() == f\"\\\"{(t, transition[0])}\\\"\":\n",
    "                node.set_color(\"black\")\n",
    "\n",
    "            if t == len(observation) - 1:\n",
    "                if node.get_name() == f\"\\\"{(t+1, transition[1])}\\\"\":\n",
    "                    node.set_color(\"black\")\n",
    "\n",
    "        edge_found = False\n",
    "\n",
    "        for edge in dot_graph.get_edges():\n",
    "            if edge_found:\n",
    "                break\n",
    "\n",
    "            if edge.get_source() == f\"\\\"{(t, transition[0])}\\\"\" and edge.get_destination() == f\"\\\"{(t+1, transition[1])}\\\"\":\n",
    "                edge.set_style(\"bold\")\n",
    "                edge.set_arrowhead(\"normal\")\n",
    "                edge.set_color(\"black\")\n",
    "                edge_found = True\n",
    "\n",
    "        if not edge_found:\n",
    "            dot_edge = pydot.Edge(str((t, transition[0])), str((t+1, transition[1])), color='black', style=\"bold\")\n",
    "            dot_edge.set_arrowhead(\"normal\")\n",
    "            dot_graph.add_edge(dot_edge)\n",
    "\n",
    "    # Save the graph to a file or display it\n",
    "    dot_graph.write_svg('graph.svg')  # Save graph to a PNG file\n",
    "    dot_graph.write_png('graph.png')  # Save graph to a PNG file\n",
    "    dot_graph.write_dot('graph.dot')  # Save graph to a DOT file\n",
    "    dot_graph.write_plain('graph.plain')  # Save graph to a plain text file\n",
    "    dot_graph.write_raw('graph.raw')  # Save graph to a raw text file\n",
    "\n",
    "    # Extract node positions and colors\n",
    "    node_info = {}\n",
    "\n",
    "    graphs = pydot.graph_from_dot_file(\"graph.dot\")\n",
    "    parsed_graph = graphs[0]  # pydot can return a list of graphs; we take the first one\n",
    "\n",
    "    for node in parsed_graph.get_nodes():\n",
    "        node_name = node.get_name()\n",
    "        not_nodes = {\"graph\", \"node\", \"\\\"\\\\n\\\"\"}\n",
    "\n",
    "        if node_name not in not_nodes:\n",
    "            node_pos = node.get_pos()\n",
    "            fillcolor = node.get_fillcolor()\n",
    "            node_info[node_name] = [node_pos, fillcolor]\n",
    "\n",
    "    return node_info"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "metadata": {}
   },
   "outputs": [],
   "source": [
    "dgen = DataGen()\n",
    "MDP_samp = dgen.mdp_sample().astype(int)\n",
    "observation = diabetes_trajectories[0]\n",
    "\n",
    "MDP_samp = np.array([observation]).astype(int)\n",
    "\n",
    "print(MDP_samp)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "metadata": {}
   },
   "outputs": [],
   "source": [
    "def truncated_gumbel(logit, truncation):\n",
    "    assert not np.isneginf(logit)\n",
    "\n",
    "    gumbel = np.random.gumbel(size=(truncation.shape[0])) + logit\n",
    "    trunc_g = -np.log(np.exp(-gumbel) + np.exp(-truncation))\n",
    "    return trunc_g\n",
    "\n",
    "def topdown_tracking_influenced_states(obs_logits, obs_state, nsamp=1):\n",
    "    poss_next_states = obs_logits.shape[0]\n",
    "    gumbels = np.zeros((nsamp, poss_next_states))\n",
    "    influenced_states = np.full(shape=poss_next_states, fill_value=False)\n",
    "\n",
    "    # Sample top gumbels\n",
    "    topgumbel = np.random.gumbel(size=(nsamp))\n",
    "\n",
    "    for next_state in range(poss_next_states):\n",
    "        # This is the observed outcome\n",
    "        if (next_state == obs_state) and not(np.isneginf(obs_logits[next_state])):\n",
    "            gumbels[:, obs_state] = topgumbel - obs_logits[next_state]\n",
    "            influenced_states[obs_state] = True\n",
    "        # These were the other feasible options (p > 0)\n",
    "        elif not(np.isneginf(obs_logits[next_state])):\n",
    "            gumbels[:, next_state] = truncated_gumbel(obs_logits[next_state], topgumbel) - obs_logits[next_state]\n",
    "            influenced_states[next_state] = True\n",
    "        # These have zero probability to start with, so are unconstrained\n",
    "        else:\n",
    "            gumbels[:, next_state] = np.random.gumbel(size=nsamp)\n",
    "\n",
    "    return gumbels, influenced_states # list of gumbel noise values derived from the observed trajectory"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "metadata": {}
   },
   "outputs": [],
   "source": [
    "from multiprocessing import Process, Manager\n",
    "\n",
    "class CounterfactualSampler(object):\n",
    "\n",
    "    def __init__(self, mdp):\n",
    "        self.mdp = mdp\n",
    "        self.sprtb_theta = 0.9\n",
    "        self.sprtb_delta = 0.05\n",
    "        self.sprtb_r = 0.9\n",
    "\n",
    "    def cf_posterior_tracking_influenced_states(self, obs_prob, intrv_prob, state, n_mc):\n",
    "        obs_logits = np.log(obs_prob);\n",
    "        next_state = state\n",
    "        intrv_logits = np.log(intrv_prob);\n",
    "\n",
    "        # Sample from the gumbel posterior\n",
    "        gumbels, influenced_states = topdown_tracking_influenced_states(obs_logits, next_state, n_mc);\n",
    "\n",
    "        # Get the posterior\n",
    "        posterior = intrv_logits + gumbels\n",
    "        intrv_posterior = np.argmax(posterior, axis=1)\n",
    "\n",
    "        # create the counterfactual transition probabilities\n",
    "        posterior_prob = np.zeros(np.size(intrv_prob, 0))\n",
    "        for i in range(np.size(intrv_prob, 0)):\n",
    "            posterior_prob[i] = np.sum(intrv_posterior == i) / n_mc\n",
    "\n",
    "        return posterior_prob, intrv_posterior, influenced_states\n",
    "    \n",
    "##########################################################################################################################################################\n",
    "\n",
    "    def cf_sample_prob_tracking_influenced_transitions(self, trajectories, a, time_idx, P_cf_save, influenced_transitions_save, n_cf_samps=1): \n",
    "        n_obs = trajectories.shape[0] \n",
    "        n_mc = 1000\n",
    "        \n",
    "        for obs_idx in range(n_obs): # for each given \"observed\" trajectory\n",
    "            P_cf = {}\n",
    "            influenced_transitions = {}\n",
    "\n",
    "            for _ in range(n_cf_samps): # get the desired number of CF trajectories for each given \"observed\" trajectory \n",
    "                    obs_state = trajectories[obs_idx, time_idx, :]\n",
    "                    obs_current_state = int(obs_state[0]) # same as s_real\n",
    "                    obs_next_state = int(obs_state[1]) # same as s_p_real\n",
    "                    obs_action = int(obs_state[2]) # same as a_real\n",
    "\n",
    "                    P_cf[a, time_idx] = np.zeros((int(720),int(720)))\n",
    "                    influenced_transitions[a, time_idx] = np.full(shape=(int(720),int(720)), fill_value=False)\n",
    "\n",
    "                    # A matrix is initialized to zeros to store transition counts.\n",
    "                    for s in range(720):\n",
    "                    \n",
    "                        obs_intrv =  tx_mat_full[obs_action, obs_current_state, :]\n",
    "                        # Get the transition probabilities for the counterfactual state and action:\n",
    "                        cf_intrv =  tx_mat_full[a, s, :]\n",
    "                        \n",
    "                        cf_prob, s_p, influenced_states = self.cf_posterior_tracking_influenced_states(obs_intrv, cf_intrv, obs_next_state, n_mc)\n",
    "                \n",
    "                        for s_p in range(len(cf_prob)):\n",
    "                            P_cf[a,time_idx][s,s_p] = cf_prob[s_p]\n",
    "                            influenced_transitions[a,time_idx][s, s_p] = influenced_states[s_p]\n",
    "\n",
    "        P_cf_save[(a,time_idx)] = P_cf\n",
    "        influenced_transitions_save[(a, time_idx)] = influenced_transitions\n",
    "\n",
    "    def run_sample_tracking_influenced_transitions(self, inp, trajectories, P_cf, influenced_transitions):\n",
    "        P_cf_save = {}\n",
    "        influenced_transitions_save = {}\n",
    "\n",
    "        for i in inp:\n",
    "            self.cf_sample_prob_tracking_influenced_transitions(trajectories, i[0], i[1], P_cf_save, influenced_transitions_save)\n",
    "\n",
    "        for i in inp:\n",
    "            P_cf.update(P_cf_save[i])\n",
    "            influenced_transitions.update(influenced_transitions_save[i])\n",
    "\n",
    "    def run_parallel_sampling_tracking_influenced_transitions(self, trajectories):\n",
    "        n_steps = trajectories.shape[1]\n",
    "        n_actions = 8\n",
    "        \n",
    "        inp = [(a, time_idx) for time_idx in range(n_steps) for a in range(n_actions)]\n",
    "\n",
    "        # Run with n threads\n",
    "        def split(a, n):\n",
    "            k, m = divmod(len(a), n)\n",
    "            return (a[i*k+min(i, m):(i+1)*k+min(i+1, m)] for i in range(n))\n",
    "        \n",
    "        split_work = split(inp, 32)\n",
    "        processes = []\n",
    "\n",
    "        with Manager() as manager:\n",
    "            P_cf = manager.dict()\n",
    "            influenced_transitions = manager.dict()\n",
    "            \n",
    "            for chunk in split_work:\n",
    "                process = Process(target=self.run_sample_tracking_influenced_transitions, args=(chunk, trajectories, P_cf, influenced_transitions))\n",
    "                processes.append(process)\n",
    "                process.start()\n",
    "\n",
    "            for process in processes:\n",
    "                process.join()\n",
    "\n",
    "            return P_cf.copy(), influenced_transitions.copy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "metadata": {}
   },
   "outputs": [],
   "source": [
    "NUM_ITERATIONS = 10000\n",
    "from collections import deque, defaultdict\n",
    "\n",
    "print(MDP_samp)\n",
    "sampler = CounterfactualSampler(dgen)\n",
    "P_cf, influenced_transitions = sampler.run_parallel_sampling_tracking_influenced_transitions(MDP_samp)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class InfluenceMDPPruner:\n",
    "    def __init__(self, mdp, look_ahead_k=11):\n",
    "        self.mdp = mdp\n",
    "        self.optimal_policy = fullPol\n",
    "        self.rewards_pi = rewards_pi\n",
    "        self.sampler = sampler\n",
    "        self.mdp_sample = MDP_samp\n",
    "        self.initial_state = self.mdp_sample[0, 0, 0]\n",
    "        self.states = range(720)\n",
    "        self.actions = range(8)\n",
    "        self.look_ahead_k = look_ahead_k\n",
    "        self.T = len(self.mdp_sample[0])\n",
    "        \n",
    "        # Generate the counterfacutal transition probabilities, keeping track\n",
    "        # of which transitionals have been influenced by the observed path.\n",
    "        self.P_cf = P_cf\n",
    "        self.influenced_transitions = influenced_transitions\n",
    "\n",
    "    def build_graph(self, transition_probs, T, all_states, all_actions):\n",
    "        G = nx.MultiDiGraph()\n",
    "        pos = {}\n",
    "        \n",
    "        for t in range(T):\n",
    "            for s in all_states:\n",
    "                G.add_node((t, s))\n",
    "                pos[(t, s)] = (s, -t)\n",
    "\n",
    "                for a in all_actions:\n",
    "                    for s_prime in all_states:\n",
    "                        if transition_probs[a, s, s_prime] > 0:\n",
    "                            G.add_node(((t+1), s_prime))\n",
    "                            pos[(t+1, s_prime)] = (s_prime, -(t+1)) \n",
    "                            G.add_edge((t, s), ((t+1), s_prime), key=a, label=f\"({a}, {transition_probs[a, s, s_prime]})\")\n",
    "\n",
    "        return G\n",
    "\n",
    "    def build_cf_graph(self, transition_probs, T, all_states, all_actions, k):\n",
    "        G = nx.MultiDiGraph()\n",
    "        pos = {}\n",
    "        \n",
    "        for t in range(T):\n",
    "            for s in all_states:\n",
    "                G.add_node((t, s))\n",
    "                pos[(t, s)] = (s, -t)\n",
    "\n",
    "                for a in all_actions:\n",
    "                    for s_prime in all_states:\n",
    "                        if transition_probs[a, t][s, s_prime] > 0:\n",
    "                            G.add_node(((t+1), s_prime))\n",
    "                            pos[(t+1, s_prime)] = (s_prime, -(t+1)) \n",
    "                            G.add_edge((t, s), ((t+1), s_prime), key=a, label=f\"({a}, {transition_probs[a, t][s, s_prime]})\")\n",
    "\n",
    "                # If node has no outgoing edges, remove it from the graph\n",
    "                if G.has_node((t, s)) and G.out_degree((t, s)) == 0 and t < T-1:\n",
    "                    G.remove_node((t, s))\n",
    "\n",
    "        # Remove unreachable nodes at t>0 with in-degree = 0\n",
    "        unreachable_nodes = {n for n in G if G.in_degree(n) == 0 and n[0]>0}\n",
    "\n",
    "        while len(unreachable_nodes) > 0:\n",
    "            G.remove_nodes_from(unreachable_nodes)\n",
    "            unreachable_nodes = {n for n in G if G.in_degree(n) == 0 and n[0]>0}\n",
    "\n",
    "        return G\n",
    "    \n",
    "    def get_counterfactual_transition_probabilities(self, P_cf, original_G, new_mdp_G, all_states, all_actions, A_real, S_real, T, k):\n",
    "        print(f\"Calculating counterfactual transition probabilities for k={k}\")\n",
    "        # Update the transition probabilities P_cf with the pruned mdp new_mdp_G.\n",
    "        # Remove actions entirely to ensure that the probabilities for each action in\n",
    "        # each state add up to 1. Keep track of which actions are valid choices in\n",
    "        # which states.        \n",
    "        valid_action = np.full((T, len(all_states), len(all_actions)), False)\n",
    "\n",
    "        print(S_real)\n",
    "        print(A_real)\n",
    "\n",
    "        T = self.T\n",
    "\n",
    "        for t in range(T-1, -1, -1):\n",
    "            for s in range(720):\n",
    "                for a in range(8):\n",
    "                    for s_prime in range(720):\n",
    "                        if new_mdp_G.has_node((t, s)) and P_cf[a, t][s, s_prime] > 0.0:\n",
    "                            if not new_mdp_G.has_edge((t, s), (t+1, s_prime), key=a):\n",
    "                                imm_descendants = nx.descendants_at_distance(new_mdp_G, (t, s), 1)\n",
    "\n",
    "                                for imm_descendant in imm_descendants:\n",
    "                                    if new_mdp_G.has_edge((t, s), imm_descendant, key=a):\n",
    "                                        new_mdp_G.remove_edge((t, s), imm_descendant, key=a)\n",
    "\n",
    "                # If node has no outgoing edges, remove it from the graph\n",
    "                if new_mdp_G.has_node((t, s)) and new_mdp_G.out_degree((t, s)) == 0 and t < T-1:\n",
    "                    new_mdp_G.remove_node((t, s))\n",
    "\n",
    "            # Remove unreachable nodes at t>0 with in-degree = 0\n",
    "            unreachable_nodes = {n for n in new_mdp_G if new_mdp_G.in_degree(n) == 0 and n[0]>0}\n",
    "\n",
    "            while len(unreachable_nodes) > 0:\n",
    "                new_mdp_G.remove_nodes_from(unreachable_nodes)\n",
    "                unreachable_nodes = {n for n in new_mdp_G if new_mdp_G.in_degree(n) == 0 and n[0]>0}\n",
    "\n",
    "        for t in range(T-1, -1, -1):\n",
    "            for s in range(720):\n",
    "                for a in range(8):\n",
    "                    for s_prime in range(720):\n",
    "                        if P_cf[a, t][s, s_prime] > 0.0:\n",
    "                            if not new_mdp_G.has_edge((t, s), (t+1, s_prime), key=a):\n",
    "                               P_cf[a, t][s, :] = 0.0\n",
    "                        else:\n",
    "                            assert(P_cf[a, t][s, s_prime] == 0.0)\n",
    "                    \n",
    "                    if sum(P_cf[a, t][s, :]) == 1.0:\n",
    "                        valid_action[t, s, a] = True\n",
    "\n",
    "                    if a == A_real[t] and s == S_real[t]:\n",
    "                        assert(valid_action[t, s, a])\n",
    "\n",
    "        return P_cf, valid_action\n",
    "\n",
    "    def find_wcc(self, G): \n",
    "        s_0 = self.initial_state\n",
    "        t_0 = 0\n",
    "        target_node = (t_0, s_0)\n",
    "        print(f'the initial state is {s_0}')\n",
    "\n",
    "        # Create a subgraph containing only the weakly connected component of the target node\n",
    "        for i, component in enumerate(nx.weakly_connected_components(G)):\n",
    "            if target_node in component:\n",
    "                G_sub = G.subgraph(component).copy()\n",
    "                break\n",
    "        \n",
    "        return G_sub\n",
    "\n",
    "    def get_influence_graph(self, G, k, influenced_transitions):\n",
    "        print(f\"Generating influence graph for k={k}\")\n",
    "\n",
    "        def reverse_bfs(G, start_nodes, k):\n",
    "            distance = defaultdict(lambda: float('inf'))\n",
    "            nodes_to_visit = deque([(node, 0) for node in start_nodes])\n",
    "            within_k_steps = set()\n",
    "\n",
    "            while nodes_to_visit:\n",
    "                curr_node, curr_dist = nodes_to_visit.popleft()\n",
    "\n",
    "                if curr_dist <= k:\n",
    "                    within_k_steps.add(curr_node)\n",
    "\n",
    "                    if distance[curr_node] > curr_dist:\n",
    "                        distance[curr_node] = curr_dist\n",
    "\n",
    "                        for predecessor in G.predecessors(curr_node):\n",
    "                            nodes_to_visit.append((predecessor, curr_dist+1))\n",
    "\n",
    "            return within_k_steps\n",
    "\n",
    "        directly_influenced_nodes = set()\n",
    "\n",
    "        for s in range(720):\n",
    "            for a in range(8):\n",
    "                for s_prime in range(720):\n",
    "                    for t in range(self.T):\n",
    "                        if influenced_transitions[a,t][s, s_prime]:\n",
    "                            directly_influenced_nodes.add((t+1, s_prime))\n",
    "\n",
    "        reachable_nodes = reverse_bfs(G, directly_influenced_nodes, k)\n",
    "        influence_graph = G.subgraph(reachable_nodes).copy()\n",
    "\n",
    "        # If we are between T-k+1 and T, then we want to add all the paths between these layers, as they are all treated as influenced.\n",
    "        for timestep in range(self.T-k+1, self.T):\n",
    "            for s in range(720):\n",
    "                for a in range(8):\n",
    "                    for s_prime in range(720):\n",
    "                        if not influence_graph.has_edge((timestep, s), (timestep+1, s_prime), key=a) and G.has_edge((timestep, s), (timestep+1, s_prime), key=a):\n",
    "                            influence_graph.add_edge((timestep, s), (timestep+1, s_prime), key=a)\n",
    "\n",
    "        # Remove nodes with in-degree = 0 or out-degree = 0\n",
    "        unreachable_nodes = {n for n in influence_graph if (influence_graph.in_degree(n) == 0 and n[0]>0) or (influence_graph.out_degree(n) == 0 and n[0] < self.T)}\n",
    "\n",
    "        while len(unreachable_nodes) > 0:\n",
    "            influence_graph.remove_nodes_from(unreachable_nodes)\n",
    "            unreachable_nodes = {n for n in influence_graph if (influence_graph.in_degree(n) == 0 and n[0]>0) or (influence_graph.out_degree(n) == 0 and n[0] < self.T)}\n",
    "\n",
    "        return influence_graph\n",
    "\n",
    "    def prune_mdp(self): \n",
    "        # Build graph using the original MDP transition probabilities.\n",
    "        G = self.build_graph(tx_mat_full, self.T, self.states, self.actions)\n",
    "\n",
    "        # Build the influence graph for each look-ahead k\n",
    "        influence_graphs = []\n",
    "\n",
    "        # Generate graphs for the pruned MDP.\n",
    "        for k in range(1, self.look_ahead_k+1):\n",
    "            influence_graph = self.get_influence_graph(copy.deepcopy(G), k, self.influenced_transitions)\n",
    "            influence_graphs.append(influence_graph)\n",
    "\n",
    "        cf_transition_probs = []\n",
    "        valid_actions = []\n",
    "\n",
    "        print(self.mdp_sample)\n",
    "\n",
    "        A_real = self.mdp_sample[0, :, 2]\n",
    "        S_real = self.mdp_sample[0, :, 0]\n",
    "\n",
    "        for look_ahead_k in range(1, self.look_ahead_k+1):\n",
    "            new_P_cf, valid_action = self.get_counterfactual_transition_probabilities(copy.deepcopy(self.P_cf), G, influence_graphs[look_ahead_k-1], self.states, self.actions, A_real, S_real, self.T, look_ahead_k)\n",
    "            cf_transition_probs.append(new_P_cf)\n",
    "            valid_actions.append(valid_action)\n",
    "\n",
    "        # Generate graphs for the pruned counterfactual MDP.\n",
    "        cf_graphs = []\n",
    "\n",
    "        for k in range(1, self.look_ahead_k+1):\n",
    "            G = self.build_cf_graph(cf_transition_probs[k-1], self.T, self.states, self.actions, k)\n",
    "            cf_graphs.append(G)\n",
    "\n",
    "        return cf_transition_probs, valid_actions, cf_graphs\n",
    "\n",
    "    def build_graphs(self, cf_transition_probs):\n",
    "        # Generate graphs for the pruned counterfactual MDP.\n",
    "        cf_graphs = []\n",
    "\n",
    "        for k in range(1, self.look_ahead_k+1):\n",
    "            G = self.build_cf_graph(cf_transition_probs[k-1], self.T, self.states, self.actions, k)\n",
    "            cf_graphs.append(G)\n",
    "\n",
    "        return cf_graphs\n",
    "\n",
    "    def get_optimal_policy(self, max_num_actions_changed, transition_probs, valid_action, new_mdp_G, all_states, all_actions, A_real, T, rewards_pi):\n",
    "        h_fun = np.zeros((720, T+1, max_num_actions_changed+1)) \n",
    "        pi = np.zeros((720, max_num_actions_changed+1, T+1), dtype=int) \n",
    "    \n",
    "        for r in range(1, T+1): \n",
    "            for s in range(720): \n",
    "                h_fun[s, r, 0] = rewards_pi[s][(A_real[T-r])]  \n",
    "                for s_p in range(720): # for every singe next state (s') for each state s\n",
    "                    h_fun[s, r, 0] += transition_probs[A_real[T-r], T-r][s, s_p] * h_fun[s_p, r-1, 0]\n",
    "                pi[s, max_num_actions_changed, T-r] = A_real[T-r]\n",
    "\n",
    "        for c in range(1, max_num_actions_changed+1): \n",
    "            for r in range(1, T+1): \n",
    "                for s in range(720):\n",
    "                    pi[s, max_num_actions_changed-c, T-r] = A_real[T-r] # instead let it be the real action\n",
    "                    best_act = A_real[T-r]\n",
    "                    max_val = -np.inf\n",
    "                    \n",
    "                    for a in range(8):\n",
    "                        if valid_action[T-r, s, a]:\n",
    "                            val = rewards_pi[s][a]\n",
    "                            if a != A_real[T-r]: \n",
    "                                for s_p in range(720):\n",
    "                                    if transition_probs[a, T-r][s, s_p] != 0:\n",
    "                                        val += transition_probs[a, T-r][s, s_p] * h_fun[s_p, r-1, c-1] \n",
    "                            elif a == A_real[T-r]:\n",
    "                                for s_p in range(720):\n",
    "                                    if transition_probs[a, T-r][s, s_p] != 0:\n",
    "                                        val += transition_probs[a, T-r][s, s_p] * h_fun[s_p, r-1, c]\n",
    "                                        \n",
    "                            if val > max_val:\n",
    "                                max_val = val\n",
    "                                best_act = a\n",
    "\n",
    "                    h_fun[s, r, c] = max_val\n",
    "                    pi[s, max_num_actions_changed-c, T-r] = best_act\n",
    "\n",
    "        return pi, h_fun\n",
    "\n",
    "    def generate_policies(self, cf_transition_probs, valid_actions, cf_graphs):\n",
    "        # Generate policies for each of the pruned counterfactual MDPs.\n",
    "        policies = []\n",
    "        h_funs = []\n",
    "        k_vals = range(1, self.look_ahead_k+1)\n",
    "        A_real = self.mdp_sample[0, :, 2]\n",
    "\n",
    "        new_all_rewards = np.zeros((self.T, len(self.states), len(self.actions)))\n",
    "\n",
    "        for t in range(self.T):\n",
    "            for s in self.states:\n",
    "                for a in self.actions:                \n",
    "                    new_all_rewards[t, s, a] = self.rewards_pi[s, a]\n",
    "\n",
    "        for look_ahead_k in k_vals:\n",
    "            print(f\"Estimating policy with k={look_ahead_k}\")\n",
    "            policies_k = []\n",
    "            h_funs_k = []\n",
    "\n",
    "            for max_num_actions_changed in k_vals:\n",
    "                # Get the optimal policy\n",
    "                pi, h_fun = self.get_optimal_policy(\n",
    "                    max_num_actions_changed, \n",
    "                    cf_transition_probs[look_ahead_k-1],\n",
    "                    valid_actions[look_ahead_k-1],\n",
    "                    cf_graphs[look_ahead_k-1],\n",
    "                    self.states,\n",
    "                    self.actions,\n",
    "                    A_real,\n",
    "                    self.T,\n",
    "                    self.rewards_pi\n",
    "                )\n",
    "\n",
    "                policies_k.append(pi)\n",
    "                h_funs_k.append(h_fun)\n",
    "\n",
    "            policies.append(policies_k)\n",
    "            h_funs.append(h_funs_k)\n",
    "\n",
    "        return policies, new_all_rewards, h_funs\n",
    "\n",
    "    def generate_random_trajectory(self, MDP_samp, transition_probs, pi, s_0, A_real, all_states, rewards_pi, T):\n",
    "        n_obs=MDP_samp.shape[0]\n",
    "        n_state=MDP_samp.shape[2]\n",
    "        CF_trajectory = np.zeros((n_obs, T, n_state))\n",
    "        \n",
    "        rng = np.random.default_rng()\n",
    "        s = np.zeros(T+1, dtype=int)\n",
    "        s[0] = s_0   # Initial state the same\n",
    "        l = np.zeros(T+1, dtype=int)\n",
    "        l[0] = 0    # Start with 0 changes\n",
    "        a = np.zeros(T, dtype=int)\n",
    "        \n",
    "        for t in range(T):\n",
    "            # Pick actions according to the given policy\n",
    "            a[t] = pi[s[t], l[t], t]\n",
    "            print(f\"t={t}, s={s[t]}, a={l[t]} probs={np.nonzero(transition_probs[a[t], t][s[t]])} pi = {pi[s[t], l[t], t]}\")\n",
    "\n",
    "            # Sample the next state\n",
    "            s[t+1] = (rng.choice(a=all_states, size=1,  p=transition_probs[a[t], t][s[t]]))[0]\n",
    "\n",
    "            # Adjust the number of changes so far\n",
    "            if a[t] != A_real[t]:\n",
    "                l[t+1] = l[t] + 1\n",
    "            else:\n",
    "                l[t+1] = l[t]\n",
    "            \n",
    "            CF_trajectory[0, t, :] = np.array([s[t], s[t+1], a[t], rewards_pi[t, s[t], a[t]]])\n",
    "                    \n",
    "        return CF_trajectory\n",
    "\n",
    "    def generate_cf_trajectories(self, cf_transition_probs, policies, new_all_rewards):\n",
    "        print(f\"Generating CF trajectories\")\n",
    "        all_obs = []\n",
    "        all_cf = []\n",
    "        k_vals = range(1, self.look_ahead_k+1)\n",
    "        A_real = self.mdp_sample[0, :, 2]\n",
    "        s_0 = self.mdp_sample[0, 0, 0]\n",
    "        cf_trajectories = []\n",
    "\n",
    "        for _ in range(1):\n",
    "            obs = np.zeros(shape=(self.look_ahead_k, self.look_ahead_k))\n",
    "            cf = np.zeros(shape=(self.look_ahead_k, self.look_ahead_k))\n",
    "\n",
    "            for look_ahead_k in k_vals:\n",
    "                for max_num_actions_changed in [11]:\n",
    "                    CF_trajectory = self.generate_random_trajectory(\n",
    "                        self.mdp_sample,\n",
    "                        cf_transition_probs[look_ahead_k-1],\n",
    "                        policies[look_ahead_k-1][max_num_actions_changed-1],\n",
    "                        s_0,\n",
    "                        A_real,\n",
    "                        self.states,\n",
    "                        new_all_rewards,\n",
    "                        self.T\n",
    "                    )\n",
    "                    cf_trajectories.append(CF_trajectory)\n",
    "\n",
    "                    obs[look_ahead_k-1][max_num_actions_changed-1] = self.mdp_sample[0, self.T-1, 3] # Immediate reward for obs path at time T\n",
    "                    cf[look_ahead_k-1][max_num_actions_changed-1] = CF_trajectory[0, self.T-1, 3] # Immediate reward for cf path at time T\n",
    "            \n",
    "            all_obs.append(obs)\n",
    "            all_cf.append(cf)\n",
    "\n",
    "        all_obs = np.array(all_obs)\n",
    "        all_cf = np.array(all_cf)\n",
    "\n",
    "        mean_obs = all_obs.mean(axis=0)\n",
    "        mean_cf = all_cf.mean(axis=0)\n",
    "\n",
    "        return mean_obs, mean_cf, k_vals, cf_trajectories"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Prune MDP"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mdp = MDP(init_state_idx=None, p_diabetes=0.2)\n",
    "influence_pruner = InfluenceMDPPruner(mdp, look_ahead_k=11)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "cf_transition_probs, valid_actions, cf_graphs = influence_pruner.prune_mdp()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Generating Policies"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "policies, new_all_rewards, h_funs = influence_pruner.generate_policies(cf_transition_probs, valid_actions, cf_graphs)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Print Value Function of S_0"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Generating CF Trajectories"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mean_obs, mean_cf, k_vals, CF_trajectories = influence_pruner.generate_cf_trajectories(cf_transition_probs, policies, new_all_rewards)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "observation = diabetes_trajectories[0].astype(int) # Or pick a different diabetes trajectory.\n",
    "example_influenced_trajectory = CF_trajectories[2][0].astype(int) # Or pick a different influenced trajectory.\n",
    "example_optimal_trajectory = CF_trajectories[-1][0].astype(int)\n",
    "node_info = generate_heat_map(whole_population_trajectories, diabetes_trajectories, observation, example_influenced_trajectory, example_optimal_trajectory)\n",
    "\n",
    "node_info = {}\n",
    "\n",
    "graphs = pydot.graph_from_dot_file(\"graph.dot\")\n",
    "parsed_graph = graphs[0]  # pydot can return a list of graphs; we take the first one\n",
    "\n",
    "for node in parsed_graph.get_nodes():\n",
    "    node_name = node.get_name()\n",
    "    not_nodes = {\"graph\", \"node\", \"\\\"\\\\n\\\"\"}\n",
    "\n",
    "    if node_name not in not_nodes:\n",
    "        node_pos = node.get_pos()\n",
    "        color = node.get_color()\n",
    "        fillcolor = node.get_fillcolor()\n",
    "        node_info[node_name] = [node_pos, fillcolor, color]\n",
    "\n",
    "edges = []\n",
    "\n",
    "for edge in parsed_graph.get_edge_list():\n",
    "    # Extract source and target nodes for each edge\n",
    "    source = edge.get_source().strip('\"')\n",
    "    target = edge.get_destination().strip('\"')\n",
    "    style = edge.get_style()\n",
    "    color = edge.get_color()\n",
    "\n",
    "    if style == \"bold\":\n",
    "        edges.append((source, target, color))\n",
    "\n",
    "from plotnine import ggplot, annotate, aes, geom_point, theme, scale_fill_manual, element_blank, element_rect, scale_color_manual, geom_segment, arrow, geom_vline, xlim\n",
    "import pandas as pd\n",
    "\n",
    "# Prepare data for the DataFrame\n",
    "xs = []\n",
    "ys = []\n",
    "var1 = []\n",
    "var2 = []\n",
    "node_colours = []\n",
    "node_borders = []\n",
    "\n",
    "for node in node_info.values():\n",
    "    if node[1] is not None:\n",
    "        pos = node[0].strip('\"')\n",
    "        x, y = pos.split(',')\n",
    "        xs.append(int(x))\n",
    "        ys.append(int(y))\n",
    "        node_colours.append(node[1].strip('\"'))\n",
    "        node_borders.append(node[2])\n",
    "\n",
    "# Create a DataFrame from the extracted data\n",
    "df = pd.DataFrame({\n",
    "    'x': xs,\n",
    "    'y': ys,\n",
    "    'colours': node_colours,\n",
    "    'borders': node_borders\n",
    "})\n",
    "\n",
    "order = ['black', 'blue', 'red', 'none']\n",
    "df = df.sort_values(by=['x', 'y'])\n",
    "df['colours'] = pd.Categorical(df['colours'])\n",
    "df['borders'] = pd.Categorical(df['borders'], categories=order, ordered=True)\n",
    "\n",
    "def get_pos(node):\n",
    "    node = f\"\\\"{node}\\\"\"\n",
    "    info = node_info[node]\n",
    "    pos = info[0].strip('\"')\n",
    "    x, y = pos.split(',')\n",
    "    return (int(x), int(y))\n",
    "\n",
    "# Arrows\n",
    "edge_df = pd.DataFrame(edges, columns=['start_node', 'end_node', 'colour'])\n",
    "edge_df['start_x'] = edge_df['start_node'].map(lambda node: get_pos(node)[0])\n",
    "edge_df['start_y'] = edge_df['start_node'].map(lambda node: get_pos(node)[1])\n",
    "edge_df['end_x'] = edge_df['end_node'].map(lambda node: get_pos(node)[0])\n",
    "edge_df['end_y'] = edge_df['end_node'].map(lambda node: get_pos(node)[1])\n",
    "\n",
    "legend_df = pd.DataFrame({\n",
    "    'x': [0, 0, 0],\n",
    "    'y': [0, 0, 0],\n",
    "    'colour': ['black', 'blue', 'red']  # Add colors for arrows\n",
    "})\n",
    "\n",
    "print(df)\n",
    "print(edge_df)\n",
    "print(legend_df)\n",
    "time_steps = range(0, 11)\n",
    "y_coordinates = [(t*44)+18 for t in time_steps]\n",
    "print(y_coordinates)\n",
    "\n",
    "plot = (\n",
    "    ggplot()\n",
    "    # + xlim(1000, 8000)\n",
    "    + geom_point(data=df, mapping=aes(x='x', y='y', fill='colours', color='borders'), size=3, stroke=0.25, alpha=0.8, show_legend=False) # 'o' shape for nodes\n",
    "    + geom_segment(data=edge_df, mapping=aes(x='start_x', y='start_y', xend='end_x', yend='end_y', color='colour'), size=0.8, arrow=arrow(length=0.08), show_legend=True)  # Add edges\n",
    "    + theme(legend_position='none', legend_key=element_rect(fill='#ffffff'), axis_text=element_blank(), axis_title=element_blank(), panel_background=element_rect(fill='#ffffff'), axis_ticks=element_blank())\n",
    "    + scale_color_manual(values={cat: cat for cat in df['borders'].cat.categories}, name=' ', labels={'black':'Observation','blue':'Naive Counterfactual', 'red':'Influenced Counterfactual', 'none': ' '})  # Manually set border color values\n",
    "    + scale_fill_manual(values={cat: cat for cat in df['colours'].cat.categories})  # Manually set fill color values\n",
    "    + geom_segment(aes(x=1000, y=458, xend=1000, yend=18), size=0.8,\n",
    "                  arrow=arrow(type='open',length=0.08))\n",
    "    + annotate('text', x=1050, y=238, label='t', ha='left')\n",
    ")\n",
    "\n",
    "# Save and display the plot\n",
    "plot.save(\"plot.svg\", width=5, height=5, dpi=300)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
