import torch
import torch.nn as nn

from transformers import GPT2Config, GPT2Model

from .utils import calculate_cumulative_rewards

class VValueTransformer(nn.Module):
    """
    Class for V-value transformers
    """
    
    def __init__(self, config, gamma=0.9, device=torch.device("cuda")):
        super(VValueTransformer, self).__init__()
        self._config = config
        self._horizon = self._config.horizon
        self._n_embd = self._config.n_embd
        self._n_layer = self._config.n_layer
        self._n_head = self._config.n_head
        self._state_dim = self._config.state_dim
        self._action_dim = self._config.action_dim
        self._dropout = self._config.dropout
        
        # Initializing the GPT2Config
        config = GPT2Config(
            n_positions=4 * (1 + self._horizon),
            n_embd=self._n_embd,
            n_layer=self._n_layer,
            n_head=self._n_head,
            resid_pdrop=self._dropout,
            embd_pdrop=self._dropout,
            attn_pdrop=self._dropout,
            use_cache=False,
        )
        
        # Initializing the Models
        self._transformer = GPT2Model(config)
        self._embed_transition = nn.Linear(self._state_dim + 1, self._n_embd)
        
        self._pred = nn.Linear(self._n_embd, 1)
        
        self._gamma = gamma
        
        self._device = device # Device for the model
        self.to(self._device)
        
    def forward(self, x):
        """
        Take as input trajectory information and predict the next action

        Args:
            x (dir): 
                A Directory containing information about the trajectory.
                x["context_states"]: Tensor of shape (batch_size, context_length, state_dim)
                x["context_actions"]: Tensor of shape (batch_size, context_length, action_dim)
                x["context_next_states"]: Tensor of shape (batch_size, context_length, state_dim)
                x["context_rewards"]: Tensor of shape (batch_size, context_length, 1)
                x["query_states"]: Tensor of shape (batch_size, state_dim)
                x["optimal_actions"]: Tensor of shape (batch_size, action_dim)

        Returns:
            _type_: _description_
        """
        batch_size, context_length, _ = x["context_states"].shape
        reward_zeros = torch.zeros(
            batch_size, 1, 1
        ).to(self._device)
        
        culumative_rewards = calculate_cumulative_rewards(x["context_rewards"], self._gamma)[: , :-1, :] # (batch_size, context_length - 1, 1), we don't need the last one 
        reward_seq = torch.cat(
            [reward_zeros, culumative_rewards], dim=1
        ) # (batch_size, context_length, 1)
        
        seq = torch.cat(
            [reward_seq, x["context_states"]], dim=2
        ) # (batch_size, context_length, state_dim + 1)
        stacked_inputs = self._embed_transition(seq)
        transformer_outputs = self._transformer(inputs_embeds=stacked_inputs)
        preds = self._pred(transformer_outputs["last_hidden_state"])
        
        return preds 
        
        