# The Agent
import numpy as np
from utils import *

class Agent():
    def __init__(self, N, d, gamma):
        self.N = N
        self.d = d
        self.gamma = gamma
        self.reset() 

    def reset(self):
        self.Transition = np.zeros((2*self.d, 2*self.d))
        self.Reward = np.zeros(2*self.d)
        self.Q_values = np.zeros(2*self.d)
        return
    
    def update(self, state, action, reward, next_state, next_action):
        """
        Update the transition matrix and reward vector using (s,a,r,s',a')
        """
        self.Transition[state*2+action][next_state*2+next_action] += 1
        self.Reward[state*2+action] += reward
        return
    
    def get_Q_values(self):
        """
        Get the Q-values
        """
        visitation = self.Transition.sum(axis=1)
        mask = visitation != 0
        num_s_a = 2*self.d
        Transition = np.ones((num_s_a, num_s_a))
        Transition /= num_s_a
        Transition[mask, :] = self.Transition[mask, :] / visitation[mask, None]
        Reward = np.zeros(num_s_a)
        Reward[mask] = self.Reward[mask] / visitation[mask]
        self.Q_values = compute_value_function(Transition, Reward, self.gamma)

        return self.Q_values
    


