from torch.autograd import Function
from mage.models.autoencoders import SymbolWiseTransformer
from mage.models.transformers import *
from mage.models.ein import EinLinear
import numpy as np
import warnings
import time

class VectorQuantization(Function):
    @staticmethod
    def forward(ctx, inputs, codebook):
        with torch.no_grad():
            embedding_size = codebook.size(1)
            inputs_size = inputs.size()
            inputs_flatten = inputs.view(-1, embedding_size)

            codebook_sqr = torch.sum(codebook ** 2, dim=1)
            inputs_sqr = torch.sum(inputs_flatten ** 2, dim=1, keepdim=True)

            # Compute the distances to the codebook
            distances = torch.addmm(codebook_sqr + inputs_sqr,
                inputs_flatten, codebook.t(), alpha=-2.0, beta=1.0)

            _, indices_flatten = torch.min(distances, dim=1)
            indices = indices_flatten.view(*inputs_size[:-1])
            ctx.mark_non_differentiable(indices)

            return indices

    @staticmethod
    def backward(ctx, grad_output):
        raise RuntimeError('Trying to call `.grad()` on graph containing '
            '`VectorQuantization`. The function `VectorQuantization` '
            'is not differentiable. Use `VectorQuantizationStraightThrough` '
            'if you want a straight-through estimator of the gradient.')

class VectorQuantizationStraightThrough(Function):
    @staticmethod
    def forward(ctx, inputs, codebook):
        indices = vq(inputs, codebook)
        indices_flatten = indices.view(-1)
        ctx.save_for_backward(indices_flatten, codebook)
        ctx.mark_non_differentiable(indices_flatten)

        codes_flatten = torch.index_select(codebook, dim=0,
            index=indices_flatten)
        codes = codes_flatten.view_as(inputs)

        return (codes, indices_flatten)

    @staticmethod
    def backward(ctx, grad_output, grad_indices):
        grad_inputs, grad_codebook = None, None

        if ctx.needs_input_grad[0]:
            grad_inputs = grad_output.clone()
        if ctx.needs_input_grad[1]:
            indices, codebook = ctx.saved_tensors
            embedding_size = codebook.size(1)

            grad_output_flatten = (grad_output.contiguous()
                                              .view(-1, embedding_size))
            grad_codebook = torch.zeros_like(codebook)
            grad_codebook.index_add_(0, indices, grad_output_flatten)

        return (grad_inputs, grad_codebook)

vq = VectorQuantization.apply
vq_st = VectorQuantizationStraightThrough.apply

class VQEmbeddingMovingAverage(nn.Module):
    def __init__(self, K, D, decay=0.99):
        super().__init__()
        embedding = torch.zeros(K, D)
        embedding.uniform_(-1./K, 1./K)
        self.decay = decay

        self.register_buffer("embedding", embedding)
        self.register_buffer("ema_count", torch.ones(K))
        self.register_buffer("ema_w", self.embedding.clone())

    def forward(self, z_e_x):
        z_e_x_ = z_e_x.contiguous()
        latents = vq(z_e_x_, self.embedding.weight)
        return latents

    def straight_through(self, z_e_x):
        K, D = self.embedding.size()

        z_e_x_ = z_e_x.contiguous()
        z_q_x_, indices = vq_st(z_e_x_, self.embedding)
        z_q_x = z_q_x_.contiguous()


        if self.training:
            encodings = F.one_hot(indices, K).float()
            self.ema_count = self.decay * self.ema_count + (1 - self.decay) * torch.sum(encodings, dim=0)

            dw = encodings.transpose(1, 0)@z_e_x_.reshape([-1, D])
            self.ema_w = self.decay * self.ema_w + (1 - self.decay) * dw

            self.embedding = self.ema_w / (self.ema_count.unsqueeze(-1))
            self.embedding = self.embedding.detach()
            self.ema_w = self.ema_w.detach()
            self.ema_count = self.ema_count.detach()

        z_q_x_bar_flatten = torch.index_select(self.embedding, dim=0, index=indices)
        z_q_x_bar_ = z_q_x_bar_flatten.view_as(z_e_x_)
        z_q_x_bar = z_q_x_bar_.contiguous()

        return z_q_x, z_q_x_bar, indices

