{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Import all modules \n",
    "import numpy as np\n",
    "from multiprocessing import Pool\n",
    "import matplotlib.pyplot as plt\n",
    "!pip3 install networkx\n",
    "import networkx as nx\n",
    "import copy\n",
    "!pip3 install pydot\n",
    "np.seterr(divide = 'ignore') "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class MDP:\n",
    "    def __init__(self, states, actions, transition_probabilities, rewards):\n",
    "        self.states = states\n",
    "        self.n_states = len(states)\n",
    "        self.actions = actions\n",
    "        self.n_actions = len(actions)\n",
    "        self.transition_probabilities = transition_probabilities\n",
    "        self.rewards = rewards\n",
    "        self.optimal_policy = np.zeros(len(states), dtype=int) # array with zeros \n",
    "        self.values = np.zeros(len(states))\n",
    "        self.discount = 0.9\n",
    "\n",
    "    def value_iteration(self): # to find optimal policy by repeatedly updating the value function.\n",
    "        # Initialize the value function to zero\n",
    "        self.values = np.zeros(len(self.states))\n",
    "        # Iterate until convergence\n",
    "        while True:\n",
    "            # Calculate the new value function\n",
    "            new_values = np.zeros(len(self.states))\n",
    "            for state in self.states:\n",
    "                # Calculate the value of performing each action in this state\n",
    "                values = np.zeros(len(self.actions))\n",
    "                for action in self.actions:\n",
    "                    # Calculate the expected value of performing this action\n",
    "                    for next_state in self.states:\n",
    "                        values[action] += self.transition_probabilities[state, action, next_state] * (self.rewards[state, action] + self.discount * self.values[next_state])\n",
    "                \n",
    "                # Select the action with the highest value\n",
    "                new_values[state] = np.max(values) # updates optimal reward values\n",
    "                self.optimal_policy[state] = np.argmax(values) # updates the optimal policy for the specified MDP\n",
    "\n",
    "            # Check for convergence\n",
    "            if np.sum(np.abs(new_values - self.values)) < 1e-4:\n",
    "                break\n",
    "\n",
    "            # Update the value function\n",
    "            self.values = new_values # corresponding optimal poicy reward values \n",
    "            \n",
    "    def policy_with_randomization(self, policy, randomization_probability):\n",
    "        policy_matrix = self.translating_policy_to_matrix(policy)\n",
    "        random_policy = randomization_probability*np.ones((len(self.states), len(self.actions)))\n",
    "        random_policy = random_policy + policy_matrix\n",
    "        random_policy = random_policy / (randomization_probability*len(self.actions) + 1)\n",
    "        assert np.allclose(np.sum(random_policy, axis=1), 1)\n",
    "        return random_policy\n",
    "\n",
    "    def translating_policy_to_matrix(self, policy):\n",
    "        policy_matrix = np.zeros((len(self.states), len(self.actions))) # policy is 2x2 matrix with states and actions\n",
    "        for i in range(len(policy)):\n",
    "            policy_matrix[i][policy[i]] = 1\n",
    "        return policy_matrix"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class createMDP:\n",
    "    def create_gridworld_mdp():\n",
    "        states = range(16)\n",
    "        actions = range(4)\n",
    "        transition_probabilities = np.zeros((len(states), len(actions), len(states))) # 3x3\n",
    "        rewards = np.zeros((len(states), len(actions))) # 2x2\n",
    "\n",
    "        # Set the transition probabilities and rewards in deterministic gridworld\n",
    "        for state in states:\n",
    "            for action in actions:\n",
    "                # Calculate the next state for deterministic gridworld\n",
    "                next_state = state\n",
    "                if action == 0:  # up\n",
    "                    if state not in [0, 1, 2, 3]:\n",
    "                        next_state = state - 4;\n",
    "                elif action == 1:  # right\n",
    "                    if state not in [3, 7, 11, 15]:\n",
    "                        next_state = state + 1;\n",
    "                elif action == 2:  # down\n",
    "                    if state not in [12, 13, 14, 15]:\n",
    "                        next_state = state + 4;\n",
    "                elif action == 3:  # left\n",
    "                    if state not in [0, 4, 8, 12]:\n",
    "                        next_state = state - 1;\n",
    "                # Set the transition probabilities and rewards\n",
    "                transition_probabilities[state, action, next_state] = 1\n",
    "                # we change the the transition probabilities and turn the unsafe states and terminal states into absorbing states\n",
    "                if state in [6]:\n",
    "                    transition_probabilities[state, action, :] = 0.0;\n",
    "                    transition_probabilities[state, action, state] = 1.0;\n",
    "                    rewards[state, action] = -100;\n",
    "                elif state == 15:\n",
    "                    transition_probabilities[state, action, :] = 0.0;\n",
    "                    transition_probabilities[state, action, state] = 1.0;\n",
    "                    rewards[state, action] = 100;\n",
    "                elif state in [1, 4, 5]:\n",
    "                    rewards[state, action] = 1.0;\n",
    "                elif state in [2, 8, 9]:\n",
    "                    rewards[state, action] = 2.0;\n",
    "                elif state in [3, 10, 12]:\n",
    "                    rewards[state, action] = 3.0;\n",
    "                elif state in [7, 13]:\n",
    "                    rewards[state, action] = 4.0;\n",
    "                elif state in [11, 14]:\n",
    "                    rewards[state, action] = 5.0;      \n",
    "        # create the stochastic gridworld MDP based on the deterministic gridworld MDP\n",
    "        s_transition_probabilities = np.zeros((len(states), len(actions), len(states)));\n",
    "        p_r = 0.9;\n",
    "        for action in actions:\n",
    "            other_actions = [a for a in actions if a != action];\n",
    "            for a in other_actions:\n",
    "                s_transition_probabilities[:, action, :] += (1 - p_r)/3 * transition_probabilities[:, a, :];\n",
    "            s_transition_probabilities[:, action, :] += p_r * transition_probabilities[:, action, :];\n",
    "        \n",
    "        # assert if the transition probabilities are correct\n",
    "        assert np.allclose(np.sum(s_transition_probabilities, axis=2), 1.0);\n",
    "        \n",
    "        # Create the MDP\n",
    "        mdp = MDP(states, actions, s_transition_probabilities, rewards)\n",
    "        mdp.value_iteration() # updating the optimal policy for the given MDP \n",
    "        \n",
    "        # Solve the MDP\n",
    "        optimal_policy = mdp.optimal_policy\n",
    "        random_policy = optimal_policy.copy() # we choose the optimal policy as the random policy except for states 2, 8\n",
    "        random_policy[0] = 1 # we choose the right action for state 0\n",
    "        random_policy[1] = 1 # we choose the right action for state 1\n",
    "        random_policy[4] = 1 # we choose the right action for state 4\n",
    "        random_policy[8] = 1 # we choose the right action for state 8\n",
    "        random_policy[9] = 1 # we choose the right action for state 9\n",
    "        optimal_policy = mdp.policy_with_randomization(optimal_policy, 0) # (randomization probability is zero)\n",
    "        random_policy = mdp.policy_with_randomization(random_policy, (0.1/(1-4*0.1))) # matrix representation of the policy\n",
    "\n",
    "        rewards_pi = [[   0,    1,    1,    0],[   1,    2,    1,    0],[   2,    3,    -100,    1],[   3,    3,    4,    2],[   0,    1,    2,    1],[   1,    -100,    2,    1],[-100, -100, -100, -100],[   3,    4,    5,    -100],[   1,    2,    3,    2],[   1,    3,    4,    2],[   -100,    5,    5,    2],[   4,    5,    100,    3],[   2,    4,    3,    3],[   2,    5,    4,    3],[   3,    100,    5,    4],[ 100,  100,  100,  100]]\n",
    "        rewards_pi_np = np.array(rewards_pi)\n",
    "        rewards_pi = rewards_pi_np.reshape(16, 4)\n",
    "\n",
    "        return mdp, optimal_policy, random_policy, rewards_pi\n",
    "    \n",
    "    def mdp_info():\n",
    "        states = range(16);\n",
    "        actions = range(4);\n",
    "        transition_probabilities = np.zeros((len(states), len(actions), len(states))); # 3x3\n",
    "        rewards = np.zeros((len(states), len(actions))); # 2x2\n",
    "\n",
    "        # Set the transition probabilities and rewards in deterministic gridworld\n",
    "        for state in states:\n",
    "            for action in actions:\n",
    "                # Calculate the next state for deterministic gridworld\n",
    "                next_state = state;\n",
    "                if action == 0:  # up\n",
    "                    if state not in [0, 1, 2, 3]:\n",
    "                        next_state = state - 4;\n",
    "                elif action == 1:  # right\n",
    "                    if state not in [3, 7, 11, 15]:\n",
    "                        next_state = state + 1;\n",
    "                elif action == 2:  # down\n",
    "                    if state not in [12, 13, 14, 15]:\n",
    "                        next_state = state + 4;\n",
    "                elif action == 3:  # left\n",
    "                    if state not in [0, 4, 8, 12]:\n",
    "                        next_state = state - 1;\n",
    "                # Set the transition probabilities and rewards\n",
    "                transition_probabilities[state, action, next_state] = 1\n",
    "                # we change the the transition probabilities and turn the unsafe states and terminal states into absorbing states\n",
    "                if state in [6]:\n",
    "                    transition_probabilities[state, action, :] = 0.0;\n",
    "                    transition_probabilities[state, action, state] = 1.0;\n",
    "                    rewards[state, action] = -100;\n",
    "                elif state == 15:\n",
    "                    transition_probabilities[state, action, :] = 0.0;\n",
    "                    transition_probabilities[state, action, state] = 1.0;\n",
    "                    rewards[state, action] = 100;\n",
    "                elif state in [1, 4, 5]:\n",
    "                    rewards[state, action] = 1.0;\n",
    "                elif state in [2, 8, 9]:\n",
    "                    rewards[state, action] = 2.0;\n",
    "                elif state in [3, 10, 12]:\n",
    "                    rewards[state, action] = 3.0;\n",
    "                elif state in [7, 13]:\n",
    "                    rewards[state, action] = 4.0;\n",
    "                elif state in [11, 14]:\n",
    "                    rewards[state, action] = 5.0;   \n",
    "        \n",
    "        # create the stochastic gridworld MDP based on the deterministic gridworld MDP\n",
    "        s_transition_probabilities = np.zeros((len(states), len(actions), len(states)));\n",
    "        p_r = 0.9;\n",
    "        for action in actions:\n",
    "            other_actions = [a for a in actions if a != action];\n",
    "            for a in other_actions:\n",
    "                s_transition_probabilities[:, action, :] += (1 - p_r)/3 * transition_probabilities[:, a, :];\n",
    "            s_transition_probabilities[:, action, :] += p_r * transition_probabilities[:, action, :];\n",
    "        \n",
    "        # assert if the transition probabilities are correct\n",
    "        assert np.allclose(np.sum(s_transition_probabilities, axis=2), 1.0);\n",
    "        \n",
    "        # Create the MDP\n",
    "        mdp = MDP(states, actions, s_transition_probabilities, rewards)\n",
    "        mdp.value_iteration() # updating the optimal policy for the given MDP \n",
    "\n",
    "        return states, actions, rewards # all posible states, all possible actions, all rewards matrix of the form np.zeros((len(states), len(actions)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def truncated_gumbel(logit, truncation):\n",
    "    assert not np.isneginf(logit)\n",
    "\n",
    "    gumbel = np.random.gumbel(size=(truncation.shape[0])) + logit\n",
    "    trunc_g = -np.log(np.exp(-gumbel) + np.exp(-truncation))\n",
    "    return trunc_g\n",
    "\n",
    "def topdown_tracking_influenced_states(obs_logits, obs_state, nsamp=1): # is there only 1 sample? \n",
    "    poss_next_states = obs_logits.shape[0]\n",
    "    gumbels = np.zeros((nsamp, poss_next_states))\n",
    "    influenced_states = np.zeros(shape=poss_next_states)\n",
    "\n",
    "    # Sample top gumbels\n",
    "    topgumbel = np.random.gumbel(size=(nsamp))\n",
    "\n",
    "    for next_state in range(poss_next_states):\n",
    "        # This is the observed outcome\n",
    "        if (next_state == obs_state) and not(np.isneginf(obs_logits[next_state])):\n",
    "            gumbels[:, obs_state] = topgumbel - obs_logits[next_state]\n",
    "            influenced_states[obs_state] = 1\n",
    "        # These were the other feasible options (p > 0)\n",
    "        elif not(np.isneginf(obs_logits[next_state])):\n",
    "            gumbels[:, next_state] = truncated_gumbel(obs_logits[next_state], topgumbel) - obs_logits[next_state]\n",
    "            influenced_states[next_state] = 1\n",
    "        # These have zero probability to start with, so are unconstrained\n",
    "        else:\n",
    "            gumbels[:, next_state] = np.random.gumbel(size=nsamp)\n",
    "\n",
    "    return gumbels, influenced_states # list of gumbel noise values derived from the observed trajectory"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "n_actions = 4"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class CounterfactualSampler(object):\n",
    "    def __init__(self, mdp):\n",
    "        self.mdp = mdp\n",
    "        self.sprtb_theta = 0.9\n",
    "        self.sprtb_delta = 0.05\n",
    "        self.sprtb_r = 0.9\n",
    "\n",
    "    def mdp_sample(self, policy, initial_state, n_obs=2, n_steps=20):\n",
    "        n_state = 4\n",
    "        trajectories = np.zeros((n_obs, n_steps, n_state))\n",
    "        for obs_idx in range(n_obs):\n",
    "            current_state = np.random.choice(initial_state, size=1)[0] \n",
    "            for time_idx in range(n_steps): \n",
    "                action = np.random.choice(\n",
    "                    self.mdp.n_actions, size=1, p=policy[current_state, :])[0] \n",
    "                next_state = np.random.choice(\n",
    "                    self.mdp.n_states, size=1, p=self.mdp.transition_probabilities[current_state, action, :])[0] \n",
    "                reward = self.mdp.rewards[current_state, action]\n",
    "                trajectories[obs_idx, time_idx, :] = np.array([current_state, next_state, action, reward])\n",
    "                current_state = next_state\n",
    "        return trajectories \n",
    "    \n",
    "    def cf_posterior_tracking_influenced_states(self, obs_prob, intrv_prob, state, n_mc):\n",
    "        obs_logits = np.log(obs_prob);\n",
    "        next_state = state\n",
    "        intrv_logits = np.log(intrv_prob);\n",
    "        gumbels, influenced_states = topdown_tracking_influenced_states(obs_logits, next_state, n_mc)\n",
    "        posterior = intrv_logits + gumbels\n",
    "        intrv_posterior = np.argmax(posterior, axis=1)\n",
    "        posterior_prob = np.zeros(np.size(intrv_prob, 0))\n",
    "        \n",
    "        for i in range(np.size(intrv_prob, 0)):\n",
    "            posterior_prob[i] = np.sum(intrv_posterior == i) / n_mc\n",
    "\n",
    "        return posterior_prob, intrv_posterior, influenced_states\n",
    "\n",
    "    def cf_sample_prob_tracking_influenced_transitions(self, trajectories, all_actions, T, influenced_transitions, n_cf_samps=1): \n",
    "        n_obs = trajectories.shape[0] \n",
    "        n_mc = 1000\n",
    "\n",
    "        P_cf = np.zeros(shape=(self.mdp.n_states, all_actions, self.mdp.n_states, T))\n",
    "        \n",
    "        for a in range(all_actions):\n",
    "            for t in range(T):\n",
    "                for obs_idx in range(n_obs):\n",
    "                    # Get the observed trajectory\n",
    "                    for _ in range(n_cf_samps): # get the desired number of CF trajectories for each given \"observed\" trajectory \n",
    "                            obs_state = trajectories[obs_idx, t, :]\n",
    "                            obs_current_state = int(obs_state[0]) # same as s_real\n",
    "                            obs_next_state = int(obs_state[1]) # same as s_p_real\n",
    "                            obs_action = int(obs_state[2]) # same as a_real\n",
    "\n",
    "                            for s in range(self.mdp.n_states):\n",
    "                                obs_intrv = self.mdp.transition_probabilities[obs_current_state, obs_action, :]\n",
    "                                cf_intrv = self.mdp.transition_probabilities[s, a, :]\n",
    "                                cf_prob, s_p, influenced_states = self.cf_posterior_tracking_influenced_states(obs_intrv, cf_intrv, obs_next_state, n_mc)\n",
    "                                \n",
    "                                for s_p in range(len(cf_prob)):\n",
    "                                    P_cf[s, a, s_p, t] = cf_prob[s_p]\n",
    "                                \n",
    "                                influenced_transitions[s, a, :, t] = influenced_states\n",
    "        \n",
    "        return P_cf, influenced_transitions\n",
    "\n",
    "    def run_parallel_sampling_tracking_influenced_transitions(self, trajectories, influenced_transitions):\n",
    "        n_steps = trajectories.shape[1]\n",
    "        n_actions = 4\n",
    "\n",
    "        P_cf, influenced_transitions = self.cf_sample_prob_tracking_influenced_transitions(trajectories, n_actions, n_steps, influenced_transitions)\n",
    "\n",
    "        return P_cf, influenced_transitions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "NUM_ITERATIONS = 10000\n",
    "from collections import deque, defaultdict\n",
    "\n",
    "class InfluenceMDPPruner:\n",
    "    def __init__(self, mdpBuilder, look_ahead_k = 1):\n",
    "        self.mdp, self.optimal_policy, self.random_policy, self.rewards_pi = mdpBuilder.create_gridworld_mdp()\n",
    "        self.sampler = CounterfactualSampler(self.mdp)\n",
    "        self.mdp_sample = self.use_saved_mdp_sample()\n",
    "        self.initial_state = self.mdp_sample\n",
    "        self.states, self.actions, _ = mdpBuilder.mdp_info()\n",
    "        self.look_ahead_k = look_ahead_k\n",
    "        self.T = len(self.mdp_sample[0])\n",
    "\n",
    "    def generate_random_mdp_sample(self):\n",
    "        return (np.array((self.sampler.mdp_sample(policy=self.random_policy, initial_state=[0], n_obs=1, n_steps=11)))).reshape(1, 11, 4)\n",
    "\n",
    "    def use_saved_mdp_sample(self):\n",
    "        MDP_samp = [[[0, 1, 1, 1],  [1, 2, 1, 2], [2, 6, 2, -100],  [6, 6, 0, -100], [6, 6, 0, -100], [6, 6, 0, -100], [6, 6, 0, -100], [6, 6, 0, -100], [6, 6, 0, -100], [6, 6, 0, -100], [6, 6, 0, -100]]]\n",
    "        MDP_samp_np = np.array(MDP_samp)\n",
    "        return MDP_samp_np.reshape(1, 11, 4)\n",
    "\n",
    "    def build_graph(self, transition_probs, T, all_states, all_actions):\n",
    "        G = nx.MultiDiGraph()\n",
    "        pos = {}\n",
    "\n",
    "        print(transition_probs)\n",
    "        \n",
    "        for t in range(T):\n",
    "            for s in all_states:\n",
    "                G.add_node((t, s))\n",
    "                pos[(t, s)] = (s, -t)\n",
    "\n",
    "                for a in all_actions:\n",
    "                    for s_prime in all_states:\n",
    "                        if transition_probs[s, a, s_prime] > 0:\n",
    "                            G.add_node(((t+1), s_prime))\n",
    "                            pos[(t+1, s_prime)] = (s_prime, -(t+1)) \n",
    "                            G.add_edge((t, s), ((t+1), s_prime), key=a, label=f\"({a}, {transition_probs[s, a, s_prime]})\")\n",
    "\n",
    "        return G\n",
    "\n",
    "    def build_cf_graph(self, transition_probs, T, all_states, all_actions, k):\n",
    "        G = nx.MultiDiGraph()\n",
    "        pos = {}\n",
    "        \n",
    "        for t in range(T):\n",
    "            for s in all_states:\n",
    "                G.add_node((t, s))\n",
    "                pos[(t, s)] = (s, -t)\n",
    "\n",
    "                for a in all_actions:\n",
    "                    for s_prime in all_states:\n",
    "                        if transition_probs[s, a, s_prime, t] > 0:\n",
    "                            G.add_node(((t+1), s_prime))\n",
    "                            pos[(t+1, s_prime)] = (s_prime, -(t+1)) \n",
    "                            G.add_edge((t, s), ((t+1), s_prime), key=a, label=f\"({a}, {transition_probs[s, a, s_prime, t]})\")\n",
    "\n",
    "                # If node has no outgoing edges, remove it from the graph\n",
    "                if G.has_node((t, s)) and G.out_degree((t, s)) == 0 and t < T-1:\n",
    "                    G.remove_node((t, s))\n",
    "\n",
    "        # Remove unreachable nodes at t>0 with in-degree = 0\n",
    "        unreachable_nodes = {n for n in G if G.in_degree(n) == 0 and n[0]>0}\n",
    "\n",
    "        while len(unreachable_nodes) > 0:\n",
    "            G.remove_nodes_from(unreachable_nodes)\n",
    "            unreachable_nodes = {n for n in G if G.in_degree(n) == 0 and n[0]>0}\n",
    "\n",
    "        return G\n",
    "    \n",
    "    def get_counterfactual_transition_probabilities(self, P_cf, original_G, new_mdp_G, all_states, all_actions, A_real, S_real, T, k):\n",
    "        print(f\"Calculating counterfactual transition probabilities for k={k}\")\n",
    "\n",
    "        # Update the transition probabilities P_cf with the pruned mdp new_mdp_G.\n",
    "        # Remove actions entirely to ensure that the probabilities for each action in\n",
    "        # each state add up to 1. Keep track of which actions are valid choices in\n",
    "        # which states.        \n",
    "        valid_action = np.full((T, len(self.states), len(self.actions)), False)\n",
    "\n",
    "        T = self.T\n",
    "\n",
    "        for t in range(T-1, -1, -1):\n",
    "            for s in self.states:\n",
    "                for a in self.actions:\n",
    "                    for s_prime in self.states:\n",
    "                        if new_mdp_G.has_node((t, s)) and P_cf[s, a, s_prime, t] > 0.0:\n",
    "                            if not new_mdp_G.has_edge((t, s), (t+1, s_prime), key=a):\n",
    "                                imm_descendants = nx.descendants_at_distance(new_mdp_G, (t, s), 1)\n",
    "\n",
    "                                for imm_descendant in imm_descendants:\n",
    "                                    if new_mdp_G.has_edge((t, s), imm_descendant, key=a):\n",
    "                                        new_mdp_G.remove_edge((t, s), imm_descendant, key=a)\n",
    "\n",
    "                # If node has no outgoing edges, remove it from the graph\n",
    "                if new_mdp_G.has_node((t, s)) and new_mdp_G.out_degree((t, s)) == 0 and t < T-1:\n",
    "                    new_mdp_G.remove_node((t, s))\n",
    "\n",
    "            # Remove unreachable nodes at t>0 with in-degree = 0\n",
    "            unreachable_nodes = {n for n in new_mdp_G if new_mdp_G.in_degree(n) == 0 and n[0]>0}\n",
    "\n",
    "            while len(unreachable_nodes) > 0:\n",
    "                new_mdp_G.remove_nodes_from(unreachable_nodes)\n",
    "                unreachable_nodes = {n for n in new_mdp_G if new_mdp_G.in_degree(n) == 0 and n[0]>0}\n",
    "\n",
    "        for t in range(T-1, -1, -1):\n",
    "            for s in all_states:\n",
    "                for a in all_actions:\n",
    "                    for s_prime in all_states:\n",
    "                        if P_cf[s, a, s_prime, t] > 0.0:\n",
    "                            if not new_mdp_G.has_edge((t, s), (t+1, s_prime), key=a):\n",
    "                               P_cf[s, a, :, t] = 0.0\n",
    "                        else:\n",
    "                            assert(P_cf[s, a, s_prime, t] == 0.0)\n",
    "                    \n",
    "                    if sum(P_cf[s, a, :, t]) == 1.0:\n",
    "                        valid_action[t, s, a] = True\n",
    "\n",
    "                    if a == A_real[t] and s == S_real[t]:\n",
    "                        assert(valid_action[t, s, a])\n",
    "\n",
    "        return P_cf, valid_action\n",
    "\n",
    "    def get_influence_graph(self, G, k, influenced_transitions):\n",
    "        print(f\"Generating influence graph for k={k}\")\n",
    "\n",
    "        def reverse_bfs(G, start_nodes, k):\n",
    "            distance = defaultdict(lambda: float('inf'))\n",
    "            nodes_to_visit = deque([(node, 0) for node in start_nodes])\n",
    "            within_k_steps = set()\n",
    "\n",
    "            while nodes_to_visit:\n",
    "                curr_node, curr_dist = nodes_to_visit.popleft()\n",
    "\n",
    "                if curr_dist <= k:\n",
    "                    within_k_steps.add(curr_node)\n",
    "\n",
    "                    if distance[curr_node] > curr_dist:\n",
    "                        distance[curr_node] = curr_dist\n",
    "\n",
    "                        for predecessor in G.predecessors(curr_node):\n",
    "                            nodes_to_visit.append((predecessor, curr_dist+1))\n",
    "\n",
    "            return within_k_steps\n",
    "        \n",
    "        directly_influenced_nodes = set()\n",
    "\n",
    "        for s in self.states:\n",
    "            for a in self.actions:\n",
    "                for s_prime in self.states:\n",
    "                    for t in range(self.T):\n",
    "                        if influenced_transitions[s, a, s_prime, t]:\n",
    "                            directly_influenced_nodes.add((t+1, s_prime))\n",
    "\n",
    "        reachable_nodes = reverse_bfs(G, directly_influenced_nodes, k)\n",
    "\n",
    "        influence_graph = G.subgraph(reachable_nodes).copy()\n",
    "\n",
    "        # If we are between T-k+1 and T, then we want to add all the paths between these layers, as they are all treated as influenced.\n",
    "        for timestep in range(self.T-k+1, self.T):\n",
    "            for s in self.states:\n",
    "                for a in self.actions:\n",
    "                    for s_prime in self.states:\n",
    "                        if not influence_graph.has_edge((timestep, s), (timestep+1, s_prime), key=a) and G.has_edge((timestep, s), (timestep+1, s_prime), key=a):\n",
    "                            influence_graph.add_edge((timestep, s), (timestep+1, s_prime), key=a)\n",
    "\n",
    "        # Remove nodes with in-degree = 0 or out-degree = 0\n",
    "        unreachable_nodes = {n for n in influence_graph if (influence_graph.in_degree(n) == 0 and n[0]>0) or (influence_graph.out_degree(n) == 0 and n[0] < self.T)}\n",
    "\n",
    "        while len(unreachable_nodes) > 0:\n",
    "            influence_graph.remove_nodes_from(unreachable_nodes)\n",
    "            unreachable_nodes = {n for n in influence_graph if (influence_graph.in_degree(n) == 0 and n[0]>0) or (influence_graph.out_degree(n) == 0 and n[0] < self.T)}\n",
    "\n",
    "        return influence_graph\n",
    "\n",
    "    def prune_mdp(self):\n",
    "        # Initialise a matrix to keep track of which transitions' probabilities\n",
    "        # are directly influenced by the observed trajectory.\n",
    "        influenced_transitions = np.zeros(shape=(len(self.states), len(self.actions), len(self.states), self.T+1))\n",
    "\n",
    "        # Generate the counterfacutal transition probabilities, keeping track\n",
    "        # of which transitionals have been influenced by the observed path.\n",
    "        P_cf, influenced_transitions = self.sampler.run_parallel_sampling_tracking_influenced_transitions(self.mdp_sample, influenced_transitions)\n",
    " \n",
    "        # Build graph using the original MDP transition probabilities.\n",
    "        G = self.build_graph(self.mdp.transition_probabilities, self.T, self.states, self.actions)\n",
    "\n",
    "        # Build the influence graph for each look-ahead k\n",
    "        influence_graphs = []\n",
    "\n",
    "        # Generate graphs for the pruned MDP.\n",
    "        for k in range(1, self.look_ahead_k+1):\n",
    "            influence_graph = self.get_influence_graph(copy.deepcopy(G), k, influenced_transitions)\n",
    "            influence_graphs.append(influence_graph)\n",
    "\n",
    "        cf_transition_probs = []\n",
    "        valid_actions = []\n",
    "\n",
    "        A_real = self.mdp_sample[0, :, 2]\n",
    "        S_real = self.mdp_sample[0, :, 0]\n",
    "\n",
    "        for look_ahead_k in range(1, self.look_ahead_k+1):\n",
    "            new_P_cf, valid_action = self.get_counterfactual_transition_probabilities(copy.deepcopy(P_cf), G, influence_graphs[look_ahead_k-1], self.states, self.actions, A_real, S_real, self.T, look_ahead_k)\n",
    "            cf_transition_probs.append(new_P_cf)\n",
    "            valid_actions.append(valid_action)\n",
    "\n",
    "        # Generate graphs for the pruned counterfactual MDP.\n",
    "        cf_graphs = []\n",
    "\n",
    "        for k in range(1, self.look_ahead_k+1):\n",
    "            G = self.build_cf_graph(cf_transition_probs[k-1], self.T, self.states, self.actions, k)\n",
    "            cf_graphs.append(G)\n",
    "\n",
    "        return cf_transition_probs, valid_actions, cf_graphs\n",
    "    \n",
    "    def get_optimal_policy(self, max_num_actions_changed, P_cf, valid_action, new_mdp_G, all_states, all_actions, S_real, A_real, T, rewards_pi):\n",
    "        if len(all_states) == 0:\n",
    "            return None\n",
    "\n",
    "        h_fun = np.zeros((len(all_states), T+1, max_num_actions_changed+1)) \n",
    "        pi = np.zeros((len(all_states), max_num_actions_changed+1, T+1), dtype=int) \n",
    "    \n",
    "        for r in range(1, T+1): # last r steps of the decision making process\n",
    "            for s in all_states: # for all possible states\n",
    "                h_fun[s, r, 0] = rewards_pi[(T-r), s, (A_real[T-r])] # for all time steps counting backwards (T-r is T-1, T-2 etc) \n",
    "\n",
    "                for s_p in all_states: # for every singe next state (s') for each state s\n",
    "                    h_fun[s, r, 0] += P_cf[s, A_real[T-r], s_p, T-r] * h_fun[s_p, r-1, 0]\n",
    "\n",
    "                pi[s, max_num_actions_changed, T-r] = A_real[T-r]\n",
    "\n",
    "        # For t=1,...,T-2 do recursive computations\n",
    "        for c in range(1, max_num_actions_changed+1): # iterates over the number of changes allowed\n",
    "            for r in range(1, T+1): # iterates over the time steps in reverse order\n",
    "                for s in all_states:\n",
    "                    pi[s, max_num_actions_changed-c, T-r] = A_real[T-r] # instead let it be the real action\n",
    "                    best_act = A_real[T-r]\n",
    "                    max_val = -np.inf\n",
    "                    \n",
    "                    for a in all_actions: # For each state and action, it computes the value based on rewards and future values.\n",
    "                        if valid_action[T-r, s, a]:\n",
    "                            assert(sum(P_cf[s, a, :, T-r]) == 1.0)\n",
    "                            val = rewards_pi[T-r][s][a]\n",
    "\n",
    "                            # If an action differs from the observed action, the number of remaining changes (c) decreases.\n",
    "                            if a != A_real[T-r]:\n",
    "                                for s_p in all_states:\n",
    "                                    if P_cf[s, a, s_p, T-r] != 0:\n",
    "                                        val += P_cf[s, a, s_p, T-r] * h_fun[s_p, r-1, c-1]\n",
    "                            elif a == A_real[T-r]:\n",
    "                                for s_p in all_states:\n",
    "                                    if P_cf[s, a, s_p, T-r] != 0:\n",
    "                                        val += P_cf[s, a, s_p, T-r] * h_fun[s_p, r-1, c]\n",
    "\n",
    "                            if val > max_val:\n",
    "                                max_val = val\n",
    "                                best_act = a\n",
    "                    \n",
    "                    h_fun[s, r, c] = max_val\n",
    "                    pi[s, max_num_actions_changed-c, T-r] = best_act\n",
    "\n",
    "        return pi, h_fun\n",
    "\n",
    "    def generate_policies(self, cf_transition_probs, valid_actions, cf_graphs):\n",
    "        # Generate policies for each of the pruned counterfactual MDPs.\n",
    "        policies = []\n",
    "        h_funs = []\n",
    "        k_vals = range(1, self.look_ahead_k+1)\n",
    "        S_real = self.mdp_sample[0, :, 0]\n",
    "        A_real = self.mdp_sample[0, :, 2]\n",
    "\n",
    "        new_all_rewards = np.zeros((self.T, len(self.states), len(self.actions)))\n",
    "        for t in range(self.T):\n",
    "            for s in self.states:\n",
    "                for a in self.actions:                \n",
    "                    new_all_rewards[t, s, a] = self.rewards_pi[s, a]\n",
    "\n",
    "        for look_ahead_k in k_vals:\n",
    "            print(f\"Estimating policy with k={look_ahead_k}\")\n",
    "            policies_k = []\n",
    "            h_funs_k = []\n",
    "\n",
    "            for max_num_actions_changed in k_vals:\n",
    "                # Get the optimal policy\n",
    "                pi, h_fun = self.get_optimal_policy(\n",
    "                    max_num_actions_changed, \n",
    "                    cf_transition_probs[look_ahead_k-1],\n",
    "                    valid_actions[look_ahead_k-1],\n",
    "                    cf_graphs[look_ahead_k-1],\n",
    "                    self.states,\n",
    "                    self.actions,\n",
    "                    S_real,\n",
    "                    A_real,\n",
    "                    self.T,\n",
    "                    new_all_rewards\n",
    "                )\n",
    "\n",
    "                policies_k.append(pi)\n",
    "                h_funs_k.append(h_fun)\n",
    "\n",
    "            policies.append(policies_k)\n",
    "            h_funs.append(h_funs_k)\n",
    "\n",
    "        return policies, new_all_rewards, h_funs\n",
    "\n",
    "\n",
    "    def generate_random_trajectory(self, MDP_samp, P_cf, pi, s_0, A_real, all_states, rewards_pi, T):\n",
    "        n_obs=MDP_samp.shape[0]\n",
    "        n_state=MDP_samp.shape[2]\n",
    "        CF_trajectory = np.zeros((n_obs, T, n_state))\n",
    "        \n",
    "        rng = np.random.default_rng()\n",
    "        s = np.zeros(T+1, dtype=int)\n",
    "        s[0] = s_0   # Initial state the same\n",
    "        l = np.zeros(T+1, dtype=int)\n",
    "        l[0] = 0    # Start with 0 changes\n",
    "        a = np.zeros(T, dtype=int)\n",
    "        \n",
    "        for t in range(T):\n",
    "            # Pick actions according to the given policy\n",
    "            a[t] = pi[s[t], l[t], t]\n",
    "\n",
    "            # Sample the next state\n",
    "            s[t+1] = (rng.choice(a=self.states, size=1,  p=P_cf[s[t], a[t], :, t]))[0]\n",
    "\n",
    "            # Adjust the number of changes so far\n",
    "            if a[t] != A_real[t]:\n",
    "                l[t+1] = l[t] + 1\n",
    "            else:\n",
    "                l[t+1] = l[t]\n",
    "            \n",
    "            CF_trajectory[0, t, :] = np.array([s[t], s[t+1], a[t], rewards_pi[t, s[t], a[t]]])\n",
    "                    \n",
    "        return CF_trajectory\n",
    "\n",
    "    def generate_cf_trajectories(self, cf_transition_probs, policies, new_all_rewards):\n",
    "        print(f\"Generating CF trajectories\")\n",
    "        all_obs = []\n",
    "        all_cf = []\n",
    "        k_vals = range(1, self.look_ahead_k+1)\n",
    "        A_real = self.mdp_sample[0, :, 2]\n",
    "        s_0 = self.mdp_sample[0, 0, 0]\n",
    "\n",
    "        for _ in range(1000):\n",
    "            obs = np.zeros(shape=(self.look_ahead_k, self.look_ahead_k))\n",
    "            cf = np.zeros(shape=(self.look_ahead_k, self.look_ahead_k))\n",
    "\n",
    "            for look_ahead_k in k_vals:\n",
    "                for max_num_actions_changed in k_vals:\n",
    "                    CF_trajectory = self.generate_random_trajectory(\n",
    "                        self.mdp_sample,\n",
    "                        cf_transition_probs[look_ahead_k-1],\n",
    "                        policies[look_ahead_k-1][max_num_actions_changed-1],\n",
    "                        s_0,\n",
    "                        A_real,\n",
    "                        self.states,\n",
    "                        new_all_rewards,\n",
    "                        self.T\n",
    "                    )\n",
    "\n",
    "                    obs[look_ahead_k-1][max_num_actions_changed-1] = self.mdp_sample[0, self.T-1, 3] # Immediate reward for obs path at time T\n",
    "                    cf[look_ahead_k-1][max_num_actions_changed-1] = CF_trajectory[0, self.T-1, 3] # Immediate reward for cf path at time T\n",
    "            \n",
    "            all_obs.append(obs)\n",
    "            all_cf.append(cf)\n",
    "\n",
    "        all_obs = np.array(all_obs)\n",
    "        all_cf = np.array(all_cf)\n",
    "\n",
    "        mean_obs = all_obs.mean(axis=0)\n",
    "        mean_cf = all_cf.mean(axis=0)\n",
    "\n",
    "        return mean_obs, mean_cf, k_vals"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "influence_pruner = InfluenceMDPPruner(createMDP, look_ahead_k=12)\n",
    "cf_transition_probs, valid_actions, cf_graphs = influence_pruner.prune_mdp()\n",
    "policies, new_all_rewards, h_funs = influence_pruner.generate_policies(cf_transition_probs, valid_actions, cf_graphs)\n",
    "mean_obs, mean_cf, k_vals = influence_pruner.generate_cf_trajectories(cf_transition_probs, policies, new_all_rewards)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "values = []\n",
    "obs_values = None\n",
    "\n",
    "for look_ahead_k in k_vals:\n",
    "    k_values = []\n",
    "    obs_values = []\n",
    "\n",
    "    for max_num_actions_changed in k_vals:\n",
    "        h_fun = h_funs[look_ahead_k-1][max_num_actions_changed-1]\n",
    "        # s_0 = 0\n",
    "        k_values.append(h_fun[0, -1, max_num_actions_changed])\n",
    "        obs_values.append(h_fun[0, -1, 0])\n",
    "\n",
    "    values.append(k_values)\n",
    "\n",
    "for row in values:\n",
    "    print(row)\n",
    "\n",
    "fig = plt.figure(figsize=(5, 3))\n",
    "ax = fig.add_subplot()\n",
    "\n",
    "plt.xlabel('Maximum Number of Actions Changed', fontsize=14)\n",
    "plt.ylabel('V(s0)', fontsize=14); \n",
    "plt.grid(which='both')\n",
    "\n",
    "ax.scatter(k_vals, obs_values, color='lightpink', label='Observed reward', marker=\"o\", s=100);\n",
    "colors = ['orange', 'red', 'aqua', 'yellow', 'darkblue', 'deeppink', 'darkviolet', 'silver', 'teal', 'blue']\n",
    "\n",
    "ax.scatter(k_vals, values[0], color='blue', label='CF reward', marker=\"x\", s=50)\n",
    "ax.scatter(k_vals, values[-1], color='green', label='CF reward', marker=\"d\", s=50)\n",
    "\n",
    "plt.legend([\"Observed Path\", \"K=1 to K=6\", \"K=7 to K=T+1\"], loc=0, frameon=True, fontsize=12)\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig = plt.figure(figsize=(10, 10));\n",
    "ax = fig.add_subplot()\n",
    "\n",
    "plt.xlabel('Maximum Number of Actions Changed');\n",
    "plt.ylabel('Final State Reward'); \n",
    "plt.grid(which='both')\n",
    "\n",
    "ax.scatter(k_vals, mean_obs[0], color='lightpink', label='Observed reward', marker=\"o\", s=100);\n",
    "colors = ['orange', 'red', 'aqua', 'yellow', 'deeppink', 'darkblue', 'darkviolet', 'silver', 'teal', 'blue']\n",
    "\n",
    "for look_ahead_k in range(1, 11):\n",
    "    ax.scatter(k_vals, mean_cf[look_ahead_k-1], color=colors[look_ahead_k-1], label='CF reward', marker=\"d\", s=50)\n",
    "\n",
    "ax.scatter(k_vals, mean_cf[-1], color = 'black', label='CF reward', marker=\"d\", s=30)\n",
    "\n",
    "plt.legend([\"Observed Path\", \"Look-Ahead K=1\", \"Look-Ahead K=2\", \"Look-Ahead K=3\", \"Look-Ahead K=4\", \"Look-Ahead K=5\", \"Look-Ahead K=6\", \"Look-Ahead K=7\", \"Look-Ahead K=8\", \"Look-Ahead K=9\", \"Look-Ahead K=∞\"], loc=0, frameon=True)\n",
    "plt.show()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
