import torch
import torch.nn as nn
import numpy as np
# from networks.layers import *
import torch.nn.functional as F
import clip
from model.encode_text import T5TextEncoder, CLIPTextEncoder
from einops import repeat
from functools import partial
from model.transformer.tools import *
from torch.distributions.categorical import Categorical

class InputProcess(nn.Module):
    def __init__(self, input_feats, latent_dim):
        super().__init__()
        self.input_feats = input_feats
        self.latent_dim = latent_dim
        self.poseEmbedding = nn.Linear(self.input_feats, self.latent_dim)

    def forward(self, x):
        # [bs, ntokens, input_feats]
        x = x.permute((1, 0, 2)) # [seqen, bs, input_feats]
        # print(x.shape)
        x = self.poseEmbedding(x)  # [seqlen, bs, d]
        return x

class PositionalEncoding(nn.Module):
    #Borrow from MDM, the same as above, but add dropout, exponential may improve precision
    def __init__(self, d_model, dropout=0.1, max_len=500):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1) #[max_len, 1, d_model]

        self.register_buffer('pe', pe)

    def forward(self, x):
        # not used in the final model
        x = x + self.pe[:x.shape[0], :]
        return self.dropout(x)

class OutputProcess_Bert(nn.Module):
    def __init__(self, out_feats, latent_dim):
        super().__init__()
        self.dense = nn.Linear(latent_dim, latent_dim)
        self.transform_act_fn = F.gelu
        self.LayerNorm = nn.LayerNorm(latent_dim, eps=1e-12)
        self.poseFinal = nn.Linear(latent_dim, out_feats) #Bias!

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        hidden_states = self.dense(hidden_states)
        hidden_states = self.transform_act_fn(hidden_states)
        hidden_states = self.LayerNorm(hidden_states)
        output = self.poseFinal(hidden_states)  # [seqlen, bs, out_feats]
        output = output.permute(1, 2, 0)  # [bs, c, seqlen]
        return output

class OutputProcess(nn.Module):
    def __init__(self, out_feats, latent_dim):
        super().__init__()
        self.dense = nn.Linear(latent_dim, latent_dim)
        self.transform_act_fn = F.gelu
        self.LayerNorm = nn.LayerNorm(latent_dim, eps=1e-12)
        self.poseFinal = nn.Linear(latent_dim, out_feats) #Bias!

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        hidden_states = self.dense(hidden_states)
        hidden_states = self.transform_act_fn(hidden_states)
        hidden_states = self.LayerNorm(hidden_states)
        output = self.poseFinal(hidden_states)  # [seqlen, bs, out_feats]
        output = output.permute(1, 2, 0)  # [bs, e, seqlen]
        return output


class MaskTransformer(nn.Module):
    def __init__(self, code_dim, latent_dim=256, ff_size=1024, num_layers=8,
                 num_heads=4, dropout=0.1, text_dim=512, cond_drop_prob=0.1,
                 device=None, cfg=None, **kargs):
        super(MaskTransformer, self).__init__()
        print(f'latent_dim: {latent_dim}, ff_size: {ff_size}, nlayers: {num_layers}, nheads: {num_heads}, dropout: {dropout}')

        self.code_dim = code_dim
        self.latent_dim = latent_dim
        self.text_dim = text_dim
        self.dropout = dropout
        self.cfg = cfg
        self.device = device

        self.cond_drop_prob = cond_drop_prob


        '''
        Preparing Networks
        '''
        self.input_process = InputProcess(self.code_dim, self.latent_dim)
        self.position_enc = PositionalEncoding(self.latent_dim, self.dropout)

        seqTransEncoderLayer = nn.TransformerEncoderLayer(d_model=self.latent_dim,
                                                          nhead=num_heads,
                                                          dim_feedforward=ff_size,
                                                          dropout=dropout,
                                                          activation='gelu')

        self.seqTransEncoder = nn.TransformerEncoder(seqTransEncoderLayer,
                                                     num_layers=num_layers)

        self.cond_emb = nn.Linear(self.text_dim, self.latent_dim)

        _num_tokens = cfg.vq.nb_code + 2  # two dummy tokens, one for masking, one for padding
        self.mask_id = cfg.vq.nb_code
        self.pad_id = cfg.vq.nb_code + 1

        self.output_process = OutputProcess_Bert(out_feats=cfg.vq.nb_code, latent_dim=latent_dim)

        self.token_emb = nn.Embedding(_num_tokens, self.code_dim)

        self.apply(self.__init_weights)

        '''
        Preparing frozen weights
        '''

        # if self.cond_mode == 'text':
        #     print('Loading CLIP...')
        #     self.clip_version = clip_version
        #     self.clip_model = self.load_and_freeze_clip(clip_version)

        self.text_emb = T5TextEncoder(
            device, 
        #  use_text_preprocessing,
            local_files_only=False, 
            from_pretrained=cfg.text_embedder.version, 
            model_max_length=cfg.data.max_text_length
        )

        self.noise_schedule = cosine_schedule


    def __init_weights(self, module):
        if isinstance(module, (nn.Linear, nn.Embedding)):
            module.weight.data.normal_(mean=0.0, std=0.02)
            if isinstance(module, nn.Linear) and module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)


    def encode_text(self, raw_text):
        text_embedding, mask = self.text_emb.get_text_embeddings(raw_text)
        return text_embedding, mask

    def mask_cond(self, cond, force_mask=False):
        bs, _, _ =  cond.shape
        if force_mask:
            return torch.zeros_like(cond)
        elif self.training and self.cond_drop_prob > 0.:
            mask = torch.bernoulli(torch.ones(bs, device=cond.device) * self.cond_drop_prob).view(bs, 1, 1)
            return cond * (1. - mask)
        else:
            return cond

    def trans_forward(self, motion_ids, cond, cond_padding_mask, motion_padding_mask, force_mask=False):
        '''
        :param motion_ids: (b, seqlen)
        :cond_padding_mask: (b, seqlen), all pad positions are TRUE else FALSE
        :motion_padding_mask: (b, t_seqlen), all pad positions are TRUE else FALSE
        :param cond: (b, t_seqlen, embed_dim) for text
        :param force_mask: boolean
        :return:
            -logits: (b, num_token, seqlen)
        '''
        b, t_seqlen, _ = cond.shape
        # cond = self.mask_cond(cond, force_mask=force_mask)

        # print(motion_ids.shape)
        x = self.token_emb(motion_ids)
        # print(x.shape)
        # (b, seqlen, d) -> (seqlen, b, latent_dim)
        x = self.input_process(x)

        cond = self.cond_emb(cond).permute(1, 0, 2) #(1, b, latent_dim)

        x = self.position_enc(x)
        cond = self.position_enc(cond)
        xseq = torch.cat([cond, x], dim=0) #(seqlen+t_seqlen, b, latent_dim)

        # padding_mask = torch.cat([torch.zeros_like(padding_mask[:, 0:1]), padding_mask], dim=1) #(b, seqlen+1)
        padding_mask = torch.cat([cond_padding_mask, motion_padding_mask], dim=1).bool() #(b, seqlen+t_seqlen)
        # print(xseq.shape, padding_mask.shape)

        # print(padding_mask.shape, xseq.shape)

        output = self.seqTransEncoder(xseq, src_key_padding_mask=padding_mask)[t_seqlen:] #(seqlen, b, e)
        logits = self.output_process(output) #(seqlen, b, e) -> (b, ntoken, seqlen)
        return logits

    def forward(self, ids, y, m_lens):
        '''
        :param ids: (b, n)
        :param y: raw text for cond_mode=text, (b, ) for cond_mode=action
        :m_lens: (b,)
        :return:
        '''

        bs, ntokens = ids.shape
        device = ids.device

        # Positions that are PADDED are ALL FALSE
        non_pad_mask = lengths_to_mask(m_lens, ntokens) #(b, n)
        ids = torch.where(non_pad_mask, ids, self.pad_id)

        with torch.no_grad():
            cond_embs, cond_att_mask = self.encode_text(y)
            cond_padding_mask = (cond_att_mask==0)

        '''
        Prepare mask
        '''
        rand_time = uniform((bs,), device=device)
        rand_mask_probs = self.noise_schedule(rand_time)
        num_token_masked = (ntokens * rand_mask_probs).round().clamp(min=1)

        batch_randperm = torch.rand((bs, ntokens), device=device).argsort(dim=-1)
        # Positions to be MASKED are ALL TRUE
        mask = batch_randperm < num_token_masked.unsqueeze(-1)

        # Positions to be MASKED must also be NON-PADDED
        mask &= non_pad_mask

        # Note this is our training target, not input
        labels = torch.where(mask, ids, self.mask_id)

        x_ids = ids.clone()

        # Further Apply Bert Masking Scheme
        # Step 1: 10% replace with an incorrect token
        mask_rid = get_mask_subset_prob(mask, 0.1)
        rand_id = torch.randint_like(x_ids, high=self.cfg.vq.nb_code)
        x_ids = torch.where(mask_rid, rand_id, x_ids)
        # Step 2: 90% x 10% replace with correct token, and 90% x 88% replace with mask token
        mask_mid = get_mask_subset_prob(mask & ~mask_rid, 0.88)

        # mask_mid = mask

        x_ids = torch.where(mask_mid, self.mask_id, x_ids)

        cond_embs = self.mask_cond(cond_embs)

        logits = self.trans_forward(x_ids, cond_embs, cond_padding_mask, ~non_pad_mask)
        ce_loss, pred_id, acc = cal_performance(logits, labels, ignore_index=self.mask_id)

        return ce_loss, pred_id, acc

    def forward_with_cond_scale(self,
                                motion_ids,
                                cond_embs,
                                cond_padding_mask,
                                motion_padding_mask,
                                cond_scale=3,
                                force_mask=False):
        # bs = motion_ids.shape[0]
        # if cond_scale == 1:

        input_motion_ids = torch.cat([motion_ids, motion_ids], dim=0)
        input_cond_embs = torch.cat([self.mask_cond(cond_embs, force_mask=True),
                                     self.mask_cond(cond_embs, force_mask=False)], dim=0)
        input_cond_padding_mask = torch.cat([cond_padding_mask, cond_padding_mask], dim=0)
        input_motion_padding_mask = torch.cat([motion_padding_mask, motion_padding_mask], dim=0)

        output_logits = self.trans_forward(input_motion_ids, input_cond_embs, input_cond_padding_mask, input_motion_padding_mask)
        aux_logits, logits = output_logits.chunk(2, dim=0)

        scaled_logits = aux_logits + (logits - aux_logits) * cond_scale
        return scaled_logits

    @torch.no_grad()
    @eval_decorator
    def generate(self,
                 conds,
                 m_lens,
                 timesteps: int,
                 cond_scale: int,
                 temperature=1,
                 topk_filter_thres=0.9,
                 gsample=False,
                 force_mask=False
                 ):
        # print(self.cfg.vq.num_quantizers)
        # assert len(timesteps) >= len(cond_scales) == self.cfg.vq.num_quantizers

        device = next(self.parameters()).device
        seq_len = max(m_lens)

        with torch.no_grad():
            cond_embs, cond_att_mask = self.encode_text(conds)
            cond_padding_mask = (cond_att_mask==0)


        padding_mask = ~lengths_to_mask(m_lens, seq_len)
        # print(padding_mask.shape, )

        # Start from all tokens being masked
        ids = torch.where(padding_mask, self.pad_id, self.mask_id)
        scores = torch.where(padding_mask, 1e5, 0.)
        starting_temperature = temperature

        for timestep, steps_until_x0 in zip(torch.linspace(0, 1, timesteps, device=device), reversed(range(timesteps))):
            # 0 < timestep < 1
            rand_mask_prob = self.noise_schedule(timestep)  # Tensor

            '''
            Maskout, and cope with variable length
            '''
            # fix: the ratio regarding lengths, instead of seq_len
            num_token_masked = torch.round(rand_mask_prob * m_lens).clamp(min=1)  # (b, )

            # select num_token_masked tokens with lowest scores to be masked
            sorted_indices = scores.argsort(
                dim=1)  # (b, k), sorted_indices[i, j] = the index of j-th lowest element in scores on dim=1
            ranks = sorted_indices.argsort(dim=1)  # (b, k), rank[i, j] = the rank (0: lowest) of scores[i, j] on dim=1
            is_mask = (ranks < num_token_masked.unsqueeze(-1))
            ids = torch.where(is_mask, self.mask_id, ids)

            '''
            Preparing input
            '''
            # (b, num_token, seqlen)
            logits = self.forward_with_cond_scale(ids, cond_embs, cond_padding_mask,
                                                  motion_padding_mask=padding_mask,
                                                  cond_scale=cond_scale,
                                                  force_mask=force_mask)
            

            logits = logits.permute(0, 2, 1)  # (b, seqlen, ntoken)
            # print(logits.shape, self.cfg.num_tokens)
            # clean low prob token
            filtered_logits = top_k(logits, topk_filter_thres, dim=-1)

            '''
            Update ids
            '''
            # if force_mask:
            temperature = starting_temperature
            # else:
            # temperature = starting_temperature * (steps_until_x0 / timesteps)
            # temperature = max(temperature, 1e-4)
            # print(filtered_logits.shape)
            # temperature is annealed, gradually reducing temperature as well as randomness
            if gsample:  # use gumbel_softmax sampling
                # print("1111")
                pred_ids = gumbel_sample(filtered_logits, temperature=temperature, dim=-1)  # (b, seqlen)
            else:  # use multinomial sampling
                # print("2222")
                probs = F.softmax(filtered_logits / temperature, dim=-1)  # (b, seqlen, ntoken)
                # print(temperature, starting_temperature, steps_until_x0, timesteps)
                # print(probs / temperature)
                pred_ids = Categorical(probs).sample()  # (b, seqlen)

            # print(pred_ids.max(), pred_ids.min())
            # if pred_ids.
            ids = torch.where(is_mask, pred_ids, ids)

            '''
            Updating scores
            '''
            probs_without_temperature = logits.softmax(dim=-1)  # (b, seqlen, ntoken)
            scores = probs_without_temperature.gather(2, pred_ids.unsqueeze(dim=-1))  # (b, seqlen, 1)
            scores = scores.squeeze(-1)  # (b, seqlen)

            # We do not want to re-mask the previously kept tokens, or pad tokens
            scores = scores.masked_fill(~is_mask, 1e5)

        ids = torch.where(padding_mask, -1, ids)
        # print("Final", ids.max(), ids.min())
        return ids