class VQEmbedding(nn.Module):
    def __init__(self, K, D):
        super().__init__()
        self.embedding = nn.Embedding(K, D)
        self.embedding.weight.data.uniform_(-1./K, 1./K)

    def forward(self, z_e_x):
        z_e_x_ = z_e_x.contiguous()
        latents = vq(z_e_x_, self.embedding.weight)
        return latents

    def straight_through(self, z_e_x):
        z_e_x_ = z_e_x.contiguous()
        z_q_x_, indices = vq_st(z_e_x_, self.embedding.weight.detach())
        z_q_x = z_q_x_.contiguous()

        z_q_x_bar_flatten = torch.index_select(self.embedding.weight,
            dim=0, index=indices)
        z_q_x_bar_ = z_q_x_bar_flatten.view_as(z_e_x_)
        z_q_x_bar = z_q_x_bar_.contiguous()

        return z_q_x, z_q_x_bar, indices

class VQStepWiseTransformer(nn.Module):
    def __init__(self, config, feature_dim, v_patch_nums):
        super().__init__()
        self.K=config.K
        self.latent_size = config.trajectory_embd
        self.embedding_dim = config.n_embd
        self.trajectory_length = config.history_horizon + config.horizon

        self.beta = 0.25
        self.v_patch_nums = v_patch_nums
        self.observation_dim = feature_dim
        self.joined_dim = self.observation_dim + 1
        self.action_dim = config.action_dim
        self.transition_dim = config.transition_dim
        self.state_conditional = config.state_conditional
        
        if "ma_update" in config and not (config.ma_update):
            self.codebook = VQEmbedding(config.K, config.trajectory_embd)
            self.ma_update = False
        else:
            self.codebook = VQEmbeddingMovingAverage(config.K, config.trajectory_embd)
            self.ma_update = True

        self.encoder = nn.Sequential(*[Block(config) for _ in range(config.n_layer)])
        self.pos_emb = nn.Parameter(torch.zeros(1, self.trajectory_length, config.n_embd))
        self.embed = nn.Linear(self.transition_dim, self.embedding_dim)
        self.cast_embed = nn.Linear(self.embedding_dim, self.latent_size)
        self.polymerization = nn.Linear(self.trajectory_length, 1)
        self.pred_action = nn.Linear(self.embedding_dim, self.action_dim)
        self.scale_up = nn.ModuleList([nn.Linear(in_features, self.trajectory_length) for in_features in self.v_patch_nums])
        self.scale_down = nn.ModuleList([nn.Linear(self.trajectory_length, in_features) for in_features in self.v_patch_nums])
        self.decoder = nn.Sequential(*[Block(config) for _ in range(config.n_layer)])
        self.latent_mixing = nn.Linear(self.latent_size+self.joined_dim, self.embedding_dim)
        self.predict = nn.Linear(self.embedding_dim, self.transition_dim)

        self.ln_f = nn.LayerNorm(config.n_embd)
        self.drop = nn.Dropout(config.embd_pdrop)


    def encode(self, joined_inputs):
        joined_inputs = joined_inputs.to(dtype=torch.float32)
        b, t, joined_dimension = joined_inputs.size()
        token_embeddings = self.embed(joined_inputs)
        position_embeddings = self.pos_emb[:, :t, :]
        x = self.drop(token_embeddings + position_embeddings)
        x = self.encoder(x)
        x = self.cast_embed(x)
        return x

    def decode(self, latents, rstate):
        B, T, _ = latents.shape
        rstate_flat = torch.reshape(rstate, shape=[B, 1, -1]).repeat(1, T, 1)
        if not self.state_conditional:
            rstate_flat = torch.zeros_like(rstate_flat)
        inputs = torch.cat([rstate_flat, latents], dim=-1)
        inputs = self.latent_mixing(inputs)
        inputs = inputs + self.pos_emb[:, :inputs.shape[1]]

        x = inputs
        x = self.decoder(x)
        x = self.ln_f(x)

        action_pred = self.polymerization(x.transpose(1, 2))
        action_pred = self.pred_action(action_pred.squeeze(-1))

        joined_pred = self.predict(x)
        joined_pred[:, :, -1] = torch.sigmoid(joined_pred[:, :, -1])
        joined_pred[:, :, :self.joined_dim] += torch.reshape(rstate, shape=[B, 1, -1])
        return joined_pred, action_pred
    
    def decode_p1(self, latents, rstate):
        B, T, _ = latents.shape
        rstate_flat = torch.reshape(rstate, shape=[B, 1, -1]).repeat(1, T, 1)
        if not self.state_conditional:
            rstate_flat = torch.zeros_like(rstate_flat)
        inputs = torch.cat([rstate_flat, latents], dim=-1)
        inputs = self.latent_mixing(inputs)
        inputs = inputs + self.pos_emb[:, :inputs.shape[1]]
        return inputs
    
    def decode_p2(self, x, rstate):
        B, T, _ = x.shape
        x = self.ln_f(x)
        action_pred = self.polymerization(x.transpose(1, 2))
        action_pred = self.pred_action(action_pred.squeeze(-1))

        joined_pred = self.predict(x)
        joined_pred[:, :, -1] = torch.sigmoid(joined_pred[:, :, -1])
        joined_pred[:, :, :self.joined_dim] += torch.reshape(rstate, shape=[B, 1, -1])

        return joined_pred, action_pred

    def get_idx(self, f):
        b, *_ = f.shape
        f_rest = f.transpose(1, 2).clone()
        f_hat = torch.zeros_like(f_rest)
        idx = []
        SN = len(self.v_patch_nums)
        for si, pn in enumerate(self.v_patch_nums):
            rest_BE = self.scale_down[si](f_rest) if si != SN-1 else f_rest
            latents_st, latents, indices = self.codebook.straight_through(rest_BE.transpose(1, 2))
            h_BET = self.scale_up[si](latents.transpose(1, 2)) if si != SN-1 else latents.transpose(1, 2)

            idx.append(indices.view(b, -1))

            f_hat = f_hat + h_BET
            f_rest = f_rest - h_BET
        return idx

    def get_codebook_usage(self, joined_inputs):
        f = self.encode(joined_inputs)
        with torch.no_grad():
            idx = self.get_idx(f)
            combined = torch.cat([t.view(-1) for t in idx])
            unique_values = torch.unique(combined)
            num_unique = len(unique_values) 
            return num_unique / self.K

    def forward(self, joined_inputs, rstate):
        b, t, _ = joined_inputs.shape
        feat = self.encode(joined_inputs)
               
        f_rest = feat.transpose(1, 2).clone()
        f_hat = torch.zeros_like(f_rest)
        f_hat_st = torch.zeros_like(f_rest)

        mean_vq_loss = 0.0
        SN = len(self.v_patch_nums)
        for si, pn in enumerate(self.v_patch_nums):
            rest_BE = self.scale_down[si](f_rest) if si != SN-1 else f_rest
            latents_st, latents, _ = self.codebook.straight_through(rest_BE.transpose(1, 2))              
            h_BET_st = self.scale_up[si](latents_st.transpose(1, 2)) if si != SN-1 else latents_st.transpose(1, 2)
            h_BET = self.scale_up[si](latents.transpose(1, 2)) if si != SN-1 else latents.transpose(1, 2)

            f_hat = f_hat + h_BET
            f_hat_st = f_hat_st + h_BET_st
            f_rest = f_rest - h_BET

            mean_vq_loss += F.mse_loss(f_hat.detach(), feat.transpose(1, 2), reduction='mean') * self.beta + F.mse_loss(f_hat, feat.transpose(1, 2).detach(), reduction='mean')     
        
        mean_vq_loss *= 1. / SN

        joined_pred, pred_actions = self.decode(f_hat_st.transpose(1, 2), rstate)

        return joined_pred, pred_actions, mean_vq_loss, f_hat, feat


