###SIMPLE MDP ENVIRONMENT
import numpy as np
from numpy.random import default_rng
import itertools
class MDP:
#Infinite horizon discounted MDP   
    def __init__(self,Ns,Na,P,R, reward_family ='Bernoulli', gamma=0.9):
        self. rng = default_rng() #Random number generator that will be used for the MDP
        self.Ns = Ns
        self.Na = Na
        assert P.shape == (Ns,Na,Ns) # P[s,a,s']
        self.P = P
        #assert R.shape == (Ns,Na)
        self.R = R
        self.reward_family = reward_family
#         if reward_family == 'Beta':
#             self.alpha = self.R[]
#             self.beta = beta
        self.gamma = gamma
        #initial state
        self.state = self.rng.integers(0, self.Ns)
    
    
    def copy(self):
        Ns = self.Ns
        Na = self.Na
        gamma = self.gamma
        reward_family = self.reward_family
        P = self.P
        R = self.R
        new_mdp = MDP(Ns,Na,P,R,reward_family,gamma)
        return new_mdp
    def current(self):
        return self.state     
    
    def query(self,s,a): #query a transition from a generative model
        if self.reward_family =='Bernoulli':
            reward = self.rng.binomial(1, self.R[s,a])
        elif self.reward_family == 'Beta':
            reward = self.rng.beta(self.R[s,a,0], self.R[s,a,1])
        next_s = self.rng.multinomial(1, self.P[s,a]) # P[s,a,s']
        self.state = np.where(next_s==1)[0][0]
        return  reward,self.state
    
    def play(self,a): #for online model
        s = self.state 
        if self.reward_family =='Bernoulli':
            reward = self.rng.binomial(1, self.R[s,a])
        elif self.reward_family == 'Beta':
            reward = self.rng.beta(self.R[s,a,0], self.R[s,a,1])
        rewards = np.array([reward])
        #print(s,a, self.P.shape, self.P[s,a].shape)
        transition = self.rng.multinomial(1, self.P[s,a]) # P[s,a,s']
        self.state = np.where(transition==1)[0][0]
        return  rewards, transition,self.state
        
    def multiple_samples(self,s,a,N): #for generative model
        if self.reward_family =='Bernoulli':
            rewards = self.rng.binomial(N, self.R[s,a])
        transitions= self.rng.multinomial(N, self.P[s,a])
        return  rewards, transitions
    
    def reset(self,visits=None):
        self.state = self.rng.integers(0, self.Ns)
#         if not (visits is None):
#             self.state = self.rng.multinomial(1, 1/visits,1)
        return self.state
    
##FUNCTIONS FOR GENERATING RANDOM MDPs 
def Random_transitions(Ns,Na): 
    rng = default_rng()
    P = np.zeros((Na,Ns,Ns)) 
    for s,a in itertools.product(range(Ns),range(Na)):
        P[a,s] = rng.dirichlet(np.ones(Ns))
    P = P.transpose(1, 0, 2) # P[s,a,s']
    return P
def Random_rewards(Ns,Na, reward_family='Bernoulli'):
    rng = default_rng()
    if reward_family=='Bernoulli':
        R = rng.uniform(0,1, (Ns,Na))
    return R
def generate_mdp(Ns,Na,reward_family='Bernoulli', gamma=0.9):
    P = Random_transitions(Ns,Na)
    R = Random_rewards(Ns,Na)
    mdp = MDP(Ns,Na,P,R,reward_family, gamma)
    return mdp


def initial_estimate_online(mdp):
    Ns = mdp.Ns
    Na = mdp.Na
    gamma = mdp.gamma
    reward_family = mdp.reward_family
    R = Random_rewards(Ns,Na, reward_family)
    P = np.zeros((Ns,Na,Ns))
    for s,a in itertools.product(range(Ns),range(Na)):
        P[s,a] =  np.ones(Ns)/Ns
    mdp = MDP(Ns,Na,P,R,reward_family, gamma)
    return mdp

#In the case of a generative model
#we draw N samples from every state-action pair to construct an initial empirical mdp
def uniform_initial_estimate(mdp,N):
    Ns = mdp.Ns
    Na = mdp.Na
    gamma = mdp.gamma
    reward_family = mdp.reward_family
    R = np.zeros((Ns, Na))
    P = np.zeros((Ns,Na,Ns))
    for s,a in itertools.product(range(Ns),range(Na)):
        rewards, transitions =  mdp.multiple_samples(s,a,N)
        R[s,a] = rewards/N
        P[s,a] =  transitions/N
    M_hat = MDP(Ns,Na,P,R,reward_family, gamma)
    return M_hat

#UPDATING EMPIRICAL MDP FROM SAMPLES
def update(mdp,s,a,rewards,transitions,N,visits):
    Ns = mdp.Ns
    Na = mdp.Na
    gamma = mdp.gamma
    reward_family = mdp.reward_family
    P = mdp.P
    R = mdp.R
    R[s,a] = (R[s,a]*visits+rewards.sum())/(visits+N)
    P[s,a,:] = (P[s,a,:]*visits+transitions)/(visits+N)
    new_mdp = MDP(Ns,Na,P,R,reward_family,gamma)
    return new_mdp  