from function import TabularFunction

from collections import deque


class QFunction(TabularFunction):
    def __init__(self, num_actions):
        super().__init__(num_actions)
        self.success_buffer = deque(maxlen=5) # TODO
        self.success_traject = [] # TODO

        self.update_index = -1
        self.cycle = False
        self.cycle_state = None

    def reset_hyperparameters(self):
        self.update_index = -1
        self.cycle = False
        self.cycle_state = None

    def backup(self, batch, i, alpha, gamma, rho, ex_reward):
        state = batch[i][0]
        action = batch[i][1]

        if i < len(batch) - 1:
            next_action = batch[i + 1][1]
        else:
            next_action = -1

        reward = rho * batch[i][2]
        next_state = batch[i][3]
        
        cycle_penalty = reward

        # Add reward if cycle is detected
        if self.cycle_state is not None and next_state == self.cycle_state:
            reward -= 1.0 #* alphas
            cycle_penalty = reward

        # Save extrinsic rewards in separate q-table for transfer learning
        #if ex_reward:
            #old_t_value = self.ttable.get_value(state, action)
            #next_t_value = self.ttable.get_value(next_state, next_action)
            #new_t_value = (1 - alpha) * old_t_value + alpha * (cycle_penalty + gamma * next_t_value)
            #self.ttable.set_value(state, action, new_t_value)

        # Calculate new q-value and update q_table
        old_q_value = self.qtable.get_value(state, action)
                
        # Do SARSA or Q-Learning update
        next_q_value = self.qtable.get_value(next_state, next_action)
        
        # new_value = (1 - alpha) * old_q_value + alpha * (reward + gamma * next_q_value)
        new_value = (1 - alpha) * old_q_value + alpha * (cycle_penalty + gamma * next_q_value)
        
        # Update Q-Table
        self.qtable.set_value(state, action, new_value)

    def update(self, batch, alpha, gamma, rho):
        if self.calculate_cycle(batch):
            self.backup(batch, self.update_index, alpha, gamma, rho, ex_reward=False)

        reward_state = [i for i in range(len(batch)) if batch[i][2] > 0]

        if len(reward_state) > 0:
            for i in range(reward_state[0], -1, -1):
                self.backup(batch, i, alpha=0.9, gamma=gamma, rho=rho, ex_reward=True)

        self.reset_hyperparameters()

    def calculate_cycle(self, batch):
        state_list = [batch[i][0] for i in range(len(batch))]

        # Calculate cycle meta data
        for i in range(len(state_list) - 1):
            if batch[-1][0] == state_list[i]:
                # Get index where cycle occurred
                self.update_index = len(batch) - 2

                # Save cycle state and set cycle to encountered
                self.cycle_state = batch[len(batch) - 1][0]
                self.cycle = True

        if self.cycle:
            return True
        return False
