#!/usr/bin/env python
# coding: utf-8

# ### README
# This code is the main file for Experiment 2 of the paper.
# 
# Ensure current working directory has an empty subfolder called `output_for_plot` available. This is where results will be stored as `.pickle` files, which are then used by `experiment2_plot.py` to plot the graphs.

# In[1]:


import numpy as np
import pprint
import matplotlib.pyplot as plt
from itertools import product
from random import shuffle 
import random
import statistics
import time
import pickle


# In[2]:


from config import params


# In[3]:


range_C0 = params['var_ranges']['range_C0']
range_C1 = params['var_ranges']['range_C1']
range_x = params['var_ranges']['range_x']
domain_y = params['var_ranges']['domain_y']
range_y = len(domain_y)


# In[4]:


class CausalGraph:
    ''' Class that defines a causal graph along with a set of useful methods. '''
    
    def __init__(self):
        print('Creating empty graph..')
        self.g = dict()  # Adjacency list of the graph
    
    def add_node(self, a, b):
        ''' Add a single node '''
        if a in self.g.keys():
            self.g[a] += [b]
        else:
            self.g[a] = [b]
            
    def add_nodes(self, list_ab):
        ''' Add multiple nodes '''
        for a, b in list_ab:
            self.add_node(a, b)
            
    def remove_node(self, a, b):
        ''' Remove node '''
        if a not in self.g.keys():
            print('Source node not found')
            return
        elif b not in self.g[a]:
            print('Destination node not found')
            return
        else:
            self.g[a].remove(b)
            if len(self.g[a]) == 0:
                del self.g[a]
            
    def get_nodes(self):
        ''' Get all nodes '''
        nodes = []
        for src in self.g.keys():
            nodes += [src]
            for dest in self.g[src]:
                nodes += [dest]
                
        nodes = list(set(nodes))
        return nodes

    
    def get_edges(self):
        edges = []
        for src in self.g.keys():
            for dest in self.g[src]:
                edges += [(src, dest)]
                
        edges = list(set(edges))
        return edges
    
    def get_parents(self, dest):
        ''' Get set of parents of given node '''
        answer = []
        for (a, b) in self.get_edges():
            if b == dest:
                answer += [a]
                
        answer = list(set(answer))
        return answer


    def print_graph(self):
        ''' Print graph to console '''
        for a in self.g.keys():
            print(a, '-->', self.g[a])
    


# In[ ]:





# In[5]:


def get_topological_order(graph):
    ''' Get topological order for graph '''
    
    top_order = ['C1', 'C0', 'X', 'Y']  # Hard-coded for this graph
    
    return top_order


# In[6]:


class SCM_Simple:
    ''' Structural causal model '''
    
    def __init__(self, graph):
        self.graph = graph
        
        # Initialize the SCM with random parameters, as discussed in the paper.
        self.cutoff = random.randint(1, int(range_C1/2))
        prob_weights = [random.randint(1, 11) for _ in range(0, self.cutoff)]
        self.ctar_p1 = [_/sum(prob_weights)*0.2 for _ in prob_weights]
        prob_weights = [random.randint(1, 11) for _ in range(self.cutoff, range_C1)]
        self.ctar_p2 = [_/sum(prob_weights)*0.8 for _ in prob_weights]
        self.ctar_p = self.ctar_p1 + self.ctar_p2
               
    def print_SCM(self):
        ''' Print underlying graph to console '''
        pp = pprint.PrettyPrinter(indent=4)
        print('Printing graph...')
        self.graph.print_graph()

        
    def get_range(self, v):
        ''' Return ranges of all variables '''
        return {
         'C1':list(np.arange(range_C1)), 
         'C0':list(np.arange(range_C0)),
         'X':list(np.arange(range_x)),
         'Y': domain_y
        }.get(v, 'Not found!')
    
    
    def get_index_of_y(self, y):
        ''' Return index of given y in domain_y '''
        return dict(zip(self.get_range('Y'), range(0,len(self.get_range('Y')))))[y]
    
    def sample_node(self, v, joint_sample, intervention):
        ''' Return ranges of all variables '''
        if v in intervention.keys():
            sample_value = intervention[v]
        else:
            if v == 'C1': # P(C1) is as per random parameter choice made
                sample_value = np.random.choice(np.arange(0, params['var_ranges']['range_C1']), p=self.ctar_p) 
            elif v == 'C0':  # C0 is 0 if C1<cutoff, else 1. This defines the CPD P(C0|C1)
                if joint_sample['C1'] < self.cutoff:
                    sample_value = 0
                else:
                    sample_value = 1
            elif v == 'X':
                sample_value = np.random.randint(range_x)
            elif v == 'Y':  # P(Y|X, C0) is as per the CPD in config.py
                sample_value = domain_y[
                    np.random.choice(np.arange(0, range_y), p=params['graph']['y_distr'][joint_sample['C0']][joint_sample['X']])
                ]

        return sample_value
        
    
    def get_sample(self, intervention={}):
        ''' Return a sample (of all variables) under given intervention '''
        joint_sample = dict()
        
        for v in get_topological_order(scm.graph):
            if len(self.graph.get_parents(v)) == 0:
                joint_sample[v] = self.sample_node(v, {}, intervention)
            else:
                joint_sample[v] = self.sample_node(v, joint_sample, intervention)


        return joint_sample

    def get_sample_conditional(self, context, intervention={}):
        ''' Return a targeted intervention sample (of all variables). 
            We use rejection sampling -- we repeatedly sample till get context=c '''
        
        _ = 0
        while True:            
            sample = self.get_sample(intervention)
            condition_vars = list(context.keys())
            if all(sample[c] == context[c] for c in condition_vars):
                return sample
            elif _ == 10000:
                print('Unable to do sample conditional with context=',context,' .. stopping..')
                break
            else:
                _ += 1
        return
    
    def get_optimal(self, ctar):
        ''' Return the optimal action X for given C_tar '''
        if ctar < self.cutoff:
            return 2
        else:
            return 3


# In[ ]:





# In[7]:


def get_conditionals(scm):
    ''' Returns a dict of parents of each node. '''
    conditionals = {}
    
    conditionals['C1'] = []
    conditionals['C0'] = ['C1']
    conditionals['X'] = ['C0']
    conditionals['Y'] = ['X', 'C0']
    
    return conditionals


# #### Create environment (SCM) instance

# In[8]:


g = CausalGraph()
g.add_nodes([('C1', 'C0'), ('C0', 'X'), ('C0', 'Y'), ('X', 'Y')])


# In[9]:


# Create SCM instance from causal graph
scm = SCM_Simple(g)


# In[10]:


scm.print_SCM()


# # Training

# In[11]:


def calc_Ey_sampled_params(beliefs, conditionals):
    
    probs = {}
    
    for V in beliefs.keys():
        probs[V] = {}
        for v_prime in beliefs[V].keys():
            probs[V][v_prime] = np.random.dirichlet(beliefs[V][v_prime])

    Ctar, Cother = ['C1'], ['C0']
    Ctar_indices = dict(zip(Ctar, range(0,len(Ctar))))
    
    ranges = [scm.get_range('X')] + [scm.get_range(_) for _ in Ctar]
    cartesian_x_ctar = list(product(*ranges))
    E_y = {__ : 0 for __ in cartesian_x_ctar}
    
    ranges = [scm.get_range(_) for _ in formula['sum_over']]
    cartesian_c_sumover = list(product(*ranges))
    
    for x_ctar in cartesian_x_ctar:
        for c in cartesian_c_sumover:
            for y in scm.get_range('Y'):
                temp = 1
                for comp in formula['numerator']:
                    parents_of_comp = conditionals[comp]
                    temp2 = []
                    for p in parents_of_comp:
                        
                        if p == 'X':
                            temp2 += [x_ctar[0]]
                        elif p in Ctar:
                            temp2 += [x_ctar[Ctar_indices[p] + 1]]
                        else:
                            temp2 += [c[formula_indices['sum_over'][p]]]
                    
                    if comp == 'Y':
                        temp3 = scm.get_index_of_y(y)
                    else:
                        temp3 = c[formula_indices['numerator'][comp] - 1]
                    
                    
                    temp = temp * probs[comp][tuple(temp2)][temp3]
                    
                E_y[x_ctar] += y * temp
                   
    return E_y
    


# In[12]:


def calc_Ey_exp_params(beliefs, conditionals):
    
    probs = {}
    
    for V in beliefs.keys():
        probs[V] = {}
        for v_prime in beliefs[V].keys():
            denom = np.sum(beliefs[V][v_prime])
            probs[V][v_prime] = [b/denom for b in beliefs[V][v_prime]]

    Ctar, Cother = ['C1'], ['C0']
    Ctar_indices = dict(zip(Ctar, range(0,len(Ctar))))
    
    ranges = [scm.get_range('X')] + [scm.get_range(_) for _ in Ctar]
    cartesian_x_ctar = list(product(*ranges))
    E_y = {__ : 0 for __ in cartesian_x_ctar}
    
    ranges = [scm.get_range(_) for _ in formula['sum_over']]
    cartesian_c_sumover = list(product(*ranges))
    
    for x_ctar in cartesian_x_ctar:
        for c in cartesian_c_sumover:
            for y in scm.get_range('Y'):
                temp = 1
                for comp in formula['numerator']:
                    parents_of_comp = conditionals[comp]
                    temp2 = []
                    for p in parents_of_comp:
                        
                        if p == 'X':
                            temp2 += [x_ctar[0]]
                        elif p in Ctar:
                            temp2 += [x_ctar[Ctar_indices[p] + 1]]
                        else:
                            temp2 += [c[formula_indices['sum_over'][p]]]
                    
                    if comp == 'Y':
                        temp3 = scm.get_index_of_y(y)
                    else:
                        temp3 = c[formula_indices['numerator'][comp] - 1]
                    
                    
                    temp = temp * probs[comp][tuple(temp2)][temp3]
                    
                E_y[x_ctar] += y * temp
                   
    return E_y
    


# ### Algorithm A: Non-causal TS, given context

# In[13]:


def non_causal_TS_given_context(scm, T):
    '''Implementation of Algorithm A baseline'''

    # Initialize beliefs about (Y|do(x), ctar). Beliefs are using Dirichlet distribution.
    belief_y_dox_c = [[1 for _ in range(range_y)] for __ in range(range_x * range_C1)]  

    # Initialize rewards array
    rewards = []
    
    # Initialize array that holds argmax over x for any ctar
    argmax_x_c = [0]*range_C1

    for t in range(T):   # T is the number of training rounds

        # Sample probabilities
        prob_y_dox_c = []
        for x in range(range_x):
            for c in range(range_C1):
                prob_y_dox_c.append(np.random.dirichlet(belief_y_dox_c[x*range_C1 + c]))

        # Calc expectation: E[Y|do(x), ctar]
        E_y = [0]*range_x*range_C1
        for x in range(range_x):
            for c in range(range_C1):
                for i_y, y in enumerate(scm.get_range('Y')):
                    E_y[x*range_C1 + c] += prob_y_dox_c[x*range_C1 + c][i_y] * y
        
        # Get sample from environment
        s = scm.get_sample()  
        c = s['C1']
    
        # Argmax over x
        argmax_x = 0
        for x in range(1, range_x):
            if E_y[x*range_C1 + c] > E_y[argmax_x*range_C1 + c]:
                argmax_x = x

       
        # Pull arm x and obtain sample including rewards
        sample = scm.get_sample_conditional(context={'C1':c}, intervention={'X':argmax_x})
        rewards += [sample['Y']]
        
        # Update beliefs
        belief_y_dox_c[argmax_x*range_C1 + c][scm.get_index_of_y(sample['Y'])] += 1
        
        
        # Update saved argmax list
        argmax_x_c[c] = argmax_x
    
    return rewards, argmax_x_c, E_y, belief_y_dox_c


# ### Algorithm A2: Non-causal uniform exploration, given context

# In[14]:


def non_causal_UE_given_context(scm, T):
    '''Implementation of Algorithm A2 baseline'''

    # Initialize beliefs about (Y|do(x), ctar). Beliefs are using Dirichlet distribution.
    belief_y_dox_c = [[1 for _ in range(range_y)] for __ in range(range_x * range_C1)]  

    # Initialize rewards array
    rewards = []
    
    # Initialize array that holds argmax over x for any ctar
    argmax_x_c = [0]*range_C1

    for t in range(T):

        # Sample probabilities
        prob_y_dox_c = []
        for x in range(range_x):
            for c in range(range_C1):
                prob_y_dox_c.append(np.random.dirichlet(belief_y_dox_c[x*range_C1 + c]))
        
        # Get sample from environment
        s = scm.get_sample()  
        c = s['C1']
        
        # Uniformly explore actions, irrespective of context
        argmax_x = t % range_x
        
        # Pull arm x and obtain sample including rewards
        sample = scm.get_sample_conditional(context={'C1':c}, intervention={'X':argmax_x})
        rewards += [sample['Y']]
        
        # Update beliefs
        belief_y_dox_c[argmax_x*range_C1 + c][scm.get_index_of_y(sample['Y'])] += 1
        
        
        # Update saved argmax list
        argmax_x_c[c] = argmax_x
    
    return rewards, argmax_x_c, [], belief_y_dox_c


# ### Algorithm A3: Non-causal uniform exploration, chosen context

# In[15]:


def non_causal_UE_chosen_context(scm, T):
    '''This is a non-contextual version, implemented for a special case graph and SCM'''

    # Initialize beliefs about (Y|do(x), ctar). Beliefs are using Dirichlet distribution.
    belief_y_dox_c = [[1 for _ in range(range_y)] for __ in range(range_x * range_C1)]  

    # Initialize rewards array
    rewards = []
    
    # Initialize array that holds argmax over x for any ctar
    argmax_x_c = [0]*range_C1
    
    cartesian_x_ctar = list(product(*[scm.get_range('X'), scm.get_range('C1')]))
    random.shuffle(cartesian_x_ctar)

    for t in range(T):
        
        # Uniformly explore actions, irrespective of context
        argmax_x, chosen_c1 = cartesian_x_ctar[t % len(cartesian_x_ctar)]
        
        # Perform targeted intervention and obtain sample including rewards
        sample = scm.get_sample_conditional(context={'C1':chosen_c1}, intervention={'X':argmax_x})

        rewards += [sample['Y']]
        
        # Update beliefs
        belief_y_dox_c[argmax_x*range_C1 + chosen_c1][scm.get_index_of_y(sample['Y'])] += 1
        
        
        # Update saved argmax list
        argmax_x_c[chosen_c1] = argmax_x
        
    
    return rewards, argmax_x_c, [], belief_y_dox_c


# ### Algorithm B - Causal TS, context from environment

# In[16]:


def causal_TS_given_context(scm, T):
    '''Implementation of Algorithm B'''

   
    # Initialize beliefs about CPDs
    beliefs = {}
    conditionals = get_conditionals(scm)
    for V, v_prime in conditionals.items():
        ranges_v_prime = [scm.get_range(_) for _ in v_prime]
        cartesian_v_prime = list(product(*ranges_v_prime))
        beliefs[V] = {__ : [1 for _ in scm.get_range(V)] for __ in cartesian_v_prime}
    
    # Initialize rewards array
    rewards = []
    
    # Initialize array that holds argmax over x for any ctar
    argmax_x_ctar = [0]*range_C1

    # Main part of the algorithm
    for t in range(T):
        # Get E[y] estimates based on current beliefs
        E_y = calc_Ey_sampled_params(beliefs, conditionals)
        
        # Get sample from environment
        s = scm.get_sample()
        c1 = s['C1']
    
        # Argmax over x
        X_values = list(range(range_x))
        shuffle(X_values)
        argmax_x = X_values[0]
        for x in X_values[1:]:
            if E_y[(x,c1)] > E_y[(argmax_x,c1)]:
                argmax_x = x

        
        # Pull arm x and obtain sample
        sample = scm.get_sample_conditional(context={'C1':c1}, intervention={'X':argmax_x})
        rewards += [sample['Y']]
        c0, c1 = sample['C0'], sample['C1']
        
        # Update beliefs
        beliefs['Y'][(argmax_x, c0)][scm.get_index_of_y(sample['Y'])] += 1
        beliefs['C0'][(c1,)][c0] += 1
        beliefs['C1'][tuple()][c1] += 1
        
        # Update saved argmax list
        argmax_x_ctar[c1] = argmax_x
        
    return rewards, argmax_x_ctar, [], beliefs


