import random
import numpy
import time
from functools import lru_cache
# from collections import defaultdict

from .SimpleQLearningAgentClass import QLearningAgentSimple
from simple_rl.tasks.gym.GymMDPClass import GymState

import numpy as np
from collections import deque

class DynaQLearningAgent(QLearningAgentSimple):
    def __init__(self, actions, name="DynaQ", num_hallucinations=16, **kwargs):
        self.transition_func = dict()
        self.reward_func = dict()
        self.num_hallucinations = num_hallucinations

        super().__init__(actions, name=name, **kwargs)

    @lru_cache(maxsize=None)
    def transition_func_cached(self, state, action):
        return self.transition_func[state][action]
    
    @lru_cache(maxsize=None)
    def reward_func_cached(self, state, action):
        return self.reward_func[state][action]
    
    def update_models(self, state, action, reward, next_state, pre_hashed=False):
        if state is None:
            self.prev_state = next_state
            return
        
        state = self._pre_extract_state(state)
        next_state = self._pre_extract_state(next_state)

        state = self.state_hash_func(state)
        next_state = self.state_hash_func(next_state)

        if state not in self.transition_func:
            self.transition_func[state] = dict()
        
        if state not in self.reward_func:
            self.reward_func[state] = dict()
        
        if state in self.transition_func and action not in self.transition_func[state]:
            self.transition_func[state][action] = next_state

        if state in self.reward_func and action not in self.reward_func[state]:
            self.reward_func[state][action] = reward

    def hallucinate(self):
        states = list(self.transition_func.keys())
        # print(states)
        for _ in range(self.num_hallucinations):
            if states == []:
                break
            state = random.choice(states)
            action = random.choice(self.actions)
            if action not in self.transition_func[state] or action not in self.reward_func[state]:
                continue
            # next_state = self.transition_func[state][action]
            next_state = self.transition_func_cached(state, action)
            # print(next_state)
            # reward = self.reward_func[state][action]
            reward = self.reward_func_cached(state, action)
            # print(reward)
            super().update(state, action, reward, next_state, pre_hashed=True)
            states.remove(state)
    
    def update(self, state, action, reward, next_state, pre_hashed=False):
        # Update Q-values
        super().update(state, action, reward, next_state, pre_hashed=pre_hashed)
        # Update model
        self.update_models(state, action, reward, next_state)
        self.hallucinate()

    def reset(self):
        super().reset()
        self.transition_func = dict()
        self.reward_func = dict()
        self.transition_func_cached.cache_clear()
        self.reward_func_cached.cache_clear()
