import torch
import torch.nn as nn

from transformers import GPT2Config, GPT2Model


class Transformer(nn.Module):
    """
    General Class for Transformers
    """
    def __init__(self, config, device=torch.device("cuda")):
        super(Transformer, self).__init__()
        
        # Initializinng the Configs
        self._config = config
        self._horizon = self._config.horizon
        self._n_embd = self._config.n_embd
        self._n_layer = self._config.n_layer
        self._n_head = self._config.n_head
        self._state_dim = self._config.state_dim
        self._action_dim = self._config.action_dim
        self._dropout = self._config.dropout
        
        # Initializing the GPT2Config
        config = GPT2Config(
            n_positions= 40 * (1 + self._horizon),
            n_embd=self._n_embd,
            n_layer=self._n_layer,
            n_head=self._n_head,
            resid_pdrop=self._dropout,
            embd_pdrop=self._dropout,
            attn_pdrop=self._dropout,
            use_cache=False,
        )
        
        # Initializing the Models
        self._transformer = GPT2Model(config)
        self._embed_transition = nn.Linear(2 * self._state_dim + self._action_dim + 1, self._n_embd)
        
        self._pred = nn.Linear(self._n_embd, self._action_dim)
        
        self._device = device # Device for the model
        self.to(self._device)
    
    def forward(self, x, test=False):
        """
        Take as input trajectory information and predict the next action

        Args:
            x (dir): 
                A Directory containing information about the trajectory.
                x["context_states"]: Tensor of shape (batch_size, context_length, state_dim)
                x["context_actions"]: Tensor of shape (batch_size, context_length, action_dim)
                x["context_next_states"]: Tensor of shape (batch_size, context_length, state_dim)
                x["context_rewards"]: Tensor of shape (batch_size, context_length, 1)
                x["query_states"]: Tensor of shape (batch_size, state_dim)
                x["optimal_actions"]: Tensor of shape (batch_size, action_dim)

        Returns:
            _type_: _description_
        """
        query_states = x['query_states'][:, None, :]
        batch_size, context_length, _ = query_states.shape
        
        
        zeros = torch.zeros(
            batch_size, 1, self._config.state_dim ** 2 + self._config.action_dim + 1,
        ).to(self._device)
        
        """
        The following piece of code concatenates the query states with the context states, actions, next_states and rewards.
        state_seq: Tensor of shape (batch_size, context_length + 1, state_dim)
        action_seq: Tensor of shape (batch_size, context_length + 1, action_dim)
        next_state_seq: Tensor of shape (batch_size, context_length + 1, state_dim)
        reward_seq: Tensor of shape (batch_size, context_length + 1, 1)
        """
        if x['context_states'] is not None:
            state_seq = torch.cat([query_states, x['context_states']], dim=1)
            action_seq = torch.cat(
                [zeros[:, :, :self._action_dim], x['context_actions']], dim=1)
            next_state_seq = torch.cat(
                [zeros[:, :, :self._state_dim], x['context_next_states']], dim=1)
            reward_seq = torch.cat([zeros[:, :, :1], x['context_rewards']], dim=1)
        else:
            state_seq = torch.cat([query_states], dim=1)
            action_seq = torch.cat(
                [zeros[:, :, :self._action_dim]], dim=1)
            next_state_seq = torch.cat(
                [zeros[:, :, :self._state_dim]], dim=1)
            reward_seq = torch.cat([zeros[:, :, :1]], dim=1)
        
        """
        Concatenating the state_seq, action_seq, next_state_seq and reward_seq to form a single sequence.
        seq: Tensor of shape (batch_size, context_length + 1, state_dim + action_dim + state_dim + reward(1)
        """
        seq = torch.cat(
            [state_seq, action_seq, next_state_seq, reward_seq], dim=2)
        stacked_inputs = self._embed_transition(seq) # Embedding the Sequence
        transformer_outputs = self._transformer(inputs_embeds=stacked_inputs) # Passing the Sequence through the Transformer
        preds = self._pred(transformer_outputs['last_hidden_state']) # Get all intermediate hidden states
        if test:
            preds = torch.softmax(preds, dim=2)
            return preds[:, -1, :].detach().cpu().numpy()
        
        return preds[:, 1:, :]
    
