import numpy as np
import torch
import torch.nn as nn
import sys
import transformers
import torch.optim as optim
import time
import random
from torch.nn import functional as F

from decision_transformer.models.model import TrajectoryModel
from decision_transformer.models.trajectory_gpt2 import GPT2Model

import torch
import torch.nn as nn
import torch.nn.functional as F

class TransformerEncoderLayer(nn.Module):
    def __init__(self, embed_dim, num_heads, ff_dim, dropout=0.1):
        super().__init__()
        self.self_attn = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout, batch_first=True)
        self.linear1 = nn.Linear(embed_dim, ff_dim)
        self.dropout = nn.Dropout(dropout)
        self.linear2 = nn.Linear(ff_dim, embed_dim)

        self.norm1 = nn.LayerNorm(embed_dim)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)

    def forward(self, src, src_mask=None, src_key_padding_mask=None):
        # Self-attention with residual and normalization
        attn_output, _ = self.self_attn(src, src, src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask)
        src = src + self.dropout1(attn_output)
        src = self.norm1(src)

        # Feed-forward network with residual and normalization
        ff_output = self.linear2(self.dropout(F.relu(self.linear1(src))))
        src = src + self.dropout2(ff_output)
        src = self.norm2(src)
        return src

class TransformerEncoder(nn.Module):
    def __init__(self, input_dim, embed_dim=256, num_heads=8, ff_dim=512, num_layers=8, dropout=0.1, max_len=5000):
        super().__init__()
        self.input_linear = nn.Linear(input_dim, embed_dim)
        self.pos_embedding = nn.Parameter(torch.zeros(1, max_len, embed_dim))
        self.layers = nn.ModuleList([
            TransformerEncoderLayer(embed_dim, num_heads, ff_dim, dropout) for _ in range(num_layers)
        ])
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask=None, key_padding_mask=None):
        batch_size, seq_len, _ = x.size()
        x = self.input_linear(x)
        x = x + self.pos_embedding[:, :seq_len, :]
        x = self.dropout(x)

        for layer in self.layers:
            x = layer(x, src_mask=mask, src_key_padding_mask=key_padding_mask)
        return x

