import torch
import torch.nn as nn

from transformers import GPT2Config, GPT2Model

class RewardTransformer(nn.Module):
    """
    General Class for Reward Transformers
    """
    def __init__(self, config, device=torch.device("cuda")):
        super(RewardTransformer, self).__init__()
        
        # Initializing the Configs
        self._config = config
        self._horizon = self._config.horizon
        self._n_embd = self._config.n_embd
        self._n_layer = self._config.n_layer
        self._n_head = self._config.n_head
        self._state_dim = self._config.state_dim
        self._action_dim = self._config.action_dim
        self._dropout = self._config.dropout
        
        # Initializing the GPT2Config
        config = GPT2Config(
            n_positions= 4 * (1 + self._horizon),
            n_embd=self._n_embd * 2,
            n_layer=self._n_layer,
            n_head=self._n_head,
            resid_pdrop=self._dropout,
            embd_pdrop=self._dropout,
            attn_pdrop=self._dropout,
            use_cache=False,
        )
        
        # Initializing the Models
        self._transformer = GPT2Model(config)
        # The first embedding is used for embedding the preferred trajectory.
        self._embed_transition_1 = nn.Linear(2 * self._state_dim + self._action_dim, self._n_embd)
        # The second embedding is used for embedding the non-preferred trajectory.
        self._embed_transition_2 = nn.Linear(2 * self._state_dim + self._action_dim, self._n_embd)
        # The third embedding is used for embedding the query state and action.
        self._embed_state_action = nn.Linear(self._state_dim + self._action_dim, 2*self._n_embd)
        
        self._pred_rewards = nn.Linear(self._n_embd * 2, 1) # Model for predicting the reward.
        """
        Maybe predicting other things would be helpful.
        self._pred_actions = nn.Linear(self._n_embd, self._action_dim)
        self._pred_next_states = nn.Linear(self._n_embd, self._state_dim)
        self._pred_states = nn.Linear(self._n_embd, self._state_dim)
        """
        
        self._device = device
        self.to(self._device)
        
    def forward(self, x, query_states, query_actions, test=False):
        """
        Take as input preference trajectory information and predict the reward.

        Args:
            x (dir): 
                A Directory containing information about the trajectory.
                x["traj_1"]:{
                    "context_states": Tensor of shape (batch_size, context_length, state_dim)
                    "context_actions": Tensor of shape (batch_size, context_length, action_dim)
                    "context_next_states": Tensor of shape (batch_size, context_length, state_dim)
                    "context_rewards": Tensor of shape (batch_size, context_length, 1)   
                }
                x["traj_2"]:{
                    "context_states": Tensor of shape (batch_size, context_length, state_dim)
                    "context_actions": Tensor of shape (batch_size, context_length, action_dim)
                    "context_next_states": Tensor of shape (batch_size, context_length, state_dim)
                    "context_rewards": Tensor of shape (batch_size, context_length, 1)  
                }
                
            query_states (Tensor): Tensor of shape (batch_size, state_dim)
            query_actions (Tensor): Tensor of shape (batch_size, action_dim)
        
        We will not use the reward signals for training the model, and we assume that traj_1 is always preferred over traj_2.
        """
       
       
        """
        The following piece of code concatenates preferred trajectory and non-preferred trajectory seperately.
        
        It also generates the embeddings for each.
        
        traj_1_embedding: Tensor of shape (batch_size, context_length, n_embd)
        traj_2_embedding: Tensor of shape (batch_size, context_length, n_embd)
        """
        traj_1 = torch.cat(
            [x['traj_1']['context_states'], 
             x['traj_1']['context_actions'], 
             x['traj_1']['context_next_states']], dim=2
        )
        traj_2 = torch.cat(
            [x['traj_2']['context_states'], 
             x['traj_2']['context_actions'], 
             x['traj_2']['context_next_states']], dim=2
        )
        
        traj_1_embedding = self._embed_transition_1(traj_1)
        traj_2_embedding = self._embed_transition_2(traj_2)
        
        traj_embedding = torch.cat(
            [traj_1_embedding, traj_2_embedding], dim=2
        )
        """
        The following piece of code concatenates the query state and the query action.
        It also generates the embedding for it.
        """
        query = torch.cat([query_states, query_actions], dim=1) # [batch_size, state_dim + action_dim]
        query_embedding = self._embed_state_action(query)[:, None, :] # [batch_size, 1, n_embd]
        
        """
        The followiing piece of code concatenates the query embedding with the preferred trajectory and non-preferred trajectory embeddings.
        """
        # query_embedding = query_embedding.repeat(1, self._horizon, 1) # [batch_size, horizon, n_embd]
        # stacked_inputs = torch.cat(
        #     [query_embedding, traj_1_embedding, traj_2_embedding], dim=2
        # ) # The query embedding is broadcasted to the shape of the preferred trajectory and non-preferred trajectory embeddings.
        stacked_inputs = torch.cat(
            [query_embedding, traj_embedding], dim=1
        )
        transformer_outputs = self._transformer(inputs_embeds=stacked_inputs)
        preds = self._pred_rewards(transformer_outputs['last_hidden_state']) # Get all intermediate hidden states
        
        if test:
            return preds[:, -1, :]
        return preds[:, 1:, :]

        