"""
highly based on
https://github.com/kzl/decision-transformer/blob/master/atari/mingpt/model_atari.py
https://github.com/kzl/decision-transformer/blob/master/gym/decision_transformer/models/decision_transformer.py
"""

import torch
import torch.nn as nn
import torch.nn.functional as F

import math
import gym
import contextlib
import numpy as np


class CausalSelfAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        assert config['n_embd'] % config['n_head'] == 0
        # key, query, value projections for all heads
        self.key = nn.Linear(config['n_embd'], config['n_embd'])
        self.query = nn.Linear(config['n_embd'], config['n_embd'])
        self.value = nn.Linear(config['n_embd'], config['n_embd'])
        # regularization
        self.attn_drop = nn.Dropout(config['attn_pdrop'])
        self.resid_drop = nn.Dropout(config['resid_pdrop'])

        # causal mask to ensure that attention is only applied to the left in the input sequence
        self.register_buffer("bias", torch.tril(torch.ones(config['n_ctx'], config['n_ctx'])).view(1, 1, config['n_ctx'], config['n_ctx']))
        self.register_buffer("masked_bias", torch.tensor(-1e4))

        # output projection
        self.proj = nn.Linear(config['n_embd'], config['n_embd'])
        self.n_head = config['n_head']

    def forward(self, x, mask):
        B, T, C = x.size()

        # calculate query, key, values for all heads in batch and move head forward to be the batch dim
        ## [ B x n_heads x T x head_dim ]
        k = self.key(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
        q = self.query(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
        v = self.value(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)

        # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
        ## [ B x n_heads x T x T ]
        mask = mask.view(B, -1)
        mask = mask[:, None, None, :]
        mask = (1.0 - mask) * -10000.0
        att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
        att = torch.where(self.bias[:, :, :T, :T].bool(), att, self.masked_bias.to(att.dtype))
        att = att + mask
        att = F.softmax(att, dim=-1)
        self._attn_map = att.clone()
        att = self.attn_drop(att)
        ## [ B x n_heads x T x head_size ]
        y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
        ## [ B x T x embedding_dim ]
        y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side

        # output projection
        y = self.resid_drop(self.proj(y))
        return y


class Block(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.ln1 = nn.LayerNorm(config['n_embd'])
        self.ln2 = nn.LayerNorm(config['n_embd'])
        self.attn = CausalSelfAttention(config)
        self.mlp = nn.Sequential(
            nn.Linear(config['n_embd'], config['n_inner']),
            nn.GELU(),
            nn.Linear(config['n_inner'], config['n_embd']),
            nn.Dropout(config['resid_pdrop']),
        )

    def forward(self, inputs_embeds, attention_mask):
        x = inputs_embeds + self.attn(self.ln1(inputs_embeds), attention_mask)
        x = x + self.mlp(self.ln2(x))
        return x


class MMDnet(nn.Module):
    def __init__(self, state_dim, action_dim, hidden_size, particle_num):
        super(MMDnet, self).__init__()
        self.state_dim = state_dim
        self.hidden_size = hidden_size
        self.particle_num = particle_num
        self.action_dim = action_dim

        self.statenet = nn.Sequential(
            nn.Linear(self.state_dim, self.hidden_size),
            nn.ReLU(),
            nn.Linear(self.hidden_size, self.hidden_size),
            nn.ReLU(),
            nn.Linear(self.hidden_size, self.hidden_size),
        )

        self.actionnet = nn.Sequential(
            nn.Linear(self.action_dim, self.hidden_size),
            nn.ReLU(),
            nn.Linear(self.hidden_size, self.hidden_size),
        )

        self.mmdnet = nn.Sequential(
            nn.Linear(self.hidden_size * 2, self.hidden_size),
            nn.ReLU(),
            nn.Linear(self.hidden_size, self.hidden_size),
            nn.ReLU(),
            nn.Linear(self.hidden_size, self.particle_num)
        )

    def forward(self, states, actions):
        hidden_states = self.statenet(states)
        hidden_actions = self.actionnet(actions)
        return self.mmdnet(torch.cat([hidden_states, hidden_actions], dim=-1))


class DecisionTransformer(nn.Module):
    """
    This model uses GPT to model (Return_1, state_1, action_1, Return_2, state_2, ...)
    """
    def __init__(self, config, action_tanh=True, **kwargs):
        super(DecisionTransformer, self).__init__()

        self.config = config
        self.particle_num = config['particle_num']
        self.length_times = config['length_times']
        self.hidden_size = config['hidden_size']
        assert self.hidden_size == config['n_embd']
        self.max_length = config['K']
        self.max_ep_len = config['max_ep_len']

        self.env = gym.make(config['env_name'])
        self.state_dim = self.env.observation_space.shape[0]
        self.act_dim = self.env.action_space.shape[0]

        # note: the only difference between this GPT2Model and the default Huggingface version
        # is that the positional embeddings are removed (since we'll add those ourselves)
        # self.transformer = GPT2Model(config)
        self.transformer = nn.ModuleList([Block(config) for _ in range(config['n_layer'])])

        self.embed_timestep = nn.Embedding(self.max_ep_len, self.hidden_size)
        self.embed_return = torch.nn.Linear(1, self.hidden_size)
        self.embed_reward = torch.nn.Linear(1, self.hidden_size)
        self.embed_state = torch.nn.Linear(self.state_dim, self.hidden_size)
        self.embed_action = torch.nn.Linear(self.act_dim, self.hidden_size)

        self.embed_ln = nn.LayerNorm(self.hidden_size)

        # note: we don't predict states or returns for the paper
        self.predict_state = torch.nn.Linear(self.hidden_size, self.state_dim)
        self.predict_action = nn.Sequential(
            *([nn.Linear(self.hidden_size, self.act_dim)] + ([nn.Tanh()] if config['action_tanh'] else []))
        )
        if self.config['model_type'] in ['mgdt']:
            if self.config['sample_return'] == False:
                self.predict_return = torch.nn.Linear(self.hidden_size, 1)
            else:
                self.predict_return_mu = torch.nn.Linear(self.hidden_size, 1)
                self.predict_return_sigma = torch.nn.Linear(self.hidden_size, 1)
        else:
            self.predict_return = torch.nn.Linear(self.hidden_size, 1)
        self.predict_reward = torch.nn.Linear(self.hidden_size, 1)
        
        self.mmdnet = MMDnet(
            state_dim=self.state_dim,
            action_dim=self.act_dim,
            hidden_size=self.hidden_size,
            particle_num=self.particle_num
        )
        
        self.target_mmdnet = MMDnet(
            state_dim=self.state_dim,
            action_dim=self.act_dim,
            hidden_size=self.hidden_size,
            particle_num=self.particle_num
        )
        
        self.target_mmdnet.load_state_dict(self.mmdnet.state_dict())
        for param in self.target_mmdnet.parameters():
            param.requires_grad = False
        self.target_mmdnet.eval()


    def forward(self, states, actions, rewards, returns_to_go, timesteps, attention_mask=None):

        batch_size, seq_length = states.shape[0], states.shape[1]

        if attention_mask is None:
            # attention mask for GPT: 1 if can be attended to, 0 if not
            attention_mask = torch.ones((batch_size, seq_length), dtype=torch.long)

        # embed each modality with a different head
        state_embeddings = self.embed_state(states)
        action_embeddings = self.embed_action(actions)
        returns_embeddings = self.embed_return(returns_to_go)
        rewards_embeddings = self.embed_reward(rewards)
        time_embeddings = self.embed_timestep(timesteps)
        if not self.config['is_stitch']:
            # time embeddings are treated similar to positional embeddings
            state_embeddings = state_embeddings + time_embeddings
            action_embeddings = action_embeddings + time_embeddings
            returns_embeddings = returns_embeddings + time_embeddings
            rewards_embeddings = rewards_embeddings + time_embeddings

        # this makes the sequence look like (R_1, s_1, a_1, R_2, s_2, a_2, ...)
        # which works nice in an autoregressive sense since states predict actions
        if self.config['model_type'] in ['dt']:
            stacked_inputs = torch.stack(
                (returns_embeddings, state_embeddings, action_embeddings), dim=1
            ).permute(0, 2, 1, 3).reshape(batch_size, 3*seq_length, self.hidden_size)
            stacked_inputs = self.embed_ln(stacked_inputs)

        elif self.config['model_type'] in ['bc']:
            stacked_inputs = torch.stack(
                (state_embeddings, action_embeddings), dim=1
            ).permute(0, 2, 1, 3).reshape(batch_size, 2*seq_length, self.hidden_size)
            stacked_inputs = self.embed_ln(stacked_inputs)

        elif self.config['model_type'] in ['mgdt']:
            stacked_inputs = torch.stack(
                (state_embeddings, returns_embeddings, action_embeddings, rewards_embeddings), dim=1
            ).permute(0, 2, 1, 3).reshape(batch_size, 4*seq_length, self.hidden_size)
            stacked_inputs = self.embed_ln(stacked_inputs)

        # to make the attention mask fit the stacked inputs, have to stack it as well
        stacked_attention_mask = torch.stack(
            ([attention_mask for _ in range(self.length_times)]), dim=1
        ).permute(0, 2, 1).reshape(batch_size, self.length_times*seq_length).to(stacked_inputs.dtype)

        # we feed in the input embeddings (not word indices as in NLP) to the model
        x = stacked_inputs
        for block in self.transformer:
            x = block(x, stacked_attention_mask)

        # reshape x so that the second dimension corresponds to the original
        # returns (0), states (1), or actions (2); i.e. x[:,1,t] is the token for s_t
        x = x.reshape(batch_size, seq_length, self.length_times, self.hidden_size).permute(0, 2, 1, 3)

        # get predictions
        if self.config['model_type'] in ['dt']:
            # return_preds = self.predict_return(x[:,2])  # predict next return given state and action
            return_preds = self.mmdnet(states, actions)
            target_return_preds = self.target_mmdnet(states, actions)
            state_preds = self.predict_state(x[:,2])    # predict next state given state and action
            action_preds = self.predict_action(x[:,1])  # predict next action given state
            return state_preds, action_preds, [return_preds, target_return_preds], None
        elif self.config['model_type'] in ['bc']:
            return_preds = self.mmdnet(states, actions)
            target_return_preds = self.target_mmdnet(states, actions)
            action_preds = self.predict_action(x[:,0])  # predict next action given state
            return None, action_preds, [return_preds, target_return_preds], None
        elif self.config['model_type'] in ['mgdt']:
            if self.config['sample_return'] == False:
                return_preds = self.predict_return(x[:,0])  # predict next return
            else:
                return_preds_mu = self.predict_return_mu(x[:,0])
                return_preds_sigma = self.predict_return_sigma(x[:,0])
                # eps = torch.randn_like(return_preds_sigma)
                # return_preds = return_preds_mu + eps * torch.exp(0.5 * return_preds_sigma)
            reward_preds = self.predict_reward(x[:,2])  # predict next rewards
            action_preds = self.predict_action(x[:,1])  # predict next action
            if self.config['sample_return'] == False:
                return None, action_preds, return_preds, reward_preds
            else:
                return None, action_preds, [return_preds_mu, return_preds_sigma], reward_preds

    def get_action(self, states, actions, rewards, returns_to_go, timesteps, **kwargs):
        # we don't care about the past rewards in this model
        states = states.reshape(1, -1, self.state_dim)
        actions = actions.reshape(1, -1, self.act_dim)
        returns_to_go = returns_to_go.reshape(1, -1, 1)
        rewards = rewards.reshape(1, -1, 1)
        timesteps = timesteps.reshape(1, -1)

        if self.max_length is not None:
            states = states[:,-self.max_length:]
            actions = actions[:,-self.max_length:]
            returns_to_go = returns_to_go[:,-self.max_length:]
            rewards = rewards[:,-self.max_length:]
            timesteps = timesteps[:,-self.max_length:]

            # pad all tokens to sequence length
            attention_mask = torch.cat([torch.zeros(self.max_length-states.shape[1]), torch.ones(states.shape[1])])
            attention_mask = attention_mask.to(dtype=torch.long, device=states.device).reshape(1, -1)
            states = torch.cat(
                [torch.zeros((states.shape[0], self.max_length-states.shape[1], self.state_dim), device=states.device), states],
                dim=1).to(dtype=torch.float32)
            actions = torch.cat(
                [torch.zeros((actions.shape[0], self.max_length-actions.shape[1], self.act_dim), device=actions.device), actions],
                dim=1).to(dtype=torch.float32)
            returns_to_go = torch.cat(
                [torch.zeros((returns_to_go.shape[0], self.max_length-returns_to_go.shape[1], 1), device=returns_to_go.device), returns_to_go],
                dim=1).to(dtype=torch.float32)
            rewards = torch.cat(
                [torch.zeros((rewards.shape[0], self.max_length-rewards.shape[1], 1), device=rewards.device), rewards],
                dim=1).to(dtype=torch.float32)
            timesteps = torch.cat(
                [torch.zeros((timesteps.shape[0], self.max_length-timesteps.shape[1]), device=timesteps.device), timesteps],
                dim=1).to(dtype=torch.long)
        else:
            attention_mask = None

        _, action_preds, return_preds, reward_preds = self.forward(
            states, actions, rewards, returns_to_go, timesteps, attention_mask=attention_mask, **kwargs)

        if self.config['model_type'] in ['bc', 'dt']:
            return action_preds[0, -1]
        elif self.config['model_type'] in ['mgdt']:
            if self.config['sample_return'] == False:
                return action_preds[0, -1], return_preds[0, -1]
            else:
                return action_preds[0, -1], [return_preds[0][0, -1], return_preds[1][0, -1]]


    def get_return(self, states, actions):
        return_preds = self.mmdnet(states[-2], actions[-1])
        return return_preds


    def get_mmd_parameters(self):
        mmd_params = list(self.mmdnet.parameters()) + \
                     list(self.target_mmdnet.parameters())
        return mmd_params
    
    def get_decision_transformer_parameters(self):
        mmd_params = set(self.get_mmd_parameters())
        return [p for p in self.parameters() if p not in mmd_params]




class InnerTransformerBlock(nn.Module):
    def __init__(self,
                 embed_dim: int,
                 num_heads: int = 1,
                 attn_dropout: float = 0.0,
                 ffn_dropout: float = 0.0,
                 dim_expand: int = 1,
                 activation_fn=nn.ReLU):
        super().__init__()
        self.ln1 = nn.LayerNorm(embed_dim)
        self.attn = nn.MultiheadAttention(embed_dim=embed_dim,
                                          num_heads=num_heads,
                                          dropout=attn_dropout,
                                          batch_first=True)
        self.ln2 = nn.LayerNorm(embed_dim)
        self.ffn = nn.Sequential(
            nn.Linear(embed_dim, dim_expand * embed_dim),
            activation_fn(),
            nn.Linear(dim_expand * embed_dim, embed_dim),
            nn.Dropout(ffn_dropout),
        )

    def forward(self, x):
        # x: (B, L, C)
        x_ln = self.ln1(x)
        attn_out, _ = self.attn(x_ln, x_ln, x_ln)
        x = x + attn_out
        x_ln = self.ln2(x)
        x = x + self.ffn(x_ln)
        return x


class TITStateEncoder(nn.Module):
    def __init__(self,
                 state_dim: int,
                 embed_dim: int, 
                 patch_dim: int = None,
                 num_blocks: int = 1,
                 num_heads: int = 1,
                 attn_dropout: float = 0.0,
                 ffn_dropout: float = 0.0,
                 dim_expand: int = 1,
                 activation_fn=nn.ReLU):
        super().__init__()
        self.state_dim = state_dim
        self.embed_dim = embed_dim
        self.patch_dim = patch_dim if patch_dim is not None else 1

        self.obs_patch_embed = nn.Conv1d(
            in_channels=1,
            out_channels=embed_dim,
            kernel_size=self.patch_dim,
            stride=self.patch_dim,
            bias=False,
        )

        # [CLS] token
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        nn.init.trunc_normal_(self.cls_token, mean=0.0, std=0.02)

        self.blocks = nn.ModuleList([
            InnerTransformerBlock(
                embed_dim=embed_dim,
                num_heads=num_heads,
                attn_dropout=attn_dropout,
                ffn_dropout=ffn_dropout,
                dim_expand=dim_expand,
                activation_fn=activation_fn
            ) for _ in range(num_blocks)
        ])

    @staticmethod
    def _pad_to_multiple(x, multiple: int):
        # x: (B, D)
        B, D = x.shape
        if D % multiple == 0:
            return x, 0
        pad_right = multiple - (D % multiple)
        x = nn.functional.pad(x, (0, pad_right), value=0.0)
        return x, pad_right

    def _patch_embed_one_step(self, step_states):
        step_states, pad_right = self._pad_to_multiple(step_states, self.patch_dim)
        step_states = step_states.unsqueeze(1)            # (B, 1, D_pad)
        patches = self.obs_patch_embed(step_states)       # (B, C=embed_dim, L)
        patches = patches.transpose(1, 2)                 # (B, L, embed_dim)
        return patches

    def forward(self, states):
        """
        states: (B, T, D)
        return: (B, T, embed_dim)
        """
        B, T, D = states.shape
        assert D == self.state_dim, f"state_dim mismatch: got {D}, expect {self.state_dim}"

        x = states.reshape(B * T, D)                      # (B*T, D)
        patches = self._patch_embed_one_step(x)           # (B*T, L, C)

        cls = self.cls_token.expand(patches.size(0), -1, -1)  # (B*T, 1, C)
        tokens = torch.cat([cls, patches], dim=1)         # (B*T, 1+L, C)

        for blk in self.blocks:
            tokens = blk(tokens)                          # (B*T, 1+L, C)

        cls_repr = tokens[:, 0, :]                        # (B*T, C)
        return cls_repr.view(B, T, self.embed_dim)        # (B, T, C)


class TIT_DecisionTransformer(nn.Module):
    """
    This model uses GPT to model (Return_1, state_1, action_1, Return_2, state_2, ...)
    """
    def __init__(self, config, **kwargs):
        super(TIT_DecisionTransformer, self).__init__()

        self.config = config
        self.particle_num = config['particle_num']
        self.length_times = config['length_times']
        self.hidden_size = config['hidden_size']
        assert self.hidden_size == config['n_embd']
        self.max_length = config['K']
        self.max_ep_len = config['max_ep_len']

        self.env = gym.make(config['env_name'])
        self.state_dim = self.env.observation_space.shape[0]
        self.act_dim = self.env.action_space.shape[0]

        self.transformer = nn.ModuleList([Block(config) for _ in range(config['n_layer'])])

        # Embeddings
        self.embed_timestep = nn.Embedding(self.max_ep_len, self.hidden_size)
        self.embed_return = nn.Linear(1, self.hidden_size)
        self.embed_reward = nn.Linear(1, self.hidden_size)
        self.embed_action = nn.Linear(self.act_dim, self.hidden_size)

        self.use_tit = True
        if self.use_tit:
            inner_patch_dim   = int(config.get('inner_patch_dim', 11))
            inner_num_blocks  = int(config.get('inner_num_blocks', 1))
            inner_num_heads   = int(config.get('inner_num_heads', 1))
            inner_attn_drop   = float(config.get('inner_attention_dropout', 0.0))
            inner_ffn_drop    = float(config.get('inner_ffn_dropout', 0.0))
            inner_dim_expand  = int(config.get('inner_dim_expand', 1))
            inner_act_name    = str(config.get('inner_activation', 'relu')).lower()
            act_map = {'relu': nn.ReLU, 'gelu': nn.GELU, 'silu': nn.SiLU, 'tanh': nn.Tanh}
            inner_activation  = act_map.get(inner_act_name, nn.ReLU)

            self.state_encoder = TITStateEncoder(
                state_dim=self.state_dim,
                embed_dim=self.hidden_size,
                patch_dim=inner_patch_dim,
                num_blocks=inner_num_blocks,
                num_heads=inner_num_heads,
                attn_dropout=inner_attn_drop,
                ffn_dropout=inner_ffn_drop,
                dim_expand=inner_dim_expand,
                activation_fn=inner_activation
            )
        else:
            self.embed_state = nn.Linear(self.state_dim, self.hidden_size)

        self.embed_ln = nn.LayerNorm(self.hidden_size)

        self.predict_state = nn.Linear(self.hidden_size, self.state_dim)
        self.predict_action = nn.Sequential(
            *([nn.Linear(self.hidden_size, self.act_dim)] + ([nn.Tanh()] if config['action_tanh'] else []))
        )
        if self.config['model_type'] in ['mgdt']:
            if not self.config['sample_return']:
                self.predict_return = nn.Linear(self.hidden_size, 1)
            else:
                self.predict_return_mu = nn.Linear(self.hidden_size, 1)
                self.predict_return_sigma = nn.Linear(self.hidden_size, 1)
        else:
            self.predict_return = nn.Linear(self.hidden_size, 1)
        self.predict_reward = nn.Linear(self.hidden_size, 1)

        self.mmdnet = MMDnet(
            state_dim=self.state_dim,
            action_dim=self.act_dim,
            hidden_size=self.hidden_size,
            particle_num=self.particle_num
        )
        self.target_mmdnet = MMDnet(
            state_dim=self.state_dim,
            action_dim=self.act_dim,
            hidden_size=self.hidden_size,
            particle_num=self.particle_num
        )
        self.target_mmdnet.load_state_dict(self.mmdnet.state_dict())
        for p in self.target_mmdnet.parameters():
            p.requires_grad = False
        self.target_mmdnet.eval()


    def forward(self, states, actions, rewards, returns_to_go, timesteps, attention_mask=None):

        B, T = states.shape[0], states.shape[1]

        if attention_mask is None:
            attention_mask = torch.ones((B, T), dtype=torch.long, device=states.device)


        if self.use_tit:
            state_embeddings = self.state_encoder(states)           # (B, T, C)
        else:
            state_embeddings = self.embed_state(states)             # (B, T, C)

        action_embeddings = self.embed_action(actions)
        returns_embeddings = self.embed_return(returns_to_go)
        rewards_embeddings = self.embed_reward(rewards)
        time_embeddings = self.embed_timestep(timesteps)

        if not self.config['is_stitch']:
            state_embeddings = state_embeddings + time_embeddings
            action_embeddings = action_embeddings + time_embeddings
            returns_embeddings = returns_embeddings + time_embeddings
            rewards_embeddings = rewards_embeddings + time_embeddings

        if self.config['model_type'] in ['dt']:
            stacked_inputs = torch.stack(
                (returns_embeddings, state_embeddings, action_embeddings), dim=1
            ).permute(0, 2, 1, 3).reshape(B, 3 * T, self.hidden_size)
        elif self.config['model_type'] in ['bc']:
            stacked_inputs = torch.stack(
                (state_embeddings, action_embeddings), dim=1
            ).permute(0, 2, 1, 3).reshape(B, 2 * T, self.hidden_size)
        elif self.config['model_type'] in ['mgdt']:
            stacked_inputs = torch.stack(
                (state_embeddings, returns_embeddings, action_embeddings, rewards_embeddings), dim=1
            ).permute(0, 2, 1, 3).reshape(B, 4 * T, self.hidden_size)
        else:
            raise ValueError(f"Unknown model_type: {self.config['model_type']}")

        stacked_inputs = self.embed_ln(stacked_inputs)

        stacked_attention_mask = torch.stack(
            ([attention_mask for _ in range(self.length_times)]), dim=1
        ).permute(0, 2, 1).reshape(B, self.length_times * T).to(stacked_inputs.dtype)


        x = stacked_inputs
        for block in self.transformer:
            x = block(x, stacked_attention_mask)

        x = x.reshape(B, T, self.length_times, self.hidden_size).permute(0, 2, 1, 3)

        if self.config['model_type'] in ['dt']:
            return_preds = self.mmdnet(states, actions)
            target_return_preds = self.target_mmdnet(states, actions)
            state_preds = self.predict_state(x[:, 2])    # next state | given (s_t, a_t)
            action_preds = self.predict_action(x[:, 1])  # next action | given (s_t)
            return state_preds, action_preds, [return_preds, target_return_preds], None

        elif self.config['model_type'] in ['bc']:
            return_preds = self.mmdnet(states, actions)
            target_return_preds = self.target_mmdnet(states, actions)
            action_preds = self.predict_action(x[:, 0])
            return None, action_preds, [return_preds, target_return_preds], None

        elif self.config['model_type'] in ['mgdt']:
            if not self.config['sample_return']:
                return_preds = self.predict_return(x[:, 0])  # predict return
            else:
                return_preds_mu = self.predict_return_mu(x[:, 0])
                return_preds_sigma = self.predict_return_sigma(x[:, 0])
            reward_preds = self.predict_reward(x[:, 2])     # predict reward
            action_preds = self.predict_action(x[:, 1])     # predict action

            if not self.config['sample_return']:
                return None, action_preds, return_preds, reward_preds
            else:
                return None, action_preds, [return_preds_mu, return_preds_sigma], reward_preds


    def get_action(self, states, actions, rewards, returns_to_go, timesteps, **kwargs):
        states = states.reshape(1, -1, self.state_dim)
        actions = actions.reshape(1, -1, self.act_dim)
        returns_to_go = returns_to_go.reshape(1, -1, 1)
        rewards = rewards.reshape(1, -1, 1)
        timesteps = timesteps.reshape(1, -1)

        if self.max_length is not None:
            states = states[:, -self.max_length:]
            actions = actions[:, -self.max_length:]
            returns_to_go = returns_to_go[:, -self.max_length:]
            rewards = rewards[:, -self.max_length:]
            timesteps = timesteps[:, -self.max_length:]

            attention_mask = torch.cat(
                [torch.zeros(self.max_length - states.shape[1], device=states.device),
                 torch.ones(states.shape[1], device=states.device)]
            ).to(dtype=torch.long).reshape(1, -1)

            states = torch.cat(
                [torch.zeros((1, self.max_length - states.shape[1], self.state_dim), device=states.device), states],
                dim=1).to(dtype=torch.float32)
            actions = torch.cat(
                [torch.zeros((1, self.max_length - actions.shape[1], self.act_dim), device=actions.device), actions],
                dim=1).to(dtype=torch.float32)
            returns_to_go = torch.cat(
                [torch.zeros((1, self.max_length - returns_to_go.shape[1], 1), device=returns_to_go.device),
                 returns_to_go],
                dim=1).to(dtype=torch.float32)
            rewards = torch.cat(
                [torch.zeros((1, self.max_length - rewards.shape[1], 1), device=rewards.device), rewards],
                dim=1).to(dtype=torch.float32)
            timesteps = torch.cat(
                [torch.zeros((1, self.max_length - timesteps.shape[1]), device=timesteps.device), timesteps],
                dim=1).to(dtype=torch.long)
        else:
            attention_mask = None

        _, action_preds, return_preds, reward_preds = self.forward(
            states, actions, rewards, returns_to_go, timesteps, attention_mask=attention_mask, **kwargs
        )

        if self.config['model_type'] in ['bc', 'dt']:
            return action_preds[0, -1]
        elif self.config['model_type'] in ['mgdt']:
            if not self.config['sample_return']:
                return action_preds[0, -1], return_preds[0, -1]
            else:
                return action_preds[0, -1], [return_preds[0][0, -1], return_preds[1][0, -1]]

    def get_return(self, states, actions):
        return self.mmdnet(states[-2], actions[-1])

    def get_mmd_parameters(self):
        return list(self.mmdnet.parameters()) + list(self.target_mmdnet.parameters())

    def get_decision_transformer_parameters(self):
        mmd_params = set(self.get_mmd_parameters())
        return [p for p in self.parameters() if p not in mmd_params]