# Adapted from https://raw.githubusercontent.com/nikhilbarhate99/min-decision-transformer/refs/heads/master/decision_transformer/model.py
# Original work Copyright (c) 2022 Nikhil Barhate
# Modifications Copyright (c) 2025 King.com Ltd

import math
import torch
import torch.nn as nn
import torch.nn.functional as F


class MaskedCausalAttention(nn.Module):
    def __init__(self, h_dim, max_T, n_heads, drop_p):
        super().__init__()

        self.n_heads = n_heads
        self.max_T = max_T

        self.q_net = nn.Linear(h_dim, h_dim)
        self.k_net = nn.Linear(h_dim, h_dim)
        self.v_net = nn.Linear(h_dim, h_dim)

        self.proj_net = nn.Linear(h_dim, h_dim)

        self.att_drop = nn.Dropout(drop_p)
        self.proj_drop = nn.Dropout(drop_p)

        ones = torch.ones((max_T, max_T))
        # causal mask of shape (TxT), when multiplied with (TxT) self-attention matrix it prevents attending to future tokens
        # [[1, 0, 0 ,0, ..., 0],
        #  [1, 1, 0, 0, ..., 0],
        #  [1, 1, 1, 0, ..., 0],
        #  ...
        #  [1, 1, 1, 1, ..., 1]]
        mask = torch.tril(ones).view(1, 1, max_T, max_T)

        # register buffer makes sure mask does not get updated
        # during backpropagation
        self.register_buffer('mask',mask)

    def forward(self, x):
        B, T, C = x.shape # batch size, seq length, h_dim * n_heads

        N, D = self.n_heads, C // self.n_heads # N = num heads, D = attention dim

        # rearrange q, k, v as (B, N, T, D)
        q = self.q_net(x).view(B, T, N, D).transpose(1,2)
        k = self.k_net(x).view(B, T, N, D).transpose(1,2)
        v = self.v_net(x).view(B, T, N, D).transpose(1,2)

        # weights (B, N, T, T)
        weights = q @ k.transpose(2,3) / math.sqrt(D)
        # causal mask applied to weights
        weights = weights.masked_fill(self.mask[...,:T,:T] == 0, float('-inf'))
        # normalize weights, all -inf -> 0 after softmax
        normalized_weights = F.softmax(weights, dim=-1)

        # attention (B, N, T, D)
        attention = self.att_drop(normalized_weights @ v)

        # gather heads and project (B, N, T, D) -> (B, T, N*D)
        attention = attention.transpose(1, 2).contiguous().view(B,T,N*D)

        out = self.proj_drop(self.proj_net(attention))
        return out


class Block(nn.Module):
    def __init__(self, h_dim, max_T, n_heads, drop_p):
        super().__init__()
        self.attention = MaskedCausalAttention(h_dim, max_T, n_heads, drop_p)
        self.mlp = nn.Sequential(
                nn.Linear(h_dim, 4*h_dim),
                nn.GELU(),
                nn.Linear(4*h_dim, h_dim),
                nn.Dropout(drop_p),
            )
        self.ln1 = nn.LayerNorm(h_dim)
        self.ln2 = nn.LayerNorm(h_dim)

    def forward(self, x):
        # Attention -> LayerNorm -> MLP -> LayerNorm
        x = x + self.attention(x) # residual
        x = self.ln1(x)
        x = x + self.mlp(x) # residual
        x = self.ln2(x)
        return x


class MLP(nn.Module):
    def __init__(self, num_layers, h_dim, out_dim, drop_p, squash_output=False):
        super().__init__()

        self.layers = nn.ModuleList()
        for _ in range(num_layers):
            self.layers.append(nn.Sequential(
                nn.Linear(h_dim, h_dim),
                nn.ReLU(),
                nn.Dropout(drop_p)
            ))

        self.out = nn.Linear(h_dim, out_dim)

        self.squash_output = squash_output

    def forward(self, x):
        x_prior = x.clone().detach()
        for layer in self.layers:
            x = layer(x)
        x = self.out(x)

        if self.squash_output:
            x = torch.tanh(x)

        return x


