from collections import defaultdict
import random

from simple_rl.agents import QLearningAgent
from simple_rl.tasks.gym.GymMDPClass import GymState
import numpy as np


class QLearningAgentSimple(QLearningAgent):
    """Implementation for a Q Learning agent that utilizes RLang hints, amenable to Minigrid"""

    def __init__(self, actions, state_hash_func=None, name="Q-Learning-hashed", available_action_function=None, *args, **kwargs):
        """
        Args:
            actions (list): Contains strings denoting the actions.
            states (dict): A dictionary of hashed states to unwrapped states
            knowledge (list): An RLangKnowledge object.
            name (str): Denotes the name of the agent.
            alpha (float): Learning rate.
            gamma (float): Discount factor.
            epsilon (float): Exploration term.
            explore (str): One of {softmax, uniform}. Denotes explore policy.
            default_q (float): the default value to initialize every entry in the q-table with [by default, set to 0.0]
        """
        self.state_hash_func = state_hash_func
        self.available_action_function = available_action_function
        self.hashed_available_actions = dict()

        super().__init__(actions, name=name, *args, **kwargs)
    
    def _pre_extract_state(self, state):
        if isinstance(state, GymState):
            state = state.data
        return state

    def update(self, state, action, reward, next_state, pre_hashed=False):
        '''
        Args:
            state (State)
            action (str)
            reward (float)
            next_state (State)

        Summary:
            Updates the internal Q Function according to the Bellman Equation. (Classic Q Learning update)
        '''
        # If this is the first state, just return.
        if state is None:
            self.prev_state = next_state
            return

        state = self._pre_extract_state(state)
        next_state = self._pre_extract_state(next_state)
        # Update the Q Function.
        max_q_curr_state = self.get_max_q_value(next_state, pre_hashed=pre_hashed)
        prev_q_val = self.get_q_value(state, action, pre_hashed=pre_hashed)

        if pre_hashed:
            self.q_func[state][action] = (1 - self.alpha) * prev_q_val + self.alpha * (reward + self.gamma * max_q_curr_state)
        else:
            self.q_func[self.state_hash_func(state)][action] = \
            (1 - self.alpha) * prev_q_val + self.alpha * (reward + self.gamma * max_q_curr_state)
    

    # def get_available_actions(self, state, pre_hashed=False):
    #     # Store the available actions of the hashed state? But we need the inverse hash function now.
    #     if self.available_action_function is None:
    #         return self.actions
        
    #     if not pre_hashed:
    #         actions = self.available_action_function(state)
    #         self.hashed_available_actions[self.state_hash_func(state)] = actions
    #         return actions
    #     else:
    #         return self.hashed_available_actions[state]


    def get_q_value(self, state, action, pre_hashed=False):
        '''
        Args:
            state (State)
            action (str)

        Returns:
            (float): denoting the q value of the (@state, @action) pair.
        '''
        state = self._pre_extract_state(state)

        # if not self.state_hash_func(self.state_unwrapper(state)) in self.q_func:
        #     print("Not Seen")

        if pre_hashed:
            return self.q_func[state][action]

        return self.q_func[self.state_hash_func(state)][action]
    
    def get_max_q_value(self, state, pre_hashed=False):
        '''
        Args:
            state (State)

        Returns:
            (float): denoting the max q value in the given @state.
        '''
        return self._compute_max_qval_action_pair(state, pre_hashed=pre_hashed)[0]
    
    def _compute_max_qval_action_pair(self, state, pre_hashed=False):
        '''
        Args:
            state (State)

        Returns:
            (tuple) --> (float, str): where the float is the Qval, str is the action.
        '''
        # Grab random initial action in case all equal
        best_action = random.choice(self.actions)
        max_q_val = float("-inf")
        shuffled_action_list = self.actions[:]
        random.shuffle(shuffled_action_list)

        # Find best action (action w/ current max predicted Q value)
        for action in shuffled_action_list:
            q_s_a = self.get_q_value(state, action, pre_hashed=pre_hashed)
            if q_s_a > max_q_val:
                max_q_val = q_s_a
                best_action = action

        return max_q_val, best_action
    
    # def epsilon_greedy_q_policy(self, state):
    #     '''
    #     Args:
    #         state (State)

    #     Returns:
    #         (str): action.
    #     '''
    #     # Policy: Epsilon of the time explore, otherwise, greedyQ.
    #     if np.random.random() > self.epsilon:
    #         # Exploit.
    #         action = self.get_max_q_action(state)
    #     else:
    #         # Explore
    #         action = np.random.choice(self.get_available_actions(state))

    #     return action