class Cross_attetion_DecisionTransformer(TrajectoryModel):
    """
    This model uses GPT to model (Return_1, state_1, action_1, Return_2, state_2, ...)
    """

    def __init__(
            self,
            device,
            state_dim,
            act_dim,
            hidden_size,
            d_model,
            nhead,
            num_layers,
            max_length=None,
            max_ep_len=4096,
            action_tanh=True,
            nom_subgoal = 5,
            **kwargs
    ):
        super().__init__(state_dim, act_dim, max_length=max_length)
        self.state_dim = state_dim
        self.hidden_size = hidden_size
        self.input_dim = state_dim+1
        self.output_dim = 16
        self.device = device
        self.d_model = 128
        self.nhead = 1
        self.num_layers = 1
        config = transformers.GPT2Config(
            vocab_size=1,  # doesn't matter -- we don't use the vocab
            n_embd=hidden_size,
            **kwargs
        )

        # note: the only difference between this GPT2Model and the default Huggingface version
        # is that the positional embeddings are removed (since we'll add those ourselves)
        self.transformer = GPT2Model(config)
        self.embed_state = torch.nn.Linear(self.state_dim, hidden_size)
        self.embed_action = torch.nn.Linear(self.act_dim, hidden_size)
        self.embed_reward = torch.nn.Linear(1, hidden_size)
        self.embed_timestep = nn.Embedding(max_ep_len, hidden_size)
        self.embed_subgoal =torch.nn.Linear(self.state_dim, hidden_size)
        self.pre_z = torch.nn.Linear(self.hidden_size, self.output_dim)
        self.embed_z = torch.nn.Linear(self.output_dim, self.hidden_size)
        self.pre_subgoal = torch.nn.Linear(hidden_size,self.state_dim)
        self.embed_ln = nn.LayerNorm(hidden_size)
        self.nom_subgoal = nom_subgoal
        self.predict_state = torch.nn.Linear(hidden_size, self.state_dim)
        self.predict_action = nn.Sequential(
            *([nn.Linear(hidden_size, self.act_dim)] + ([nn.Tanh()] if action_tanh else []))
        )
        self.mlp = nn.Sequential(
            nn.Linear(hidden_size*2, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, self.output_dim),
            nn.ReLU()
        )
        self.encoder = TransformerEncoder(input_dim=self.hidden_size*2)
        self.decoder = TransformerDecoder(d_model,nhead,num_layers)
        self.predict_return = torch.nn.Linear(hidden_size, 1)
    def forward( 
        self,
        states,
        actions,
        rewards,
        dones,
        returns_to_go,
        timesteps,
        attention_mask,
        target_rtg,
        subgoal
        ):
        batch_size, seq_length = states.shape[0], states.shape[1]
        if attention_mask is None:
            # attention mask for GPT: 1 if can be attended to, 0 if not
            attention_mask = torch.ones((batch_size, seq_length), dtype=torch.long)
        state_embeddings = self.embed_state(states)
        returns_embeddings = self.embed_reward(returns_to_go)
        target_rtg_embeddings = self.embed_reward(target_rtg).reshape(batch_size,seq_length,self.hidden_size)
        combined = torch.cat([state_embeddings, target_rtg_embeddings], dim=-1).reshape(batch_size,seq_length,self.hidden_size*2)
        pre_z = self.mlp(combined)
        z_embeddings = self.embed_z(pre_z)
        action_embeddings = self.embed_action(actions)
        time_embeddings = self.embed_timestep(timesteps)
        # predict_subgoal = self.plan_transformer(pre_z = pre_z,subgoal=subgoal,target_rtg=target_rtg,plan_masks=plan_masks)
        
        state_embeddings = state_embeddings + time_embeddings
        action_embeddings = action_embeddings + time_embeddings
        returns_embeddings = returns_embeddings + time_embeddings
        z_embeddings = z_embeddings + time_embeddings
        # 按时间步将每个 subgoal 单独展开，形成 (batch_size, seq_length, hidden_size) 的输入
        stacked_inputs = torch.stack(
            (returns_embeddings,state_embeddings, z_embeddings,action_embeddings),dim=1
        ).permute(0, 2, 1, 3).reshape(batch_size, 4*seq_length, self.hidden_size)
        # stacked_inputs = torch.stack(
        #     (returns_embeddings, state_embeddings, subgoal_embeddings,action_embeddings), dim=1
        # ).permute(0, 2, 1, 3).reshape(batch_size, 4*seq_length, self.hidden_size)
        stacked_inputs = self.embed_ln(stacked_inputs)
        stacked_attention_mask = torch.stack(
            (attention_mask,attention_mask, attention_mask, attention_mask), dim=1
        ).permute(0, 2, 1).reshape(batch_size, 4*seq_length)
        # we feed in the input embeddings (not word indices as in NLP) to the model
        transformer_outputs = self.transformer(
            inputs_embeds=stacked_inputs,
            attention_mask = stacked_attention_mask,
        )
        x = transformer_outputs['last_hidden_state']

        # reshape x so that the second dimension corresponds to the original
        # returns (0), states (1), or actions (2); i.e. x[:,1,t] is the token for s_t
        x = x.reshape(batch_size, seq_length, 4, self.hidden_size).permute(0, 2, 1, 3)

        # get predictions
        return_preds = self.predict_return(x[:,3])  # predict next return given state and action
        state_preds = self.predict_state(x[:,3])    # predict next state given state and action
        action_preds = self.predict_action(x[:,2])  # predict next action given state
        return action_preds
    def forward_pre_noisy(self,s,target_rtg,subgoal,tgt_mask,all_rtg,noisy):
        noise_std_dev = 0.1
        subgoal_loss = 0
        batch_size, seq_length = s.shape[0], s.shape[1]
        state_embeddings = self.embed_state(s)
        target_rtg_embeddings = self.embed_reward(target_rtg).reshape(batch_size,seq_length,self.hidden_size)
        combined = torch.cat([target_rtg_embeddings,state_embeddings], dim=-1).reshape(-1, self.hidden_size*2)
        pre_z = self.mlp(combined).reshape(batch_size,seq_length,self.output_dim)
        z_embeddings =  self.embed_z(pre_z).reshape(batch_size*seq_length,-1,self.hidden_size).permute(1,0,2)
        noise = np.random.normal(0, noise_std_dev, subgoal.shape)
        subgoal_embeddings = self.embed_subgoal(subgoal).reshape(batch_size*seq_length,-1,self.hidden_size)
        all_rtg_embeddings = self.embed_reward(all_rtg).reshape(batch_size*seq_length,-1,self.hidden_size)
        stacked = torch.stack([all_rtg_embeddings,subgoal_embeddings], dim=1).permute(0, 2, 1, 3).reshape(batch_size*seq_length, 2*self.nom_subgoal, self.hidden_size).permute(1,0,2)
        pred = self.decoder(stacked,z_embeddings,tgt_mask).permute(1,0,2).reshape(batch_size,seq_length, -1, self.hidden_size)
        pre_subgoal = self.pre_subgoal(pred[:, :, ::2, :])
        i = random.randint(0, 18)
        return pre_subgoal,pre_z[:, i, :],pre_z[:, i+1, :]
    def forward_pre(self,s,target_rtg,subgoal,tgt_mask,all_rtg):
        subgoal_loss = 0
        batch_size, seq_length = s.shape[0], s.shape[1]
        state_embeddings = self.embed_state(s)
        target_rtg_embeddings = self.embed_reward(target_rtg).reshape(batch_size,seq_length,self.hidden_size)
        combined = torch.cat([target_rtg_embeddings,state_embeddings], dim=-1).reshape(-1, self.hidden_size*2)
        pre_z = self.mlp(combined).reshape(batch_size,seq_length,self.output_dim)
        z_embeddings =  self.embed_z(pre_z).reshape(batch_size*seq_length,-1,self.hidden_size).permute(1,0,2)
        subgoal_embeddings = self.embed_subgoal(subgoal).reshape(batch_size*seq_length,-1,self.hidden_size)
        all_rtg_embeddings = self.embed_reward(all_rtg).reshape(batch_size*seq_length,-1,self.hidden_size)
        stacked = torch.stack([all_rtg_embeddings,subgoal_embeddings], dim=1).permute(0, 2, 1, 3).reshape(batch_size*seq_length, 2*self.nom_subgoal, self.hidden_size).permute(1,0,2)
        pred = self.decoder(stacked,z_embeddings,tgt_mask).permute(1,0,2).reshape(batch_size,seq_length, -1, self.hidden_size)
        pre_subgoal = self.pre_subgoal(pred[:, :, ::2, :])
        i = random.randint(0, 18)
        return pre_subgoal,pre_z[:, i, :],pre_z[:, i+1, :]
    def forward_pre_e(self,s,target_rtg,subgoal,tgt_mask,all_rtg):
        subgoal_loss = 0
        batch_size, seq_length = s.shape[0], s.shape[1]
        state_embeddings = self.embed_state(s)
        target_rtg_embeddings = self.embed_reward(target_rtg).reshape(batch_size,seq_length,self.hidden_size)
        combined = torch.cat([target_rtg_embeddings,state_embeddings], dim=-1).reshape(-1, self.hidden_size*2)
        pre_z = self.mlp(combined).reshape(batch_size,seq_length,self.output_dim)
        z_embeddings =  self.embed_z(pre_z).reshape(batch_size*seq_length,-1,self.hidden_size).permute(1,0,2)
        subgoal_embeddings = self.embed_subgoal(subgoal).reshape(batch_size*seq_length,-1,self.hidden_size)
        all_rtg_embeddings = self.embed_reward(all_rtg).reshape(batch_size*seq_length,-1,self.hidden_size)
        stacked = torch.stack([all_rtg_embeddings,subgoal_embeddings], dim=1).permute(0, 2, 1, 3).reshape(batch_size*seq_length, 2*self.nom_subgoal, self.hidden_size).permute(1,0,2)
        pred = self.decoder(stacked,z_embeddings,tgt_mask).permute(1,0,2).reshape(batch_size,seq_length, -1, self.hidden_size)
        pre_subgoal = self.pre_subgoal(pred[:, :, ::2, :])
        return pre_subgoal   
    def loss_decoder(self,target,pred):
        subgoal_loss = torch.mean((pred - target)** 2)
        return subgoal_loss
    def get_mask_subgoals(self,prev_subgoals,prev_rtg, mask_prob=0.3, mask_value=0.0):
        """
        subgoals: (B, T, D)
        mask_prob: 每个 subgoal 被 mask 掉的概率
        mask_value: mask 的值，通常是 0 向量或 learnable embedding
        """
        B, T, D = prev_subgoals.shape
        mask = torch.rand(B, T) < mask_prob   # shape: (B, T)
        mask = mask.to(prev_subgoals.device).unsqueeze(-1)  # (B, T, 1)
        
        # 用 mask_value 替换
        masked_subgoals = prev_subgoals.clone()
        masked_rtg = prev_rtg.clone()
        masked_subgoals[mask.expand_as(prev_subgoals)] = mask_value
        masked_rtg[mask.expand_as(prev_subgoals)] = mask_value
        return masked_subgoals,masked_rtg
    def get_action(self, states, actions,rewards,z,returns_to_go,timesteps,**kwargs):
        # we don't care about the past rewards in this model
        z = z.reshape(1,-1,self.output_dim)
        states = states.reshape(1, -1, self.state_dim)
        actions = actions.reshape(1, -1, self.act_dim)
        returns_to_go = returns_to_go.reshape(1, -1, 1)
        combined = (torch.cat([(self.embed_state(states[0,-1])).reshape(1,-1,self.hidden_size), (self.embed_reward(returns_to_go[0,-1])).reshape(1,-1,self.hidden_size)], dim=-1))
        pre_z = self.mlp(combined)
        z = torch.cat([z,pre_z],dim=1)
        z_clone = z[:]
        timesteps = timesteps.reshape(1, -1)
        if self.max_length is not None:
            states = states[:,-self.max_length:]
            actions = actions[:,-self.max_length:]
            returns_to_go = returns_to_go[:,-self.max_length:]
            z = z[:,-self.max_length:]
            timesteps = timesteps[:,-self.max_length:]

            # pad all tokens to sequence length
            attention_mask = torch.cat([torch.zeros(self.max_length-states.shape[1]), torch.ones(states.shape[1])])
            attention_mask = attention_mask.to(dtype=torch.long, device=states.device).reshape(1, -1)
            states = torch.cat(
                [torch.zeros((states.shape[0], self.max_length-states.shape[1], self.state_dim), device=states.device), states],
                dim=1).to(dtype=torch.float32)
            actions = torch.cat(
                [torch.zeros((actions.shape[0], self.max_length - actions.shape[1], self.act_dim),
                             device=actions.device), actions],
                dim=1).to(dtype=torch.float32)
            returns_to_go = torch.cat(
                [torch.zeros((returns_to_go.shape[0], self.max_length-returns_to_go.shape[1], 1), device=returns_to_go.device), returns_to_go],
                dim=1).to(dtype=torch.float32)
            z = torch.cat(
                [torch.zeros((z.shape[0], self.max_length-z.shape[1], self.output_dim), device=z.device), z],
                dim=1).to(dtype=torch.float32)
            timesteps = torch.cat(
                [torch.zeros((timesteps.shape[0], self.max_length-timesteps.shape[1]), device=timesteps.device), timesteps],
                dim=1
            ).to(dtype=torch.long)
        else:
            attention_mask = None
        batch_size, seq_length = states.shape[0], states.shape[1]
        state_embeddings = self.embed_state(states)
        action_embeddings = self.embed_action(actions)
        returns_embeddings = self.embed_reward(returns_to_go)
        time_embeddings = self.embed_timestep(timesteps)
        z_embeddings = self.embed_z(z)
        # time embeddings are treated similar to positional embeddings
        state_embeddings = state_embeddings + time_embeddings
        action_embeddings = action_embeddings + time_embeddings
        returns_embeddings = returns_embeddings + time_embeddings
        z_embeddings = z_embeddings + time_embeddings

        # this makes the sequence look like (R_1, s_1, a_1, R_2, s_2, a_2, ...)
        # which works nice in an autoregressive sense since states predict actions
        stacked_inputs = torch.stack(
            (returns_embeddings, state_embeddings, z_embeddings,action_embeddings), dim=1
        ).permute(0, 2, 1, 3).reshape(batch_size, 4* seq_length, self.hidden_size)
        stacked_inputs = self.embed_ln(stacked_inputs)

        # to make the attention mask fit the stacked inputs, have to stack it as well
        stacked_attention_mask = torch.stack(
            (attention_mask, attention_mask, attention_mask, attention_mask), dim=1
        ).permute(0, 2, 1).reshape(batch_size, 4 * seq_length)

        # we feed in the input embeddings (not word indices as in NLP) to the model
        transformer_outputs = self.transformer(
            inputs_embeds=stacked_inputs,
            attention_mask=stacked_attention_mask,
        )
        x = transformer_outputs['last_hidden_state']

        # reshape x so that the second dimension corresponds to the original
        # returns (0), states (1), or actions (2); i.e. x[:,1,t] is the token for s_t
        x = x.reshape(batch_size, seq_length, 4, self.hidden_size).permute(0, 2, 1, 3)

        # get predictions
        return_preds = self.predict_return(x[:, 3])  # predict next return given state and action
        state_preds = self.predict_state(x[:, 3])  # predict next state given state and action
        action_preds = self.predict_action(x[:, 2])[0,-1]  # predict next action given state

        return action_preds,z_clone