class PreferenceTransformer(nn.Module):
    def __init__(self, config, device=torch.device("cuda")):
        super().__init__()
        
        # Initializing the Configs
        self._config = config
        self._horizon = self._config.horizon
        self._n_embd = self._config.n_embd
        self._n_layer = self._config.n_layer
        self._n_head = self._config.n_head
        self._state_dim = self._config.state_dim
        self._action_dim = self._config.action_dim
        self._dropout = self._config.dropout
        
        # Initializing the GPT2Config
        config = GPT2Config(
            n_positions= 4 * (1 + self._horizon),
            n_embd=self._n_embd * 2,
            n_layer=self._n_layer,
            n_head=self._n_head,
            resid_pdrop=self._dropout,
            embd_pdrop=self._dropout,
            attn_pdrop=self._dropout,
            use_cache=False,
        )
        
        # Initializing the Models
        self._transformer = GPT2Model(config)
        # The first embedding is used for embedding the preferred trajectory.
        self._embed_transition_1 = nn.Linear(2 * self._state_dim + self._action_dim, self._n_embd)
        # The second embedding is used for embedding the non-preferred trajectory.
        self._embed_transition_2 = nn.Linear(2 * self._state_dim + self._action_dim, self._n_embd)
        # The third embedding is used for embedding the query state and action.
        self._embed_state = nn.Linear(self._state_dim, 2*self._n_embd)
        
        self._pred = nn.Linear(self._n_embd * 2, self._action_dim) # Model for predicting the reward.
        """
        Maybe predicting other things would be helpful.
        self._pred_actions = nn.Linear(self._n_embd, self._action_dim)
        self._pred_next_states = nn.Linear(self._n_embd, self._state_dim)
        self._pred_states = nn.Linear(self._n_embd, self._state_dim)
        """
        
        self._device = device
        self.to(self._device)
        
    def forward(self, x, query_states, test=False):
        """
        Take as input preference trajectory information and predict the reward.

        Args:
            x (dir): 
                A Directory containing information about the trajectory.
                x["traj_1"]:{
                    "context_states": Tensor of shape (batch_size, context_length, state_dim)
                    "context_actions": Tensor of shape (batch_size, context_length, action_dim)
                    "context_next_states": Tensor of shape (batch_size, context_length, state_dim)
                    "context_rewards": Tensor of shape (batch_size, context_length, 1)   
                }
                x["traj_2"]:{
                    "context_states": Tensor of shape (batch_size, context_length, state_dim)
                    "context_actions": Tensor of shape (batch_size, context_length, action_dim)
                    "context_next_states": Tensor of shape (batch_size, context_length, state_dim)
                    "context_rewards": Tensor of shape (batch_size, context_length, 1)  
                }
                
            query_states (Tensor): Tensor of shape (batch_size, state_dim)
            query_actions (Tensor): Tensor of shape (batch_size, action_dim)
        
        We will not use the reward signals for training the model, and we assume that traj_1 is always preferred over traj_2.
        """
       
       
        """
        The following piece of code concatenates preferred trajectory and non-preferred trajectory seperately.
        
        It also generates the embeddings for each.
        
        traj_1_embedding: Tensor of shape (batch_size, context_length, n_embd)
        traj_2_embedding: Tensor of shape (batch_size, context_length, n_embd)
        """
        traj_1 = torch.cat(
            [x['traj_1']['context_states'], 
             x['traj_1']['context_actions'], 
             x['traj_1']['context_next_states']], dim=2
        )
        traj_2 = torch.cat(
            [x['traj_2']['context_states'], 
             x['traj_2']['context_actions'], 
             x['traj_2']['context_next_states']], dim=2
        )
        
        traj_1_embedding = self._embed_transition_1(traj_1)
        traj_2_embedding = self._embed_transition_2(traj_2)
        
        traj_embedding = torch.cat(
            [traj_1_embedding, traj_2_embedding], dim=2
        )
        """
        The following piece of code concatenates the query state and the query action.
        It also generates the embedding for it.
        """
        query = torch.cat([query_states], dim=1) # [batch_size, state_dim + action_dim]
        query_embedding = self._embed_state(query)[:, None, :] # [batch_size, 1, n_embd]
        
        """
        The followiing piece of code concatenates the query embedding with the preferred trajectory and non-preferred trajectory embeddings.
        """
        # query_embedding = query_embedding.repeat(1, self._horizon, 1) # [batch_size, horizon, n_embd]
        # stacked_inputs = torch.cat(
        #     [query_embedding, traj_1_embedding, traj_2_embedding], dim=2
        # ) # The query embedding is broadcasted to the shape of the preferred trajectory and non-preferred trajectory embeddings.
        stacked_inputs = torch.cat(
            [query_embedding, traj_embedding], dim=1
        )
        transformer_outputs = self._transformer(inputs_embeds=stacked_inputs)
        preds = self._pred(transformer_outputs['last_hidden_state']) # Get all intermediate hidden states
        
        if test:
            preds = torch.softmax(preds, dim=2)
            return preds[:, -1, :].detach().cpu().numpy()
        return preds[:, 1:, :]