# ### Algorithm C -  Causal UE given context

# In[17]:




def causal_UE_given_context(scm, T):
    '''Implementation of Algorithm C'''
    
    # Initialize beliefs about CPDs
    beliefs = {}
    conditionals = get_conditionals(scm)
    for V, v_prime in conditionals.items():
        ranges_v_prime = [scm.get_range(_) for _ in v_prime]
        cartesian_v_prime = list(product(*ranges_v_prime))
        beliefs[V] = {__ : [1 for _ in scm.get_range(V)] for __ in cartesian_v_prime}
    
    # Initialize rewards array
    rewards = []
    
    # Initialize array that holds argmax over x for any ctar
    argmax_x_ctar = [0]*range_C1
    
    last_x_choice = [0]*range_C1
    x_choices = [_ for _ in range(range_x)]
    random.shuffle(x_choices)

    # Main part of algorithm
    for t in range(T):
        # Get sample from environment
        s = scm.get_sample()
        c1 = s['C1']
    
        # Choose x as next item in shuffled list
        x_index = last_x_choice[c1] % range_x
        argmax_x = x_choices[x_index]
        last_x_choice[c1] += 1
    
        # Pull arm x
        sample = scm.get_sample_conditional(context={'C1':c1}, intervention={'X':argmax_x})
        rewards += [sample['Y']]
        c0, c1 = sample['C0'], sample['C1']
        
        # Update beliefs
        beliefs['Y'][(argmax_x, c0)][scm.get_index_of_y(sample['Y'])] += 1
        beliefs['C0'][(c1,)][c0] += 1
        beliefs['C1'][tuple()][c1] += 1
        
        # Update saved argmax list
        argmax_x_ctar[c1] = argmax_x
        
    return rewards, argmax_x_ctar, [], beliefs


# ### Algorithm D - Causal TS over space of targeted interventions

# In[18]:


def causal_TS_chosen_context(scm, T):
    '''Implementation of Algorithm D'''
    
    # Initialize beliefs about CPDs
    beliefs = {}
    conditionals = get_conditionals(scm)
    for V, v_prime in conditionals.items():
        ranges_v_prime = [scm.get_range(_) for _ in v_prime]
        cartesian_v_prime = list(product(*ranges_v_prime))
        beliefs[V] = {__ : [1 for _ in scm.get_range(V)] for __ in cartesian_v_prime}
    
    # Initialize rewards array
    rewards = []
    
    # Initialize array that holds argmax over x for any ctar
    argmax_x_ctar = [0]*range_C1

    # Main part of the algorithm
    for t in range(T):
        # Estimate E[y] based on current beliefs
        E_y = calc_Ey_sampled_params(beliefs, conditionals)

    
        # Argmax over x, c1
        x_values = list(range(range_x))
        shuffle(x_values)
        C1_values = list(range(range_C1))
        shuffle(C1_values)
        argmax_x = x_values[0]
        chosen_c1 = C1_values[0]
        for c1 in C1_values[1:]:
            for x in x_values[1:]:
                if E_y[(x,c1)] > E_y[(argmax_x,chosen_c1)]:
                    argmax_x = x
                    chosen_c1 = c1
        
        # Do targeted intervention x under c1 and obtain sample
        sample = scm.get_sample_conditional(context={'C1':chosen_c1}, intervention={'X':argmax_x})
        rewards += [sample['Y']]
        c0 = sample['C0']
        
        # Update beliefs
        beliefs['Y'][(argmax_x, c0)][scm.get_index_of_y(sample['Y'])] += 1
        beliefs['C0'][(chosen_c1,)][c0] += 1
        
        # Update saved argmax list
        argmax_x_ctar[chosen_c1] = argmax_x
        

    return rewards, argmax_x_ctar, [], beliefs


# ### Algorithm E - Causal pure-exploration over space of targeted interventions

# In[19]:




def causal_PureExp_chosen_context(scm, T, context={'C1'}):
    '''Implementation of Algorithm E'''
    
    # Initialize beliefs about CPDs
    beliefs = {}
    conditionals = get_conditionals(scm)
    for V, v_prime in conditionals.items():
        ranges_v_prime = [scm.get_range(_) for _ in v_prime]
        cartesian_v_prime = list(product(*ranges_v_prime))
        beliefs[V] = {__ : [1 for _ in scm.get_range(V)] for __ in cartesian_v_prime}
    
    # Initialize rewards array
    rewards = []
    
    # Initialize array that holds argmax over x for any ctar
    argmax_x_ctar = [0]*range_C1
    
    cartesian_x_ctar = list(product(*[scm.get_range('X'), scm.get_range('C1')]))
    random.shuffle(cartesian_x_ctar)

    # Main part of the algorithm
    t = 0
    while t < T: 
        # Estimate E[y] based on current beliefs
        E_y = calc_Ey_exp_params(beliefs, conditionals)

        argmax_x, chosen_c1 = cartesian_x_ctar[t % len(cartesian_x_ctar)]
    
        # Do targeted intervention x under c1 and obtain sample
        sample = scm.get_sample_conditional(context={'C1':chosen_c1}, intervention={'X':argmax_x})
        rewards += [sample['Y']]
        c0 = sample['C0']
        
        # Update beliefs
        beliefs['Y'][(argmax_x, c0)][scm.get_index_of_y(sample['Y'])] += 1
        beliefs['C0'][(chosen_c1,)][c0] += 1
        
        # Update saved argmax list
        argmax_x_ctar[chosen_c1] = argmax_x
        
        t += 1
        

    return rewards, argmax_x_ctar, [], beliefs


# ### Algorithm TargInt_e_greedy - Causal TS over space of targeted interventions with epsilon-greedy

# In[20]:


def TargInt_epsilon_greedy(scm, T, epsilon):
    '''Implementation of Algorithm TargInt_e_greedy'''
    
    # Initialize beliefs about CPDs
    beliefs = {}
    conditionals = get_conditionals(scm)
    for V, v_prime in conditionals.items():
        ranges_v_prime = [scm.get_range(_) for _ in v_prime]
        cartesian_v_prime = list(product(*ranges_v_prime))
        beliefs[V] = {__ : [1 for _ in scm.get_range(V)] for __ in cartesian_v_prime}
    
    # Initialize rewards array
    rewards = []
    
    # Initialize array that holds argmax over x for any ctar
    argmax_x_ctar = [0]*range_C1

    # Main part of the algorithm
    for t in range(T):
        # Estimate E[y] based on current beliefs
        E_y = calc_Ey_sampled_params(beliefs, conditionals)

    
        # Argmax over x, c1
        x_values = list(range(range_x))
        shuffle(x_values)
        C1_values = list(range(range_C1))
        shuffle(C1_values)
        argmax_x = x_values[0]
        chosen_c1 = C1_values[0]
        for c1 in C1_values[1:]:
            for x in x_values[1:]:
                if E_y[(x,c1)] > E_y[(argmax_x,chosen_c1)]:
                    argmax_x = x
                    chosen_c1 = c1
                    
        # Coin toss for epsilon-greedy
        if np.random.uniform() < epsilon: # With probability epsilon, choose random targeted intervention
            argmax_x = random.choice(x_values)
            chosen_c1 = random.choice(C1_values)
        
        # Do targeted intervention x under c1 and obtain sample
        sample = scm.get_sample_conditional(context={'C1':chosen_c1}, intervention={'X':argmax_x})
        rewards += [sample['Y']]
        c0 = sample['C0']
        
        # Update beliefs
        beliefs['Y'][(argmax_x, c0)][scm.get_index_of_y(sample['Y'])] += 1
        beliefs['C0'][(chosen_c1,)][c0] += 1
        
        # Update saved argmax list
        argmax_x_ctar[chosen_c1] = argmax_x
        

    return rewards, argmax_x_ctar, [], beliefs


# ### Algorithm Z - Our algorithm

# #### Entropy functions

# In[21]:


def entropy(beliefs):
    ''' Compute Ent() as defined in paper ''' 
    denom = np.sum(beliefs)
    probs = [b/denom for b in beliefs]
    temp = [-1 * p * np.log(p) for p in probs]
    return np.sum(temp)