class TransformerDecoderLayer(nn.Module):
    def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1):
        super().__init__()
        self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
        self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
        self.linear1 = nn.Linear(d_model, dim_feedforward)
        self.dropout = nn.Dropout(dropout)
        self.linear2 = nn.Linear(dim_feedforward, d_model)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
        self.dropout3 = nn.Dropout(dropout)

    def forward(self, tgt, memory, tgt_mask=None, memory_mask=None, tgt_key_padding_mask=None, memory_key_padding_mask=None):
        tgt2 = self.self_attn(tgt, tgt, tgt, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask)[0]
        tgt = self.norm1(tgt + self.dropout1(tgt2))
        tgt2 = self.multihead_attn(tgt, memory, memory, attn_mask=memory_mask, key_padding_mask=memory_key_padding_mask)[0]
        tgt = self.norm2(tgt + self.dropout2(tgt2))
        tgt2 = self.linear2(self.dropout(F.relu(self.linear1(tgt))))
        tgt = self.norm3(tgt + self.dropout3(tgt2))
        return tgt

class TransformerDecoder(nn.Module):
    def __init__(self, d_model, nhead, num_layers, dim_feedforward=2048, dropout=0.1):
        super().__init__()
        self.layers = nn.ModuleList([TransformerDecoderLayer(d_model, nhead, dim_feedforward, dropout) for _ in range(num_layers)])
        self.norm = nn.LayerNorm(d_model)
    
    def forward(self, tgt, memory, tgt_mask=None, memory_mask=None, tgt_key_padding_mask=None, memory_key_padding_mask=None):
        for layer in self.layers:
            tgt = layer(tgt, memory, tgt_mask, memory_mask, tgt_key_padding_mask, memory_key_padding_mask)
        return self.norm(tgt)

