import numpy as np
import scipy
import torch

from .base_ctrl import Controller

class DarkRoomContext(dict):
    def __init__(
        self,
        state_dim = 2,
        action_dim = 5,
        device = torch.device("cpu"),
    ):
        super().__init__()
        self._state_dim = state_dim
        self._action_dim = action_dim
        self._device = device
        
        self["context_states"] = None
        self["context_actions"] = None
        self["context_next_states"] = None
        self["context_rewards"] = None
        
        self._initlized = {"context_states":False,
                           "context_actions":False,
                           "context_next_states":False,
                           "context_rewards":False}

    def __getattr__(self, name):
        try:
            return self[name]
        except KeyError:
            raise AttributeError(f"'InitialContext' object has no attribute '{name}'")

    def __setattr__(self, name, value):
        self[name] = value

    def __delattr__(self, name):
        try:
            del self[name]
        except KeyError:
            raise AttributeError(f"'InitialContext' object has no attribute '{name}'")
    
    def append(self, key, value):
        # Make sure that the value is of the shape (1, 1, state_dim) or (1, 1, action_dim), etc.
        if key not in self:
            raise KeyError(f"Key {key} not found in context.")
        if self._initlized[key]:
            self[key] = torch.cat((self[key], value), dim=1)
        else:
            self[key] = value
            self._initlized[key] = True

class DarkroomOptimalController(Controller): 
    def __init__(self, env):
        super().__init__()
        self._env = env
        self._goal = env.goal
        
    def reset(self):
        return
    
    def act(self, state):
        return self._env.opt_action(state)
    
class DarkroomTransformerController(Controller):
    def __init__(self, model): 
        super().__init__()
        self._model = model
        self._state_dim = model._config.state_dim
        self._action_dim = model._config.action_dim
        self._horizon = model._config.horizon
        self.zeros = None # This is the zero tensor for the transformer model

        self._temp = 1.0
        
        self._device = model._device
    
        self.set_context(DarkRoomContext(
            state_dim=self._state_dim,
            action_dim=self._action_dim,
            device=self._device
        ))
        
    def reset(self):
        self.set_context(DarkRoomContext(
            state_dim=self._state_dim,
            action_dim=self._action_dim,
            device=self._device
        ))
        
    def act(self, state, sample=False):
        """
        This function takes in a state from the environment and returns an action.

        Args:
            state (np.array): shape: [state_dim,]
        """
        
        ### The following function could be changed further.
        state = torch.tensor(state).float().to(self._device) #It is in the shape of (state_dim,)
        self.context["query_states"]= state.unsqueeze(0)
        action_probs = self._model(self.context, test=True)[0] #(action_dim,)
        
        if sample:
            action_index = np.random.choice(
                np.arange(self._action_dim), p=action_probs
            )
        else:
            action_index = np.argmax(action_probs, axis=-1)
        
        action = np.zeros(self._action_dim)
        action[action_index] = 1.0
        
        return action
    
    def append(self, key, value):
        value = torch.tensor(value).float().to(self._device)
        
        if value.dim() == 0:
            value = value.unsqueeze(0).unsqueeze(0).unsqueeze(0)
        elif value.dim() == 1:
            value = value.unsqueeze(0).unsqueeze(0)
        elif value.dim() == 2:
            value = value.unsqueeze(0)
        
        self.context.append(key, value)
        
class MetaWorldTransformerController(DarkroomTransformerController):
    def __init__(self, model): 
        super().__init__(model)
        
    def act(self, state):
        state = torch.tensor(state).float().to(self._device) #It is in the shape of (state_dim,)
        self.context["query_states"]= state.unsqueeze(0)
        action = self._model(self.context, test=True)[0] #(action_dim,)
        
        return action
        