# Modified from DPLM, ESM, and MAR 
#     DPLM: https://github.com/bytedance/dplm/blob/main/src/byprot/models/dplm/dplm.py
#     ESM: https://huggingface.co/docs/transformers/model_doc/esm
#     MAR: https://github.com/LTH14/mar/blob/main/models/mar.py

from typing import List, Optional, Tuple, Union

import math
import numpy as np
from dataclasses import dataclass

import torch
import torch.nn as nn

from transformers.models.esm.modeling_esm import EsmModel, EsmPreTrainedModel, EsmEmbeddings, EsmPooler, EsmLMHead
from transformers.utils import ModelOutput

from byprot.models.dplm.modules.dplm_modeling_esm import ModifiedEsmEncoder, sample_from_categorical
from byprot.models.utils import sample_from_categorical, stochastic_sample_from_categorical, top_k_top_p_filtering, topk_masking
from byprot.modules.cross_entropy import RDMCrossEntropyLoss

from .diffloss_construct import DiffLoss

@dataclass
class ConstructModelOutput(ModelOutput):
    
    last_hidden_state: torch.FloatTensor = None
    seq_embedding: torch.FloatTensor = None
    struct_embedding: torch.FloatTensor = None


class ModifiedEsmModel(EsmModel):
    def __init__(self, config, add_pooling_layer=True):
        EsmPreTrainedModel.__init__(self, config)
        self.config = config

        # structure track: continuous "token" projection
        self.z_proj = nn.Sequential(
            nn.LayerNorm(20, eps=1e-6),
            nn.Linear(config.struct_token_dim, config.hidden_size, bias=True),
        )

        # sequence track: discrete token mapping
        self.embeddings = EsmEmbeddings(config)
        
        self.encoder = ModifiedEsmEncoder(config)

        self.pooler = EsmPooler(config) if add_pooling_layer else None

        # Initialize weights and apply final processing
        self.post_init()
    
    # Those are mostly taken from EsmModel (insert our z_proj only)
    def forward(
        self,
        input_ids: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.Tensor] = None,
        head_mask: Optional[torch.Tensor] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
        # structure track
        input_struct_embeds: Optional[torch.Tensor] = None,
        encoder_hidden_states: Optional[torch.Tensor] = None,
        encoder_attention_mask: Optional[torch.Tensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple[torch.Tensor], ConstructModelOutput]:
        
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        if self.config.is_decoder:
            use_cache = use_cache if use_cache is not None else self.config.use_cache
        else:
            use_cache = False

        if input_ids is not None and inputs_embeds is not None:
            raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
        elif input_ids is not None:
            input_shape = input_ids.size()
        elif inputs_embeds is not None:
            input_shape = inputs_embeds.size()[:-1]
        else:
            raise ValueError("You have to specify either input_ids or inputs_embeds")

        batch_size, seq_length = input_shape
        device = input_ids.device if input_ids is not None else inputs_embeds.device

        # past_key_values_length
        past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0

        if attention_mask is None:
            attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)

        extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)

        if self.config.is_decoder and encoder_hidden_states is not None:
            encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
            encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
            if encoder_attention_mask is None:
                encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
            encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
        else:
            encoder_extended_attention_mask = encoder_attention_mask

        head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
        
        seq_embedding_output = self.embeddings(
            input_ids=input_ids,
            position_ids=position_ids,
            attention_mask=attention_mask,
            inputs_embeds=inputs_embeds,
            past_key_values_length=past_key_values_length,
        )
        # structure track
        struct_embedding_output = self.z_proj(input_struct_embeds)
        embedding_output = seq_embedding_output + struct_embedding_output

        encoder_outputs = self.encoder(
            embedding_output,
            attention_mask=extended_attention_mask,
            head_mask=head_mask,
            encoder_hidden_states=encoder_hidden_states,
            encoder_attention_mask=encoder_extended_attention_mask,
            past_key_values=past_key_values,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
        sequence_output = encoder_outputs[0]
        pooled_output = self.pooler(sequence_output) if self.pooler is not None else None

        if not return_dict:
            return (sequence_output, pooled_output) + encoder_outputs[1:]
        
        return sequence_output, seq_embedding_output, struct_embedding_output

class ConstructModel(nn.Module):
    
    def __init__(self, config, dplm_config, tokenizer):
        super().__init__()
        
        self.config = config
        
        dplm_config.struct_token_dim = config.struct_token_dim
        
        self.esm = ModifiedEsmModel(dplm_config, add_pooling_layer=False)
        self.struct_mask_token = nn.Parameter(torch.zeros(1, 1, config.struct_token_dim))
        
        self.lm_head = EsmLMHead(dplm_config)
        self.seq_critertion = RDMCrossEntropyLoss()
        self.num_seq_diffusion_timesteps = config.num_seq_diffusion_timesteps
        self.num_struct_diffusion_timesteps = config.num_struct_diffusion_timesteps

        # structure track: DiffLoss Denioser
        self.diffloss = DiffLoss(
            target_channels=config.struct_token_dim,
            z_channels=config.hidden_size,
            width=config.diffloss_w,
            depth=config.diffloss_d,
            num_sampling_steps=config.num_diffloss_sampling_steps,
            # grad_checkpointing=grad_checkpointing
        )
        self.diffusion_batch_mul = config.diffusion_batch_mul
        
        self.mask_id = tokenizer.mask_token_id
        self.pad_id = tokenizer.pad_token_id
        self.bos_id = tokenizer.cls_token_id
        self.eos_id = tokenizer.eos_token_id
        self.x_id = tokenizer._token_to_id['X']
        
        # self.contact_head = None
        self.tokenizer = tokenizer

    # larger t means more masking
    def seq_q_sample(self, seq_0, t, non_special_mask):
        T = self.num_seq_diffusion_timesteps
        # randomly mask t/T tokens
        u = torch.rand_like(seq_0, dtype=torch.float)
        t_mask = (u < (t / T)[:, None]) & non_special_mask
        seq_t = seq_0.masked_fill(t_mask, self.mask_id)

        return {
            "x_t": seq_t,
            "mask_mask": t_mask,
        }

    def struct_q_sample(self, struct_0, t, non_special_mask):
        T = self.num_struct_diffusion_timesteps
        # randomly mask t/T tokens
        u = torch.rand((struct_0.shape[0], struct_0.shape[1]), dtype=torch.float, device=struct_0.device)
        t_mask = (u < (t / T)[:, None]) & non_special_mask
        
        struct_t = self.struct_mask_token.repeat(t_mask.shape[0], t_mask.shape[1], 1).to(struct_0.dtype).clone()
        struct_t[~t_mask.bool()] = struct_0[~t_mask.bool()]

        return {
            "x_t": struct_t,
            "mask_mask": t_mask
        }
    
    # lambda^t: larger t means less weight
    def get_weight_coeff_by_masks(self, non_special_mask, loss_mask, weighting="linear"):

        total_l = non_special_mask.sum(dim=1, keepdim=True).float()
        mask_l = loss_mask.sum(dim=1, keepdim=True).float()
        mask_l = torch.clamp(mask_l, min=1.0)  # avoid division by zero
        seq_weight = {
            "linear": 1 - (mask_l - 1) / total_l,    # num_timesteps * (1 - (t-1)/num_timesteps)
            "constant": torch.ones_like(mask_l),
            "inverse": total_l / mask_l,
        }[weighting]
        return seq_weight.float()

    # get a random order of residues for the structure track
    def sample_struct_orders(self, non_special_mask):
        orders = []
        bsz = non_special_mask.shape[0]
        for i in range(bsz):
            order = np.where(non_special_mask[i].cpu().numpy())[0]
            np.random.shuffle(order)
            orders.append(order)
        return orders
    
    # given a random order, mask the first mask_rate tokens
    def random_struct_masking(self, non_special_mask, orders, mask_rate=None, struct_mask=None):
        bsz, _ = non_special_mask.shape
        seq_len = non_special_mask.sum(dim=1).max().item()
        
        if mask_rate is None:
            mask_rate = self.struct_mask_ratio_generator.rvs(1)[0]
            num_masked_tokens = int(np.ceil(seq_len * mask_rate))
        else:
            num_masked_tokens = int(np.floor(seq_len * mask_rate))
            struct_mask_len = torch.sum(struct_mask[0]).cpu().numpy()
            num_masked_tokens = np.maximum(1, np.minimum(int(struct_mask_len) - 1, num_masked_tokens))
        
        mask = torch.zeros_like(non_special_mask, device=non_special_mask.device)
        for i in range(bsz):
            mask[i, orders[i][:num_masked_tokens]] = 1
        
        return mask
    
    # encoding the input seq_ids and struct_embeds into lantent representations
    def encode(self, input_seq_ids, input_struct_embeds, position_ids=None, cal_seq_logits=False, add_orth_term=True):

        attention_mask = input_seq_ids.ne(self.pad_id)
        (
            seq_struct_latent, seq_embedding, struct_embedding
        ) = self.esm(
            input_seq_ids,
            input_struct_embeds=input_struct_embeds,
            attention_mask=attention_mask,
            position_ids=position_ids,
        )

        outputs = {
            'seq_struct_latent': seq_struct_latent
        }
        if cal_seq_logits:
            seq_logits = self.lm_head(seq_struct_latent)
            outputs['seq_logits'] = seq_logits
        if add_orth_term:
            outputs['seq_embedding_before_sum'] = seq_embedding
            outputs['struct_embedding_before_sum'] = struct_embedding

        return outputs
    
    def cal_struct_loss(self, z, target, mask, weight_coeff=None):
        # z: noisy seq_struct_latent; target: gt_struct_embeds; mask: struct_loss_mask
        bsz, seq_len, _ = target.shape
        # repeat the target and z for diffusion_batch_mul times (on page 4 of the MAR paper)
        target = target.reshape(bsz * seq_len, -1).repeat(self.diffusion_batch_mul, 1)
        z = z.reshape(bsz*seq_len, -1).repeat(self.diffusion_batch_mul, 1)
        mask = mask.reshape(bsz*seq_len).repeat(self.diffusion_batch_mul)
        weight_coeff = weight_coeff.repeat(1, seq_len)
        weight_coeff = weight_coeff.reshape(bsz*seq_len).repeat(self.diffusion_batch_mul)
        loss, pred_target = self.diffloss(z=z, target=target, weights=weight_coeff, mask=mask)
        return loss, pred_target

    # this setting is eventually given up
    def cal_orth_term(self, seq_embedding_before_sum, struct_embedding_before_sum, non_special_mask):
        # orthogonal regularization term
        inner_product = seq_embedding_before_sum * struct_embedding_before_sum
        inner_product = inner_product.masked_fill(~non_special_mask.unsqueeze(-1), 0.0)
        inner_product = inner_product.sum(dim=1)
        orth_term = torch.mean(inner_product ** 2)
        return orth_term

    # the main forward function: given the initial seq_ids and struct_embeds, return the loss
    def forward(self, input_seq_ids, input_struct_embeds, 
                complementary_masking=False, self_mixup=True, cfg_training=False, 
                seq_reweighting="linear", struct_reweighting="constant",
                position_ids=None, add_orth_term=False):
        
        non_special_mask = self.get_non_special_sym_mask(input_seq_ids)

        batch_size = non_special_mask.size(0)
        cfg_size = batch_size // 5

        # --------------------------------------------------------------------------
        # preperations for the sequence track
        # randomly sample the timestep for each sequence, and mask the corresponding proportion of tokens
        num_seq_timesteps = self.num_seq_diffusion_timesteps
        init_seq_t = torch.randint(
            1, num_seq_timesteps + 1,
            (batch_size, ),
            device=input_seq_ids.device
        )
        seq_t = init_seq_t.clone()
        if cfg_training:
            seq_t[:cfg_size] = num_seq_timesteps  # cfg_size sequences are set to the maximum timestep
        
        input_seq_ids_t, seq_loss_mask = list(
            self.seq_q_sample(
                input_seq_ids, seq_t,
                non_special_mask=non_special_mask
            ).values()
        )
        # weight the loss of each timestep (larger t means less weight)
        seq_weight = self.get_weight_coeff_by_masks(non_special_mask, seq_loss_mask, weighting=seq_reweighting)
        if cfg_training:
            seq_weight[:cfg_size] = 0.0
            seq_loss_mask[:cfg_size, ...] = False
        
        if complementary_masking:
            comp_seq_ids = input_seq_ids.clone()
            comp_seq_loss_mask = (~seq_loss_mask)
            if cfg_training:
                comp_seq_loss_mask[:cfg_size, ...] = True
            comp_seq_loss_mask = comp_seq_loss_mask & non_special_mask
            comp_seq_ids_t = comp_seq_ids.masked_fill(comp_seq_loss_mask, self.mask_id)
            comp_seq_weight = self.get_weight_coeff_by_masks(non_special_mask, comp_seq_loss_mask, weighting=seq_reweighting)
            if cfg_training:
                comp_seq_weight[:cfg_size] = 0.0
                comp_seq_loss_mask[:cfg_size, ...] = False
            
            input_seq_ids = torch.cat([input_seq_ids, input_seq_ids], dim=0)
            input_seq_ids_t = torch.cat([input_seq_ids_t, comp_seq_ids_t], dim=0)
            seq_loss_mask = torch.cat([seq_loss_mask, comp_seq_loss_mask], dim=0)
            seq_weight = torch.cat([seq_weight, comp_seq_weight], dim=0)
            del comp_seq_ids, comp_seq_ids_t, comp_seq_loss_mask, comp_seq_weight

        # --------------------------------------------------------------------------
        # preperations for the structure track
        # global normalization has been done in the dataloader
        gt_struct_embeds = input_struct_embeds.clone().detach()

        num_struct_timesteps = num_seq_timesteps
        struct_t = init_seq_t.clone()
        if cfg_training:
            struct_t[cfg_size:cfg_size*2] = num_struct_timesteps
        
        input_struct_embeds_t, struct_loss_mask = list(
            self.struct_q_sample(
                input_struct_embeds, struct_t,
                non_special_mask=non_special_mask
            ).values()
        )
        # weight the loss of each timestep (larger t means less weight)
        struct_weight = self.get_weight_coeff_by_masks(non_special_mask, struct_loss_mask, weighting=struct_reweighting)
        if cfg_training:
            struct_weight[cfg_size:cfg_size*2] = 0.0
            struct_loss_mask[cfg_size:cfg_size*2, ...] = False
        
        if complementary_masking:
            comp_struct_embeds = input_struct_embeds.clone()
            comp_struct_loss_mask = (~struct_loss_mask)
            if cfg_training:
                comp_struct_loss_mask[cfg_size:cfg_size*2, ...] = True
            comp_struct_loss_mask = comp_struct_loss_mask & non_special_mask
            comp_struct_embeds_t = self.struct_mask_token.repeat(comp_struct_loss_mask.shape[0], comp_struct_loss_mask.shape[1], 1).to(input_struct_embeds.dtype).clone()
            comp_struct_embeds_t[~comp_struct_loss_mask.bool()] = comp_struct_embeds[~comp_struct_loss_mask.bool()]
            comp_struct_weight = self.get_weight_coeff_by_masks(non_special_mask, comp_struct_loss_mask, weighting=struct_reweighting)
            if cfg_training:
                comp_struct_weight[cfg_size:cfg_size*2] = 0.0
                comp_struct_loss_mask[cfg_size:cfg_size*2, ...] = False
            
            input_struct_embeds = torch.cat([input_struct_embeds, input_struct_embeds], dim=0)
            input_struct_embeds_t = torch.cat([input_struct_embeds_t, comp_struct_embeds_t], dim=0)
            struct_loss_mask = torch.cat([struct_loss_mask, comp_struct_loss_mask], dim=0)
            struct_weight = torch.cat([struct_weight, comp_struct_weight], dim=0)
            del comp_struct_embeds, comp_struct_embeds_t, comp_struct_loss_mask, comp_struct_weight

            gt_struct_embeds = torch.cat([gt_struct_embeds, gt_struct_embeds], dim=0)
            position_ids = torch.cat([position_ids, position_ids], dim=0) if position_ids is not None else None

        # --------------------------------------------------------------------------
        # forward encoding
        encoding_outputs = self.encode(
            input_seq_ids=input_seq_ids_t, 
            input_struct_embeds=input_struct_embeds_t,
            position_ids=position_ids,
            add_orth_term=add_orth_term,
        )
        seq_struct_latent = encoding_outputs['seq_struct_latent']
        seq_embedding_before_sum = encoding_outputs.get('seq_embedding_before_sum', None)
        struct_embedding_before_sum = encoding_outputs.get('struct_embedding_before_sum', None)

        # --------------------------------------------------------------------------
        # compute ce loss for the sequence track
        seq_logits = self.lm_head(seq_struct_latent)
        seq_ce_loss, _ = self.seq_critertion(
            scores=seq_logits, target=input_seq_ids,
            label_mask=seq_loss_mask,
            weights=seq_weight,
            watch_t1_t2_loss=False,
            cal_constant_loss=False
        )

        # --------------------------------------------------------------------------
        # compute diffusion loss for the structure track
        struct_diff_loss, pred_struct_embeds = self.cal_struct_loss(z=seq_struct_latent, target=gt_struct_embeds, 
                                                                    mask=struct_loss_mask, weight_coeff=struct_weight)

        outputs = {
            'seq_ce_loss': seq_ce_loss,
            'struct_diff_loss': struct_diff_loss
        }
        if add_orth_term:
            if complementary_masking:
                non_special_mask = non_special_mask.repeat(2, 1)
            orth_loss = self.cal_orth_term(seq_embedding_before_sum, struct_embedding_before_sum, non_special_mask=non_special_mask)
            outputs['orth_loss'] = orth_loss
        

        if self_mixup:

            # initialize the mixup_seq_ids: <cls> mixup sequence <eos> <pad> ...
            _, _mixup_seq_ids = seq_logits.log_softmax(dim=-1).max(dim=-1)
            mixup_seq_ids = input_seq_ids.clone()
            mixup_seq_ids[non_special_mask.bool()] = _mixup_seq_ids[non_special_mask.bool()]
            # mixup_seq_loss_mask is inverted to seq_loss_mask
            if cfg_training:
                invert_seq_loss_mask = (~seq_loss_mask)
                invert_seq_loss_mask[:cfg_size] = True
            else:
                invert_seq_loss_mask = (~seq_loss_mask)
            invert_seq_loss_mask = invert_seq_loss_mask & non_special_mask
            # mask the mixup_seq_ids with the mixup_seq_loss_mask
            mixup_seq_ids_t = mixup_seq_ids.masked_fill(invert_seq_loss_mask, self.mask_id)
            # according to dplm2's implementation, mixup_seq_weight is the same to seq_weight
            # mixup_seq_loss_mask (indicating the training target) is the same to non_special_mask (all positions)
            mixup_seq_loss_mask = non_special_mask.clone()
            mixup_seq_weight = seq_weight.clone()
            if cfg_training:
                mixup_seq_loss_mask[:cfg_size, ...] = False
            
            # initialize the mixup_struct_embeds: 0.0 mixup struct_embeds 0.0 0.0 ...
            # correctly set the shape pred_struct_embeds: [bsz * seq_len * diffusion_batch_mul, 20] -> [bsz, seq_len, 20]
            pred_struct_embeds = pred_struct_embeds.reshape(self.diffusion_batch_mul, -1, pred_struct_embeds.shape[-1])
            pred_struct_embeds = pred_struct_embeds.mean(dim=0)
            pred_struct_embeds = pred_struct_embeds.reshape(struct_loss_mask.shape[0], struct_loss_mask.shape[1], -1)
            # pad the struct_embeds with zeros according to non_special_mask
            pred_struct_embeds = pred_struct_embeds.masked_fill(~non_special_mask.unsqueeze(-1), 0.0)
            # mixup_struct_loss_mask is inverted to struct_loss_mask
            if cfg_training:
                invert_struct_loss_mask = (~struct_loss_mask)
                invert_struct_loss_mask[cfg_size:cfg_size*2] = True
            else:
                invert_struct_loss_mask = (~struct_loss_mask)
            invert_struct_loss_mask = invert_struct_loss_mask & non_special_mask
            # except for the masked tokens, the mixup_struct_embeds_t are set to the pred_struct_embeds
            mixup_struct_embeds_t = self.struct_mask_token.repeat(struct_loss_mask.shape[0], struct_loss_mask.shape[1], 1).to(input_struct_embeds.dtype).clone()
            mixup_struct_embeds_t[~invert_struct_loss_mask.bool()] = pred_struct_embeds[~invert_struct_loss_mask.bool()]
            # according to dplm2's implementation, mixup_struct_weight is the same to struct_weight
            # mixup_struct_loss_mask (indicating the training target) is the same to non_special_mask (all positions)
            mixup_struct_loss_mask = non_special_mask.clone()
            mixup_struct_weight = struct_weight.clone()
            if cfg_training:
                mixup_struct_loss_mask[cfg_size:cfg_size*2, ...] = False

            # forward
            mixup_encoding_output = self.encode(
                input_seq_ids=mixup_seq_ids_t, 
                input_struct_embeds=mixup_struct_embeds_t,
                position_ids=position_ids,
            )
            mixup_seq_struct_latent = mixup_encoding_output['seq_struct_latent']
            mixup_seq_embedding_before_sum = mixup_encoding_output.get('seq_embedding_before_sum', None)
            mixup_struct_embedding_before_sum = mixup_encoding_output.get('struct_embedding_before_sum', None)

            # compute mixup ce loss
            mixup_seq_logits = self.lm_head(mixup_seq_struct_latent)
            mixup_ce_loss, _ = self.seq_critertion(
                scores=mixup_seq_logits, target=input_seq_ids,
                label_mask=mixup_seq_loss_mask,
                weights=mixup_seq_weight,
                watch_t1_t2_loss=False,
                cal_constant_loss=False
            )
            # compute mixup diffusion loss
            mixup_diff_loss, _ = self.cal_struct_loss(z=mixup_seq_struct_latent, target=gt_struct_embeds, 
                                                      mask=mixup_struct_loss_mask, weight_coeff=mixup_struct_weight)
            
            outputs['mixup_seq_ce_loss'] = mixup_ce_loss
            outputs['mixup_struct_diff_loss'] = mixup_diff_loss
            if add_orth_term:
                non_special_mask = non_special_mask.repeat(2, 1)
                mixup_orth_loss = self.cal_orth_term(mixup_seq_embedding_before_sum, mixup_struct_embedding_before_sum, non_special_mask=non_special_mask)
                outputs['mixup_orth_loss'] = mixup_orth_loss
        
        return outputs

    def initialize_output_tokens(self, batch):
        
        # In inference, we have initial batch with input ids like: <cls> <mask> <mask> ... <mask> <eos>
        init_seq_tokens = batch['input_ids']
        # [bsz, seq_len]: 1 for general content, 0 for special tokens (i.e., <pad>, <bos>, <eos>)
        init_seq_scores = torch.zeros_like(init_seq_tokens, dtype=torch.float)

        non_special_mask = self.get_non_special_sym_mask(init_seq_tokens)
        if 'struct_latent' in batch:
            partial_masks = torch.eq(init_seq_tokens, self.mask_id)
            non_special_mask = self.get_non_special_sym_mask(init_seq_tokens, partial_masks=partial_masks)
            init_struct_embeds = batch['struct_latent']
        else:
            init_struct_embeds = torch.zeros((init_seq_tokens.shape[0], init_seq_tokens.shape[1], 20), \
                                             dtype=torch.float32, device=init_seq_tokens.device)
        
        init_struct_embeds[non_special_mask] = self.struct_mask_token

        return init_seq_tokens, init_seq_scores, init_struct_embeds, non_special_mask
    
    # decode one step: from the i step to the i-1 step
    def decode(self, prev_decoder_out, sampling_strategy='vanilla', seq_cfg=1.0, seq_cfg_schedule='linear', enable_resample=True):
        
        # get the previous decoding out
        # Note: seq_tokens, seq_scores, struct_embeds are updated in this decoding step function
        seq_tokens = prev_decoder_out['seq_tokens'].clone()
        seq_scores = prev_decoder_out['seq_scores'].clone()
        struct_embeds = prev_decoder_out['struct_embeds'].clone()
        # However, the seq_mask_to_pred and struct_mask_to_pred are not updated in this decoding step function
        seq_mask_to_pred = prev_decoder_out['seq_mask_to_pred'].clone()
        seq_temp = prev_decoder_out['seq_temp']
        step, max_step = prev_decoder_out['step'], prev_decoder_out['max_step']
        uncon_struct_embeds = prev_decoder_out['uncon_struct_embeds']

        # forward encoding to get the seq_struct_latent and seq_logits
        with torch.no_grad():
            encode_out = self.encode(input_seq_ids=seq_tokens, input_struct_embeds=struct_embeds, cal_seq_logits=True)
        seq_struct_latent = encode_out['seq_struct_latent']
        seq_logits = encode_out['seq_logits']

        if not seq_cfg == 1.0:
            with torch.no_grad():
                uncon_seq_logits = self.encode(
                    input_seq_ids=seq_tokens,
                    input_struct_embeds=uncon_struct_embeds,
                    cal_seq_logits=True
                )['seq_logits']
            if seq_cfg_schedule == "linear":
                cfg_step = 1 + (seq_cfg - 1) * (max_step - step) / max_step
            elif seq_cfg_schedule == "constant":
                cfg_step = seq_cfg
            else:
                raise NotImplementedError
            # (1-w) * unconditional probabilities + w * conditional probabilities
            seq_logits = (1 - cfg_step) * uncon_seq_logits +  cfg_step * seq_logits  
        else:
            cfg_step = 1.0

        if seq_logits.dtype != seq_scores.dtype:
            seq_logits = seq_logits.type_as(seq_scores)
        seq_logits[..., self.mask_id] = -math.inf
        seq_logits[..., self.x_id] = -math.inf
        seq_logits[..., self.pad_id] = -math.inf
        seq_logits[..., self.bos_id] = -math.inf
        seq_logits[..., self.eos_id] = -math.inf

        seq_logits = top_k_top_p_filtering(seq_logits, top_p=0.95)

        # decoding out all the masked seq tokens
        if sampling_strategy == 'vanilla':
            _seq_tokens, _seq_scores = sample_from_categorical(seq_logits, temperature=seq_temp)
            if enable_resample:
                # NOTE: _tokens and _scores are revidsed in-place in the "resample" function
                self.resample(_seq_tokens, _seq_scores, struct_embeds, ratio=0.20, 
                              sampling_strategy='vanilla', temperature=2.0, 
                              cfg=cfg_step, uncon_struct_embeds=uncon_struct_embeds)
        elif sampling_strategy == 'argmax':
            _seq_scores, _seq_tokens = seq_logits.log_softmax(dim=-1).max(-1)
        elif sampling_strategy == 'gumbel_argmax':
            noise_scale = seq_temp
            _seq_tokens, _seq_scores = stochastic_sample_from_categorical(seq_logits, temperature=0.0, noise_scale=noise_scale)
            if enable_resample:
                # NOTE: _tokens and _scores are revidsed in-place in the "resample" function
                self.resample(_seq_tokens, _seq_scores, struct_embeds, ratio=0.20, 
                              sampling_strategy='gumbel_argmax', temperature=1.0,
                              cfg=cfg_step, uncon_struct_embeds=uncon_struct_embeds)
        elif sampling_strategy.startswith("annealing"):
            max_temp, min_temp = map(
                float, sampling_strategy.split("@")[1].split(":")
            )
            rate = 1 - step / max_step
            temperature = min_temp + (max_temp - min_temp) * rate
            _seq_tokens, _seq_scores = sample_from_categorical(
                seq_logits, temperature=temperature
            )
            if enable_resample:
                # NOTE: _tokens and _scores are revidsed in-place in the "resample" function
                self.resample(_seq_tokens, _seq_scores, struct_embeds, ratio=0.20, sampling_strategy='vanilla', temperature=2.0)
        else:
            raise NotImplementedError
        
        seq_tokens.masked_scatter_(seq_mask_to_pred, _seq_tokens[seq_mask_to_pred])
        seq_scores.masked_scatter_(seq_mask_to_pred, _seq_scores[seq_mask_to_pred])

        return dict(
            seq_tokens=seq_tokens,
            seq_scores=seq_scores,
            seq_struct_latent=seq_struct_latent
        )

    def get_non_special_sym_mask(self, output_tokens, partial_masks=None):
        non_special_sym_mask = (
            output_tokens.ne(self.pad_id) &
            output_tokens.ne(self.bos_id) &
            output_tokens.ne(self.eos_id)
        )
        if partial_masks is not None:
            non_special_sym_mask = non_special_sym_mask & partial_masks
        return non_special_sym_mask
    
    def get_predicted_struct_embeds(self, struct_embeds, seq_struct_latent, struct_mask_to_pred, temperature, struct_cfg):
        z = seq_struct_latent[struct_mask_to_pred.nonzero(as_tuple=True)]
        _struct_embeds = self.diffloss.sample(z=z, temperature=temperature, cfg=struct_cfg)
        if not struct_cfg == 1.0:
            _struct_embeds, _ = _struct_embeds.chunk(2, dim=0)
            struct_mask_to_pred, _ = struct_mask_to_pred.chunk(2, dim=0)
        updated_struct_embeds = struct_embeds.clone()
        updated_struct_embeds[struct_mask_to_pred.nonzero(as_tuple=True)] = _struct_embeds
        return updated_struct_embeds
    
    def generate(
            self, batch, 
            seq_temp=None, struct_temp=None,
            sampling_strategy='gumbel_argmax', # vanilla, annealing@2.0-1.0, argmax, gumbel_argmax
            unmasking_strategy='deterministic', # deterministic, stochastic1.0, mix1.5
            seq_cfg=1.0, seq_cfg_schedule='constant',
            struct_cfg=1.0, struct_cfg_schedule='constant',
            cover_ori_motif=False,
        ):

        seq_temp = seq_temp
        struct_temp = struct_temp
        
        init_seq_tokens, init_seq_scores, init_struct_embeds, non_special_mask = self.initialize_output_tokens(batch)

        seq_non_special_mask = non_special_mask.clone()
        if ('struct_latent' in batch) & cover_ori_motif:
            struct_non_special_mask = self.get_non_special_sym_mask(init_seq_tokens)
        else:
            struct_non_special_mask = non_special_mask.clone()
        struct_orders = self.sample_struct_orders(struct_non_special_mask)
        seq_len = struct_non_special_mask.sum(dim=1).max().item()
        max_iter = seq_len
        
        prev_decoder_out = dict(
            uncon_seq_tokens=init_seq_tokens.clone(),
            uncon_struct_embeds=init_struct_embeds.clone(),
            seq_tokens=init_seq_tokens,
            seq_scores=init_seq_scores,
            struct_embeds=init_struct_embeds,
            seq_mask_to_pred=seq_non_special_mask,
            struct_mask=struct_non_special_mask,
            step=0,
            max_step=max_iter,
            seq_temp=seq_temp,
        )
        
        for step in range(max_iter):
            
            # 2.1: predict
            with torch.no_grad():
                decoder_out = self.decode(
                    prev_decoder_out=prev_decoder_out,
                    sampling_strategy=sampling_strategy,
                    seq_cfg=seq_cfg, seq_cfg_schedule=seq_cfg_schedule
                )

            seq_tokens, seq_scores, seq_struct_latent = (
                decoder_out['seq_tokens'], 
                decoder_out['seq_scores'], 
                decoder_out['seq_struct_latent']
            )

            # 2.2: re-mask skeptical parts of low confidence seq tokens
            # structure_track: Not participated!
            result_seq_mask_to_pred, result_seq_tokens, result_seq_scores = self._reparam_decoding(
                prev_seq_tokens=prev_decoder_out['seq_tokens'].clone(),
                prev_seq_scores=prev_decoder_out['seq_scores'].clone(),
                cur_seq_tokens=seq_tokens.clone(),
                cur_seq_scores=seq_scores.clone(),
                decoding_strategy=f'reparam-uncond-{unmasking_strategy}-linear',
                xt_neq_x0=prev_decoder_out['seq_mask_to_pred'],
                non_special_mask=seq_non_special_mask,
                t=step + 1,
                max_step=max_iter,
                noise=self.mask_id,
            )
            
            # 2.3: decoding the structure track at this step
            struct_embeds = prev_decoder_out['struct_embeds'].clone()
            if not struct_cfg == 1.0:
                uncon_seq_tokens = prev_decoder_out['uncon_seq_tokens'].clone()
                with torch.no_grad():
                    uncon_seq_struct_latent = self.encode(
                        input_seq_ids=uncon_seq_tokens, 
                        input_struct_embeds=struct_embeds, 
                        cal_seq_logits=False
                    )['seq_struct_latent']
                seq_struct_latent = torch.cat([seq_struct_latent, uncon_seq_struct_latent], dim=0)
            
            struct_mask = prev_decoder_out['struct_mask'].clone()
            # cosine decay the mask ratio from 1.0 to 0.0
            mask_struct_ratio = np.cos(math.pi / 2. * (step - (max_iter - seq_len) + 1) / seq_len)
            # the factual calculation of num_masked_tokens in the self.random_struct_masking leads to linear decoding
            struct_mask_next = self.random_struct_masking(
                non_special_mask=struct_non_special_mask,
                orders=struct_orders, 
                mask_rate=mask_struct_ratio, 
                struct_mask=struct_mask
            )
            if step >= max_iter - 1:
                struct_mask_to_pred = struct_mask.bool()
            else:
                # Note: xor operation is used to get the different parts between the current and the next mask
                # i.e., the difference between the L and L-1 masked tokens (the one token) is to be predicted
                struct_mask_to_pred = torch.logical_xor(struct_mask.bool(), struct_mask_next.bool())
                
            if not struct_cfg == 1.0:
                struct_mask_to_pred = torch.cat([struct_mask_to_pred, struct_mask_to_pred], dim=0)
            if struct_cfg_schedule == "linear":
                cfg_iter = 1 + (struct_cfg - 1) * (max_iter - step) / max_iter
            elif struct_cfg_schedule == "constant":
                cfg_iter = struct_cfg
            else:
                raise NotImplementedError
            with torch.no_grad():
                struct_embeds = self.get_predicted_struct_embeds(
                    struct_embeds, 
                    seq_struct_latent, 
                    struct_mask_to_pred, 
                    struct_temp,
                    struct_cfg=cfg_iter
                )
            struct_mask = struct_mask_next

            prev_decoder_out.update(
                seq_tokens=result_seq_tokens,
                seq_scores=result_seq_scores,
                struct_embeds=struct_embeds,
                seq_mask_to_pred=result_seq_mask_to_pred,
                struct_mask=struct_mask,
                step=step + 1
            )
        
        return prev_decoder_out['seq_tokens'], prev_decoder_out['seq_scores'], prev_decoder_out['struct_embeds']

    def resample(self, _tokens, _scores, struct_embeds, ratio, sampling_strategy, temperature, cfg=1.0, uncon_struct_embeds=None):
        
        to_be_resample_idx = []
        resample_input = []
        resample_input_mask = []
        resample_input_scores = []
        resample_struct_embeds = [] # + struct track
        resample_uncon_struct_embeds = [] # + uncon struct for cfg
        
        # Calculate the frequency of all tokens
        for i, seq in enumerate(_tokens):
            most_token_dict = {}
            # most_token = None
            most_token_num = -1
            for j, token in enumerate(seq):
                token = int(token)
                if token not in most_token_dict:
                    most_token_dict[token] = [j]
                else:
                    most_token_dict[token].append(j)
                if len(most_token_dict[token]) > most_token_num:
                    most_token_num = len(most_token_dict[token])
            if most_token_num > len(seq) * ratio:
                to_be_resample_idx.append(i)
                resample_input_scores.append(_scores[i])
                mask = torch.zeros_like(seq).bool()
                for k, v in most_token_dict.items():
                    if len(v) > len(seq) * ratio:
                        mask |= seq.eq(k)
                resample_input_mask.append(mask)
                resample_input.append(seq.masked_fill(mask, self.mask_id))
                resample_struct_embeds.append(struct_embeds[i]) # + struct track
                if not cfg == 1.0:
                    resample_uncon_struct_embeds.append(uncon_struct_embeds[i]) # + uncon struct for cfg
            
        if len(to_be_resample_idx) > 0:
            # Resample the sequences that have tokens with higher frequency than threthold.
            resample_input = torch.stack(resample_input, dim=0).type_as(_tokens)
            resample_input_scores = torch.stack(resample_input_scores, dim=0).type_as(_scores)
            resample_input_mask = torch.stack(resample_input_mask, dim=0).type_as(_tokens).bool()
            resample_struct_embeds = torch.stack(resample_struct_embeds, dim=0).type_as(struct_embeds) # + struct track
            resample_logits = self.encode(
                input_seq_ids=resample_input, 
                input_struct_embeds=resample_struct_embeds, # + struct track
                cal_seq_logits=True
            )['seq_logits']
            if not cfg == 1.0:
                resample_uncon_struct_embeds = torch.stack(resample_uncon_struct_embeds, dim=0).type_as(uncon_struct_embeds)
                resample_uncon_logits = self.encode(
                    input_seq_ids=resample_input, 
                    input_struct_embeds=resample_uncon_struct_embeds, # + uncon struct for cfg
                    cal_seq_logits=True
                )['seq_logits']
                # (1-w) * unconditional probabilities + w * conditional probabilities
                resample_logits = (1 - cfg) * resample_uncon_logits + cfg * resample_logits
            if resample_logits.dtype != _scores.dtype:
                resample_logits = resample_logits.type_as(_scores)
            resample_logits[..., self.mask_id] = -math.inf
            resample_logits[..., self.x_id] = -math.inf
            resample_logits[..., self.pad_id] = -math.inf
            resample_logits[..., self.bos_id] = -math.inf
            resample_logits[..., self.eos_id] = -math.inf
            
            try:
                resample_logits = top_k_top_p_filtering(resample_logits, top_p=0.95)
                if sampling_strategy == 'vanilla':
                    resample_tokens, resample_scores = sample_from_categorical(resample_logits, temperature=temperature)
                elif sampling_strategy == 'gumbel_argmax':
                    # noise_scale = 1.5 - 0.2 * ((step + 1) / max_step)
                    assert resample_logits.size(0) == len(to_be_resample_idx)
                    noise_scale = temperature
                    resample_tokens, resample_scores = stochastic_sample_from_categorical(resample_logits, temperature=0.0, noise_scale=noise_scale)
                else:
                    raise NotImplementedError
                resample_input.masked_scatter_(resample_input_mask, resample_tokens[resample_input_mask])
                resample_input_scores.masked_scatter_(resample_input_mask, resample_scores[resample_input_mask])
                _tokens[to_be_resample_idx], _scores[to_be_resample_idx] = resample_input, resample_input_scores
            except ValueError as e:
                print(f"Value Error during resampling: {e}")

    def _reparam_decoding(
        self,
        prev_seq_tokens,
        prev_seq_scores,
        cur_seq_tokens,
        cur_seq_scores,
        decoding_strategy,
        xt_neq_x0,
        non_special_mask,
        t,
        max_step,
        noise,
    ):
        """
            This function is used to perform reparameterized decoding.
        """
        # output_tokens: [B, N]
        # output_scores: [B, N]
        # cur_tokens: [B, N]
        # cur_scores: [B, N]
        # xt_neq_x0: equivalent to not_b_t [B, N]
        # non_special_sym_mask: [B, N]
        # noise: either [B, N] or scalar (if using the mask noise)

        # decoding_strategy needs to take the form of "reparam-<conditioning>-<topk_mode>-<schedule>"
        _, condition, topk_mode, schedule = decoding_strategy.split("-")

        # first set the denoising rate according to the schedule
        if schedule == "linear":
            rate = 1 - t / max_step
        elif schedule == "cosine":
            rate = np.cos(t / max_step * np.pi * 0.5)
        else:
            raise NotImplementedError

        # compute the cutoff length for denoising top-k positions
        cutoff_len = (
            non_special_mask.sum(1, keepdim=True).type_as(prev_seq_scores) * rate
        ).long()
        # set the scores of special symbols to a large value so that they will never be selected
        _scores_for_topk = cur_seq_scores.masked_fill(~non_special_mask, 1000.0)
                
        # the top-k selection can be done in two ways: stochastic by injecting Gumbel noise or deterministic
        if topk_mode.startswith("stochastic"):
            noise_scale = float(topk_mode.replace("stochastic", ""))
            lowest_k_mask = topk_masking(_scores_for_topk, cutoff_len, stochastic=True, temp=noise_scale * rate)
        elif topk_mode.startswith("mix"):
            to_be_resample = []
            for i, seq in enumerate(cur_seq_tokens):
                most_token_dict = {}
                # most_token = None
                most_token_num = -1
                for j, token in enumerate(seq):
                    token = int(token)
                    if token == self.pad_id:
                        continue
                    if token not in most_token_dict:
                        most_token_dict[token] = [j]
                    else:
                        most_token_dict[token].append(j)
                    if len(most_token_dict[token]) > most_token_num:
                        # most_token = token
                        most_token_num = len(most_token_dict[token])
                if most_token_num > len(seq) * 0.20:
                    to_be_resample.append(i)
            lowest_k_mask = topk_masking(_scores_for_topk, cutoff_len, stochastic=False)
            if len(to_be_resample) > 0:
                noise_scale = float(topk_mode.replace("mix", ""))
                lowest_k_mask[to_be_resample] = topk_masking(_scores_for_topk[to_be_resample], cutoff_len[to_be_resample], 
                                                             stochastic=True, temp=noise_scale * rate)
        elif topk_mode == "deterministic":
            lowest_k_mask = topk_masking(_scores_for_topk, cutoff_len, stochastic=False)
        else:
            raise NotImplementedError
        
        if condition == "cond":
            not_v1_t = (cur_seq_tokens == prev_seq_tokens) & (cur_seq_scores < prev_seq_scores) & lowest_k_mask
        elif condition == "uncond":
            not_v1_t = lowest_k_mask
        else:
            raise NotImplementedError

        # for b_t = 0, the token is set to noise if it is in the lowest k scores.
        not_v2_t = lowest_k_mask

        last_mask_position = xt_neq_x0
        masked_to_noise = (~xt_neq_x0 & not_v1_t) | (xt_neq_x0 & not_v2_t)
        if isinstance(noise, torch.Tensor):
            prev_seq_tokens.masked_scatter_(masked_to_noise, noise[masked_to_noise])
        elif isinstance(noise, (int, float)):
            prev_seq_tokens.masked_fill_(masked_to_noise, noise)
        else:
            raise NotImplementedError("noise should be either a tensor or a scalar")
        prev_seq_scores.masked_fill_(masked_to_noise, -math.inf)

        masked_to_x0 = xt_neq_x0 & ~not_v2_t
        prev_seq_tokens.masked_scatter_(masked_to_x0, cur_seq_tokens[masked_to_x0])
        prev_seq_scores.masked_scatter_(masked_to_x0, cur_seq_scores[masked_to_x0])
        assert ((masked_to_x0 & last_mask_position) == masked_to_x0).all()
        
        new_xt_neq_x0 = (xt_neq_x0 | not_v1_t) & not_v2_t
        assert (new_xt_neq_x0 == not_v2_t).all()
        return new_xt_neq_x0, prev_seq_tokens, prev_seq_scores