class DecisionTransformer(nn.Module):
    def __init__(self, state_dim, act_dim, n_blocks, h_dim, context_len,
                 n_heads, transformer_drop_p, max_timestep=4096, which_model="dt",
                 traj_prompt_j=None,  # the number of episode segments per traj prompt
                 traj_prompt_h=None,  # the number of steps per episode segment in the traj prompt
                 mlp_drop_p=0.0, mlp_num_layers=2
                 ):
        super().__init__()

        self.state_dim = state_dim
        self.act_dim = act_dim
        self.h_dim = h_dim
        self.which_model = which_model
        assert self.which_model in ["dt", "traj_pdt"]
        self.traj_prompt_j = traj_prompt_j
        self.traj_prompt_h = traj_prompt_h

        # input sequence length
        if self.which_model == "dt":
            input_seq_len = 3 * context_len
        elif self.which_model == "traj_pdt":
            input_seq_len = (3 * context_len) + (3 * self.traj_prompt_j * traj_prompt_h)
        else:
            raise NotImplementedError

        ### transformer blocks
        blocks = [Block(h_dim, input_seq_len, n_heads, transformer_drop_p) for _ in range(n_blocks)]
        self.transformer = nn.Sequential(*blocks)

        ### projection heads (project to embedding)
        self.embed_ln = nn.LayerNorm(h_dim)
        self.embed_timestep = nn.Embedding(max_timestep, h_dim)
        self.embed_rtg = torch.nn.Linear(1, h_dim)
        self.embed_state = torch.nn.Linear(state_dim, h_dim)

        self.embed_action = torch.nn.Linear(act_dim, h_dim)
        use_action_tanh = True # True for continuous actions

        ### prediction heads
        self.predict_rtg = MLP(num_layers=mlp_num_layers, h_dim=h_dim, out_dim=1, drop_p=mlp_drop_p)
        self.predict_state = MLP(num_layers=mlp_num_layers, h_dim=h_dim, out_dim=state_dim, drop_p=mlp_drop_p)
        self.predict_action = MLP(num_layers=mlp_num_layers, h_dim=h_dim, out_dim=act_dim, drop_p=mlp_drop_p, squash_output=use_action_tanh)

        # prompt embeddings and prediction
        if self.which_model == "dt":
            pass

        elif self.which_model == "traj_pdt":
            assert traj_prompt_j is not None and traj_prompt_h is not None, "traj_prompt_j and traj_prompt_h must be positive integers"
            assert traj_prompt_j > 0 and traj_prompt_h > 0, "traj_prompt_j and traj_prompt_h must be positive integers"

            self.embed_traj_prompt_timesteps = torch.nn.Embedding(max_timestep, h_dim)
            self.embed_traj_prompt_states = torch.nn.Linear(state_dim, h_dim)
            self.embed_traj_prompt_rtg = torch.nn.Linear(1, h_dim)
            self.embed_traj_prompt_actions = torch.nn.Linear(act_dim, h_dim)

        else:
            raise NotImplementedError

    def forward(
            self, timesteps, states, actions, returns_to_go,
            traj_prompt_timesteps=None, traj_prompt_states=None, traj_prompt_actions=None, traj_prompt_rtgs=None,
            return_token_features=False
    ):

        B, T, _ = states.shape

        time_embeddings = self.embed_timestep(timesteps)

        # time embeddings are treated similar to positional embeddings
        state_embeddings = self.embed_state(states) + time_embeddings
        action_embeddings = self.embed_action(actions) + time_embeddings
        returns_embeddings = self.embed_rtg(returns_to_go) + time_embeddings

        # stack rtg, states and actions and reshape sequence as
        # (r_0, s_0, a_0, r_1, s_1, a_1, r_2, s_2, a_2 ...)
        if self.which_model == "dt":
            h = torch.stack(
                (returns_embeddings, state_embeddings, action_embeddings), dim=1
            ).permute(0, 2, 1, 3).reshape(B, 3 * T, self.h_dim)

        elif self.which_model == "traj_pdt":
            prompt_time_embeddings = self.embed_traj_prompt_timesteps(traj_prompt_timesteps)

            # embed traj prompt
            traj_prompt_state_embeddings = self.embed_traj_prompt_states(traj_prompt_states) + prompt_time_embeddings
            traj_prompt_action_embeddings = self.embed_traj_prompt_actions(traj_prompt_actions) + prompt_time_embeddings
            traj_prompt_returns_embeddings = self.embed_traj_prompt_rtg(traj_prompt_rtgs) + prompt_time_embeddings

            # prepend prompt
            prompt_seq_state_embeddings = torch.cat([traj_prompt_state_embeddings, state_embeddings], dim=1)
            prompt_seq_action_embeddings = torch.cat([traj_prompt_action_embeddings, action_embeddings], dim=1)
            prompt_seq_rtg_embeddings = torch.cat([traj_prompt_returns_embeddings, returns_embeddings], dim=1)

            h = torch.stack(
                (prompt_seq_rtg_embeddings, prompt_seq_state_embeddings, prompt_seq_action_embeddings), dim=1
            ).permute(0, 2, 1, 3).reshape(B, (3 * T) + (3 * self.traj_prompt_j * self.traj_prompt_h), self.h_dim)

        else:
            raise NotImplementedError

        h = self.embed_ln(h)  # layer norm

        # transformer and prediction
        h = self.transformer(h)

        if self.which_model == "dt":
            # get h reshaped such that its size = (B x 3 x T x h_dim) and
            # h[:, 0, t] is conditioned on the input sequence r_0, s_0, a_0 ... r_t
            # h[:, 1, t] is conditioned on the input sequence r_0, s_0, a_0 ... r_t, s_t
            # h[:, 2, t] is conditioned on the input sequence r_0, s_0, a_0 ... r_t, s_t, a_t
            # that is, for each timestep (t) we have 3 output embeddings from the transformer,
            # each conditioned on all previous timesteps plus
            # the 3 input variables at that timestep (r_t, s_t, a_t) in sequence.
            h = h.reshape(B, T, 3, self.h_dim).permute(0, 2, 1, 3)

            # get predictions
            return_preds_seq = self.predict_rtg(h[:, 2])  # predict next rtg given r, s, a
            state_preds_seq = self.predict_state(h[:, 2])  # predict next state given r, s, a
            action_preds_seq = self.predict_action(h[:, 1])  # predict action given r, s

            action_preds_prompt = None
            state_preds_prompt = None
            return_preds_prompt = None

        elif self.which_model == "traj_pdt":
            # reshape h, the hidden representations for each token, by splitting on the modalities and moving them to dim 1
            h = h.reshape(B, T + self.traj_prompt_j * self.traj_prompt_h, 3, self.h_dim).permute(0, 2, 1, 3)

            rtg_tokens_repr = h[:, 0, :, :]
            state_tokens_repr = h[:, 1, :, :]
            action_tokens_repr = h[:, 2, :, :]

            return_preds = self.predict_rtg(action_tokens_repr)
            state_preds = self.predict_state(action_tokens_repr)
            action_preds = self.predict_action(state_tokens_repr)

            return_preds_seq = return_preds[:, -T:, :]
            return_preds_prompt = return_preds[:, :-T, :]
            state_preds_seq = state_preds[:, -T:, :]
            state_preds_prompt = state_preds[:, :-T, :]
            action_preds_seq = action_preds[:, -T:, :]
            action_preds_prompt = action_preds[:, :-T, :]

        else:
            raise NotImplementedError

        action_logits = None

        if return_token_features:
            return state_preds_seq, action_preds_seq, return_preds_seq, action_logits, action_preds_prompt, state_preds_prompt, return_preds_prompt, rtg_tokens_repr, state_tokens_repr, action_tokens_repr
        else:
            return state_preds_seq, action_preds_seq, return_preds_seq, action_logits, action_preds_prompt, state_preds_prompt, return_preds_prompt
