from collections import defaultdict
import random
from functools import lru_cache

from simple_rl.agents import QLearningAgent
from simple_rl.tasks.gym.GymMDPClass import GymState
from simple_rl.mdp.StateClass import State as Simple_rl_State
import numpy as np


class VHQLearningAgentSimple(QLearningAgent):
    """Implementation for a Q Learning agent that utilizes RLang hints, amenable to VirtualHome"""

    def __init__(self, state_hash_func=None, name="Q-Learning", available_action_function=None, *args, **kwargs):
        """
        Args:
            name (str): Denotes the name of the agent.
            alpha (float): Learning rate.
            gamma (float): Discount factor.
            epsilon (float): Exploration term.
            explore (str): One of {softmax, uniform}. Denotes explore policy. Softmax is not supported!
            default_q (float): the default value to initialize every entry in the q-table with [by default, set to 0.0]
        """
        self.state_hash_func = state_hash_func
        self.available_action_function = available_action_function
        self.hashed_available_actions = dict()

        super().__init__(actions=[], name=name, *args, **kwargs)
    
    def _pre_extract_state(self, state):
        # if isinstance(state, Simple_rl_State):
        #     state = state.data
        # print(type(state))
        return state

    def update(self, state, action, reward, next_state, pre_hashed=False):
        '''
        Args:
            state (State)
            action (str)
            reward (float)
            next_state (State)

        Summary:
            Updates the internal Q Function according to the Bellman Equation. (Classic Q Learning update)
        '''
        # If this is the first state, just return.
        if state is None:
            self.prev_state = next_state
            return

        state = self._pre_extract_state(state)
        next_state = self._pre_extract_state(next_state)
        # Update the Q Function.
        max_q_curr_state = self.get_max_q_value(next_state, pre_hashed=pre_hashed)
        prev_q_val = self.get_q_value(state, action, pre_hashed=pre_hashed)

        if pre_hashed:
            self.q_func[state][action] = (1 - self.alpha) * prev_q_val + self.alpha * (reward + self.gamma * max_q_curr_state)
        else:
            self.q_func[self.state_hash_func(state)][action] = \
            (1 - self.alpha) * prev_q_val + self.alpha * (reward + self.gamma * max_q_curr_state)
    

    @lru_cache(maxsize=None)
    def get_available_actions(self, state, pre_hashed=False): # state must not be a dictionary, as dictionaries are not hashable (note the @lru_cache)
        # Store the available actions of the hashed state? But we need the inverse hash function now.
        # print(type(state), pre_hashed)
        # print(type(state.data))
        # print(state)
        
        if self.available_action_function is None:
            return self.actions
        
        if not pre_hashed:
            # print("Unhashed type: ", type(state))
            actions = self.available_action_function(state)
            hashed_state = self.state_hash_func(state)
            # print(hashed_state)
            self.hashed_available_actions[hashed_state] = actions
            return actions
        else:
            # print("Prehashed type: ", type(state))
            if state in self.hashed_available_actions:
                return self.hashed_available_actions[state]
            else:
                # This is bad
                raise ValueError("The state not found, it was likely not entered properly the first time.")
                return None

    def get_q_value(self, state, action, pre_hashed=False):
        '''
        Args:
            state (State)
            action (str)

        Returns:
            (float): denoting the q value of the (@state, @action) pair.
        '''
        state = self._pre_extract_state(state)

        # if not self.state_hash_func(self.state_unwrapper(state)) in self.q_func:
        #     print("Not Seen")

        if pre_hashed:
            return self.q_func[state][action]

        return self.q_func[self.state_hash_func(state)][action]
    
    def get_max_q_value(self, state, pre_hashed=False):
        '''
        Args:
            state (State)

        Returns:
            (float): denoting the max q value in the given @state.
        '''
        return self._compute_max_qval_action_pair(state, pre_hashed=pre_hashed)[0]
    
    def _compute_max_qval_action_pair(self, state, pre_hashed=False):
        '''
        Args:
            state (State)

        Returns:
            (tuple) --> (float, str): where the float is the Qval, str is the action.
        '''
        # Grab random initial action in case all equal
        available_actions = self.get_available_actions(state, pre_hashed)
        best_action = random.choice(available_actions)
        max_q_val = float("-inf")
        shuffled_action_list = available_actions[:]
        random.shuffle(shuffled_action_list)

        # Find best action (action w/ current max predicted Q value)
        for action in shuffled_action_list:
            q_s_a = self.get_q_value(state, action, pre_hashed=pre_hashed)
            if q_s_a > max_q_val:
                max_q_val = q_s_a
                best_action = action

        return max_q_val, best_action
    
    def epsilon_greedy_q_policy(self, state):
        '''
        Args:
            state (State)

        Returns:
            (str): action.
        '''
        # Policy: Epsilon of the time explore, otherwise, greedyQ.
        if np.random.random() > self.epsilon:
            # Exploit.
            action = self.get_max_q_action(state)
        else:
            # Explore
            action = np.random.choice(self.get_available_actions(state))
        # print("[in agent] Action taken:", action)
        return action

    def reset(self):
        super().reset()
        self.hashed_available_actions = dict()
        self.get_available_actions.cache_clear()