def new_entropy(beliefs):
    ''' Compute Ent_new() as defined in paper ''' 
    new_ents = []
    new_beliefs = [[beliefs[j]+1 if j == i else beliefs[j] for j in range(len(beliefs))] for i in range(len(beliefs))]
    new_ents = [entropy(b) for b in new_beliefs]
    
    return np.mean(new_ents)


# In[22]:


# Dicts that facilitate factorized computation of E[y]

formula = {
    'numerator' : ['Y', 'C0'],
    'denominator' : [],
    'sum_over' : ['C0']
}

formula_indices = {
    'numerator': {'Y':0, 'C0':1},
    'denominator': {},
    'sum_over' : {'C0':0}
}


# In[ ]:





# #### Algorithm for training

# In[23]:


def train_Causal_source_v2(scm, T, final_clusters, context, alpha, verbosity='high'):
    '''Implements our proposed algorithm, called Algorithm Z'''

    # Initialize beliefs about CPDs
    beliefs = {}
    conditionals = get_conditionals(scm)
    for V, v_prime in conditionals.items():
        ranges_v_prime = [scm.get_range(_) for _ in v_prime]
        cartesian_v_prime = list(product(*ranges_v_prime))
        beliefs[V] = {__ : [1 for _ in scm.get_range(V)] for __ in cartesian_v_prime}
    
    # Initialize rewards array
    rewards = []
    
    # Initialize array that holds argmax over x for any ctar
    argmin_x_ctar = [0]*range_C1
    
    ## Main part of the algorithm
    # Phase 1
    last_x_choice = [0]*range_C1
    x_choices = [_ for _ in range(range_x)]
    random.shuffle(x_choices)

    for t in range(int(T*alpha)):
        # Get a full sample with x chosen as next in the shuffled list              
        sample = scm.get_sample()
        x_index = last_x_choice[sample['C1']] % range_x
        x = x_choices[x_index]
        last_x_choice[sample['C1']] += 1
        
        
        sample = scm.get_sample_conditional(context={'C1':sample['C1']}, intervention={'X':x})
    
        beliefs['C1'][tuple()][sample['C1']] += 1
        beliefs['C0'][(sample['C1'],)][sample['C0']] += 1
        beliefs['Y'][(x, sample['C0'])][scm.get_index_of_y(sample['Y'])] += 1


    if verbosity == 'high':
        print('Starting beliefs..')
        print(beliefs)
        print('Final clusters = ', final_clusters)
    
    # Phase 1
    for t in range(T - int(T*alpha)):
        
        Ent = [0 for _ in range(range_x * range_C1)]  # Array for storing system-level Unc for every candidate (x, c1)

        # Lists to store some computed values to avoid duplicate computation
        new_entropies = {} 
        current_entropies = {}
        
        # Calculate estimates E[y] based on current beliefs
        E_y = calc_Ey_exp_params(beliefs, conditionals)
        
        # Compute system-level Unc for every candidate (x, c1)
        for c1, in list(final_clusters.keys()):  # Candidate c1
            for x in range(range_x):   # Candidate x    
                for c1_prime, in list(final_clusters.keys()):
                    for x_prime in range(range_x):
                        temp = 0
                        ranges = [scm.get_range('C0')]
                        cartesian_x_cother = list(product(*ranges))
                        
                        # Compute Unc(E[y|do(x_prime), c1_prime] | x, c1)
                        for c0, in cartesian_x_cother:
                            temp_in = 0
                            if c1 == c1_prime and x == x_prime:
                                if tuple(beliefs['Y'][(x_prime, c0)]) not in new_entropies:
                                    new_entropies[tuple(beliefs['Y'][(x_prime, c0)])] = new_entropy(beliefs['Y'][(x_prime, c0)])
                                temp_in += new_entropies[tuple(beliefs['Y'][(x_prime, c0)])]
                                
                                if tuple(beliefs['C0'][(c1_prime,)]) not in new_entropies:
                                    new_entropies[tuple(beliefs['C0'][(c1_prime,)])] = new_entropy(beliefs['C0'][(c1_prime,)])
                                temp_in += new_entropies[tuple(beliefs['C0'][(c1_prime,)])]


                            elif c1 == c1_prime and x != x_prime:
                                if tuple(beliefs['Y'][(x_prime, c0)]) not in current_entropies:
                                    current_entropies[tuple(beliefs['Y'][(x_prime, c0)])] = entropy(beliefs['Y'][(x_prime, c0)])
                                temp_in += current_entropies[tuple(beliefs['Y'][(x_prime, c0)])]
                        
                                if tuple(beliefs['C0'][(c1_prime,)]) not in new_entropies:
                                    new_entropies[tuple(beliefs['C0'][(c1_prime,)])] = new_entropy(beliefs['C0'][(c1_prime,)])
                                temp_in += new_entropies[tuple(beliefs['C0'][(c1_prime,)])]
                                
                                
                            elif c1 != c1_prime and x == x_prime:
                                if tuple(beliefs['Y'][(x_prime, c0)]) not in new_entropies:
                                    new_entropies[tuple(beliefs['Y'][(x_prime, c0)])] = new_entropy(beliefs['Y'][(x_prime, c0)])
                                temp_in += new_entropies[tuple(beliefs['Y'][(x_prime, c0)])]
                        
                                if tuple(beliefs['C0'][(c1_prime,)]) not in current_entropies:
                                    current_entropies[tuple(beliefs['C0'][(c1_prime,)])] = entropy(beliefs['C0'][(c1_prime,)])
                                temp_in += current_entropies[tuple(beliefs['C0'][(c1_prime,)])]
                                
                            
                            else:
                                if tuple(beliefs['Y'][(x_prime, c0)]) not in current_entropies:
                                    current_entropies[tuple(beliefs['Y'][(x_prime, c0)])] = entropy(beliefs['Y'][(x_prime, c0)])
                                temp_in += current_entropies[tuple(beliefs['Y'][(x_prime, c0)])]
                        
                                if tuple(beliefs['C0'][(c1_prime,)]) not in current_entropies:
                                    current_entropies[tuple(beliefs['C0'][(c1_prime,)])] = entropy(beliefs['C0'][(c1_prime,)])
                                temp_in += current_entropies[tuple(beliefs['C0'][(c1_prime,)])]

                            temp_in = temp_in * (beliefs['C0'][(c1_prime,)][c0]/sum(beliefs['C0'][(c1_prime,)]))
                            temp += temp_in

                        prob_ctar = (beliefs['C1'][tuple()][c1_prime]/sum(beliefs['C1'][tuple()]))   # prob(c1,c1) = p(c1)*p(c1|c1)
                        Ent[x*range_C1 + c1] += temp * prob_ctar * E_y[(x_prime,c1_prime)]                                                       
        

                            
        chosen_c1, chosen_x = list(final_clusters.keys())[0][0], 0
        

        # Argmin over (x, c1)
        for c1, in list(final_clusters.keys()):
            for x in range(range_x):
                if Ent[x*range_C1 + c1] < Ent[chosen_x*range_C1 + chosen_c1]:
                    chosen_c1, chosen_x = c1, x       

        argmin_c1 = chosen_c1
        argmin_x = chosen_x

        # Perform targeted intervention (x, c1) and obtain sample
        sample = scm.get_sample_conditional(context={'C1':argmin_c1}, intervention={'X':argmin_x})
        rewards += [sample['Y']]
        c0 = sample['C0']

        # Update beliefs
        beliefs['C0'][(argmin_c1,)][c0] += 1
        beliefs['Y'][(argmin_x, c0)][scm.get_index_of_y(sample['Y'])] += 1

        # Update saved argmax list
        argmin_x_ctar[argmin_c1] = argmin_x

    return rewards, argmin_x_ctar, [], beliefs


# # Evaluation

# ### Function definitions

# In[24]:


def do_evaluate_NC(scm, T, belief_y_dox_c):
    ''' Evaluate the non-causal algorithms 
        belief_y_dox_c : beliefs learnt by the agent during training'''
    
    rewards = []
    regrets = []
    argmax_x_c = [0]*range_C1
    
    for t in range(T):
        # Sample probabilities
        prob_y_dox_c = []
        for x in range(range_x):
            for c in range(range_C1):
                prob_y_dox_c.append(np.random.dirichlet(belief_y_dox_c[x*range_C1 + c]))
                
        # Calc expectation: E[Y|do(x), c1]
        E_y = [0]*range_x*range_C1
        for x in range(range_x):
            for c in range(range_C1):
                for i_y, y in enumerate(scm.get_range('Y')):
                    E_y[x*range_C1 + c] += prob_y_dox_c[x*range_C1 + c][i_y] * y
       
        # Get sample from environment
        s = scm.get_sample()
        c = s['C1']
    
        # Argmax over x to find algorithm's chosen x
        argmax_x = 0
        for x in range(1, range_x):
            if E_y[x*range_C1 + c] > E_y[argmax_x*range_C1 + c]:
                argmax_x = x
        
        # Sample reward from (y | do(argmax_x), c1)
        sample = scm.get_sample_conditional(context={'C1':c}, intervention={'X':argmax_x})
        rewards += [sample['Y']]
        
        # Sample reward from optimal action (y | do(x*), c1)
        opt_sample = scm.get_sample_conditional(context={'C1':c}, intervention={'X':scm.get_optimal(c)})
        regrets += [opt_sample['Y'] - sample['Y']]
        
        # Update saved argmax list
        argmax_x_c[c] = argmax_x
        
    return rewards, argmax_x_c, regrets


# In[ ]:





# In[25]:


def do_evaluate_Ca(scm, T, beliefs):
    ''' Evaluate the causal baselines
        beliefs : beliefs learnt by the agent during training '''
    
    argmax_x_ctar = [0]*range_C1
    conditionals = get_conditionals(scm)
    rewards = []
    regrets = []
    
    # Estimate E[y] based on beliefs learned by the agent during training
    E_y = calc_Ey_exp_params(beliefs, conditionals)
    
    for t in range(T):       
        
        # Get context from environment
        s = scm.get_sample()  
        c1 = s['C1']
    
        # Argmax over x to find algorithm's chosen x
        argmax_x = 0
        for x in range(1, range_x):
            if E_y[(x,c1)] > E_y[(argmax_x,c1)]:
                argmax_x = x
        
        # Sample reward from (y | do(argmax_x), c1)
        sample = scm.get_sample_conditional(context={'C1':c1}, intervention={'X':argmax_x})
        rewards += [sample['Y']]
        
        # Sample reward from optimal action (y | do(x*), c1)
        opt_sample = scm.get_sample_conditional(context={'C1':c1}, intervention={'X':scm.get_optimal(c1)})
        regrets += [opt_sample['Y'] - sample['Y']]
        
        # Update saved argmax list
        argmax_x_ctar[c1] = argmax_x
        
    return rewards, argmax_x_ctar, regrets


# In[26]:


def do_evaluate_MyAlgo_v2(scm, T, beliefs, final_clusters):
    ''' Evaluate our algorithm. Functionally same as do_evaluate_Ca, just adjusted for some legacy code. 
        beliefs : beliefs learnt by the agent during training '''
    
    argmax_x_ctar = [0]*range_C1
    conditionals = get_conditionals(scm)
    rewards = []
    regrets = []
    
    # Estimate E[y] based on beliefs learned by the agent during training
    E_y = calc_Ey_exp_params(beliefs, conditionals)
    
    for t in range(T):        
        # Get context from environment
        s = scm.get_sample()  
        ctar = s['C1']
    
        # Find center of cluster that contains c1
        for cc in final_clusters.keys():
            if ctar in final_clusters[cc]:
                ctar = cc
                break
            
        # Argmax over x to find algorithm's chosen x
        argmax_x = 0
        for x in range(1, range_x):
            if E_y[(x,ctar[0])] > E_y[(argmax_x,ctar[0])]:
                argmax_x = x

        # Sample reward from (y | do(argmax_x), c1)
        sample = scm.get_sample_conditional(context={'C1':ctar[0]}, intervention={'X':argmax_x})
        rewards += [sample['Y']]
        
        # Sample reward from optimal action (y | do(x*), c1)
        opt_sample = scm.get_sample_conditional(context={'C1':ctar[0]}, intervention={'X':scm.get_optimal(ctar[0])})
        regrets += [opt_sample['Y'] - sample['Y']]
        
        # Update saved argmax list
        argmax_x_ctar[ctar[0]] = argmax_x
        
    return rewards, argmax_x_ctar, regrets


# In[27]:


def train_all(train_T, final_clusters, alpha, verbosity):
    ''' Train using all algorithms and return beliefs (which represent learned policies) '''
    
    beliefs_dict = {}
    train_rewards_dict = {}
    
    # Train contextual TS
    rewards_NCTSObs_A = []

    reward_per_run_NCTSObs_A, argmax_x_c_NCTSObs_A, E_y_TS, belief_y_dox_c_NCTSObs_A = non_causal_TS_given_context(scm, train_T)

    if verbosity == 'high':
        print('TS: Final beliefs..')
        print('belief_y_dox_c_NCTSObs_A : ', belief_y_dox_c_NCTSObs_A)
    
    beliefs_dict['NC_TS_Obs_A']=[belief_y_dox_c_NCTSObs_A]
    train_rewards_dict['NC_TS_Obs_A'] = reward_per_run_NCTSObs_A

    # Train Algorithm A2
    rewards_NCUEObs_A2 = []

    reward_per_run_NCUEObs_A2, argmax_x_c_NCUEObs_A2, E_y_TS, belief_y_dox_c_NCUEObs_A2 = non_causal_UE_given_context(scm, train_T)

    if verbosity == 'high':
        print('TS: Final beliefs..')
        print('belief_y_dox_c_NCUEObs_A2 : ', belief_y_dox_c_NCUEObs_A2)

    beliefs_dict['NC_TS_Obs_A2']=[belief_y_dox_c_NCUEObs_A2]
    train_rewards_dict['NC_TS_Obs_A2'] = reward_per_run_NCUEObs_A2


    # Train Algorithm A3
    rewards_NCUEChoose_A3 = []

    reward_per_run_NCUEChoose_A3, argmax_x_c_NCUEChoose_A3, E_y_TS, belief_y_dox_c_NCUEChoose_A3 = non_causal_UE_chosen_context(scm, train_T)

    if verbosity == 'high':
        print('TS: Final beliefs..')
        print('belief_y_dox_c_NCUEObs_A2 : ', belief_y_dox_c_NCUEChoose_A3)
    
    beliefs_dict['NC_TS_Choose_A3']=[belief_y_dox_c_NCUEChoose_A3]
    train_rewards_dict['NC_TS_Choose_A3'] = reward_per_run_NCUEChoose_A3
    
    
    # Train Algorithm B
    rewards_CaTSObs_B = []
    reward_per_run_CaTSObs_B, argmax_x_c_CaTSObs_B, _, beliefs_CaTSObs_B = causal_TS_given_context(scm, train_T)

    if verbosity == 'high':
        print('MyAlgo v1: Final beliefs..')
        print('belief_y_x_c0 : ', beliefs_CaTSObs_B['Y'])
        print('belief_c0_c1 : ', beliefs_CaTSObs_B['C0'])
    
    beliefs_dict['Ca_TS_Obs_B']=[beliefs_CaTSObs_B['Y'], beliefs_CaTSObs_B['C0']]
    train_rewards_dict['Ca_TS_Obs_B'] = reward_per_run_CaTSObs_B

    
    # Train Algorithm C
    rewards_CaUEObs_C = []
    reward_per_run_CaUEObs_C, argmax_x_c_CaUEObs_C, _, beliefs_CaUEObs_C = causal_UE_given_context(scm, train_T)

    if verbosity == 'high':
        print('MyAlgo v1: Final beliefs..')
        print('belief_y_x_c0 : ', beliefs_CaUEObs_C['Y'])
        print('belief_c0_c1 : ', beliefs_CaUEObs_C['C0'])
    
    beliefs_dict['Ca_UE_Obs_C']=[beliefs_CaUEObs_C['Y'], beliefs_CaUEObs_C['C0']]
    train_rewards_dict['Ca_UE_Obs_C'] = reward_per_run_CaUEObs_C

    
    # Train Algorithm D
    rewards_CaTSChoose_D = []

    reward_per_run_CaTSChoose_D, argmax_x_c_CaTSChoose_D, _, beliefs_CaTSChoose_D = causal_TS_chosen_context(scm, train_T)

    if verbosity == 'high':
        print('MyAlgo v1: Final beliefs..')
        print('belief_y_x_c0 : ', beliefs_CaTSChoose_D['Y'])
        print('belief_c0_c1 : ', beliefs_CaTSChoose_D['C0'])

    beliefs_dict['Ca_TS_Choose_D']=[beliefs_CaTSChoose_D['Y'], beliefs_CaTSChoose_D['C0']]
    train_rewards_dict['Ca_TS_Choose_D'] = reward_per_run_CaTSChoose_D

    
    # Train Algorithm E
    rewards_CaTSChoose_E = []

    reward_per_run_CaPEChoose_E, argmax_x_c_CaPEChoose_E, _, beliefs_CaPEChoose_E = causal_PureExp_chosen_context(scm, train_T)

    if verbosity == 'high':
        print('MyAlgo v1: Final beliefs..')
        print('belief_y_x_c0 : ', beliefs_CaPEChoose_E['Y'])
        print('belief_c0_c1 : ', beliefs_CaPEChoose_E['C0'])


    beliefs_dict['Ca_PE_Choose_E']=[beliefs_CaPEChoose_E['Y'], beliefs_CaPEChoose_E['C0']]
    train_rewards_dict['Ca_PE_Choose_E'] = reward_per_run_CaPEChoose_E
    
    # Train Algorithm TargInt_TS_UniExp (F)
    rewards_CaTSChoose_F = []

    reward_per_run_CaTSChoose_F, argmax_x_c_CaTSChoose_F, _, beliefs_CaTSChoose_F = TargInt_epsilon_greedy(scm, train_T, epsilon=0.5) 

    if verbosity == 'high':
        print('MyAlgo v1: Final beliefs..')
        print('belief_y_x_c0 : ', beliefs_CaTSChoose_F['Y'])
        print('belief_c0_c1 : ', beliefs_CaTSChoose_F['C0'])

    beliefs_dict['Ca_TS_Choose_F']=[beliefs_CaTSChoose_F['Y'], beliefs_CaTSChoose_F['C0']]
    train_rewards_dict['Ca_TS_Choose_F'] = reward_per_run_CaTSChoose_F

    # Train our algorithm - Algorithm Z
    rewards_v2 = []
    
    reward_per_run_v2, argmax_x_c_v2, _, beliefs_v2 = train_Causal_source_v2(scm, train_T, final_clusters, context=['C1'], alpha=alpha, verbosity=verbosity)

    if verbosity == 'high':
        print('MyAlgo v2: Final beliefs..')    
        print('Beliefs Y|X,C0 : ', beliefs_v2['Y']) 
        print('Beliefs C0|C1 :', beliefs_v2['C0'])
        print('Beliefs C1 :', beliefs_v2['C1'])

    
    beliefs_dict['v2']=[beliefs_v2['Y'], beliefs_v2['C0']]
    train_rewards_dict['v2'] = reward_per_run_v2

    
    return beliefs_dict, train_rewards_dict