# class PlanTransformer(TrajectoryModel):
#     """
#     Using second transformer as anti-causal aggregator
#     """
#     def __init__(
#             self,
#             state_dim,
#             act_dim,
#             hidden_size,
#             device,
#             output_dim,
#             max_length=None,
#             max_ep_len=4096,
#             **kwargs
#     ):
        
#         super().__init__(state_dim, act_dim,max_length=max_length)
#         self.state_dim = state_dim
#         self.hidden_size = hidden_size
#         config = transformers.GPT2Config(
#             vocab_size=1,  # doesn't matter -- we don't use the vocab
#             n_embd=hidden_size,
#             **kwargs
#         )
#         self.device =device
#         self.transformer = GPT2Model(config)
#         self.output_dim = output_dim
#         self.embed_timestep = nn.Embedding(max_ep_len, hidden_size)
#         self.embed_subgoal = torch.nn.Linear(self.state_dim, hidden_size)
#         self.embed_z = torch.nn.Linear(self.output_dim, hidden_size)
#         self.pre_subgoal = torch.nn.Linear(hidden_size,self.state_dim)
#         self.embed_rtg = torch.nn.Linear(1,self.hidden_size)

#         self.embed_ln = nn.LayerNorm(hidden_size)
#     def forward(self, pre_z,subgoal,target_rtg,plan_masks):
#         batch_size, seq_length = pre_z.shape[0], pre_z.shape[1]
#         pre_z = pre_z.unsqueeze(2)
#         pre_z = (self.embed_z(pre_z)).reshape(batch_size*seq_length,-1,self.hidden_size)
#         subgoal = self.embed_subgoal(subgoal)
#         target_rtg = self.embed_rtg(target_rtg)
#         stacked_inputs = torch.stack(
#             (target_rtg,subgoal),dim=2
#         ).permute(0, 1, 3, 2, 4).reshape(batch_size*seq_length, -1, self.hidden_size)
#         stacked_inputs = torch.cat((pre_z,stacked_inputs),dim=1)
#         stacked_inputs = self.embed_ln(stacked_inputs)

