import random
import numpy
import time
from functools import lru_cache
# from collections import defaultdict

from .DynaQLearningAgentClass import DynaQLearningAgent
from simple_rl.tasks.gym.GymMDPClass import GymState
from rlang import VectorState

import numpy as np
from collections import deque

class RLangDynaQLearningAgent(DynaQLearningAgent):
    def __init__(self, actions, get_knowledge, use_policy=False, use_plan=False, use_effects=False,
                  name="RLang-Dyna-Q", policy_epsilon=1.0, num_hallucinations=16, **kwargs):
        
        self.use_policy = use_policy
        self.use_plan = use_plan
        self.use_effects = use_effects
        self.policy_epsilon = policy_epsilon
        self.get_knowledge = get_knowledge

        self.state_featurizer = None
        self.was_reset = True
        self.knowledge = None

        self.plan_used_up = False
        self.reverse_state_hash_map = dict()

        super().__init__(actions, name=name, num_hallucinations=num_hallucinations, **kwargs)
    
    def rlang_state_to_state(self, rlang_state):
        return (rlang_state[0].view(np.ndarray), tuple(rlang_state[1].view(np.ndarray)), tuple(rlang_state[2].view(np.ndarray)))

    @lru_cache(maxsize=None)
    def memoized_plan(self, rlang_state):
        return self.knowledge.plan(state=rlang_state)

    def act(self, state, reward, learning=True):
        if self.was_reset:
            init_state = self._pre_extract_state(state)
            self.knowledge, self.state_featurizer = self.get_knowledge(init_state)
            self.knowledge.memoized_reward_function.cache_clear()
            self.knowledge.memoized_transition_function.cache_clear()
            self.was_reset = False
        
        # if self.state_featurizer:
        #     self.state_featurizer.update_objects(self.state_unwrapper(state))

        if learning:
            # self.update(self.prev_state, self.prev_action, reward, state)
            # Before choosing an action, let's cycle through the next actions, next transtions, and update q values
            if self.use_effects and self.knowledge:
                start_time = time.perf_counter()
                rlang_state = VectorState(state, dtype=object)
                hashed_state = self.state_hash_func(state)
                if hashed_state not in self.reward_func:
                    # for a in self.actions:
                    a = random.choice(self.actions)
                    # Check to see if a value update has already been done for this state-action pair
                    potential_reward = int(self.knowledge.memoized_reward_function(state=rlang_state, action=a))
                    if potential_reward != 0:
                        # print("Learning from RLang reward")
                        # update reward function
                        if hashed_state not in self.reward_func:
                            self.reward_func[hashed_state] = dict()

                        self.reward_func[hashed_state][a] = potential_reward
                        # update q function
                        partial_q = self.q_func[hashed_state]
                        if partial_q[a] == self.default_q:
                            max_q_curr_state = self.get_max_q_value(state) # This is wrong!! should be max q of next state
                            partial_q[a] = (1 - self.alpha) * self.default_q + self.alpha * (potential_reward + self.gamma * max_q_curr_state)
                # print("RLang reward update time: ", time.perf_counter() - start_time)
        
        return super().act(state, reward, learning=learning)
    
    def epsilon_greedy_q_policy(self, state):
        state = self._pre_extract_state(state)

        if (self.use_policy or (self.use_plan and not self.plan_used_up)) and np.random.random() < self.policy_epsilon:
            if self.use_plan and not self.plan_used_up:
                action = self.memoized_plan(rlang_state=VectorState(state, dtype=object))
                # print(action)
            else:
                action = self.knowledge.policy(state=VectorState(state, dtype=object))
            
            # if self.use_policy:
            #     print(action)
                # print(state)
                
            if action is not None and action != {}:
                if not isinstance(action, int):
                    action = int(list(action.keys())[0])
                
                if self.use_effects:
                    if self.knowledge.memoized_reward_function(state=VectorState(state, dtype=object), action=action) >= 0 and \
                          (self.state_hash_func(state) not in self.reward_func or 
                           action not in self.reward_func[self.state_hash_func(state)] or
                           self.reward_func[self.state_hash_func(state)][action] >= 0):
                        return action
                    else:
                        self.plan_used_up = True    # Kill the plan if it has us make a bad move.
                        return super().epsilon_greedy_q_policy(state)
                else:
                    return action
            else:
                if self.plan_used_up == False:
                    # print("plan used up")
                    self.plan_used_up = True
                return super().epsilon_greedy_q_policy(state)
        else:
            return super().epsilon_greedy_q_policy(state)

    def hallucinate(self):
        states = list(self.transition_func.keys())

        hallucination_counter = 0

        while hallucination_counter < self.num_hallucinations:
            if states == []:
                break
            hashed_state = random.choice(states)
            
            if self.use_effects and hashed_state in self.reverse_state_hash_map:
                rlang_state = VectorState(self.reverse_state_hash_map[hashed_state], dtype=object)
            else:
                rlang_state = None

            if self.use_policy and hashed_state in self.reverse_state_hash_map and False:
                rlang_state_for_policy = VectorState(self.reverse_state_hash_map[hashed_state], dtype=object)
                action = self.knowledge.policy(state=rlang_state_for_policy)
                if action is None or action == {}:
                    action = random.choice(self.actions)
            else:
                action = random.choice(self.actions)

            next_state = None
            reward = None

            # print("Attempting to hallucinate")
            hallucination_counter += 1

            if action not in self.transition_func[hashed_state]:    # If it's never tried this action before.
                # Check to see if this state-action sequence is covered in an RLang effect or prediction. (This is for getting state)
                if self.use_effects and self.knowledge and rlang_state is not None:
                    # print("RLang transition function checked!")
                    # print(rlang_state, action)
                    potential_next_state = self.knowledge.memoized_transition_function(state=rlang_state, action=action)
                    # if potential_next_state, a dictionary, is empty, continue
                    # print(rlang_state, action)
                    if potential_next_state == {}:
                        # print("RLang transition function does not have this state. Checking for predictions.")

                        predictions = self.knowledge.predictions(state=rlang_state, action=action)
                        if predictions == {}:
                            # print("No predictions")
                            continue
                        # print("There are predications, don't know how to use them, though")
                        # next_state = self.state_featurizer.get_state_from_complete_predictions(self.reverse_state_hash_map[hashed_state], predictions)
                        # if next_state is None:
                        #     print("State transition was unpredictable from RLang")
                        #     continue
                        continue
                    else:
                        # print("Rlang transition function has this state!")
                        potential_next_state = list(potential_next_state.keys())[0]
                        # print(potential_next_state)
                                               
                        hashed_next_state = self.state_hash_func(potential_next_state)
                        # print("RLang transition function used!")
                else:
                    continue
            else:
                # print("Already took this action")
                hashed_next_state = self.transition_func[hashed_state][action]

            # print(self.reward_func[hashed_state], action, action not in self.reward_func[hashed_state])
            if action not in self.reward_func[hashed_state]:
                # Check if it's in the RLang objects, if not, continue
                if self.use_effects and self.knowledge and rlang_state is not None:
                    # print("RLang reward function checked!")
                    potential_reward = int(self.knowledge.memoized_reward_function(state=rlang_state, action=action))
                    if potential_reward == None:
                        # print("No reward")
                        continue
                    else:   # RLang reward function used!
                        reward = potential_reward
                        # print("RLang reward function used!")
                else:
                    continue
            else:
                reward = self.reward_func[hashed_state][action]

            super(DynaQLearningAgent, self).update(hashed_state, action, reward, hashed_next_state, pre_hashed=True)
            states.remove(hashed_state)
            
    
    def update(self, state, action, reward, next_state, pre_hashed=False):
        # Record states. Hash them, and then store them in the thing.
        if self.use_effects and pre_hashed == False and state is not None:
            state_ = self._pre_extract_state(state)
            next_state_ = self._pre_extract_state(next_state)

            state_hashed = self.state_hash_func(state_)
            next_state_hashed = self.state_hash_func(next_state_)
            self.reverse_state_hash_map[state_hashed] = state_
            self.reverse_state_hash_map[next_state_hashed] = next_state_
        # Update Q-values
        super().update(state, action, reward, next_state, pre_hashed=pre_hashed)

    def reset(self):
        super().reset()
        self.was_reset = True
        if self.use_plan:
            self.knowledge.plan.reset()
            self.memoized_plan.cache_clear()
            self.plan_used_up = False
        self.reverse_hash_map = dict()
        self.knowledge.memoized_transition_function.cache_clear()
        self.knowledge.memoized_reward_function.cache_clear()

    def end_of_episode(self):
        super().end_of_episode()
        if self.use_plan:
            self.knowledge.plan.reset()
            self.plan_used_up = False
