import numpy as np
import random

class q_learning:
    def __init__(self, state_size, action_size):
        self.qtable = np.zeros((state_size, action_size))

        # hyperparameters
        self.learning_rate = 0.1 #0.9
        self.discount_rate = 1.0 #0.8
        self.epsilon = 1.0
        self.decay_rate = 0.005

    def take_action(self, action_sample, state):
        # exploration-exploitation tradeoff
        if random.uniform(0, 1) < self.epsilon:
            # explore
            action = action_sample
        else:
            # exploit
            action = np.argmax(self.qtable[state, :])
        return action

    def get_max_action(self, state):
        action = np.argmax(self.qtable[state, :])
        return action

    def update_qtable(self, state, new_state, action, reward):
        # Q-learning algorithm
        self.qtable[state, action] = self.qtable[state, action] + self.learning_rate * (
                reward + self.discount_rate * np.max(self.qtable[new_state, :]) - self.qtable[state, action])

    def get_qtable(self):
        return self.qtable

    def decrease_epsilon(self, episode):
        self.epsilon = np.exp(-self.decay_rate * episode)