#         transformer_outputs = self.transformer(
#             inputs_embeds=stacked_inputs,
#             attention_mask = plan_masks
#         )
#         x = transformer_outputs['last_hidden_state']
#         # reshape x so that the second dimension corresponds to
#         # predicting states (1)
#         # get predictions
#         a = x[:, [1, 3, 5, 7, 9], :]  # 取所有 batch 的前 5 个序列，每个序列 128 维
#         a = a.reshape(batch_size,seq_length,-1,self.hidden_size)
#         predict_subgoal = self.pre_subgoal(a) 
#         return predict_subgoal
#     def get_subgoal(self, states,subgoal,returns_to_go,iter_num,**kwargs):
#         # we don't care about the past rewards in this model
#         states = states.reshape(1, -1, self.state_dim)
#         if iter_num:
#             subgoal = subgoal.reshape(1,iter_num,self.state_dim)
#             states = torch.cat([states, subgoal], dim=1)
#         returns_to_go = returns_to_go.reshape(1,-1, 1)
#         # plan_timesteps = plan_timesteps.reshape(1, -1)

#         if self.max_length is not None:
#             states = states[:,-self.max_length:]
#             returns_to_go = returns_to_go[:,-self.max_length:]
#             # plan_timesteps = plan_timesteps[:,-self.max_length:]