# In[28]:


def train_plot(train_rewards_dict):
    ''' Plot training rewards '''
    
    print('Total rewards for NC_TS_Obs_A = ', np.sum(train_rewards_dict['NC_TS_Obs_A']))
    print('Total rewards for NC_TS_Obs_A2 = ', np.sum(train_rewards_dict['NC_TS_Obs_A2']))
    print('Total rewards for NC_TS_Choose_A3 = ', np.sum(train_rewards_dict['NC_TS_Choose_A3']))
    print('Total rewards for Ca_TS_Obs_B = ', np.sum(train_rewards_dict['Ca_TS_Obs_B']))
    print('Total rewards for Ca_UE_Obs_C = ', np.sum(train_rewards_dict['Ca_UE_Obs_C']))
    print('Total rewards for Ca_TS_Choose_D = ', np.sum(train_rewards_dict['Ca_TS_Choose_D']))
    print('Total rewards for Ca_PE_Choose_E = ', np.sum(train_rewards_dict['Ca_PE_Choose_E']))
    print('Total rewards for Ca_TS_Choose_F = ', np.sum(train_rewards_dict['Ca_TS_Choose_F']))
    print('Total rewards for v2 = ', np.sum(train_rewards_dict['v2']))

    fig, ax = plt.subplots(figsize=(14,7))
    ax.plot(train_rewards_dict['NC_TS_Obs_A'], marker='o', color='black', label='NC_TS_Obs_A')
    ax.plot(train_rewards_dict['NC_TS_Obs_A2'], marker='o', color='green', label='NC_UE_Obs_A2')
    ax.plot(train_rewards_dict['NC_TS_Choose_A3'], marker='o', color='cyan', label='NC_UE_Choose_A3')
    ax.plot(train_rewards_dict['Ca_TS_Obs_B'], marker='o', color='yellow', label='Ca_TS_Obs_B')
    ax.plot(train_rewards_dict['Ca_UE_Obs_C'], marker='o', color='brown', label='Ca_UE_Obs_C')
    ax.plot(train_rewards_dict['Ca_TS_Choose_D'], marker='o', color='orange', label='Ca_TS_Choose_D')
    ax.plot(train_rewards_dict['Ca_PE_Choose_E'], marker='o', color='blue', label='Ca_PE_Choose_E')
    ax.plot(train_rewards_dict['Ca_TS_Choose_F'], marker='o', color='orange', label='Ca_TS_Choose_F')
    ax.plot(train_rewards_dict['v2'], marker='o', color='red', label='MyAlgo v2')

    ax.legend()
    ax.set_title('Train: Rewards vs. $t$')
    ax.set_ylabel('Expected reward $\mathbb{E}[Y]$')
    ax.set_xlabel('Round $t$')
    plt.show()


# In[29]:


def test_all(train_runs, test_runs_per_train_run, train_T, test_T, alpha, verbosity='high'):
    ''' 
        Test all algorithms
        train_runs : number of training runs
        test_runs_per_train_run : number of test runs for each training run; just set to 1 here.
        train_T : number of training rounds
        test_T : number of testing rounds (> 1 to get lower-variance estimates)
        alpha : alpha
        verbosity : controls level of printing on console
    '''
    
    start_time = time.time()
    
    rewards_dict_for_plot = {}
    regrets_dict_for_plot = {}

    
    rewards_NCTSObs_A, regrets_NCTSObs_A = [], []
    rewards_NCUEObs_A2, regrets_NCUEObs_A2 = [], []
    rewards_NCUEChoose_A3, regrets_NCUEChoose_A3 = [], []
    rewards_CaTSObs_B, regrets_CaTSObs_B = [], []
    rewards_CaUEObs_C, regrets_CaUEObs_C = [], []
    rewards_CaTSChoose_D, regrets_CaTSChoose_D = [], []
    rewards_CaPEChoose_E, regrets_CaPEChoose_E = [], []
    rewards_CaTSChoose_F, regrets_CaTSChoose_F = [], []
    rewards_v2, regrets_v2 = [], []
    
    for run in range(train_runs):
        print('\nTrain run number = {}'.format(run))
        
        final_clusters = {(c1,) : [(c1,)]  for c1 in range(range_C1)}
        
        if verbosity == 'high':
            print('Final clusters')
            print(final_clusters)

        
        beliefs_dict, train_rewards_dict = train_all(train_T, final_clusters, alpha, verbosity)  
        
        # Plot first train graph