class ExtendMaskTransformer(nn.Module):
    def __init__(self, code_dim, latent_dim=256, ff_size=1024, num_layers=8,
                 num_heads=4, dropout=0.1, text_dim=512, cond_drop_prob=0.1,
                 device=None, cfg=None, full_length=80, scales=[8, 4, 2, 1]):
        super(ExtendMaskTransformer, self).__init__()
        print(f'latent_dim: {latent_dim}, ff_size: {ff_size}, nlayers: {num_layers}, nheads: {num_heads}, dropout: {dropout}')

        self.code_dim = code_dim
        self.latent_dim = latent_dim
        self.text_dim = text_dim
        self.dropout = dropout
        self.cfg = cfg
        self.device = device
        self.full_length = full_length
        self.scales = scales

        self.cond_drop_prob = cond_drop_prob


        '''
        Preparing Networks
        '''
        self.input_process = InputProcess(self.code_dim, self.latent_dim)
        self.position_enc = PositionalEncoding(self.latent_dim, self.dropout)

        if self.cfg.model.fuse_mode == 'in_context':
            seqTransEncoderLayer = nn.TransformerEncoderLayer(d_model=self.latent_dim,
                                                            nhead=num_heads,
                                                            dim_feedforward=ff_size,
                                                            dropout=dropout,
                                                            activation='gelu')

            self.seqTransEncoder = nn.TransformerEncoder(seqTransEncoderLayer,
                                                        num_layers=num_layers)
        elif self.cfg.model.fuse_mode == 'cross_attention':
            seqTransEncoderLayer = nn.TransformerDecoderLayer(d_model=self.latent_dim,
                                                              nhead=num_heads,
                                                              dim_feedforward=ff_size,
                                                              dropout=dropout,
                                                              activation='gelu')
            self.seqTransEncoder = nn.TransformerDecoder(seqTransEncoderLayer,
                                                         num_layers=num_layers)

        self.cond_emb = nn.Linear(self.text_dim, self.latent_dim)

        _num_tokens = cfg.vq.nb_code + 2  # two dummy tokens, one for masking, one for padding
        self.mask_id = cfg.vq.nb_code
        self.pad_id = cfg.vq.nb_code + 1

        self.output_process = OutputProcess_Bert(out_feats=cfg.vq.nb_code, latent_dim=latent_dim)

        self.token_emb = nn.Embedding(_num_tokens, self.code_dim)

        self.apply(self.__init_weights)

        '''
        Preparing frozen weights
        '''

        self.text_emb = T5TextEncoder(
            device, 
        #  use_text_preprocessing,
            local_files_only=False, 
            from_pretrained=cfg.text_embedder.version, 
            model_max_length=cfg.data.max_text_length
        )

        self.noise_schedule = cosine_schedule


    def __init_weights(self, module):
        if isinstance(module, (nn.Linear, nn.Embedding)):
            module.weight.data.normal_(mean=0.0, std=0.02)
            if isinstance(module, nn.Linear) and module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)


    # def encode_text(self, raw_text):
    #     device = next(self.parameters()).device
    #     text = clip.tokenize(raw_text, truncate=True).to(device)
    #     feat_clip_text = self.clip_model.encode_text(text).float()
    #     return feat_clip_text

    def encode_text(self, raw_text):
        text_embedding, mask = self.text_emb.get_text_embeddings(raw_text)
        return text_embedding, mask

    def mask_cond(self, cond, force_mask=False):
        bs, _, _ =  cond.shape
        if force_mask:
            return torch.zeros_like(cond)
        elif self.training and self.cond_drop_prob > 0.:
            mask = torch.bernoulli(torch.ones(bs, device=cond.device) * self.cond_drop_prob).view(bs, 1, 1)
            return cond * (1. - mask)
        else:
            return cond

    def trans_forward(self, motion_ids, cond, cond_padding_mask, motion_padding_mask):
        '''
        :param motion_ids: (b, seqlen)
        :cond_padding_mask: (b, seqlen), all pad positions are TRUE else FALSE
        :motion_padding_mask: (b, t_seqlen), all pad positions are TRUE else FALSE
        :param cond: (b, t_seqlen, embed_dim) for text
        :param force_mask: boolean
        :return:
            -logits: (b, num_token, seqlen)
        '''
        b, t_seqlen, _ = cond.shape
        # cond = self.mask_cond(cond, force_mask=force_mask)

        # print(motion_ids.shape)
        x = self.token_emb(motion_ids)
        # print(x.shape)
        # (b, seqlen, d) -> (seqlen, b, latent_dim)
        x = self.input_process(x)

        cond = self.cond_emb(cond).permute(1, 0, 2) #(t, b, latent_dim)

        x = self.position_enc(x)
        cond = self.position_enc(cond)

        if self.cfg.model.fuse_mode == 'in_context':
            xseq = torch.cat([cond, x], dim=0) #(seqlen+t_seqlen, b, latent_dim)

            padding_mask = torch.cat([cond_padding_mask, motion_padding_mask], dim=1).bool() #(b, seqlen+t_seqlen)
            output = self.seqTransEncoder(xseq, src_key_padding_mask=padding_mask)[t_seqlen:] #(seqlen, b, e)
            logits = self.output_process(output) #(seqlen, b, e) -> (b, ntoken, seqlen)
        elif self.cfg.model.fuse_mode == 'cross_attention':
            output = self.seqTransEncoder(x, 
                                          cond, 
                                          tgt_key_padding_mask=motion_padding_mask, 
                                          memory_key_padding_mask=cond_padding_mask,
                                          )
            logits = self.output_process(output) #(seqlen, b, e) -> (b, ntoken, seqlen)
        return logits

    def forward(self, id_list, y, m_lens):
        '''
        :param ids: (b, n)
        :param y: raw text for cond_mode=text, (b, ) for cond_mode=action
        :m_lens: (b,)
        :return:
        '''

        # ids = []
        non_pad_mask = []
        ids = []
        assert self.full_length == id_list[-1].shape[1]
        for ele in id_list:
            ds_mlens = (m_lens * ele.shape[1]) // self.full_length 
            ds_non_pad_mask = lengths_to_mask(ds_mlens, ele.shape[1])
            non_pad_mask.append(ds_non_pad_mask)
            ids.append(ele)
        
        ids = torch.cat(ids, dim=1)
        non_pad_mask = torch.cat(non_pad_mask, dim=1)

        assert ids.shape[:2] == non_pad_mask.shape[:2]

        bs, ntokens = ids.shape
        device = ids.device

        # Positions that are PADDED are ALL FALSE
        # non_pad_mask = lengths_to_mask(m_lens, ntokens) #(b, n)
        ids = torch.where(non_pad_mask, ids, self.pad_id)

        with torch.no_grad():
            cond_embs, cond_att_mask = self.encode_text(y)
            cond_padding_mask = (cond_att_mask==0)

        '''
        Prepare mask
        '''
        rand_time = uniform((bs,), device=device)
        rand_mask_probs = self.noise_schedule(rand_time)
        num_token_masked = (ntokens * rand_mask_probs).round().clamp(min=1)

        batch_randperm = torch.rand((bs, ntokens), device=device).argsort(dim=-1)
        # Positions to be MASKED are ALL TRUE
        mask = batch_randperm < num_token_masked.unsqueeze(-1)

        # Positions to be MASKED must also be NON-PADDED
        mask &= non_pad_mask

        # Note this is our training target, not input
        labels = torch.where(mask, ids, self.mask_id)

        x_ids = ids.clone()

        # Further Apply Bert Masking Scheme
        # Step 1: 10% replace with an incorrect token
        mask_rid = get_mask_subset_prob(mask, 0.1)
        rand_id = torch.randint_like(x_ids, high=self.cfg.vq.nb_code)
        x_ids = torch.where(mask_rid, rand_id, x_ids)
        # Step 2: 90% x 10% replace with correct token, and 90% x 88% replace with mask token
        mask_mid = get_mask_subset_prob(mask & ~mask_rid, 0.88)

        x_ids = torch.where(mask_mid, self.mask_id, x_ids)

        # x_ids = torch.where(mask, self.mask_id, x_ids)

        # if self.training:
        #     mask_rid = get_mask_subset_prob(~mask & non_pad_mask, self.cfg.training.pert_prob)
        #     rand_tokens = torch.randint_like(x_ids, high=self.cfg.vq.nb_code)
        #     x_ids = torch.where(mask_rid, rand_tokens, x_ids)
        

        cond_embs = self.mask_cond(cond_embs)

        logits = self.trans_forward(x_ids, cond_embs, cond_padding_mask, ~non_pad_mask)
        ce_loss, pred_id, acc = cal_performance(logits, labels, ignore_index=self.mask_id)

        return ce_loss, pred_id, acc

    def forward_with_cond_scale(self,
                                motion_ids,
                                cond_embs,
                                cond_padding_mask,
                                motion_padding_mask,
                                cond_scale=3,
                                force_mask=False):
        # bs = motion_ids.shape[0]
        # if cond_scale == 1:

        input_motion_ids = torch.cat([motion_ids, motion_ids], dim=0)
        input_cond_embs = torch.cat([self.mask_cond(cond_embs, force_mask=True),
                                     self.mask_cond(cond_embs, force_mask=False)], dim=0)
        input_cond_padding_mask = torch.cat([cond_padding_mask, cond_padding_mask], dim=0)
        input_motion_padding_mask = torch.cat([motion_padding_mask, motion_padding_mask], dim=0)

        output_logits = self.trans_forward(input_motion_ids, input_cond_embs, input_cond_padding_mask, input_motion_padding_mask)
        aux_logits, logits = output_logits.chunk(2, dim=0)

        scaled_logits = aux_logits + (logits - aux_logits) * cond_scale
        return scaled_logits

    @torch.no_grad()
    @eval_decorator
    def generate(self,
                 conds,
                 m_lens,
                 timesteps: int,
                 cond_scale: int,
                 temperature=1,
                 topk_filter_thres=0.9,
                 gsample=False,
                 force_mask=False,
                #  scales=[8, 4, 2, 1],
                 ):
        # print(self.cfg.vq.num_quantizers)
        # assert len(timesteps) >= len(cond_scales) == self.cfg.vq.num_quantizers

        device = next(self.parameters()).device
        seq_len = max(m_lens)
        # batch_size = len(m_lens)
        non_padding_mask = []
        lengths_div = []
        new_mlens = torch.zeros_like(m_lens)
        for scale in self.scales:
            non_padding_mask.append(
                lengths_to_mask((m_lens//scale).long(), int(self.full_length//scale))
            )
            lengths_div.append(int(self.full_length//scale))
            new_mlens += (m_lens // scale).long()

        non_padding_mask = torch.cat(non_padding_mask, dim=1)
        padding_mask = ~non_padding_mask

        with torch.no_grad():
            cond_embs, cond_att_mask = self.encode_text(conds)
            cond_padding_mask = (cond_att_mask==0)


        # padding_mask = ~lengths_to_mask(m_lens, seq_len)
        # print(padding_mask.shape, )

        # Start from all tokens being masked
        ids = torch.where(padding_mask, self.pad_id, self.mask_id)
        scores = torch.where(padding_mask, 1e5, 0.)
        starting_temperature = temperature

        for timestep, steps_until_x0 in zip(torch.linspace(0, 1, timesteps, device=device), reversed(range(timesteps))):
            # 0 < timestep < 1
            rand_mask_prob = self.noise_schedule(timestep)  # Tensor

            '''
            Maskout, and cope with variable length
            '''
            # fix: the ratio regarding lengths, instead of seq_len
            num_token_masked = torch.round(rand_mask_prob * new_mlens).clamp(min=1)  # (b, )

            # select num_token_masked tokens with lowest scores to be masked
            sorted_indices = scores.argsort(
                dim=1)  # (b, k), sorted_indices[i, j] = the index of j-th lowest element in scores on dim=1
            ranks = sorted_indices.argsort(dim=1)  # (b, k), rank[i, j] = the rank (0: lowest) of scores[i, j] on dim=1
            is_mask = (ranks < num_token_masked.unsqueeze(-1))
            ids = torch.where(is_mask, self.mask_id, ids)

            '''
            Preparing input
            '''
            # (b, num_token, seqlen)
            logits = self.forward_with_cond_scale(ids, cond_embs, cond_padding_mask,
                                                  motion_padding_mask=padding_mask,
                                                  cond_scale=cond_scale,
                                                  force_mask=force_mask)
            

            logits = logits.permute(0, 2, 1)  # (b, seqlen, ntoken)
            # print(logits.shape, self.cfg.num_tokens)
            # clean low prob token
            filtered_logits = top_k(logits, topk_filter_thres, dim=-1)

            '''
            Update ids
            '''
            # if force_mask:
            temperature = starting_temperature
            # else:
            # temperature = starting_temperature * (steps_until_x0 / timesteps)
            # temperature = max(temperature, 1e-4)
            # print(filtered_logits.shape)
            # temperature is annealed, gradually reducing temperature as well as randomness
            if gsample:  # use gumbel_softmax sampling
                # print("1111")
                pred_ids = gumbel_sample(filtered_logits, temperature=temperature, dim=-1)  # (b, seqlen)
            else:  # use multinomial sampling
                # print("2222")
                probs = F.softmax(filtered_logits / temperature, dim=-1)  # (b, seqlen, ntoken)
                # print(temperature, starting_temperature, steps_until_x0, timesteps)
                # print(probs / temperature)
                pred_ids = Categorical(probs).sample()  # (b, seqlen)

            # print(pred_ids.max(), pred_ids.min())
            # if pred_ids.
            ids = torch.where(is_mask, pred_ids, ids)

            '''
            Updating scores
            '''
            probs_without_temperature = logits.softmax(dim=-1)  # (b, seqlen, ntoken)
            scores = probs_without_temperature.gather(2, pred_ids.unsqueeze(dim=-1))  # (b, seqlen, 1)
            scores = scores.squeeze(-1)  # (b, seqlen)

            # We do not want to re-mask the previously kept tokens, or pad tokens
            scores = scores.masked_fill(~is_mask, 1e5)

        ids = torch.where(padding_mask, -1, ids)
        return_list = []
        start = 0
        for length in lengths_div:
            return_list.append(ids[..., start:start+int(length)])
            start += length
        # print("Final", ids.max(), ids.min())
        return return_list
    

class MoMaskPlus(nn.Module):
    def __init__(self, code_dim, latent_dim=256, ff_size=1024, num_layers=8,
                 num_heads=4, dropout=0.1, text_dim=512, cond_drop_prob=0.1,
                 device=None, cfg=None, full_length=80, scales=[8, 4, 2, 1]):
        super(MoMaskPlus, self).__init__()
        print(f'latent_dim: {latent_dim}, ff_size: {ff_size}, nlayers: {num_layers}, nheads: {num_heads}, dropout: {dropout}')

        self.code_dim = code_dim
        self.latent_dim = latent_dim
        self.text_dim = text_dim
        self.dropout = dropout
        self.cfg = cfg
        self.device = device
        self.full_length = full_length
        self.scales = scales
        self.patch_sizes = [int(full_length // scale) for scale in self.scales]
        self.cond_drop_prob = cond_drop_prob

        init_std = math.sqrt(1 / self.latent_dim / 3)


        '''
        Preparing Networks
        '''
        self.input_process = InputProcess(self.code_dim, self.latent_dim)
        self.position_enc = PositionalEncoding(self.latent_dim, self.dropout)

        if self.cfg.model.fuse_mode == 'in_context':
            seqTransEncoderLayer = nn.TransformerEncoderLayer(d_model=self.latent_dim,
                                                            nhead=num_heads,
                                                            dim_feedforward=ff_size,
                                                            dropout=dropout,
                                                            activation='gelu')

            self.seqTransEncoder = nn.TransformerEncoder(seqTransEncoderLayer,
                                                        num_layers=num_layers)
            
        elif self.cfg.model.fuse_mode == 'cross_attention':
            seqTransEncoderLayer = nn.TransformerDecoderLayer(d_model=self.latent_dim,
                                                              nhead=num_heads,
                                                              dim_feedforward=ff_size,
                                                              dropout=dropout,
                                                              activation='gelu')
            self.seqTransEncoder = nn.TransformerDecoder(seqTransEncoderLayer,
                                                         num_layers=num_layers)
            
        self.lvl_embed = nn.Embedding(len(self.patch_sizes), self.latent_dim)
        nn.init.trunc_normal_(self.lvl_embed.weight.data, mean=0, std=init_std)
        # input_patch_size = self.patch_sizes

        self.cond_emb = nn.Linear(self.text_dim, self.latent_dim)

        _num_tokens = cfg.vq.nb_code + 2  # two dummy tokens, one for masking, one for padding
        self.mask_id = cfg.vq.nb_code
        self.pad_id = cfg.vq.nb_code + 1

        d = torch.cat([torch.full((ps,), i) for i, ps in enumerate(self.patch_sizes)]) #[1, 2, 2, 3, 3, 3, 3, 4, ...,]
        self.register_buffer('lvl_1L', d.contiguous())

        self.output_process = OutputProcess_Bert(out_feats=cfg.vq.nb_code, latent_dim=latent_dim)

        self.token_emb = nn.Embedding(_num_tokens, self.code_dim)

        self.apply(self.__init_weights)

        '''
        Preparing frozen weights
        '''

        if 't5' in cfg.text_embedder.version:
            self.text_emb = T5TextEncoder(
                device, 
            #  use_text_preprocessing,
                local_files_only=False, 
                from_pretrained=cfg.text_embedder.version, 
                model_max_length=cfg.data.max_text_length
            )
        else:
            self.text_emb = CLIPTextEncoder(
                device,
                clip_version=cfg.text_embedder.version
            )

        self.noise_schedule = cosine_schedule


    def __init_weights(self, module):
        if isinstance(module, (nn.Linear, nn.Embedding)):
            module.weight.data.normal_(mean=0.0, std=0.02)
            if isinstance(module, nn.Linear) and module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)


    # def encode_text(self, raw_text):
    #     device = next(self.parameters()).device
    #     text = clip.tokenize(raw_text, truncate=True).to(device)
    #     feat_clip_text = self.clip_model.encode_text(text).float()
    #     return feat_clip_text

    def encode_text(self, raw_text):
        text_embedding, mask = self.text_emb.get_text_embeddings(raw_text)
        return text_embedding, mask
    
    def sinusoidal_encoding(self, t):
        """
        Compute sinusoidal positional encoding for a batch of timesteps t.
        Args:
            t (Tensor): Shape (B, L), representing the timestep indices.
            d_model (int): Embedding dimension.

        Returns:
            Tensor of shape (B, L, D).
        """
        div_term = torch.exp(torch.arange(0, self.latent_dim, 2, dtype=torch.float32, device=t.device) * (-math.log(10000.0) / self.latent_dim))
        
        pe = torch.zeros(*t.shape, self.latent_dim, device=t.device)  # (B, L, D)
        pe[..., 0::2] = torch.sin(t.unsqueeze(-1) * div_term)  # Apply sin to even indices
        pe[..., 1::2] = torch.cos(t.unsqueeze(-1) * div_term)  # Apply cos to odd indices
        
        return pe
    
    def get_pe_from_mlens(self, mlens, max_len):
        B = len(mlens)
        t = torch.arange(max_len, device=mlens.device).unsqueeze(0).expand(B, max_len) # [0, 1, 2, 3,..., max_len]
        T = mlens.unsqueeze(1).expand(B, max_len) # [12, 12, 12, 12, 12, ..., 12]
        t_progress = ((T - t - 1) / (T - 1 + 1e-4)) * 80 # [11/11, 10/11, 9/11, ..., 0/11] * 80
        torch.clamp_min_(t_progress, 0.)
        return self.sinusoidal_encoding(t_progress)


    def mask_cond(self, cond, force_mask=False):
        bs, _, _ =  cond.shape
        if force_mask:
            return torch.zeros_like(cond)
        elif self.training and self.cond_drop_prob > 0.:
            mask = torch.bernoulli(torch.ones(bs, device=cond.device) * self.cond_drop_prob).view(bs, 1, 1)
            return cond * (1. - mask)
        else:
            return cond

    def trans_forward(self, motion_ids, cond, toa_pe, cond_padding_mask, motion_padding_mask):
        '''
        :param motion_ids: (b, seqlen)
        :cond_padding_mask: (b, seqlen), all pad positions are TRUE else FALSE
        :motion_padding_mask: (b, t_seqlen), all pad positions are TRUE else FALSE
        :param cond: (b, t_seqlen, embed_dim) for text
        :param force_mask: boolean
        :return:
            -logits: (b, num_token, seqlen)
        '''
        b, t_seqlen, _ = cond.shape
        # cond = self.mask_cond(cond, force_mask=force_mask)

        # print(motion_ids.shape)
        x = self.token_emb(motion_ids)
        # print(x.shape)
        # (b, seqlen, d) -> (seqlen, b, latent_dim)
        x = self.input_process(x)

        cond = self.cond_emb(cond).permute(1, 0, 2) #(t, b, latent_dim)

        x = self.position_enc(x)
        cond = self.position_enc(cond)

        if self.cfg.model.use_toa_pe:
            x = x + toa_pe.permute(1, 0, 2)

        if self.cfg.model.use_lvl_pe:
            x = x + self.lvl_embed(self.lvl_1L).unsqueeze(1) 

        if self.cfg.model.fuse_mode == 'in_context':
            xseq = torch.cat([cond, x], dim=0) #(seqlen+t_seqlen, b, latent_dim)
            padding_mask = torch.cat([cond_padding_mask, motion_padding_mask], dim=1).bool() #(b, seqlen+t_seqlen)
            output = self.seqTransEncoder(xseq, src_key_padding_mask=padding_mask)[t_seqlen:] #(seqlen, b, e)
            logits = self.output_process(output) #(seqlen, b, e) -> (b, ntoken, seqlen)
        elif self.cfg.model.fuse_mode == 'cross_attention':
            output = self.seqTransEncoder(x, 
                                          cond, 
                                          tgt_key_padding_mask=motion_padding_mask, 
                                          memory_key_padding_mask=cond_padding_mask,
                                          )
            logits = self.output_process(output) #(seqlen, b, e) -> (b, ntoken, seqlen)
        return logits

    def forward(self, id_list, y, m_lens):
        '''
        :param ids: (b, n)
        :param y: raw text for cond_mode=text, (b, ) for cond_mode=action
        :m_lens: (b,)
        :return:
        '''

        # ids = []
        non_pad_mask = []
        ids = []
        time_to_arrival_pe = []
        assert self.full_length == id_list[-1].shape[1]
        for scale, ele in zip(self.scales, id_list):
            ds_mlens = (m_lens // scale).long() 
            ds_non_pad_mask = lengths_to_mask(ds_mlens, ele.shape[1])
            non_pad_mask.append(ds_non_pad_mask)
            ids.append(ele)
            time_to_arrival_pe.append(self.get_pe_from_mlens(ds_mlens, ele.shape[1]))

        
        ids = torch.cat(ids, dim=1)
        non_pad_mask = torch.cat(non_pad_mask, dim=1)

        assert ids.shape[:2] == non_pad_mask.shape[:2]

        bs, ntokens = ids.shape
        time_to_arrival_pe = torch.cat(time_to_arrival_pe, dim=1)
        device = ids.device

        # Positions that are PADDED are ALL FALSE
        # non_pad_mask = lengths_to_mask(m_lens, ntokens) #(b, n)
        ids = torch.where(non_pad_mask, ids, self.pad_id)

        with torch.no_grad():
            cond_embs, cond_att_mask = self.encode_text(y)
            cond_padding_mask = (cond_att_mask==0)

        '''
        Prepare mask
        '''
        rand_time = uniform((bs,), device=device)
        rand_mask_probs = self.noise_schedule(rand_time)
        num_token_masked = (ntokens * rand_mask_probs).round().clamp(min=1)

        batch_randperm = torch.rand((bs, ntokens), device=device).argsort(dim=-1)
        # Positions to be MASKED are ALL TRUE
        mask = batch_randperm < num_token_masked.unsqueeze(-1)

        # Positions to be MASKED must also be NON-PADDED
        mask &= non_pad_mask

        # Note this is our training target, not input
        labels = torch.where(mask, ids, self.mask_id)

        x_ids = ids.clone()

        # Further Apply Bert Masking Scheme
        # Step 1: 10% replace with an incorrect token
        mask_rid = get_mask_subset_prob(mask, 0.1)
        rand_id = torch.randint_like(x_ids, high=self.cfg.vq.nb_code)
        x_ids = torch.where(mask_rid, rand_id, x_ids)
        # Step 2: 90% x 10% replace with correct token, and 90% x 88% replace with mask token
        mask_mid = get_mask_subset_prob(mask & ~mask_rid, 0.88)

        # mask_mid = mask

        x_ids = torch.where(mask_mid, self.mask_id, x_ids)

        if self.training and self.cfg.training.pert_prob > 0.:
            mask_rid = get_mask_subset_prob(~mask & non_pad_mask, self.cfg.training.pert_prob)
            rand_tokens = torch.randint_like(x_ids, high=self.cfg.vq.nb_code)
            x_ids = torch.where(mask_rid, rand_tokens, x_ids)

        cond_embs = self.mask_cond(cond_embs)

        logits = self.trans_forward(x_ids, cond_embs, time_to_arrival_pe, cond_padding_mask, ~non_pad_mask)
        ce_loss, pred_id, acc = cal_performance(logits, labels, ignore_index=self.mask_id)

        return ce_loss, pred_id, acc

    def forward_with_cond_scale(self,
                                motion_ids,
                                cond_embs,
                                time_to_arrival_pe,
                                cond_padding_mask,
                                motion_padding_mask,
                                cond_scale=3):
        # bs = motion_ids.shape[0]
        # if cond_scale == 1:

        input_motion_ids = torch.cat([motion_ids, motion_ids], dim=0)
        input_cond_embs = torch.cat([self.mask_cond(cond_embs, force_mask=True),
                                     self.mask_cond(cond_embs, force_mask=False)], dim=0)
        input_cond_padding_mask = torch.cat([cond_padding_mask, cond_padding_mask], dim=0)
        input_motion_padding_mask = torch.cat([motion_padding_mask, motion_padding_mask], dim=0)
        input_toa_pe = torch.cat([time_to_arrival_pe, time_to_arrival_pe], dim=0)

        output_logits = self.trans_forward(input_motion_ids, 
                                           input_cond_embs, 
                                           input_toa_pe,
                                           input_cond_padding_mask, 
                                           input_motion_padding_mask)
        aux_logits, logits = output_logits.chunk(2, dim=0)

        scaled_logits = aux_logits + (logits - aux_logits) * cond_scale
        return scaled_logits

    @torch.no_grad()
    @eval_decorator
    def generate(self,
                 conds,
                 m_lens,
                 timesteps: int,
                 cond_scale: int,
                 temperature=1,
                 topk_filter_thres=0.9,
                 gsample=False,
                #  scales=[8, 4, 2, 1],
                 ):
        # print(self.cfg.vq.num_quantizers)
        # assert len(timesteps) >= len(cond_scales) == self.cfg.vq.num_quantizers

        device = next(self.parameters()).device
        seq_len = max(m_lens)
        # batch_size = len(m_lens)
        non_padding_mask = []
        lengths_div = []
        new_mlens = torch.zeros_like(m_lens)
        time_to_arrival_pe = []
        for scale in self.scales:
            non_padding_mask.append(
                lengths_to_mask((m_lens//scale).long(), int(self.full_length//scale))
            )
            lengths_div.append(int(self.full_length//scale))
            new_mlens += m_lens // scale
            time_to_arrival_pe.append(self.get_pe_from_mlens((m_lens//scale).long(), int(self.full_length//scale)))

        non_padding_mask = torch.cat(non_padding_mask, dim=1)
        padding_mask = ~non_padding_mask
        time_to_arrival_pe = torch.cat(time_to_arrival_pe, dim=1)

        with torch.no_grad():
            cond_embs, cond_att_mask = self.encode_text(conds)
            cond_padding_mask = (cond_att_mask==0)


        # padding_mask = ~lengths_to_mask(m_lens, seq_len)
        # print(padding_mask.shape, )

        # Start from all tokens being masked
        ids = torch.where(padding_mask, self.pad_id, self.mask_id)
        scores = torch.where(padding_mask, 1e5, 0.)
        starting_temperature = temperature

        for timestep in torch.linspace(0, 1, timesteps, device=device):
            # 0 < timestep < 1
            rand_mask_prob = self.noise_schedule(timestep)  # Tensor

            '''
            Maskout, and cope with variable length
            '''
            # fix: the ratio regarding lengths, instead of seq_len
            num_token_masked = torch.round(rand_mask_prob * new_mlens).clamp(min=1)  # (b, )

            # select num_token_masked tokens with lowest scores to be masked
            sorted_indices = scores.argsort(dim=1)  # (b, k), sorted_indices[i, j] = the index of j-th lowest element in scores on dim=1
            ranks = sorted_indices.argsort(dim=1)  # (b, k), rank[i, j] = the rank (0: lowest) of scores[i, j] on dim=1
            is_mask = (ranks < num_token_masked.unsqueeze(-1))
            ids = torch.where(is_mask, self.mask_id, ids)

            '''
            Preparing input
            '''
            # (b, num_token, seqlen)
            logits = self.forward_with_cond_scale(ids, 
                                                  cond_embs, 
                                                  time_to_arrival_pe=time_to_arrival_pe,
                                                  cond_padding_mask=cond_padding_mask,
                                                  motion_padding_mask=padding_mask,
                                                  cond_scale=cond_scale)
            

            logits = logits.permute(0, 2, 1)  # (b, seqlen, ntoken)
            # print(logits.shape, self.cfg.num_tokens)
            # clean low prob token
            filtered_logits = top_k(logits, topk_filter_thres, dim=-1)

            '''
            Update ids
            '''
            # if force_mask:
            temperature = starting_temperature
            # else:
            # temperature = starting_temperature * (steps_until_x0 / timesteps)
            # temperature = max(temperature, 1e-4)
            # print(filtered_logits.shape)
            # temperature is annealed, gradually reducing temperature as well as randomness
            if gsample:  # use gumbel_softmax sampling
                # print("1111")
                pred_ids = gumbel_sample(filtered_logits, temperature=temperature, dim=-1)  # (b, seqlen)
            else:  # use multinomial sampling
                # print("2222")
                probs = F.softmax(filtered_logits / temperature, dim=-1)  # (b, seqlen, ntoken)
                # print(temperature, starting_temperature, steps_until_x0, timesteps)
                # print(probs / temperature)
                pred_ids = Categorical(probs).sample()  # (b, seqlen)

            # print(pred_ids.max(), pred_ids.min())
            # if pred_ids.
            ids = torch.where(is_mask, pred_ids, ids)

            '''
            Updating scores
            '''
            probs_without_temperature = logits.softmax(dim=-1)  # (b, seqlen, ntoken)
            scores = probs_without_temperature.gather(2, pred_ids.unsqueeze(dim=-1))  # (b, seqlen, 1)
            scores = scores.squeeze(-1)  # (b, seqlen)

            # We do not want to re-mask the previously kept tokens, or pad tokens
            scores = scores.masked_fill(~is_mask, 1e5)

        ids = torch.where(padding_mask, -1, ids)
        return_list = []
        start = 0
        for length in lengths_div:
            return_list.append(ids[..., start:start+length])
            start += length
        # print("Final", ids.max(), ids.min())
        return return_list


class ResidualTransformer(nn.Module):
    def __init__(self, code_dim, latent_dim=256, ff_size=1024, num_layers=8, cond_drop_prob=0.1,
                 num_heads=4, dropout=0.1, text_dim=512, shared_codebook=False, share_weight=False,
                 device=None, cfg=None, **kargs):
        super(ResidualTransformer, self).__init__()
        print(f'latent_dim: {latent_dim}, ff_size: {ff_size}, nlayers: {num_layers}, nheads: {num_heads}, dropout: {dropout}')

        # assert shared_codebook == True, "Only support shared codebook right now!"

        self.code_dim = code_dim
        self.latent_dim = latent_dim
        self.text_dim = text_dim
        self.dropout = dropout
        self.cfg = cfg

        self.device = device
        # self.cond_drop_prob = cond_drop_prob

        self.cond_drop_prob = cond_drop_prob

        '''
        Preparing Networks
        '''
        self.input_process = InputProcess(self.code_dim, self.latent_dim)
        self.position_enc = PositionalEncoding(self.latent_dim, self.dropout)

        seqTransEncoderLayer = nn.TransformerEncoderLayer(d_model=self.latent_dim,
                                                          nhead=num_heads,
                                                          dim_feedforward=ff_size,
                                                          dropout=dropout,
                                                          activation='gelu')

        self.seqTransEncoder = nn.TransformerEncoder(seqTransEncoderLayer,
                                                     num_layers=num_layers)

        self.encode_quant = partial(F.one_hot, num_classes=self.cfg.vq.num_quantizers)
        # self.encode_action = partial(F.one_hot, num_classes=self.num_actions)

        self.quant_emb = nn.Linear(self.cfg.vq.num_quantizers, self.latent_dim)

        self.cond_emb = nn.Linear(self.text_dim, self.latent_dim)

        _num_tokens = cfg.vq.nb_code + 1  # one dummy tokens for padding
        self.pad_id = cfg.vq.nb_code

        # self.output_process = OutputProcess_Bert(out_feats=cfg.num_tokens, latent_dim=latent_dim)
        self.output_process = OutputProcess(out_feats=code_dim, latent_dim=latent_dim)

        if shared_codebook:
            token_embed = nn.Parameter(torch.normal(mean=0, std=0.02, size=(_num_tokens, code_dim)))
            self.token_embed_weight = token_embed.expand(cfg.vq.num_quantizers-1, _num_tokens, code_dim)
            if share_weight:
                self.output_proj_weight = self.token_embed_weight
                self.output_proj_bias = None
            else:
                output_proj = nn.Parameter(torch.normal(mean=0, std=0.02, size=(_num_tokens, code_dim)))
                output_bias = nn.Parameter(torch.zeros(size=(_num_tokens,)))
                # self.output_proj_bias = 0
                self.output_proj_weight = output_proj.expand(cfg.vq.num_quantizers-1, _num_tokens, code_dim)
                self.output_proj_bias = output_bias.expand(cfg.vq.num_quantizers-1, _num_tokens)

        else:
            if share_weight:
                self.embed_proj_shared_weight = nn.Parameter(torch.normal(mean=0, std=0.02, size=(cfg.vq.num_quantizers - 2, _num_tokens, code_dim)))
                self.token_embed_weight_ = nn.Parameter(torch.normal(mean=0, std=0.02, size=(1, _num_tokens, code_dim)))
                self.output_proj_weight_ = nn.Parameter(torch.normal(mean=0, std=0.02, size=(1, _num_tokens, code_dim)))
                self.output_proj_bias = None
                self.registered = False
            else:
                output_proj_weight = torch.normal(mean=0, std=0.02,
                                                  size=(cfg.vq.num_quantizers - 1, _num_tokens, code_dim))

                self.output_proj_weight = nn.Parameter(output_proj_weight)
                self.output_proj_bias = nn.Parameter(torch.zeros(size=(cfg.vq.num_quantizers, _num_tokens)))
                token_embed_weight = torch.normal(mean=0, std=0.02,
                                                  size=(cfg.vq.num_quantizers - 1, _num_tokens, code_dim))
                self.token_embed_weight = nn.Parameter(token_embed_weight)

        self.apply(self.__init_weights)
        self.shared_codebook = shared_codebook
        self.share_weight = share_weight

        # if self.cond_mode == 'text':
        #     print('Loading CLIP...')
        #     self.clip_version = clip_version
        #     self.clip_model = self.load_and_freeze_clip(clip_version)
        self.text_emb = T5TextEncoder(
            device, 
        #  use_text_preprocessing,
            local_files_only=False, 
            from_pretrained=cfg.text_embedder.version, 
            model_max_length=cfg.data.max_text_length
        )

    # def

    def mask_cond(self, cond, force_mask=False):
        bs, _, _ =  cond.shape
        if force_mask:
            return torch.zeros_like(cond)
        elif self.training and self.cond_drop_prob > 0.:
            mask = torch.bernoulli(torch.ones(bs, device=cond.device) * self.cond_drop_prob).view(bs, 1, 1)
            return cond * (1. - mask)
        else:
            return cond

    def __init_weights(self, module):
        if isinstance(module, (nn.Linear, nn.Embedding)):
            module.weight.data.normal_(mean=0.0, std=0.02)
            if isinstance(module, nn.Linear) and module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)

    def parameters_wo_clip(self):
        return [p for name, p in self.named_parameters() if not name.startswith('clip_model.')]

    def load_and_freeze_clip(self, clip_version):
        clip_model, clip_preprocess = clip.load(clip_version, device='cpu',
                                                jit=False)  # Must set jit=False for training
        # Added support for cpu
        if str(self.device) != "cpu":
            clip.model.convert_weights(
                clip_model)  # Actually this line is unnecessary since clip by default already on float16
            # Date 0707: It's necessary, only unecessary when load directly to gpu. Disable if need to run on cpu

        # Freeze CLIP weights
        clip_model.eval()
        for p in clip_model.parameters():
            p.requires_grad = False

        return clip_model

    # def encode_text(self, raw_text):
    #     device = next(self.parameters()).device
    #     text = clip.tokenize(raw_text, truncate=True).to(device)
    #     feat_clip_text = self.clip_model.encode_text(text).float()
    #     return feat_clip_text

    def encode_text(self, raw_text):
        text_embedding, mask = self.text_emb.get_text_embeddings(raw_text)
        return text_embedding, mask

    def q_schedule(self, bs, low, high):
        noise = uniform((bs,), device=self.device)
        # schedule = 1 - cosine_schedule(noise)
        # return torch.round(schedule * (high - low)) + low
        return torch.round(noise * (high - low)) + low

    def process_embed_proj_weight(self):
        if self.share_weight and (not self.shared_codebook):
            # if not self.registered:
            self.output_proj_weight = torch.cat([self.embed_proj_shared_weight, self.output_proj_weight_], dim=0)
            self.token_embed_weight = torch.cat([self.token_embed_weight_, self.embed_proj_shared_weight], dim=0)
            # self.registered = True

    def output_project(self, logits, qids):
        '''
        :logits: (bs, code_dim, seqlen)
        :qids: (bs)

        :return:
            -logits (bs, ntoken, seqlen)
        '''
        # (num_qlayers-1, num_token, code_dim) -> (bs, ntoken, code_dim)
        output_proj_weight = self.output_proj_weight[qids]
        # (num_qlayers, ntoken) -> (bs, ntoken)
        output_proj_bias = None if self.output_proj_bias is None else self.output_proj_bias[qids]

        output = torch.einsum('bnc, bcs->bns', output_proj_weight, logits)
        if output_proj_bias is not None:
            output += output + output_proj_bias.unsqueeze(-1)
        return output

    def trans_forward(self, motion_codes, qids, cond, cond_padding_mask, motion_padding_mask, force_mask=False):
        '''
        :param motion_codes: (b, seqlen, d)
        :cond_padding_mask: (b, seqlen), all pad positions are TRUE else FALSE
        :motion_padding_mask: (b, t_seqlen), all pad positions are TRUE else FALSE
        :param qids: (b), quantizer layer ids
        :param cond: (b, embed_dim) for text, (b, num_actions) for action
        :return:
            -logits: (b, num_token, seqlen)
        '''
        b, t_seqlen, _ = cond.shape
        cond = self.mask_cond(cond, force_mask=force_mask)

        # (b, seqlen, d) -> (seqlen, b, latent_dim)
        x = self.input_process(motion_codes)

        # (b, num_quantizer)
        q_onehot = self.encode_quant(qids).float().to(x.device)

        q_emb = self.quant_emb(q_onehot).unsqueeze(0)  # (1, b, latent_dim)
        cond = self.cond_emb(cond).permute(1, 0, 2)  # (tseqlen, b, latent_dim)

        x = self.position_enc(x)
        cond = self.position_enc(cond)

        xseq = torch.cat([cond, q_emb, x], dim=0)  # (seqlen+tseqlen+1, b, latent_dim)

        padding_mask = torch.cat([cond_padding_mask, torch.zeros_like(motion_padding_mask[:, 0:1]), motion_padding_mask], dim=1).bool()  # (b, seqlen+2)
        output = self.seqTransEncoder(xseq, src_key_padding_mask=padding_mask)[t_seqlen+1:]  # (seqlen, b, e)
        logits = self.output_process(output)
        return logits

    def forward_with_cond_scale(self,
                                motion_codes,
                                q_id,
                                cond_embs,
                                cond_padding_mask,
                                motion_padding_mask,
                                cond_scale=3,
                                force_mask=False):
        bs = motion_codes.shape[0]
        # if cond_scale == 1:
        qids = torch.full((bs,), q_id, dtype=torch.long, device=motion_codes.device)
        if force_mask:
            logits = self.trans_forward(motion_codes, qids, cond_embs, cond_padding_mask, motion_padding_mask, force_mask=True)
            logits = self.output_project(logits, qids-1)
            return logits

        logits = self.trans_forward(motion_codes, qids, cond_embs, cond_padding_mask, motion_padding_mask)
        logits = self.output_project(logits, qids-1)
        if cond_scale == 1:
            return logits

        aux_logits = self.trans_forward(motion_codes, qids, cond_embs, cond_padding_mask, motion_padding_mask, force_mask=True)
        aux_logits = self.output_project(aux_logits, qids-1)

        scaled_logits = aux_logits + (logits - aux_logits) * cond_scale
        return scaled_logits

    def forward(self, all_indices, y, m_lens, active_q_layers=-1):
        '''
        :param all_indices: (b, n, q)
        :param y: raw text for cond_mode=text, (b, ) for cond_mode=action
        :m_lens: (b,)
        :return:
        '''
        active_q_layers=-1
        
        self.process_embed_proj_weight()

        bs, ntokens, num_quant_layers = all_indices.shape
        device = all_indices.device

        # Positions that are PADDED are ALL FALSE
        non_pad_mask = lengths_to_mask(m_lens, ntokens)  # (b, n)

        q_non_pad_mask = repeat(non_pad_mask, 'b n -> b n q', q=num_quant_layers)
        all_indices = torch.where(q_non_pad_mask, all_indices, self.pad_id) #(b, n, q)

        if active_q_layers == -1 or active_q_layers >= num_quant_layers:
            # randomly sample quantization layers to work on, [1, num_q)
            active_q_layers = q_schedule(bs, low=1, high=num_quant_layers, device=device)
        else:
            active_q_layers = q_schedule(bs, low=1, high=active_q_layers+1, device=device)

        # print(self.token_embed_weight.shape, all_indices.shape)
        token_embed = repeat(self.token_embed_weight, 'q c d-> b c d q', b=bs)
        gather_indices = repeat(all_indices[..., :-1], 'b n q -> b n d q', d=token_embed.shape[2])
        # print(token_embed.shape, gather_indices.shape)
        all_codes = token_embed.gather(1, gather_indices)  # (b, n, d, q-1)

        cumsum_codes = torch.cumsum(all_codes, dim=-1) #(b, n, d, q-1)

        active_indices = all_indices[torch.arange(bs), :, active_q_layers]  # (b, n)
        history_sum = cumsum_codes[torch.arange(bs), :, :, active_q_layers - 1]

        force_mask = False

        with torch.no_grad():
            cond_vector, cond_att_mask = self.encode_text(y)
            cond_padding_mask = (cond_att_mask==0)

        logits = self.trans_forward(history_sum, active_q_layers, cond_vector, cond_padding_mask, ~non_pad_mask, force_mask)
        logits = self.output_project(logits, active_q_layers-1)
        ce_loss, pred_id, acc = cal_performance(logits, active_indices, ignore_index=self.pad_id)

        return ce_loss, pred_id, acc

    @torch.no_grad()
    @eval_decorator
    def generate(self,
                 motion_ids,
                 conds,
                 m_lens,
                 temperature=1,
                 topk_filter_thres=0.9,
                 cond_scale=2,
                 num_res_layers=-1, # If it's -1, use all.
                 ):

        # print(self.cfg.vq.num_quantizers)
        # assert len(timesteps) >= len(cond_scales) == self.cfg.vq.num_quantizers
        self.process_embed_proj_weight()

        # device = next(self.parameters()).device
        seq_len = motion_ids.shape[1]
        batch_size = len(conds)

        with torch.no_grad():
            cond_embs, cond_att_mask = self.encode_text(conds)
            cond_padding_mask = (cond_att_mask==0)

        # token_embed = repeat(self.token_embed_weight, 'c d -> b c d', b=batch_size)
        # gathered_ids = repeat(motion_ids, 'b n -> b n d', d=token_embed.shape[-1])
        # history_sum = token_embed.gather(1, gathered_ids)

        # print(pa, seq_len)
        padding_mask = ~lengths_to_mask(m_lens, seq_len)
        # print(padding_mask.shape, motion_ids.shape)
        motion_ids = torch.where(padding_mask, self.pad_id, motion_ids)
        all_indices = [motion_ids]
        history_sum = 0
        num_quant_layers = (
            self.cfg.vq.num_quantizers
            if (num_res_layers == -1 or num_res_layers >= self.cfg.vq.num_quantizers)
            else num_res_layers + 1
        )

        for i in range(1, num_quant_layers):
            # print(f"--> Working on {i}-th quantizer")
            # Start from all tokens being masked
            # qids = torch.full((batch_size,), i, dtype=torch.long, device=motion_ids.device)
            token_embed = self.token_embed_weight[i-1]
            token_embed = repeat(token_embed, 'c d -> b c d', b=batch_size)
            gathered_ids = repeat(motion_ids, 'b n -> b n d', d=token_embed.shape[-1])
            history_sum += token_embed.gather(1, gathered_ids)

            logits = self.forward_with_cond_scale(history_sum, i, cond_embs, cond_padding_mask, padding_mask, cond_scale=cond_scale)
            # logits = self.trans_forward(history_sum, qids, cond_vector, padding_mask)

            logits = logits.permute(0, 2, 1)  # (b, seqlen, ntoken)
            # clean low prob token
            filtered_logits = top_k(logits, topk_filter_thres, dim=-1)

            pred_ids = gumbel_sample(filtered_logits, temperature=temperature, dim=-1)  # (b, seqlen)

            # probs = F.softmax(filtered_logits, dim=-1)  # (b, seqlen, ntoken)
            # # print(temperature, starting_temperature, steps_until_x0, timesteps)
            # # print(probs / temperature)
            # pred_ids = Categorical(probs / temperature).sample()  # (b, seqlen)

            ids = torch.where(padding_mask, self.pad_id, pred_ids)

            motion_ids = ids
            all_indices.append(ids)

        all_indices = torch.stack(all_indices, dim=-1)
        # padding_mask = repeat(padding_mask, 'b n -> b n q', q=all_indices.shape[-1])
        # all_indices = torch.where(padding_mask, -1, all_indices)
        all_indices = torch.where(all_indices==self.pad_id, -1, all_indices)
        # all_indices = all_indices.masked_fill()
        return all_indices

    @torch.no_grad()
    @eval_decorator
    def edit(self,
            motion_ids,
            conds,
            m_lens,
            temperature=1,
            topk_filter_thres=0.9,
            cond_scale=2
            ):

        # print(self.cfg.vq.num_quantizers)
        # assert len(timesteps) >= len(cond_scales) == self.cfg.vq.num_quantizers
        self.process_embed_proj_weight()

        device = next(self.parameters()).device
        seq_len = motion_ids.shape[1]
        batch_size = len(conds)

        with torch.no_grad():
            cond_vector = self.encode_text(conds)

        # token_embed = repeat(self.token_embed_weight, 'c d -> b c d', b=batch_size)
        # gathered_ids = repeat(motion_ids, 'b n -> b n d', d=token_embed.shape[-1])
        # history_sum = token_embed.gather(1, gathered_ids)

        # print(pa, seq_len)
        padding_mask = ~lengths_to_mask(m_lens, seq_len)
        # print(padding_mask.shape, motion_ids.shape)
        motion_ids = torch.where(padding_mask, self.pad_id, motion_ids)
        all_indices = [motion_ids]
        history_sum = 0

        for i in range(1, self.cfg.vq.num_quantizers):
            # print(f"--> Working on {i}-th quantizer")
            # Start from all tokens being masked
            # qids = torch.full((batch_size,), i, dtype=torch.long, device=motion_ids.device)
            token_embed = self.token_embed_weight[i-1]
            token_embed = repeat(token_embed, 'c d -> b c d', b=batch_size)
            gathered_ids = repeat(motion_ids, 'b n -> b n d', d=token_embed.shape[-1])
            history_sum += token_embed.gather(1, gathered_ids)

            logits = self.forward_with_cond_scale(history_sum, i, cond_vector, padding_mask, cond_scale=cond_scale)
            # logits = self.trans_forward(history_sum, qids, cond_vector, padding_mask)

            logits = logits.permute(0, 2, 1)  # (b, seqlen, ntoken)
            # clean low prob token
            filtered_logits = top_k(logits, topk_filter_thres, dim=-1)

            pred_ids = gumbel_sample(filtered_logits, temperature=temperature, dim=-1)  # (b, seqlen)

            # probs = F.softmax(filtered_logits, dim=-1)  # (b, seqlen, ntoken)
            # # print(temperature, starting_temperature, steps_until_x0, timesteps)
            # # print(probs / temperature)
            # pred_ids = Categorical(probs / temperature).sample()  # (b, seqlen)

            ids = torch.where(padding_mask, self.pad_id, pred_ids)

            motion_ids = ids
            all_indices.append(ids)

        all_indices = torch.stack(all_indices, dim=-1)
        # padding_mask = repeat(padding_mask, 'b n -> b n q', q=all_indices.shape[-1])
        # all_indices = torch.where(padding_mask, -1, all_indices)
        all_indices = torch.where(all_indices==self.pad_id, -1, all_indices)
        # all_indices = all_indices.masked_fill()
        return all_indices


class BottomMaskTransformer(nn.Module):
    def __init__(self, code_dim, latent_dim=256, ff_size=1024, num_layers=8,
                 num_heads=4, dropout=0.1, text_dim=512, cond_drop_prob=0.1,
                 device=None, cfg=None, **kargs):
        super(BottomMaskTransformer, self).__init__()
        print(f'latent_dim: {latent_dim}, ff_size: {ff_size}, nlayers: {num_layers}, nheads: {num_heads}, dropout: {dropout}')

        self.code_dim = code_dim
        self.latent_dim = latent_dim
        self.text_dim = text_dim
        self.dropout = dropout
        self.cfg = cfg
        self.device = device

        self.cond_drop_prob = cond_drop_prob


        '''
        Preparing Networks
        '''
        self.input_process_t = InputProcess(self.code_dim, self.latent_dim)
        self.input_process_b = InputProcess(self.code_dim, self.latent_dim)
        self.position_enc = PositionalEncoding(self.latent_dim, self.dropout)

        seqTransEncoderLayer = nn.TransformerEncoderLayer(d_model=self.latent_dim,
                                                          nhead=num_heads,
                                                          dim_feedforward=ff_size,
                                                          dropout=dropout,
                                                          activation='gelu')

        self.seqTransEncoder = nn.TransformerEncoder(seqTransEncoderLayer,
                                                     num_layers=num_layers)

        self.cond_emb = nn.Linear(self.text_dim, self.latent_dim)

        _num_tokens_b = cfg.vq.nb_code_b + 2  # two dummy tokens, one for masking, one for padding
        _num_tokens_t = cfg.vq.nb_code_t + 1
        self.mask_id = cfg.vq.nb_code_b
        self.pad_id_b = cfg.vq.nb_code_b + 1
        self.pad_id_t = cfg.vq.nb_code_t

        self.output_process = OutputProcess_Bert(out_feats=cfg.vq.nb_code_b, latent_dim=latent_dim)

        self.token_emb_t = nn.Embedding(_num_tokens_t, self.code_dim)
        self.token_emb_b = nn.Embedding(_num_tokens_b, self.code_dim)

        self.apply(self.__init_weights)

        '''
        Preparing frozen weights
        '''

        # if self.cond_mode == 'text':
        #     print('Loading CLIP...')
        #     self.clip_version = clip_version
        #     self.clip_model = self.load_and_freeze_clip(clip_version)

        self.text_emb = T5TextEncoder(
            device, 
        #  use_text_preprocessing,
            local_files_only=False, 
            from_pretrained=cfg.text_embedder.version, 
            model_max_length=cfg.data.max_text_length
        )

        self.noise_schedule = cosine_schedule


    def load_and_freeze_token_emb(self, codebook):
        '''
        :param codebook: (c, d)
        :return:
        '''
        assert self.training, 'Only necessary in training mode'
        c, d = codebook.shape
        self.token_emb.weight = nn.Parameter(torch.cat([codebook, torch.zeros(size=(2, d), device=codebook.device)], dim=0)) #add two dummy tokens, 0 vectors
        self.token_emb.requires_grad_(False)
        # self.token_emb.weight.requires_grad = False
        # self.token_emb_ready = True
        print("Token embedding initialized!")

    def __init_weights(self, module):
        if isinstance(module, (nn.Linear, nn.Embedding)):
            module.weight.data.normal_(mean=0.0, std=0.02)
            if isinstance(module, nn.Linear) and module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)

    def parameters_wo_clip(self):
        return [p for name, p in self.named_parameters() if not name.startswith('clip_model.')]

    def load_and_freeze_clip(self, clip_version):
        clip_model, clip_preprocess = clip.load(clip_version, device='cpu',
                                                jit=False)  # Must set jit=False for training
        # Added support for cpu
        if str(self.device) != "cpu":
            clip.model.convert_weights(
                clip_model)  # Actually this line is unnecessary since clip by default already on float16
            # Date 0707: It's necessary, only unecessary when load directly to gpu. Disable if need to run on cpu

        # Freeze CLIP weights
        clip_model.eval()
        for p in clip_model.parameters():
            p.requires_grad = False

        return clip_model

    # def encode_text(self, raw_text):
    #     device = next(self.parameters()).device
    #     text = clip.tokenize(raw_text, truncate=True).to(device)
    #     feat_clip_text = self.clip_model.encode_text(text).float()
    #     return feat_clip_text

    def encode_text(self, raw_text):
        text_embedding, mask = self.text_emb.get_text_embeddings(raw_text)
        return text_embedding, mask

    def mask_cond(self, cond, force_mask=False):
        bs, _, _ =  cond.shape
        if force_mask:
            return torch.zeros_like(cond)
        elif self.training and self.cond_drop_prob > 0.:
            mask = torch.bernoulli(torch.ones(bs, device=cond.device) * self.cond_drop_prob).view(bs, 1, 1)
            return cond * (1. - mask)
        else:
            return cond

    def trans_forward(self, motion_ids, top_motion_ids, cond, cond_padding_mask, top_motion_padding_mask, motion_padding_mask, force_mask=False):
        '''
        :param motion_ids: (b, seqlen)
        :cond_padding_mask: (b, seqlen), all pad positions are TRUE else FALSE
        :motion_padding_mask: (b, t_seqlen), all pad positions are TRUE else FALSE
        :param cond: (b, t_seqlen, embed_dim) for text
        :param force_mask: boolean
        :return:
            -logits: (b, num_token, seqlen)
        '''
        t_seqlen, topx_seqlen = cond.shape[1], top_motion_ids.shape[1]
        
        cond = self.mask_cond(cond, force_mask=force_mask)

        # print(motion_ids.shape)
        x = self.token_emb_b(motion_ids)
        # print(x.shape)
        # (b, seqlen, d) -> (seqlen, b, latent_dim)
        x = self.input_process_b(x)

        topx = self.token_emb_t(top_motion_ids)
        topx = self.input_process_t(topx)

        cond = self.cond_emb(cond).permute(1, 0, 2) #(1, b, latent_dim)

        x = self.position_enc(x)
        topx = self.position_enc(topx)
        cond = self.position_enc(cond)
        xseq = torch.cat([cond, topx, x], dim=0) #(seqlen+t_seqlen, b, latent_dim)

        # padding_mask = torch.cat([torch.zeros_like(padding_mask[:, 0:1]), padding_mask], dim=1) #(b, seqlen+1)
        padding_mask = torch.cat([cond_padding_mask, top_motion_padding_mask, motion_padding_mask], dim=1).bool() #(b, seqlen+t_seqlen)
        # print(xseq.shape, padding_mask.shape)

        # print(padding_mask.shape, xseq.shape)

        output = self.seqTransEncoder(xseq, src_key_padding_mask=padding_mask)[t_seqlen+topx_seqlen:] #(seqlen, b, e)
        logits = self.output_process(output) #(seqlen, b, e) -> (b, ntoken, seqlen)
        return logits

    def forward(self, ids, top_ids, y, m_lens):
        '''
        :param ids: (b, n)
        :param y: raw text for cond_mode=text, (b, ) for cond_mode=action
        :m_lens: (b,)
        :return:
        '''

        bs, ntokens = ids.shape
        device = ids.device

        # Positions that are PADDED are ALL FALSE
        non_pad_mask = lengths_to_mask(m_lens, ntokens) #(b, n)
        ids = torch.where(non_pad_mask, ids, self.pad_id_b)

        non_top_pad_mask = lengths_to_mask(m_lens//2, top_ids.shape[1])
        top_ids = torch.where(non_top_pad_mask, top_ids, self.pad_id_t)

        with torch.no_grad():
            cond_embs, cond_att_mask = self.encode_text(y)
            cond_padding_mask = (cond_att_mask==0)

        '''
        Prepare mask
        '''
        rand_time = uniform((bs,), device=device)
        rand_mask_probs = self.noise_schedule(rand_time)
        num_token_masked = (ntokens * rand_mask_probs).round().clamp(min=1)

        batch_randperm = torch.rand((bs, ntokens), device=device).argsort(dim=-1)
        # Positions to be MASKED are ALL TRUE
        mask = batch_randperm < num_token_masked.unsqueeze(-1)

        # Positions to be MASKED must also be NON-PADDED
        mask &= non_pad_mask

        # Note this is our training target, not input
        labels = torch.where(mask, ids, self.mask_id)

        x_ids = ids.clone()

        # Further Apply Bert Masking Scheme
        # Step 1: 10% replace with an incorrect token
        mask_rid = get_mask_subset_prob(mask, 0.1)
        rand_id = torch.randint_like(x_ids, high=self.cfg.vq.nb_code_b)
        x_ids = torch.where(mask_rid, rand_id, x_ids)
        # Step 2: 90% x 10% replace with correct token, and 90% x 88% replace with mask token
        mask_mid = get_mask_subset_prob(mask & ~mask_rid, 0.88)

        # mask_mid = mask

        x_ids = torch.where(mask_mid, self.mask_id, x_ids)

        logits = self.trans_forward(x_ids, top_ids, cond_embs, cond_padding_mask, ~non_top_pad_mask, ~non_pad_mask, force_mask=False)
        ce_loss, pred_id, acc = cal_performance(logits, labels, ignore_index=self.mask_id)

        return ce_loss, pred_id, acc

    def forward_with_cond_scale(self,
                                motion_ids,
                                top_motion_ids,
                                cond_embs,
                                cond_padding_mask,
                                top_motion_padding_mask,
                                motion_padding_mask,
                                cond_scale=3,
                                force_mask=False):
        # bs = motion_ids.shape[0]
        # if cond_scale == 1:
        if force_mask:
            return self.trans_forward(motion_ids, cond_embs, top_motion_ids, cond_padding_mask, top_motion_padding_mask, motion_padding_mask, force_mask=True)

        logits = self.trans_forward(motion_ids, cond_embs, top_motion_ids, cond_padding_mask, top_motion_padding_mask, motion_padding_mask,)
        if cond_scale == 1:
            return logits

        aux_logits = self.trans_forward(motion_ids, cond_embs, top_motion_ids, cond_padding_mask, top_motion_padding_mask, motion_padding_mask, force_mask=True)

        scaled_logits = aux_logits + (logits - aux_logits) * cond_scale
        return scaled_logits

    @torch.no_grad()
    @eval_decorator
    def generate(self,
                 conds,
                 top_motion_ids,
                 m_lens,
                 timesteps: int,
                 cond_scale: int,
                 temperature=1,
                 topk_filter_thres=0.9,
                 gsample=False,
                 force_mask=False
                 ):
        # print(self.cfg.vq.num_quantizers)
        # assert len(timesteps) >= len(cond_scales) == self.cfg.vq.num_quantizers

        device = next(self.parameters()).device
        # seq_len = max(m_lens)
        # batch_size = len(m_lens)
        batch_size, seq_len = top_motion_ids.shape[0], top_motion_ids.shape[1]*2

        with torch.no_grad():
            cond_embs, cond_att_mask = self.encode_text(conds)
            cond_padding_mask = (cond_att_mask==0)


        padding_mask = ~lengths_to_mask(m_lens, seq_len)

        top_padding_mask = ~lengths_to_mask(m_lens//2, top_motion_ids.shape[1])
        # print(padding_mask.shape, )
        top_motion_ids = torch.where(top_padding_mask, self.pad_id_t, top_motion_ids)

        # Start from all tokens being masked
        ids = torch.where(padding_mask, self.pad_id_b, self.mask_id)
        scores = torch.where(padding_mask, 1e5, 0.)
        starting_temperature = temperature

        for timestep, steps_until_x0 in zip(torch.linspace(0, 1, timesteps, device=device), reversed(range(timesteps))):
            # 0 < timestep < 1
            rand_mask_prob = self.noise_schedule(timestep)  # Tensor

            '''
            Maskout, and cope with variable length
            '''
            # fix: the ratio regarding lengths, instead of seq_len
            num_token_masked = torch.round(rand_mask_prob * m_lens).clamp(min=1)  # (b, )

            # select num_token_masked tokens with lowest scores to be masked
            sorted_indices = scores.argsort(
                dim=1)  # (b, k), sorted_indices[i, j] = the index of j-th lowest element in scores on dim=1
            ranks = sorted_indices.argsort(dim=1)  # (b, k), rank[i, j] = the rank (0: lowest) of scores[i, j] on dim=1
            is_mask = (ranks < num_token_masked.unsqueeze(-1))
            ids = torch.where(is_mask, self.mask_id, ids)

            '''
            Preparing input
            '''
            # (b, num_token, seqlen)
            logits = self.forward_with_cond_scale(ids, cond_embs, top_motion_ids, cond_padding_mask,
                                                  top_motion_padding_mask=top_padding_mask,
                                                  motion_padding_mask=padding_mask,
                                                  cond_scale=cond_scale,
                                                  force_mask=force_mask)
            

            logits = logits.permute(0, 2, 1)  # (b, seqlen, ntoken)
            # print(logits.shape, self.cfg.num_tokens)
            # clean low prob token
            filtered_logits = top_k(logits, topk_filter_thres, dim=-1)

            '''
            Update ids
            '''
            # if force_mask:
            temperature = starting_temperature
            # else:
            # temperature = starting_temperature * (steps_until_x0 / timesteps)
            # temperature = max(temperature, 1e-4)
            # print(filtered_logits.shape)
            # temperature is annealed, gradually reducing temperature as well as randomness
            if gsample:  # use gumbel_softmax sampling
                # print("1111")
                pred_ids = gumbel_sample(filtered_logits, temperature=temperature, dim=-1)  # (b, seqlen)
            else:  # use multinomial sampling
                # print("2222")
                probs = F.softmax(filtered_logits / temperature, dim=-1)  # (b, seqlen, ntoken)
                # print(temperature, starting_temperature, steps_until_x0, timesteps)
                # print(probs / temperature)
                pred_ids = Categorical(probs).sample()  # (b, seqlen)

            # print(pred_ids.max(), pred_ids.min())
            # if pred_ids.
            ids = torch.where(is_mask, pred_ids, ids)

            '''
            Updating scores
            '''
            probs_without_temperature = logits.softmax(dim=-1)  # (b, seqlen, ntoken)
            scores = probs_without_temperature.gather(2, pred_ids.unsqueeze(dim=-1))  # (b, seqlen, 1)
            scores = scores.squeeze(-1)  # (b, seqlen)

            # We do not want to re-mask the previously kept tokens, or pad tokens
            scores = scores.masked_fill(~is_mask, 1e5)

        ids = torch.where(padding_mask, -1, ids)
        # print("Final", ids.max(), ids.min())
        return ids