#             # pad all tokens to sequence length
#             attention_mask = torch.cat([torch.zeros(self.max_length-states.shape[1]), torch.ones(states.shape[1])])
#             attention_mask = attention_mask.to(dtype=torch.long, device=states.device).reshape(1, -1)
#             states = torch.cat(
#                 [torch.zeros((states.shape[0], self.max_length-states.shape[1], self.state_dim), device=states.device), states],
#                 dim=1).to(dtype=torch.float32)
#             returns_to_go = torch.cat(
#                 [torch.zeros((returns_to_go.shape[0], self.max_length-returns_to_go.shape[1], 1), device=returns_to_go.device), returns_to_go],
#                 dim=1).to(dtype=torch.float32)
#         else:
#             attention_mask = None
#         batch_size, seq_length = states.shape[0], states.shape[1]
#         state_embeddings = self.embed_state(states)
#         returns_embeddings = self.embed_reward(returns_to_go)
#         # time_embeddings = self.embed_timestep(plan_timesteps)
#         # time embeddings are treated similar to positional embeddings
#         # state_embeddings = state_embeddings + time_embeddings
#         # returns_embeddings = returns_embeddings + time_embeddings
#         state_embeddings = state_embeddings.reshape(1,-1,self.hidden_size)
#         returns_embeddings = returns_embeddings.reshape(1,-1,self.hidden_size)


#         # this makes the sequence look like (R_1, s_1, a_1, R_2, s_2, a_2, ...)
#         # which works nice in an autoregressive sense since states predict actions
#         stacked_inputs = torch.stack(
#             (returns_embeddings, state_embeddings), dim=1
#         ).permute(0, 2, 1, 3).reshape(batch_size, -1, self.hidden_size)
#         stacked_inputs = self.embed_ln(stacked_inputs)

#         # to make the attention mask fit the stacked inputs, have to stack it as well
#         stacked_attention_mask = torch.stack(
#             (attention_mask, attention_mask), dim=1
#         ).permute(0, 2, 1).reshape(batch_size,seq_length*2)

#         # we feed in the input embeddings (not word indices as in NLP) to the model
#         transformer_outputs = self.transformer(
#             inputs_embeds=stacked_inputs,
#             attention_mask=stacked_attention_mask,
#         )
#         x = transformer_outputs['last_hidden_state']

#         # reshape x so that the second dimension corresponds to the original
#         # returns (0), states (1), or actions (2); i.e. x[:,1,t] is the token for s_t
#         subgoal_preds = self.pre_state(x[0,-1]) # predict next action given state

#         return subgoal_preds,

