from itertools import product, combinations
import numpy as np 
from tqdm import tqdm
from functools import reduce
import mdptoolbox


class CombMDP():

    def __init__(self):
        self.T = None
        self.R = None
        self.states = None
        self.actions = None
        self.state_to_ind = None
        self.policy = None

        self.STATE_LIMIT = 1e4



    # T should be A,S,S when passed in
    def make_mdp(self, T, R, C, HB, LB, fairness_epsilon=None):
        T = np.swapaxes(T, 1, 2)  

        N = T.shape[0]
        A = T.shape[1]
        S = T.shape[2]

        num_states = S
        num_actions = A

        if fairness_epsilon is None:
            fairness_epsilon = C.max()

        comb_num_states = S**N
        if comb_num_states >= self.STATE_LIMIT:
            raise ValueError('State size is too big (%s states) -- will not run value iteration'%comb_num_states)


        states_per_patient = np.arange(num_states)
        combined_state_space = list(product(states_per_patient, repeat=N))
        # print('State Space')
        # print(combined_state_space)
        # print()

        per_arm_action_space = np.arange(num_actions)
        combined_action_space = list(product(per_arm_action_space, repeat=N))

        feasible_action_space = []
        for i, actions in enumerate(combined_action_space):
            under_budget = True
            payment = np.zeros(num_actions)
            for arm in range(N):
                payment[actions[arm]] += C[arm, actions[arm]]
            EPS = 1e-6
            if (payment[1:] - EPS > HB).any():
                under_budget = False
            if (payment[1:] + EPS < LB).any():
                under_budget = False
            respects_fairness = payment[1:].max() - payment[1:].min() <= fairness_epsilon
            if under_budget and respects_fairness:
            	feasible_action_space.append(actions)
        
        comb_num_actions = len(feasible_action_space)
        
        if comb_num_actions >= self.STATE_LIMIT:
            raise ValueError('Action size is too big (%s actions) -- will not run value iteration'%comb_num_actions)
        cs = len(combined_state_space)
        ca = len(feasible_action_space)

        #print("Combined state space size:",cs)
        #print("Combined action space size:",ca)
        #print("Num matrix entries:",ca*cs*cs)

        T_matrices = []
        for a in feasible_action_space:

            inputs = [ T[ind, arm_a] for ind, arm_a in enumerate(a) ]

            mat = reduce(lambda a,b : np.kron(a,b),inputs)
            T_matrices.append(mat)


        R = [sum([R[arm,s] for arm,s in enumerate(states)]) for states in combined_state_space]

        self.T = np.array(T_matrices)
        self.R = np.array(R)
        self.states = combined_state_space
        self.actions = feasible_action_space

        self.state_to_ind = dict((tup, i) for i, tup in enumerate(combined_state_space))


        return self.T, self.R, self.states, self.actions

    def value_iteration(self, gamma):


        # rewards need to be A,S,S too, but R is only S (current state)
        R_expanded = np.zeros(self.T.shape)
        for x in range(R_expanded.shape[0]):
            for y in range(R_expanded.shape[1]):
                R_expanded[x,:,y] += self.R

        # run value iteration
        mdp = mdptoolbox.mdp.ValueIteration(self.T, R_expanded, discount=gamma)
        mdp.run()
        policy = np.array(mdp.policy)
        self.policy = policy



    # states should be a list or 1-d np array
    def get_action(self,states):
        state_tuple = tuple(states)
        state_ind = self.state_to_ind[state_tuple]
        action_ind = self.policy[state_ind]
        action_tuple = self.actions[action_ind]
        return np.array(action_tuple)


    def enumerate_policy(self):
        for ind, state in enumerate(self.states):
            print('state', state, 'action', self.actions[self.policy[ind]])


