'''
Q-learning implementation.
Algorithm can either be run as an online or offline learning method.
'''


import numpy as np 
import random 

class qLearningAgent(): # training algorithm made offline
    def __init__(self, env, stateTable):
        self.learningRate = 0.001
        self.numStates = np.shape(stateTable)[0]
        self.env = env
        self.epsilon = 1
        self.beta = 0.999
        self.max_epsilon = 1
        self.min_epsilon = 0.01
        self.decay = 0.01
        self.num_lamda = 100
        self.stateTable = stateTable # all states stored in tabular form
        self.lamdaSet = np.linspace(0, 10, num=self.num_lamda)

        self.lamda_qTable = np.zeros((self.num_lamda, self.numStates, 2)) # self.num_lambda, state space, action space
        self.init_lamda_index = np.zeros(self.numStates, dtype=np.uint32)


    def _getLamda(self, state):

        self.currentState = self._findStateIndex(state) # find state in the table
        return self.lamdaSet[self.init_lamda_index[self.currentState]]

    def _takeAction(self, action, episode):
        self.exploration_explotation_tradeoff = random.uniform(0,1)

        if self.exploration_explotation_tradeoff > self.epsilon: # Q-learning is off-policy. randomize taken action.
            chosenAction = action
        else:
            chosenAction = np.random.choice([0,1], 1)[0]
        
        nextState, reward, done, info = self.env.step(chosenAction)

        self._updateQTable(self.currentState, nextState, reward, action)
        self.epsilon = self.min_epsilon+(self.max_epsilon-self.min_epsilon)*np.exp(-self.decay*episode)

        self.epsilon = self.min_epsilon 

        return nextState, reward
 
    def _findStateIndex(self, state):
        
        if(np.size(state) > 1): # if state is a vector
            stateLocation = np.where((self.stateTable == state).all(axis=1))[0][0]
        else:
            stateLocation = np.where((self.stateTable == state))[0][0]

        return stateLocation

    def _updateQTable(self, state, new_state, reward, action):
        new_state =  self._findStateIndex(new_state)

        self.lamda_qTable[self.init_lamda_index[state],state,action] = self.lamda_qTable[self.init_lamda_index[state], state, action]+self.learningRate*(reward-action*self.lamdaSet[(self.init_lamda_index[state])]+self.beta*
        np.max(self.lamda_qTable[self.init_lamda_index[state],new_state, :]))-self.learningRate*self.lamda_qTable[self.init_lamda_index[state],state, action]

        for st in range(len(self.stateTable)):
            ds=np.absolute(self.lamda_qTable[:,st,1]-self.lamda_qTable[:,st,0])
            self.init_lamda_index[st]=np.argmin(ds)
