{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Import all modules \n",
    "import numpy as np\n",
    "from multiprocessing import Pool\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "!pip3 install networkx\n",
    "import networkx as nx\n",
    "import copy\n",
    "!pip3 install pydot\n",
    "np.seterr(divide = 'ignore') \n",
    "from scipy.stats import hypergeom\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Constants\n",
    "MAX_POPULATION = 10\n",
    "\n",
    "state_space = [(S, I, V) for S in range(MAX_POPULATION + 1)\n",
    "                         for I in range(MAX_POPULATION + 1)\n",
    "                         for V in range(2 * MAX_POPULATION + 1)]\n",
    "num_states = len(state_space)\n",
    "state_index = {state: i for i, state in enumerate(state_space)}\n",
    "state_from_index = {i: state for i, state in enumerate(state_space)}\n",
    "\n",
    "# Action Space\n",
    "actions = [\"NIL\", \"V_I\", \"V_S\"]\n",
    "action_index = {\"NIL\": 0, \"V_I\": 1, \"V_S\": 2}\n",
    "action_from_index = {0: \"NIL\", 1: \"V_I\", 2: \"V_S\"}\n",
    "num_actions = len(actions)\n",
    "\n",
    "# Initialize Transition Matrix and Reward Matrix\n",
    "transition_matrix = np.zeros((num_actions, num_states, num_states))\n",
    "reward_matrix = np.zeros((num_states, num_actions))\n",
    "\n",
    "# Function to compute transition probabilities\n",
    "def compute_transitions(S, I, V, action):\n",
    "    M = S + I  # Total population for hypergeometric distribution\n",
    "    transitions = {}\n",
    "\n",
    "    if action == \"NIL\":\n",
    "        N = S\n",
    "        n = min(S, I)\n",
    "        V_prime = V\n",
    "\n",
    "        for k in range(S + 1):\n",
    "            prob = hypergeom(M, n, N).pmf(k)\n",
    "            S_prime, I_prime = S - k, I + k\n",
    "\n",
    "            if S_prime >= 0 and I_prime <= MAX_POPULATION:\n",
    "                transitions[(S_prime, I_prime, V_prime)] = prob\n",
    "\n",
    "    elif action == \"V_I\" and I > 0 and V > 0:\n",
    "        M -= 1\n",
    "        N = S\n",
    "        n = min(S, I - 1)\n",
    "        V_prime = V - 1\n",
    "\n",
    "        for k in range(S + 1):\n",
    "            prob = hypergeom(M, n, N).pmf(k)\n",
    "            S_prime, I_prime = S - k, I - 1 + k\n",
    "            if S_prime >= 0 and I_prime <= MAX_POPULATION:\n",
    "                transitions[(S_prime, I_prime, V_prime)] = prob\n",
    "\n",
    "    elif action == \"V_S\" and S > 0 and V > 0:\n",
    "        M -= 1\n",
    "        N = S - 1\n",
    "        n = min(S - 1, I)\n",
    "        V_prime = V - 1\n",
    "        for k in range(S):\n",
    "            prob = hypergeom(M, n, N).pmf(k)\n",
    "            S_prime, I_prime = S - 1 - k, I + k\n",
    "            if S_prime >= 0 and I_prime <= MAX_POPULATION:\n",
    "                transitions[(S_prime, I_prime, V_prime)] = prob\n",
    "\n",
    "    return transitions\n",
    "\n",
    "# Compute Transition and Reward Matrices\n",
    "for action_idx, action in enumerate(actions):\n",
    "    for state in state_space:\n",
    "        S, I, V = state\n",
    "        state_idx = state_index[state]\n",
    "        transitions = compute_transitions(S, I, V, action)\n",
    "\n",
    "        # Update transition matrix\n",
    "        for next_state, prob in transitions.items():\n",
    "            next_state_idx = state_index[next_state]\n",
    "            transition_matrix[action_idx, state_idx, next_state_idx] = prob\n",
    "\n",
    "        # Update reward matrix (negative of the number of infected individuals)\n",
    "        reward_matrix[state_idx, action_idx] = -I\n",
    "\n",
    "transition_matrix = np.nan_to_num(transition_matrix)\n",
    "\n",
    "# Output the size of the matrices for verification\n",
    "print(transition_matrix.shape)\n",
    "print(reward_matrix.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class MDP:\n",
    "    def __init__(self, states, state_index, actions, action_index, transition_probabilities, rewards):\n",
    "        self.states = states\n",
    "        self.n_states = len(states)\n",
    "        self.state_index = state_index\n",
    "        self.actions = actions\n",
    "        self.n_actions = len(actions)\n",
    "        self.action_index = action_index\n",
    "        self.transition_probabilities = transition_probabilities\n",
    "        self.rewards = rewards\n",
    "        self.initial_state = self.get_initial_state()\n",
    "        self.T = 7\n",
    "\n",
    "    def get_initial_state(self):\n",
    "        return (9, 1, 20)\n",
    "        \n",
    "    def get_optimal_policy(self):\n",
    "        # The suboptimal policy performs NIL when V is below a threshold, and\n",
    "        # always performs V_I (i.e., no prevention, vaccinate only infected ones)\n",
    "        pi = np.zeros((len(self.states))) \n",
    "        threshold = 10 # 20% of the population\n",
    "\n",
    "        for state in self.states:\n",
    "            (S, I, V) = state\n",
    "\n",
    "            if V <= threshold:\n",
    "                pi[self.state_index[state]] = self.action_index[\"NIL\"]\n",
    "            elif I == 0:\n",
    "                # No one to vaccinate.\n",
    "                pi[self.state_index[state]] = self.action_index[\"NIL\"]\n",
    "            else:\n",
    "                pi[self.state_index[state]] = self.action_index[\"V_I\"]\n",
    "        \n",
    "        return pi\n",
    "\n",
    "    def get_suboptimal_policy(self):\n",
    "        # Never vaccinates.\n",
    "        \n",
    "        pi = np.zeros((len(self.states))) \n",
    "\n",
    "        for state in self.states:\n",
    "            pi[self.state_index[state]] = self.action_index[\"NIL\"]\n",
    "\n",
    "        return pi\n",
    "    \n",
    "    def generate_suboptimal_path(self, policy):\n",
    "        print(f\"Initial State = {self.initial_state}\")\n",
    "        path_to_print = []\n",
    "        samp = []\n",
    "        rng = np.random.default_rng()\n",
    "    \n",
    "        s = self.state_index[self.initial_state]\n",
    "\n",
    "        for t in range(self.T+1):\n",
    "            a = policy[s]\n",
    "            a = int(a)\n",
    "\n",
    "            print(np.sum(self.transition_probabilities[a, s]))\n",
    "            s_prime = (rng.choice(a=range(len(self.states)), size=1,  p=self.transition_probabilities[a, s]))[0]\n",
    "            samp.append([t, s, a, s_prime])\n",
    "            path_to_print.append([t, state_from_index[s], action_from_index[a], state_from_index[s_prime]])\n",
    "            print(path_to_print[-1])\n",
    "            s = s_prime\n",
    "\n",
    "        return path_to_print, np.array([samp])\n",
    "\n",
    "epidemic_mdp = MDP(state_space, state_index, actions, action_index, transition_matrix, reward_matrix)\n",
    "policy = epidemic_mdp.get_suboptimal_policy()\n",
    "\n",
    "MDP_samp = np.array([[[   0, 2120,    0, 1910],\n",
    "  [   1, 1910,    0, 1700],\n",
    "  [   2, 1700,    0, 1070],\n",
    "  [   3, 1070,    0,  650],\n",
    "  [   4,  650,    0,  440],\n",
    "  [   5,  440,    0,  440],\n",
    "  [   6,  440,    0,  440]]])\n",
    "\n",
    "print(MDP_samp)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def truncated_gumbel(logit, truncation):\n",
    "    assert not np.isneginf(logit)\n",
    "\n",
    "    gumbel = np.random.gumbel(size=(truncation.shape[0])) + logit\n",
    "    trunc_g = -np.log(np.exp(-gumbel) + np.exp(-truncation))\n",
    "    return trunc_g\n",
    "\n",
    "def topdown_tracking_influenced_states(obs_logits, obs_state, nsamp=1): # is there only 1 sample? \n",
    "    poss_next_states = obs_logits.shape[0]\n",
    "    gumbels = np.zeros((nsamp, poss_next_states))\n",
    "    influenced_states = np.zeros(shape=poss_next_states)\n",
    "\n",
    "    # Sample top gumbels\n",
    "    topgumbel = np.random.gumbel(size=(nsamp))\n",
    "\n",
    "    for next_state in range(poss_next_states):\n",
    "        # This is the observed outcome\n",
    "        if (next_state == obs_state) and not(np.isneginf(obs_logits[next_state])):\n",
    "            gumbels[:, obs_state] = topgumbel - obs_logits[next_state]\n",
    "            influenced_states[obs_state] = 1\n",
    "        # These were the other feasible options (p > 0)\n",
    "        elif not(np.isneginf(obs_logits[next_state])):\n",
    "            gumbels[:, next_state] = truncated_gumbel(obs_logits[next_state], topgumbel) - obs_logits[next_state]\n",
    "            influenced_states[next_state] = 1\n",
    "        # These have zero probability to start with, so are unconstrained\n",
    "        else:\n",
    "            gumbels[:, next_state] = np.random.gumbel(size=nsamp)\n",
    "\n",
    "    return gumbels, influenced_states # list of gumbel noise values derived from the observed trajectory"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class CounterfactualSampler(object):\n",
    "    def __init__(self, mdp):\n",
    "        self.mdp = mdp\n",
    "        self.sprtb_theta = 0.9\n",
    "        self.sprtb_delta = 0.05\n",
    "        self.sprtb_r = 0.9\n",
    "    \n",
    "    def cf_posterior_tracking_influenced_states(self, obs_prob, intrv_prob, state, n_mc):\n",
    "        obs_logits = np.log(obs_prob)\n",
    "        next_state = state\n",
    "        intrv_logits = np.log(intrv_prob)\n",
    "        gumbels, influenced_states = topdown_tracking_influenced_states(obs_logits, next_state, n_mc)\n",
    "        posterior = intrv_logits + gumbels\n",
    "        intrv_posterior = np.argmax(posterior, axis=1)\n",
    "        posterior_prob = np.zeros(np.size(intrv_prob, 0))\n",
    "        \n",
    "        for i in range(np.size(intrv_prob, 0)):\n",
    "            posterior_prob[i] = np.sum(intrv_posterior == i) / n_mc\n",
    "\n",
    "        return posterior_prob, intrv_posterior, influenced_states\n",
    "\n",
    "    def cf_sample_prob_tracking_influenced_transitions(self, trajectories, T, influenced_transitions, n_cf_samps=1): \n",
    "        n_obs = trajectories.shape[0] \n",
    "        n_mc = 1000\n",
    "\n",
    "        P_cf = np.zeros(shape=(self.mdp.n_states, self.mdp.n_actions, self.mdp.n_states, T))\n",
    "        \n",
    "        for a in range(self.mdp.n_actions):\n",
    "            for t in range(T):\n",
    "                for obs_idx in range(n_obs): # for each given \"observed\" trajectory \n",
    "                    # Get the observed trajectory\n",
    "                    for _ in range(n_cf_samps): # get the desired number of CF trajectories for each given \"observed\" trajectory \n",
    "                            obs_state = trajectories[obs_idx, t, :]\n",
    "\n",
    "                            obs_current_state = int(obs_state[1]) # same as s_real\n",
    "                            obs_next_state = int(obs_state[3]) # same as s_p_real\n",
    "                            obs_action = int(obs_state[2]) # same as a_real\n",
    "\n",
    "                            for s in range(self.mdp.n_states):\n",
    "                                obs_intrv = self.mdp.transition_probabilities[obs_action, obs_current_state, :]\n",
    "                                cf_intrv = self.mdp.transition_probabilities[a, s, :]\n",
    "                                cf_prob, s_p, influenced_states = self.cf_posterior_tracking_influenced_states(obs_intrv, cf_intrv, obs_next_state, n_mc)\n",
    "                                \n",
    "                                for s_p in range(len(cf_prob)):\n",
    "                                    P_cf[s, a, s_p, t] = cf_prob[s_p]\n",
    "                                \n",
    "                                influenced_transitions[s, a, :, t] = influenced_states\n",
    "        \n",
    "        return P_cf, influenced_transitions\n",
    "\n",
    "    def run_parallel_sampling_tracking_influenced_transitions(self, trajectories, influenced_transitions):\n",
    "        n_steps = trajectories.shape[1]\n",
    "\n",
    "        P_cf, influenced_transitions = self.cf_sample_prob_tracking_influenced_transitions(trajectories, n_steps, influenced_transitions)\n",
    "\n",
    "        return P_cf, influenced_transitions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "NUM_ITERATIONS = 10000\n",
    "from collections import deque, defaultdict\n",
    "\n",
    "class InfluenceMDPPruner:\n",
    "    def __init__(self, mdp, mdp_sample, look_ahead_k = 1):\n",
    "        self.mdp = mdp\n",
    "        self.rewards_pi = reward_matrix\n",
    "        self.sampler = CounterfactualSampler(self.mdp)\n",
    "        self.mdp_sample = mdp_sample\n",
    "        self.initial_state = mdp_sample[0][1]\n",
    "        self.look_ahead_k = look_ahead_k\n",
    "        self.T = len(self.mdp_sample[0])\n",
    "        self.states = range(mdp.n_states)\n",
    "        self.actions = range(mdp.n_actions)\n",
    "\n",
    "    def build_graph(self, transition_probs):\n",
    "        G = nx.MultiDiGraph()\n",
    "        pos = {}\n",
    "        \n",
    "        for t in range(self.T):\n",
    "            for s in range(self.mdp.n_states):\n",
    "                G.add_node((t, s))\n",
    "                pos[(t, s)] = (s, -t)\n",
    "\n",
    "                for a in range(self.mdp.n_actions):\n",
    "                    for s_prime in range(self.mdp.n_states):\n",
    "\n",
    "                        if transition_probs[a, s, s_prime] > 0:\n",
    "                            G.add_node(((t+1), s_prime))\n",
    "                            pos[(t+1, s_prime)] = (s_prime, -(t+1)) \n",
    "                            G.add_edge((t, s), ((t+1), s_prime), key=a, label=f\"({a}, {transition_probs[a, s, s_prime]})\")\n",
    "\n",
    "        return G\n",
    "\n",
    "    def build_cf_graph(self, transition_probs, T, all_states, all_actions, k):\n",
    "        G = nx.MultiDiGraph()\n",
    "        pos = {}\n",
    "        \n",
    "        for t in range(T):\n",
    "            for s in all_states:\n",
    "                G.add_node((t, s))\n",
    "                pos[(t, s)] = (s, -t)\n",
    "\n",
    "                for a in all_actions:\n",
    "                    for s_prime in all_states:\n",
    "                        if transition_probs[s, a, s_prime, t] > 0:\n",
    "                            G.add_node(((t+1), s_prime))\n",
    "                            pos[(t+1, s_prime)] = (s_prime, -(t+1)) \n",
    "                            G.add_edge((t, s), ((t+1), s_prime), key=a, label=f\"({a}, {transition_probs[s, a, s_prime, t]})\")\n",
    "\n",
    "                # If node has no outgoing edges, remove it from the graph\n",
    "                if G.has_node((t, s)) and G.out_degree((t, s)) == 0 and t < T-1:\n",
    "                    G.remove_node((t, s))\n",
    "\n",
    "        # Remove unreachable nodes at t>0 with in-degree = 0\n",
    "        unreachable_nodes = {n for n in G if G.in_degree(n) == 0 and n[0]>0}\n",
    "\n",
    "        while len(unreachable_nodes) > 0:\n",
    "            G.remove_nodes_from(unreachable_nodes)\n",
    "            unreachable_nodes = {n for n in G if G.in_degree(n) == 0 and n[0]>0}\n",
    "\n",
    "        return G\n",
    "    \n",
    "    def get_counterfactual_transition_probabilities(self, P_cf, original_G, new_mdp_G, all_states, all_actions, A_real, S_real, T, k):\n",
    "        print(f\"Calculating counterfactual transition probabilities for k={k}\")\n",
    "\n",
    "        # Update the transition probabilities P_cf with the pruned mdp new_mdp_G.\n",
    "        # Remove actions entirely to ensure that the probabilities for each action in\n",
    "        # each state add up to 1. Keep track of which actions are valid choices in\n",
    "        # which states.        \n",
    "        valid_action = np.full((T, len(self.states), len(self.actions)), False)\n",
    "\n",
    "        for t in range(T-1, -1, -1):\n",
    "            for s in self.states:\n",
    "                for a in self.actions:\n",
    "                    for s_prime in self.states:\n",
    "                        if new_mdp_G.has_node((t, s)) and P_cf[s, a, s_prime, t] > 0.0:\n",
    "                            if not new_mdp_G.has_edge((t, s), (t+1, s_prime), key=a):\n",
    "                                imm_descendants = nx.descendants_at_distance(new_mdp_G, (t, s), 1)\n",
    "\n",
    "                                for imm_descendant in imm_descendants:\n",
    "                                    if new_mdp_G.has_edge((t, s), imm_descendant, key=a):\n",
    "                                        new_mdp_G.remove_edge((t, s), imm_descendant, key=a)\n",
    "\n",
    "                # If node has no outgoing edges, remove it from the graph\n",
    "                if new_mdp_G.has_node((t, s)) and new_mdp_G.out_degree((t, s)) == 0 and t < T-1:\n",
    "                    new_mdp_G.remove_node((t, s))\n",
    "\n",
    "            # Remove unreachable nodes at t>0 with in-degree = 0\n",
    "            unreachable_nodes = {n for n in new_mdp_G if new_mdp_G.in_degree(n) == 0 and n[0]>0}\n",
    "\n",
    "            while len(unreachable_nodes) > 0:\n",
    "                new_mdp_G.remove_nodes_from(unreachable_nodes)\n",
    "                unreachable_nodes = {n for n in new_mdp_G if new_mdp_G.in_degree(n) == 0 and n[0]>0}\n",
    "\n",
    "        for t in range(T-1, -1, -1):\n",
    "            for s in all_states:\n",
    "                for a in all_actions:\n",
    "                    for s_prime in all_states:\n",
    "                        if P_cf[s, a, s_prime, t] > 0.0:\n",
    "                            if not new_mdp_G.has_edge((t, s), (t+1, s_prime), key=a):\n",
    "                               P_cf[s, a, :, t] = 0.0\n",
    "                        else:\n",
    "                            assert(P_cf[s, a, s_prime, t] == 0.0)\n",
    "                    \n",
    "                    if round(sum(P_cf[s, a, :, t]), 10) == 1.0:\n",
    "                        valid_action[t, s, a] = True\n",
    "\n",
    "        return P_cf, valid_action\n",
    "\n",
    "    def get_influence_graph(self, G, k, influenced_transitions):\n",
    "        print(f\"Generating influence graph for k={k}\")\n",
    "\n",
    "        def reverse_bfs(G, start_nodes, k):\n",
    "            distance = defaultdict(lambda: float('inf'))\n",
    "            nodes_to_visit = deque([(node, 0) for node in start_nodes])\n",
    "            within_k_steps = set()\n",
    "\n",
    "            while nodes_to_visit:\n",
    "                curr_node, curr_dist = nodes_to_visit.popleft()\n",
    "\n",
    "                if curr_dist <= k:\n",
    "                    within_k_steps.add(curr_node)\n",
    "\n",
    "                    if distance[curr_node] > curr_dist:\n",
    "                        distance[curr_node] = curr_dist\n",
    "\n",
    "                        for predecessor in G.predecessors(curr_node):\n",
    "                            nodes_to_visit.append((predecessor, curr_dist+1))\n",
    "\n",
    "            return within_k_steps\n",
    "            \n",
    "        directly_influenced_nodes = set()\n",
    "\n",
    "        for s in self.states:\n",
    "            for a in self.actions:\n",
    "                for s_prime in self.states:\n",
    "                    for t in range(self.T):\n",
    "                        if influenced_transitions[s, a, s_prime, t]:\n",
    "                            directly_influenced_nodes.add((t+1, s_prime))\n",
    "\n",
    "        reachable_nodes = reverse_bfs(G, directly_influenced_nodes, k)\n",
    "        influence_graph = G.subgraph(reachable_nodes).copy()\n",
    "\n",
    "        # If we are between T-k+1 and T, then we want to add all the paths between these layers, as they are all treated as influenced.\n",
    "        for timestep in range(self.T-k+1, self.T):\n",
    "            for s in self.states:\n",
    "                for a in self.actions:\n",
    "                    for s_prime in self.states:\n",
    "                        if not influence_graph.has_edge((timestep, s), (timestep+1, s_prime), key=a) and G.has_edge((timestep, s), (timestep+1, s_prime), key=a):\n",
    "                            influence_graph.add_edge((timestep, s), (timestep+1, s_prime), key=a)\n",
    "    \n",
    "        # Remove nodes with in-degree = 0 or out-degree = 0\n",
    "        unreachable_nodes = {n for n in influence_graph.nodes if (influence_graph.in_degree(n) == 0 and n[0]>0) or (influence_graph.out_degree(n) == 0 and n[0] < self.T)}\n",
    "\n",
    "        while len(unreachable_nodes) > 0:\n",
    "            influence_graph.remove_nodes_from(unreachable_nodes)\n",
    "            unreachable_nodes = {n for n in influence_graph.nodes if (influence_graph.in_degree(n) == 0 and n[0]>0) or (influence_graph.out_degree(n) == 0 and n[0] < self.T)}\n",
    "        \n",
    "        return influence_graph\n",
    "\n",
    "    def prune_mdp(self):\n",
    "        # Initialise a matrix to keep track of which transitions' probabilities\n",
    "        # are directly influenced by the observed trajectory.\n",
    "        influenced_transitions = np.zeros(shape=(len(self.states), len(self.actions), len(self.states), self.T+1))\n",
    "\n",
    "        # Generate the counterfacutal transition probabilities, keeping track\n",
    "        # of which transitionals have been influenced by the observed path.\n",
    "        P_cf, influenced_transitions = self.sampler.run_parallel_sampling_tracking_influenced_transitions(self.mdp_sample, influenced_transitions)\n",
    "\n",
    "        # Build graph using the original MDP transition probabilities.\n",
    "        G = self.build_graph(self.mdp.transition_probabilities)\n",
    "\n",
    "        # Build the influence graph for each look-ahead k\n",
    "        influence_graphs = []\n",
    "\n",
    "        # Generate graphs for the pruned MDP.\n",
    "        for k in range(1, self.look_ahead_k+1):\n",
    "            influence_graph = self.get_influence_graph(copy.deepcopy(G), k, influenced_transitions)\n",
    "            influence_graphs.append(influence_graph)\n",
    "\n",
    "        cf_transition_probs = []\n",
    "        valid_actions = []\n",
    "\n",
    "        A_real = self.mdp_sample[0, :, 2]\n",
    "        S_real = self.mdp_sample[0, :, 1]\n",
    "\n",
    "        for look_ahead_k in range(1, self.look_ahead_k+1):\n",
    "            new_P_cf, valid_action = self.get_counterfactual_transition_probabilities(copy.deepcopy(P_cf), G, influence_graphs[look_ahead_k-1], self.states, self.actions, A_real, S_real, self.T, look_ahead_k)\n",
    "            cf_transition_probs.append(new_P_cf)\n",
    "            valid_actions.append(valid_action)\n",
    "\n",
    "        # Generate graphs for the pruned counterfactual MDP.\n",
    "        cf_graphs = []\n",
    "\n",
    "        for k in range(1, self.look_ahead_k+1):\n",
    "            G = self.build_cf_graph(cf_transition_probs[k-1], self.T, self.states, self.actions, k)\n",
    "            cf_graphs.append(G)\n",
    "\n",
    "        return cf_transition_probs, valid_actions, cf_graphs\n",
    "    \n",
    "    def get_optimal_policy(self, max_num_actions_changed, P_cf, valid_action, new_mdp_G, all_states, all_actions, S_real, A_real, T, rewards_pi):\n",
    "        if len(all_states) == 0:\n",
    "            return None\n",
    "\n",
    "        for t in range(T):\n",
    "            for s in all_states:\n",
    "                for a in all_actions:\n",
    "                    if valid_action[t, s, a]:\n",
    "                        assert(round(sum(P_cf[s, a, :, t]), 10) == 1.0)\n",
    "                    else:\n",
    "                        assert(round(sum(P_cf[s, a, :, t]), 10) == 0.0)\n",
    "\n",
    "        h_fun = np.zeros((len(all_states), T+1, max_num_actions_changed+1)) \n",
    "        pi = np.zeros((len(all_states), max_num_actions_changed+1, T+1), dtype=int) \n",
    "    \n",
    "        for r in range(1, T+1): # last r steps of the decision making process\n",
    "            for s in all_states: # for all possible states\n",
    "                h_fun[s, r, 0] = rewards_pi[(T-r), s, (A_real[T-r])] # for all time steps counting backwards (T-r is T-1, T-2 etc) \n",
    "\n",
    "                for s_p in all_states: # for every singe next state (s') for each state s\n",
    "                    h_fun[s, r, 0] += P_cf[s, A_real[T-r], s_p, T-r] * h_fun[s_p, r-1, 0] # P_cf[obs_ind][a,t][s,s']\n",
    "\n",
    "                pi[s, max_num_actions_changed, T-r] = A_real[T-r]\n",
    "\n",
    "        # For t=1,...,T-2 do recursive computations\n",
    "        for c in range(1, max_num_actions_changed+1): # iterates over the number of changes allowed\n",
    "            for r in range(1, T+1): # iterates over the time steps in reverse order\n",
    "                for s in all_states:\n",
    "                    pi[s, max_num_actions_changed-c, T-r] = A_real[T-r] # instead let it be the real action\n",
    "                    best_act = A_real[T-r]                \n",
    "                    max_val = -np.inf\n",
    "\n",
    "                    for a in all_actions: # For each state and action, it computes the value based on rewards and future values.\n",
    "                        if valid_action[T-r, s, a]:\n",
    "                            val = rewards_pi[T-r][s][a]\n",
    "\n",
    "                            # If an action differs from the observed action, the number of remaining changes (c) decreases.\n",
    "                            if a != A_real[T-r]:\n",
    "                                for s_p in all_states:\n",
    "                                    if P_cf[s, a, s_p, T-r] != 0:\n",
    "                                        val += P_cf[s, a, s_p, T-r] * h_fun[s_p, r-1, c-1]                                        \n",
    "                            elif a == A_real[T-r]:\n",
    "                                for s_p in all_states:\n",
    "                                    if P_cf[s, a, s_p, T-r] != 0:\n",
    "                                        val += P_cf[s, a, s_p, T-r] * h_fun[s_p, r-1, c]\n",
    "\n",
    "                            if val > max_val:\n",
    "                                max_val = val\n",
    "                                best_act = a\n",
    "                    \n",
    "                    h_fun[s, r, c] = max_val\n",
    "\n",
    "                    if max_val == -np.inf:\n",
    "                        pi[s, max_num_actions_changed-c, T-r] = A_real[T-r]\n",
    "                    else:\n",
    "                        pi[s, max_num_actions_changed-c, T-r] = best_act\n",
    "\n",
    "        return pi, h_fun\n",
    "\n",
    "    def generate_policies(self, cf_transition_probs, valid_actions, cf_graphs):\n",
    "        # Generate policies for each of the pruned counterfactual MDPs.\n",
    "        policies = []\n",
    "        h_funs = []\n",
    "        k_vals = range(1, self.look_ahead_k+1)\n",
    "        S_real = self.mdp_sample[0, :, 1]\n",
    "        A_real = self.mdp_sample[0, :, 2]\n",
    "\n",
    "        new_all_rewards = np.zeros((self.T, len(self.states), len(self.actions)))\n",
    "\n",
    "        for t in range(self.T):\n",
    "            for s in self.states:\n",
    "                for a in self.actions:                \n",
    "                    new_all_rewards[t, s, a] = self.rewards_pi[s, a]\n",
    "\n",
    "        for look_ahead_k in k_vals:\n",
    "            print(f\"Estimating policy with k={look_ahead_k}\")\n",
    "            policies_k = []\n",
    "            h_funs_k = []\n",
    "\n",
    "            for max_num_actions_changed in k_vals:\n",
    "                # Get the optimal policy\n",
    "                pi, h_fun = self.get_optimal_policy(\n",
    "                    max_num_actions_changed, \n",
    "                    cf_transition_probs[look_ahead_k-1],\n",
    "                    valid_actions[look_ahead_k-1],\n",
    "                    cf_graphs[look_ahead_k-1],\n",
    "                    self.states,\n",
    "                    self.actions,\n",
    "                    S_real,\n",
    "                    A_real,\n",
    "                    self.T,\n",
    "                    new_all_rewards\n",
    "                )\n",
    "\n",
    "                policies_k.append(pi)\n",
    "                h_funs_k.append(h_fun)\n",
    "\n",
    "            policies.append(policies_k)\n",
    "            h_funs.append(h_funs_k)\n",
    "\n",
    "        return policies, new_all_rewards, h_funs\n",
    "\n",
    "\n",
    "    def generate_random_trajectory(self, MDP_samp, P_cf, pi, s_0, A_real, all_states, rewards_pi, T):\n",
    "        n_obs=MDP_samp.shape[0]\n",
    "        n_state=MDP_samp.shape[2]\n",
    "        CF_trajectory = np.zeros((n_obs, T, n_state))\n",
    "        \n",
    "        rng = np.random.default_rng()\n",
    "        s = np.zeros(T+1, dtype=int)\n",
    "        s[0] = s_0   # Initial state the same\n",
    "        l = np.zeros(T+1, dtype=int)\n",
    "        l[0] = 0    # Start with 0 changes\n",
    "        a = np.zeros(T, dtype=int)\n",
    "        \n",
    "        for t in range(T):\n",
    "            # Pick actions according to the given policy\n",
    "            a[t] = pi[s[t], l[t], t]\n",
    "            \n",
    "            # Sample the next state\n",
    "            s[t+1] = (rng.choice(a=self.states, size=1,  p=P_cf[s[t], a[t], :, t]))[0]\n",
    "\n",
    "            # Adjust the number of changes so far\n",
    "            if a[t] != A_real[t]:\n",
    "                l[t+1] = l[t] + 1\n",
    "            else:\n",
    "                l[t+1] = l[t]\n",
    "            \n",
    "            CF_trajectory[0, t, :] = np.array([s[t], s[t+1], a[t], rewards_pi[t, s[t], a[t]]])\n",
    "                    \n",
    "        return CF_trajectory\n",
    "\n",
    "    def generate_cf_trajectories(self, cf_transition_probs, policies, new_all_rewards):\n",
    "        print(f\"Generating CF trajectories\")\n",
    "        all_obs = []\n",
    "        all_cf = []\n",
    "        k_vals = range(1, self.look_ahead_k+1)\n",
    "        A_real = self.mdp_sample[0, :, 2]\n",
    "        s_0 = self.mdp_sample[0, 0, 1]\n",
    "\n",
    "        for _ in range(1000):\n",
    "            obs = np.zeros(shape=(self.look_ahead_k, self.look_ahead_k))\n",
    "            cf = np.zeros(shape=(self.look_ahead_k, self.look_ahead_k))\n",
    "\n",
    "            for look_ahead_k in k_vals:\n",
    "                for max_num_actions_changed in k_vals:\n",
    "                    CF_trajectory = self.generate_random_trajectory(\n",
    "                        self.mdp_sample,\n",
    "                        cf_transition_probs[look_ahead_k-1],\n",
    "                        policies[look_ahead_k-1][max_num_actions_changed-1],\n",
    "                        s_0,\n",
    "                        A_real,\n",
    "                        self.states,\n",
    "                        new_all_rewards,\n",
    "                        self.T\n",
    "                    )\n",
    "\n",
    "                    print(CF_trajectory)\n",
    "\n",
    "                    obs[look_ahead_k-1][max_num_actions_changed-1] = self.mdp_sample[0, self.T-1, 3] # Immediate reward for obs path at time T\n",
    "                    cf[look_ahead_k-1][max_num_actions_changed-1] = CF_trajectory[0, self.T-1, 3] # Immediate reward for cf path at time T\n",
    "            \n",
    "            all_obs.append(obs)\n",
    "            all_cf.append(cf)\n",
    "\n",
    "        all_obs = np.array(all_obs)\n",
    "        all_cf = np.array(all_cf)\n",
    "\n",
    "        mean_obs = all_obs.mean(axis=0)\n",
    "        mean_cf = all_cf.mean(axis=0)\n",
    "\n",
    "        return mean_obs, mean_cf, k_vals"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "influence_pruner = InfluenceMDPPruner(epidemic_mdp, MDP_samp, look_ahead_k=epidemic_mdp.T+1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "cf_transition_probs, valid_actions, cf_graphs = influence_pruner.prune_mdp()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "policies, new_all_rewards, h_funs = influence_pruner.generate_policies(cf_transition_probs, valid_actions, cf_graphs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(MDP_samp[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mean_obs, mean_cf, k_vals = influence_pruner.generate_cf_trajectories(cf_transition_probs, policies, new_all_rewards)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "values = []\n",
    "T = len(MDP_samp[0])\n",
    "obs_values = None\n",
    "\n",
    "for look_ahead_k in k_vals:\n",
    "    k_values = []\n",
    "    obs_values = []\n",
    "\n",
    "    for max_num_actions_changed in k_vals:\n",
    "        h_fun = h_funs[look_ahead_k-1][max_num_actions_changed-1]\n",
    "        # s_0 = 2120\n",
    "        obs_values.append(h_fun[2120, -1, 0])\n",
    "        k_values.append(h_fun[2120, -1, max_num_actions_changed])\n",
    "\n",
    "    values.append(k_values)\n",
    "\n",
    "fig = plt.figure(figsize=(8, 8));\n",
    "ax = fig.add_subplot()\n",
    "\n",
    "plt.xlabel('Maximum Number of Actions Changed', fontsize=14)\n",
    "plt.ylabel('Value of Initial State',fontsize=14); \n",
    "plt.grid(which='both')\n",
    "\n",
    "ax.scatter(k_vals, obs_values, color='lightpink', label='Observed reward', marker=\"o\", s=100);\n",
    "colors = ['darkblue', 'darkblue', 'darkviolet', 'red', 'green', 'orange', 'blue', 'grey']\n",
    "markers = ['x', 'x', 'd', '+', '.', '^', 's', 'x']\n",
    "\n",
    "for look_ahead_k in k_vals[1:]:\n",
    "    ax.scatter(k_vals, values[look_ahead_k-1], color=colors[look_ahead_k-1], label='CF reward', marker=markers[look_ahead_k-1], s=50)\n",
    "\n",
    "plt.legend([\"Observed Path\", \"K=1 to K=2\", \"K=3\", \"K=4\", \"K=5\", \"K=6\", \"K=7\", \"K=T+1\"], loc=0, frameon=True, fontsize=12)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig = plt.figure(figsize=(10, 10));\n",
    "ax = fig.add_subplot()\n",
    "\n",
    "plt.xlabel('Maximum Number of Actions Changed');\n",
    "plt.ylabel('Final State Reward'); \n",
    "plt.grid(which='both')\n",
    "\n",
    "ax.scatter(k_vals, mean_obs[-1], color='lightpink', label='Observed reward', marker=\"o\", s=100);\n",
    "colors = ['orange', 'red', 'aqua', 'yellow', 'darkblue', 'deeppink', 'darkviolet', 'silver', 'teal', 'blue']\n",
    "\n",
    "for look_ahead_k in k_vals:\n",
    "    ax.scatter(k_vals, mean_cf[look_ahead_k-1], color=colors[look_ahead_k-1], label='CF reward', marker=\"d\", s=50)\n",
    "\n",
    "plt.legend([\"Observed Path\", \"Look-Ahead K=1\", \"Look-Ahead K=2\", \"Look-Ahead K=3\", \"Look-Ahead K=4\", \"Look-Ahead K=5\", \"Look-Ahead K=6\", \"Look-Ahead K=7\", \"Look-Ahead K=T+1\"], loc=0, frameon=True)\n",
    "plt.show()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