class VQContinuousVAE(nn.Module):

    def __init__(self, config):
        super().__init__()

        self.trajectory_embd = config.trajectory_embd
        self.vocab_size = config.vocab_size
        self.observation_dim = config.observation_dim

        self.action_dim = config.action_dim
        self.joined_dim = 1 + self.observation_dim
        self.history_horizon = config.history_horizon
        self.horizon = config.horizon
        self.trajectory_length = config.history_horizon + config.horizon
        self.use_action = config.use_action
        self.transition_dim = config.transition_dim

        self.action_weight = config.action_weight
        self.reward_weight = config.reward_weight
        self.value_weight = config.value_weight
        self.position_weight = config.position_weight
        self.current_obs_weight = config.current_obs_weight
        self.current_action_weight = config.current_action_weight
        self.next_obs_weight = config.next_obs_weight
        self.next_action_weight = config.next_action_weight

        maxp = self.trajectory_length
        self.v_patch_nums = (1, 2, 4, maxp//4, maxp//2, maxp*3//4, maxp)

        self.model = VQStepWiseTransformer(config, config.observation_dim, self.v_patch_nums)

        self.padding_vector = torch.zeros(self.transition_dim-1)
        self.apply(self._init_weights)

    def set_padding_vector(self, padding):
        self.padding_vector = padding

    def _init_weights(self, module):
        if isinstance(module, (nn.Linear, nn.Embedding, nn.Conv1d)):
            module.weight.data.normal_(mean=0.0, std=0.02)
            if isinstance(module, nn.Linear) and module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)
    
    def configure_optimizers(self, train_config):

        decay = set()
        no_decay = set()
        whitelist_weight_modules = (torch.nn.Linear, EinLinear, torch.nn.Conv1d, torch.nn.LSTM)
        blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding)
        for mn, m in self.named_modules():
            for pn, p in m.named_parameters():
                fpn = '%s.%s' % (mn, pn) if mn else pn

                if pn.endswith('bias'):
                    no_decay.add(fpn)
                elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules):
                    decay.add(fpn)
                elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules):
                    no_decay.add(fpn)
                elif isinstance(m, whitelist_weight_modules):
                    decay.add(fpn)
        if isinstance(self.model, SymbolWiseTransformer) or isinstance(self.model, VQStepWiseTransformer):
            no_decay.add('model.pos_emb')

        param_dict = {pn: p for pn, p in self.named_parameters()}
        inter_params = decay & no_decay
        union_params = decay | no_decay
        assert len(inter_params) == 0, "parameters %s made it into both decay/no_decay sets!" % (str(inter_params), )
        assert len(param_dict.keys() - union_params) == 0, "parameters %s were not separated into either decay/no_decay set!" \
                                                    % (str(param_dict.keys() - union_params), )

        optim_groups = [
            {"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": train_config.weight_decay},
            {"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0},
        ]
        optimizer = torch.optim.AdamW(optim_groups, lr=train_config.learning_rate, betas=train_config.betas)
        return optimizer

    @torch.no_grad()
    def encode(self, joined_inputs, terminals):
        b, t, joined_dimension = joined_inputs.size()
        padded = self.padding_vector.clone().detach().requires_grad_(True).to(dtype=torch.float32, device=joined_inputs.device).repeat(b, t, 1)
        terminal_mask = torch.clone(1 - terminals).repeat(1, 1, joined_inputs.shape[-1])
        joined_inputs = joined_inputs*terminal_mask+(1-terminal_mask)*padded

        trajectory_feature = self.model.encode(torch.cat([joined_inputs, terminals], dim=2))
        return trajectory_feature
    
    @torch.no_grad()
    def encode_to_idx(self, joined_inputs, terminals):
        b, t, joined_dimension = joined_inputs.size()
        padded = self.padding_vector.clone().detach().requires_grad_(True).to(dtype=torch.float32, device=joined_inputs.device).repeat(b, t, 1)
        terminal_mask = torch.clone(1 - terminals).repeat(1, 1, joined_inputs.shape[-1])
        joined_inputs = joined_inputs*terminal_mask+(1-terminal_mask)*padded

        trajectory_feature = self.model.encode(torch.cat([joined_inputs, terminals], dim=2))
        indices = self.model.get_idx(trajectory_feature)
        return indices

    def decode(self, latent, rstate):
        return self.model.decode(latent, rstate)

    def decode_from_indices(self, indices, state):
        pass
    
    def forward(self, joined_inputs, targets=None, mask=None, ts=None, terminals=None):
        actions = joined_inputs[:, self.history_horizon, self.joined_dim:-1]

        if not self.use_action:
            joined_inputs = joined_inputs[:, :, :self.joined_dim]
            if targets is not None:
                targets = targets[:, :, :self.joined_dim]
            if mask is not None:
                mask = mask[:, :, :self.joined_dim]

        joined_inputs = joined_inputs.to(dtype=torch.float32)
        b, t, joined_dimension = joined_inputs.size()
        
        padded = self.padding_vector.clone().detach().to(dtype=torch.float32, device=joined_inputs.device).repeat(b, t, 1)

        if terminals is not None:
            terminal_mask = torch.clone(1 - terminals)
            joined_inputs = joined_inputs*terminal_mask+(1-terminal_mask)*padded

        rstate = joined_inputs[:, self.history_horizon, :self.joined_dim]
        reconstructed, pred_actions, mean_vq_loss, latents, feature = self.model(torch.cat([joined_inputs, terminals], dim=2), rstate)
        pred_trajectory = torch.reshape(reconstructed[:, :, :-1], shape=[b, t, joined_dimension])
        pred_terminals = reconstructed[:, :, -1, None]

        if targets is not None:
            weights = torch.cat([
                torch.ones(3, device=joined_inputs.device)*self.position_weight,
                torch.ones(self.observation_dim-2, device=joined_inputs.device),
            ])
            if self.use_action:
                weights = torch.cat([weights, torch.ones(self.action_dim, device=joined_inputs.device)*self.action_weight])
            
            mse = F.mse_loss(pred_trajectory, joined_inputs, reduction='none')*weights[None, None, :]

            current_obs_loss = self.current_obs_weight*F.mse_loss(joined_inputs[:, self.history_horizon, :self.joined_dim],
                                                                  pred_trajectory[:, self.history_horizon, :self.joined_dim])
            next_obs_loss = self.next_obs_weight*F.mse_loss(joined_inputs[:, self.history_horizon+1, :self.joined_dim],
                                                            pred_trajectory[:, self.history_horizon+1, :self.joined_dim])
            if self.use_action:
                current_action_loss = self.current_action_weight*F.mse_loss(joined_inputs[:, self.history_horizon, self.joined_dim:],
                                                                            pred_trajectory[:, self.history_horizon, self.joined_dim:])
                next_action_loss = self.next_action_weight*F.mse_loss(joined_inputs[:, self.history_horizon+1, self.joined_dim:],
                                                                            pred_trajectory[:, self.history_horizon+1, self.joined_dim:])
            else:
                current_action_loss = 0
                next_action_loss = 0
            
            cross_entropy = F.binary_cross_entropy(pred_terminals, torch.clip(terminals.float(), 0.0, 1.0))
            reconstruction_loss = (mse*mask*terminal_mask).mean()+cross_entropy
            reconstruction_loss = reconstruction_loss + current_obs_loss + current_action_loss + next_obs_loss + next_action_loss
            
            action_loss = F.mse_loss(pred_actions, actions, reduction='mean')
        else:
            reconstruction_loss = None
            mean_vq_loss = None
            
        codebook_usage = self.model.get_codebook_usage(torch.cat([joined_inputs, terminals], dim=2))
        return reconstructed, pred_actions, reconstruction_loss, mean_vq_loss, action_loss, codebook_usage


class TransformerPrior(nn.Module):

    def __init__(self, config):
        super().__init__()

        self.joint_dim = config.observation_dim + 1
        self.state_emb = nn.Linear(config.observation_dim, config.n_embd)
        self.combine_proj = nn.Linear(config.n_embd + 1, config.n_embd)

        self.word_emb = nn.Linear(config.trajectory_embd, config.n_embd)
        self.drop = nn.Dropout(config.embd_pdrop)
        self.blocks = nn.Sequential(*[AttentionBlock(config) for _ in range(config.n_layer)])
        self.ln_f = nn.LayerNorm(config.n_embd)
        self.head = nn.Linear(config.n_embd, config.K, bias=False)
        self.observation_dim = config.observation_dim
        self.action_dim = config.action_dim
        self.history_horizon = config.history_horizon
        self.horizon = config.horizon
        self.max_path_length = config.max_path_length       

        self.patch_nums = config.v_patch_nums
        self.L = sum(self.patch_nums)
        self.lvl_embed = nn.Embedding(len(self.patch_nums), config.n_embd)
        self.pos_embed = nn.Parameter(torch.zeros(1, self.L, config.n_embd))
        d: torch.Tensor = torch.cat([torch.full((pn,), i) for i, pn in enumerate(self.patch_nums)]).view(1, self.L, 1)
        dT = d.transpose(1, 2)
        attn_bias_for_masking = torch.where(d >= dT, 1, 0).reshape(1, 1, self.L, self.L)
        self.register_buffer('attn_bias_for_masking', attn_bias_for_masking.contiguous())
        self.begin_ends = []
        cur = 0
        for i, pn in enumerate(self.patch_nums):
            self.begin_ends.append((cur, cur + pn))
            cur += pn

        hidden_dim = 256
        self.inv_model = nn.Sequential(
            nn.Linear(2 * self.observation_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, self.action_dim),
        )
        
        # adapter
        self.adapter_layers = nn.ModuleList([
            Adapter(config.n_embd, hidden_dim) for _ in range(config.n_layer)
        ])

        self.vocab_size = config.K
        self.embedding_dim = config.n_embd
        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, (nn.Linear, nn.Embedding)):
            module.weight.data.normal_(mean=0.0, std=0.02)
            if isinstance(module, nn.Linear) and module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)

    def configure_optimizers(self, train_config):

        decay = set()
        no_decay = set()
        whitelist_weight_modules = (torch.nn.Linear, EinLinear, torch.nn.Conv1d, torch.nn.LSTM)
        blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding)
        for mn, m in self.named_modules():
            for pn, p in m.named_parameters():
                fpn = '%s.%s' % (mn, pn) if mn else pn 

                if pn.endswith('bias'):
                    no_decay.add(fpn)
                elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules):
                    decay.add(fpn)
                elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules):
                    no_decay.add(fpn)
                elif pn.endswith('ada_gss'):
                    no_decay.add(fpn)
                elif isinstance(m, whitelist_weight_modules):
                    decay.add(fpn)

        no_decay.add('pos_embed')
        param_dict = {pn: p for pn, p in self.named_parameters()}
        inter_params = decay & no_decay
        union_params = decay | no_decay
        assert len(inter_params) == 0, "parameters %s made it into both decay/no_decay sets!" % (str(inter_params), )
        assert len(param_dict.keys() - union_params) == 0, "parameters %s were not separated into either decay/no_decay set!" \
                                                    % (str(param_dict.keys() - union_params), )

        optim_groups = [
            {"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": train_config.weight_decay},
            {"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0},
        ]
        optimizer = torch.optim.AdamW(optim_groups, lr=train_config.learning_rate, betas=train_config.betas)
        return optimizer
    
    @torch.no_grad()
    def get_inputs(self, idx, vae):
        next_scale = []
        B, T = idx[-1].shape
        SN = len(self.patch_nums)
        f_hat = torch.zeros(B, vae.model.embedding_dim, T, device=idx[0].device)
        for si in range(SN - 1): 
            h_BET = vae.model.scale_up[si](  torch.index_select(vae.model.codebook.embedding, dim=0, index=idx[si].flatten()).reshape([B, -1, self.patch_nums[si]])  )
            f_hat.add_(h_BET)
            next_scale.append(  vae.model.scale_down[si+1](f_hat) ) 
        return torch.cat(next_scale, dim=-1).transpose(1, 2)

    def forward(self, rtg, states, gt_idx=None, vae=None, temp=0.5):
        b, hh, obs_dim = states.shape
        hh -= 1
        states = states.to(dtype=torch.float32)
        rtg = rtg.to(dtype=torch.float32)
        state = states[:, -1, :]

        var_inputs = self.get_inputs(gt_idx, vae)

        if states.shape[1] == 1:
            sos = self.state_emb(states)  
        else: 
            output, (h_n, _) = self.lstm(states)    
            sos = h_n[-1].unsqueeze(1) 
         
        inputs = torch.cat([sos, self.word_emb(var_inputs)], dim=1)

        for i, pn in enumerate(self.patch_nums):
            lvl_embeddings = self.lvl_embed(torch.tensor([i], device=inputs.device)).repeat(pn, 1)
            if i == 0:
                level_embeddings = lvl_embeddings
            else:
                level_embeddings = torch.cat([level_embeddings, lvl_embeddings], dim=0)
        level_embeddings = level_embeddings[None, :, :]

        B, T, E = inputs.shape
        x = inputs + level_embeddings

        x = torch.cat([x, torch.reshape(rtg, shape=[B, 1, -1]).repeat(1, T, 1)], dim=-1)
        x = self.drop(self.combine_proj(x))
        for block in self.blocks: block.attn.kv_caching(False)
        for block in self.blocks:
            x = block(x, attn_bias=self.attn_bias_for_masking)
        # laynorm    
        x = self.ln_f(x)  
        
        logits = self.head(x)
        logits = logits.reshape(b, -1, self.vocab_size)
        
        gt = torch.cat(gt_idx, dim=1)
        loss = F.cross_entropy(logits.reshape(-1, self.vocab_size), gt.reshape([-1]), reduction='none')
        loss = loss.mean()

        logits_idx = []
        for i, (start, end) in enumerate(self.begin_ends):
            logits_idx.append(logits[:, start:end])
            
        cb_weights = gumbel_softmax(logits, temperature=temp, hard=True)
        latents = self.get_latent(cb_weights, vae)
        pred_traj, pred_actions = self.decode(latents, vae, rstate=torch.cat([rtg.reshape(-1, 1), state], dim=-1))
        
        return logits, loss, logits_idx, pred_traj, pred_actions
    
    def get_latent(self, cb_weights, vae):
        B, T, K = cb_weights.shape
        cb_weights = cb_weights.reshape(-1, K) 
        cb_vec = cb_weights @ vae.model.codebook.embedding
        cb_vec = cb_vec.reshape(B, T, -1).transpose(1, 2)
        latents = torch.zeros(B, self.embedding_dim, self.patch_nums[-1], device=cb_weights.device)
        for i, (st, ed) in enumerate(self.begin_ends):
            latents += vae.model.scale_up[i](cb_vec[..., st:ed])
        return latents.transpose(1, 2)
    
    def decode(self, latents, vae, rstate):
        x = vae.model.decode_p1(latents, rstate)
        assert len(vae.model.decoder) == len(self.adapter_layers), "Mismatch in block and adapter count"
        for block, adapter in zip(vae.model.decoder, self.adapter_layers):
            x = block(x)
            x = x + adapter(x)
        return vae.model.decode_p2(x, rstate)
    
    def pred_loss(self, traj, pred_traj, mask=None, terminals=None, padded=None, weight=2, beta=0.2):
        pred_terminals = pred_traj[..., -1, None]                                                     
        pred_traj = pred_traj[..., :-1]                          
        b, t, joined_dim = traj.shape
        if terminals is not None:
            padded = padded.repeat(b, t, 1)
            terminal_mask = torch.clone(1 - terminals).repeat(1, 1, traj.shape[-1])
            traj = traj*terminal_mask+(1-terminal_mask)*padded
        
        terminal_loss = F.binary_cross_entropy(pred_terminals, torch.clip(terminals.float(), 0.0, 1.0))
        
        loss = (F.mse_loss(traj, pred_traj, reduction='none') * mask[..., :joined_dim] * terminal_mask).mean() + terminal_loss   
        loss += F.mse_loss(traj[:, :2, :], pred_traj[:, :2, :], reduction='mean') * (weight - 1)
        return loss * beta
    
    def action_loss(self, actions, pred_actions, beta=0.2):
        return F.mse_loss(actions, pred_actions, reduction='mean') * beta
        
    def inv_loss(self, obs, next_obs, action):
        pred_action = self.inv_model(torch.cat([obs, next_obs], dim=-1))
        return F.mse_loss(pred_action, action)

    @torch.no_grad()
    def sample(self, rtg, representation, trajectory):
        if len(trajectory.shape) == 2:
            trajectory = trajectory[None, :, :]
        b, *_ = trajectory.shape
        e = representation.model.embedding_dim

        states = trajectory[:, :representation.history_horizon+1, :]
        state = trajectory[:, representation.history_horizon, :representation.observation_dim]
        state = state.to(dtype=torch.float32)
        rtg = torch.tensor([rtg], dtype=torch.float32, device= state.device)
        
        if states.shape[1] == 1:
            sos = self.state_emb(states)  
        else: 
            output, (h_n, _) = self.lstm(states)    
            sos = h_n[-1].unsqueeze(1) 

        pn = self.patch_nums[0]    
        for i, pn in enumerate(self.patch_nums):
            lvl_embeddings = self.lvl_embed(torch.tensor([i], device=trajectory.device)).repeat(pn, 1)
            if i == 0:
                level_embeddings = lvl_embeddings
            else:
                level_embeddings = torch.cat([level_embeddings, lvl_embeddings], dim=0)
        level_embeddings = level_embeddings[None, :, :]

        x = sos
        latents = torch.zeros(b, e, self.patch_nums[-1], device=trajectory.device)
        logits_idx = []

        for block in self.blocks: block.attn.kv_caching(True)
        for si, pn in enumerate(self.patch_nums):  
            start, end = self.begin_ends[si] 
            x = x + level_embeddings[:, start:end, :]
            x = torch.cat([x, torch.reshape(rtg, shape=[-1, 1, 1]).repeat(1, pn, 1)], dim=-1)
            x = self.combine_proj(x)

            for block in self.blocks:
                x = block(x, attn_bias=None)
            x = self.ln_f(x)     
            logits = self.head(x)
            logits = logits.reshape(b, -1, self.vocab_size)

            idx = torch.argmax(logits, dim=-1)
            logits_idx.append(idx)
            latent = torch.index_select(  representation.model.codebook.embedding, dim=0, index=idx.flatten()  ).reshape([b, -1, pn]) 
            latents += representation.model.scale_up[si](latent) if si != len(self.patch_nums)-1 else latent
            
            if si != len(self.patch_nums)-1:
                x = self.word_emb( representation.model.scale_down[si+1](latents).transpose(1, 2) )
       
        for block in self.blocks: block.attn.kv_caching(False)
        pred_traj, pred_actions = self.decode(latents.transpose(1, 2), representation, torch.cat([rtg.reshape(-1, 1), state], dim=-1))
        return pred_actions