import numpy as np

class VMBPO():
    def __init__(self, num_states,num_actions,env,epsilon=0.5,alpha=0.01,gamma=0.9):
        self.num_states = num_states
        self.num_actions = num_actions
        self.Q = np.zeros((num_states,num_actions))
        self.epsilon = epsilon
        self.alpha = alpha
        self.gamma = gamma
        self.env = env

    def epsilon_greedy(self,s):
        a = np.random.uniform()
        if(a < self.epsilon):
            return np.random.randint(self.num_actions)
        else:
            return np.argmax(self.Q[s])

    def update_step(self, s,a,next_s, r):
        i,j = self.env.index_to_ij(s)
        if self.env.WALLS[i,j]:
            self.Q[s, a] += self.alpha * (r - self.Q[s, a])
        else:
            if(s == 9):
                self.Q[s,a] += self.alpha * (r - self.Q[s,a])
            else:
                self.Q[s, a] += self.alpha * (r + self.gamma * np.max(self.Q[next_s]) - self.Q[s, a])