#         if run == 0:
#             train_plot(train_rewards_dict)
        
        belief_y_dox_c_NCTSObs_A = beliefs_dict['NC_TS_Obs_A'][0]
        belief_y_dox_c_NCUEObs_A2 = beliefs_dict['NC_TS_Obs_A2'][0]
        belief_y_dox_c_NCUEChoose_A3 = beliefs_dict['NC_TS_Choose_A3'][0]
        belief_y_x_c0_CaTSObs_B, belief_c0_c1_CaTSObs_B = beliefs_dict['Ca_TS_Obs_B'][0], beliefs_dict['Ca_TS_Obs_B'][1]
        belief_y_x_c0_CaUEObs_C, belief_c0_c1_CaUEObs_C = beliefs_dict['Ca_UE_Obs_C'][0], beliefs_dict['Ca_UE_Obs_C'][1]
        belief_y_x_c0_CaTSChoose_D, belief_c0_c1_CaTSChoose_D = beliefs_dict['Ca_TS_Choose_D'][0], beliefs_dict['Ca_TS_Choose_D'][1]
        belief_y_x_c0_CaPEChoose_E, belief_c0_c1_CaPEChoose_E = beliefs_dict['Ca_PE_Choose_E'][0], beliefs_dict['Ca_PE_Choose_E'][1]
        belief_y_x_c0_CaTSChoose_F, belief_c0_c1_CaTSChoose_F = beliefs_dict['Ca_TS_Choose_F'][0], beliefs_dict['Ca_TS_Choose_F'][1]
        belief_y_x_c0_v2, belief_c0_c1_v2 = beliefs_dict['v2'][0], beliefs_dict['v2'][1]

        for _ in range(test_runs_per_train_run):
            if verbosity in ['high', 'mid']:
                print('\nTest run number = {}'.format(_))
            

            ### TESTING ###
            # Test TS        

            reward_per_run_NCTSObs_A, argmax_x_c_NCTSObs_A, regret_per_run_NCTSObs_A = do_evaluate_NC(scm, test_T, belief_y_dox_c_NCTSObs_A)
            rewards_NCTSObs_A.append(reward_per_run_NCTSObs_A)
            regrets_NCTSObs_A.append(regret_per_run_NCTSObs_A)
            rewards_dict_for_plot[(run, _, 'A')] = np.array(reward_per_run_NCTSObs_A)
            regrets_dict_for_plot[(run, _, 'A')] = np.array(regret_per_run_NCTSObs_A)
            
            # Test A2

            reward_per_run_NCUEObs_A2, argmax_x_c_NCUEObs_A2, regret_per_run_NCUEObs_A2 = do_evaluate_NC(scm, test_T, belief_y_dox_c_NCUEObs_A2)
            rewards_NCUEObs_A2.append(reward_per_run_NCUEObs_A2)
            regrets_NCUEObs_A2.append(regret_per_run_NCUEObs_A2)
            rewards_dict_for_plot[(run, _, 'A2')] = np.array(reward_per_run_NCUEObs_A2)
            regrets_dict_for_plot[(run, _, 'A2')] = np.array(regret_per_run_NCUEObs_A2)

            # Test A3

            reward_per_run_NCUEChoose_A3, argmax_x_c_NCUEChoose_A3, regret_per_run_NCUEChoose_A3 = do_evaluate_NC(scm, test_T, belief_y_dox_c_NCUEChoose_A3)
            rewards_NCUEChoose_A3.append(reward_per_run_NCUEChoose_A3)
            regrets_NCUEChoose_A3.append(regret_per_run_NCUEChoose_A3)
            rewards_dict_for_plot[(run, _, 'A3')] = np.array(reward_per_run_NCUEChoose_A3)
            regrets_dict_for_plot[(run, _, 'A3')] = np.array(regret_per_run_NCUEChoose_A3)            

            # Test MyAlgo v1

            reward_per_run_CaTSObs_B, argmax_x_c_CaTSObs_B, regret_per_run_CaTSObs_B = do_evaluate_Ca(scm, test_T, {'Y':belief_y_x_c0_CaTSObs_B, 'C0':belief_c0_c1_CaTSObs_B})
            rewards_CaTSObs_B.append(reward_per_run_CaTSObs_B)
            regrets_CaTSObs_B.append(regret_per_run_CaTSObs_B)
            rewards_dict_for_plot[(run, _, 'B')] = np.array(reward_per_run_CaTSObs_B)
            regrets_dict_for_plot[(run, _, 'B')] = np.array(regret_per_run_CaTSObs_B)
            
            reward_per_run_CaUEObs_C, argmax_x_c_CaUEObs_C, regret_per_run_CaUEObs_C = do_evaluate_Ca(scm, test_T, {'Y':belief_y_x_c0_CaUEObs_C, 'C0':belief_c0_c1_CaUEObs_C})
            rewards_CaUEObs_C.append(reward_per_run_CaUEObs_C)
            regrets_CaUEObs_C.append(regret_per_run_CaUEObs_C)
            rewards_dict_for_plot[(run, _, 'C')] = np.array(reward_per_run_CaUEObs_C)
            regrets_dict_for_plot[(run, _, 'C')] = np.array(regret_per_run_CaUEObs_C)
            
            reward_per_run_CaTSChoose_D, argmax_x_c_CaTSChoose_D, regret_per_run_CaTSChoose_D = do_evaluate_Ca(scm, test_T, {'Y':belief_y_x_c0_CaTSChoose_D, 'C0':belief_c0_c1_CaTSChoose_D})
            rewards_CaTSChoose_D.append(reward_per_run_CaTSChoose_D)
            regrets_CaTSChoose_D.append(regret_per_run_CaTSChoose_D)
            rewards_dict_for_plot[(run, _, 'D')] = np.array(reward_per_run_CaTSChoose_D)
            regrets_dict_for_plot[(run, _, 'D')] = np.array(regret_per_run_CaTSChoose_D)
            
            
            reward_per_run_CaPEChoose_E, argmax_x_c_CaPEChoose_E, regret_per_run_CaPEChoose_E = do_evaluate_Ca(scm, test_T, {'Y':belief_y_x_c0_CaPEChoose_E, 'C0':belief_c0_c1_CaPEChoose_E})
            rewards_CaPEChoose_E.append(reward_per_run_CaPEChoose_E)
            regrets_CaPEChoose_E.append(regret_per_run_CaPEChoose_E)
            rewards_dict_for_plot[(run, _, 'E')] = np.array(reward_per_run_CaPEChoose_E)
            regrets_dict_for_plot[(run, _, 'E')] = np.array(regret_per_run_CaPEChoose_E)
            
          
           
            
            reward_per_run_CaTSChoose_F, argmax_x_c_CaTSChoose_F, regret_per_run_CaTSChoose_F = do_evaluate_Ca(scm, test_T, {'Y':belief_y_x_c0_CaTSChoose_F, 'C0':belief_c0_c1_CaTSChoose_F})
            rewards_CaTSChoose_F.append(reward_per_run_CaTSChoose_F)
            regrets_CaTSChoose_F.append(regret_per_run_CaTSChoose_F)
            rewards_dict_for_plot[(run, _, 'F')] = np.array(reward_per_run_CaTSChoose_F)
            regrets_dict_for_plot[(run, _, 'F')] = np.array(regret_per_run_CaTSChoose_F)


            # Test MyAlgo v2
            reward_per_run_v2, argmax_x_c_v2, regret_per_run_v2 = do_evaluate_MyAlgo_v2(scm, test_T, {'Y':belief_y_x_c0_v2, 'C0':belief_c0_c1_v2}, final_clusters)
            rewards_v2.append(reward_per_run_v2) 
            regrets_v2.append(regret_per_run_v2) 
            rewards_dict_for_plot[(run, _, 'v2')] = np.array(reward_per_run_v2)
            regrets_dict_for_plot[(run, _, 'v2')] = np.array(regret_per_run_v2)

            if verbosity in ['high', 'mid']:
                print('Rewards')
                print(np.mean(reward_per_run_NCTSObs_A))
                print(np.mean(reward_per_run_NCUEObs_A2))
                print(np.mean(reward_per_run_NCUEChoose_A3))
                print(np.mean(reward_per_run_CaTSObs_B))
                print(np.mean(reward_per_run_CaUEObs_C))
                print(np.mean(reward_per_run_CaTSChoose_D))
                print(np.mean(reward_per_run_CaPEChoose_E))
                print(np.mean(reward_per_run_CaTSChoose_F))
                print(np.mean(reward_per_run_v2))
                print('Regrets')
                print(np.mean(regret_per_run_NCTSObs_A))
                print(np.mean(regret_per_run_NCUEObs_A2))
                print(np.mean(regret_per_run_NCUEChoose_A3))
                print(np.mean(regret_per_run_CaTSObs_B))
                print(np.mean(regret_per_run_CaUEObs_C))
                print(np.mean(regret_per_run_CaTSChoose_D))
                print(np.mean(regret_per_run_CaPEChoose_E))
                print(np.mean(regret_per_run_CaTSChoose_F))
                print(np.mean(regret_per_run_v2))
        
        if verbosity in ['high', 'mid']:
            print('NC TS_Obs_A1 Policy = ', argmax_x_c_NCTSObs_A)
            print('NC UE_Obs_A2 Policy = ', argmax_x_c_NCUEObs_A2)
            print('NC UE_Choose_A3 Policy = ', argmax_x_c_NCUEChoose_A3)
            print('Ca_TS_Obs_B Policy = ', argmax_x_c_CaTSObs_B)
            print('Ca_UE_Obs_C Policy = ', argmax_x_c_CaUEObs_C)
            print('Ca_TS_Choose_D Policy = ', argmax_x_c_CaTSChoose_D)
            print('Ca_UE_Choose_E Policy = ', argmax_x_c_CaPEChoose_E)
            print('Ca_TS_Choose_D Policy = ', argmax_x_c_CaTSChoose_F)
            print('v2 Policy = ', argmax_x_c_v2)
        
    rewards_NCTSObs_A, rewards_NCUEObs_A2, rewards_NCUEChoose_A3, rewards_CaTSObs_B, rewards_CaUEObs_C, rewards_CaTSChoose_D, rewards_CaPEChoose_E, rewards_CaTSChoose_F, rewards_v2 = np.array(rewards_NCTSObs_A), np.array(rewards_NCUEObs_A2), np.array(rewards_NCUEChoose_A3), np.array(rewards_CaTSObs_B), np.array(rewards_CaUEObs_C), np.array(rewards_CaTSChoose_D), np.array(rewards_CaPEChoose_E), np.array(rewards_CaTSChoose_F), np.array(rewards_v2)
    rewards_mean_NCTSObs_A, rewards_mean_NCUEObs_A2, rewards_mean_NCUEChoose_A3, rewards_mean_CaTSObs_B, rewards_mean_CaUEObs_C, rewards_mean_CaTSChoose_D, rewards_mean_CaPEChoose_E, rewards_mean_CaTSChoose_F, rewards_mean_v2 = np.mean(rewards_NCTSObs_A, axis=0), np.mean(rewards_NCUEObs_A2, axis=0), np.mean(rewards_NCUEChoose_A3, axis=0), np.mean(rewards_CaTSObs_B, axis=0), np.mean(rewards_CaUEObs_C, axis=0), np.mean(rewards_CaTSChoose_D, axis=0), np.mean(rewards_CaPEChoose_E, axis=0), np.mean(rewards_CaTSChoose_F, axis=0), np.mean(rewards_v2, axis=0)
    rewards_std_TS, rewards_std_A2, rewards_std_A3, rewards_std_CaTSObs_B, rewards_std_CaUEObs_C, rewards_std_CaTSChoose_D, rewards_std_CaPEChoose_E, rewards_std_CaTSChoose_F, rewards_std_v2 = np.std(rewards_NCTSObs_A, axis=0), np.std(rewards_NCUEObs_A2, axis=0), np.std(rewards_NCUEChoose_A3, axis=0), np.std(rewards_CaTSObs_B, axis=0), np.std(rewards_CaUEObs_C, axis=0), np.std(rewards_CaTSChoose_D, axis=0), np.std(rewards_CaPEChoose_E, axis=0), np.std(rewards_CaTSChoose_F, axis=0), np.std(rewards_v2, axis=0)
    
    regrets_NCTSObs_A, regrets_NCUEObs_A2, regrets_NCUEChoose_A3, regrets_CaTSObs_B, regrets_CaUEObs_C, regrets_CaTSChoose_D, regrets_CaPEChoose_E, regrets_CaTSChoose_F, regrets_v2 = np.array(regrets_NCTSObs_A), np.array(regrets_NCUEObs_A2), np.array(regrets_NCUEChoose_A3), np.array(regrets_CaTSObs_B), np.array(regrets_CaUEObs_C), np.array(regrets_CaTSChoose_D), np.array(regrets_CaPEChoose_E), np.array(regrets_CaTSChoose_F), np.array(regrets_v2)
    regrets_mean_NCTSObs_A, regrets_mean_NCUEObs_A2, regrets_mean_NCUEChoose_A3, regrets_mean_CaTSObs_B, regrets_mean_CaUEObs_C, regrets_mean_CaTSChoose_D, regrets_mean_CaPEChoose_E, regrets_mean_CaTSChoose_F, regrets_mean_v2 = np.mean(regrets_NCTSObs_A, axis=0), np.mean(regrets_NCUEObs_A2, axis=0), np.mean(regrets_NCUEChoose_A3, axis=0), np.mean(regrets_CaTSObs_B, axis=0), np.mean(regrets_CaUEObs_C, axis=0), np.mean(regrets_CaTSChoose_D, axis=0), np.mean(regrets_CaPEChoose_E, axis=0), np.mean(regrets_CaTSChoose_F, axis=0), np.mean(regrets_v2, axis=0)
    regrets_std_TS, regrets_std_A2, regrets_std_A3, regrets_std_CaTSObs_B, regrets_std_CaUEObs_C, regrets_std_CaTSChoose_D, regrets_std_CaPEChoose_E, regrets_std_CaTSChoose_F, regrets_std_v2 = np.std(regrets_NCTSObs_A, axis=0), np.std(regrets_NCUEObs_A2, axis=0), np.std(regrets_NCUEChoose_A3, axis=0), np.std(regrets_CaTSObs_B, axis=0), np.std(regrets_CaUEObs_C, axis=0), np.std(regrets_CaTSChoose_D, axis=0), np.std(regrets_CaPEChoose_E, axis=0), np.std(regrets_CaTSChoose_F, axis=0), np.std(regrets_v2, axis=0)
    
    
    
    print("\nTime taken: --- %s seconds ---" % (round(time.time() - start_time, 2)))  
    print("\nTime taken: --- %s minutes ---" % ((time.time() - start_time)/60))  
    print()
    
    ### Plotting ###
    print('NC_TS_Obs_A Policy = ', argmax_x_c_NCTSObs_A)
    print('NC_UE_Obs_A2 Policy = ', argmax_x_c_NCUEObs_A2)
    print('NC_UE_Choose_A2 Policy = ', argmax_x_c_NCUEChoose_A3)
    print('Ca_TS_Obs_B Policy = ', argmax_x_c_CaTSObs_B)
    print('Ca_UE_Obs_C Policy = ', argmax_x_c_CaUEObs_C)
    print('Ca_TS_Choose_D Policy = ', argmax_x_c_CaTSChoose_D)
    print('Ca_PE_Choose_E Policy = ', argmax_x_c_CaPEChoose_E)
    print('Ca_TS_Choose_F Policy = ', argmax_x_c_CaTSChoose_F)
    print('v2 Policy = ', argmax_x_c_v2)

    
    return rewards_dict_for_plot, regrets_dict_for_plot


# In[ ]:





# In[ ]:





# #### Execution (without plotting) - For multiple settings (T, alpha, ..)

# In[30]:


# Load hyperparam
train_runs_list = params['eval']['train_runs_list']
test_runs_per_train_run_list = params['eval']['test_runs_per_train_run_list']
train_T_list = params['eval']['train_T_list']
test_T_list = params['eval']['test_T_list']
alpha_list = params['eval']['alpha_list']


# In[31]:


# Run all algorithms, get rewards, and save for plotting.
for i in range(len(train_runs_list)):
    print('\nSetting #', i,'\n------')
    
    
    # Run all algorithms and get rewards    
    rewards_dict_for_plot, regrets_dict_for_plot = test_all(
        train_runs = train_runs_list[i], 
        test_runs_per_train_run = test_runs_per_train_run_list[i], 
        train_T = train_T_list[i], 
        test_T = test_T_list[i], 
        alpha=alpha_list[i], 
        verbosity='low'
    )
    
    # Save outputs
    filename = 'output_for_plot/regrets_dict_for_plot_' +  str(np.random.randint(1000)) + '.pickle'
    for_pickle = {}
    for_pickle['config'] = {
        'train_runs' : train_runs_list[i], 
        'test_runs_per_train_run' : test_runs_per_train_run_list[i], 
        'train_T' : train_T_list[i],
        'test_T' : test_T_list[i],
        'alpha' : alpha_list[i]
    }
    for_pickle['params'] = params
    for_pickle['regrets'] = regrets_dict_for_plot

    with open(filename, 'wb') as handle:
        pickle.dump(for_pickle, handle, protocol=pickle.HIGHEST_PROTOCOL)
        
    with open(filename, 'rb') as handle:
        for_pickle = pickle.load(handle)
    
    rewards_dict_for_plot = for_pickle['regrets']
    
    print(for_pickle['config'])


# In[ ]:





# In[ ]:


# scm.get_index_of_y(2)


# In[ ]:




