import math
import logging
from typing import Optional, List, Union
from pathlib import Path

import torch
import torch.nn as nn
import torch.nn.functional as F
from ..zipformer.model import ZipformerEncoderModel
from ..zipformer.utils.padding import make_pad_mask

from ...auto.auto_config import AutoConfig
from ...utils.checkpoint import load_model_params
from transformers import (
    T5Tokenizer, CLIPTokenizer, BartTokenizer,
    T5ForConditionalGeneration, BartForConditionalGeneration,
)

TOKENIZERS = {
    't5-small': T5Tokenizer,
    'facebook/bart-base': BartTokenizer,
    'openai/clip-vit-base-patch32': CLIPTokenizer,
}

DECODERS = {
    'facebook/bart-base': BartForConditionalGeneration, # Now only support BART
    't5-small': T5ForConditionalGeneration,
}

SPECIAL_TOKENS = {
    'bos_token': '<s>',
    'eos_token': '</s>',
    'pad_token': '<pad>',
    'mask_token': '<mask>',
}

def ensure_special_tokens(tokenizer, required_special_tokens=SPECIAL_TOKENS):
    to_add = {k:v for k,v in required_special_tokens.items() if getattr(tokenizer, k) is None}
    if to_add:
        tokenizer.add_special_tokens(to_add)

def causal_mask(length: int, device: Optional[torch.device] = None) -> torch.BoolTensor:
    """Upper‑triangular mask for *batch‑first* causal self‑attention."""
    return torch.triu(torch.ones(length, length, dtype=torch.bool, device=device), diagonal=1)

class SinusoidalPositionalEncoding(nn.Module):
    """Classic positional embedding (fixed, non-trainable)."""
    def __init__(self, d_model, max_len=4096):
        super(SinusoidalPositionalEncoding, self).__init__()
        pe = torch.zeros(max_len, d_model)
        pos = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(pos * div)
        pe[:, 1::2] = torch.cos(pos * div)
        self.register_buffer("pe", pe, persistent=False)

    def forward(self, x):
        """
        x: [B, T] token ids → returns [B, T, D]
        """
        return self.pe[: x.size(1)].unsqueeze(0).to(x.device)

class DecoderModel(nn.Module):
    """
    Wrapper for PyTorch's TransformerDecoder.
    """
    def __init__(self, config, tokenizer):
        super().__init__()

        self.vocab_size = len(tokenizer)

        # Basic model parameters
        self.d_model = config.d_model if hasattr(config, 'd_model') else max(config.encoder_dim)
        self.decoder_nhead = config.decoder_nhead
        self.text_embed = nn.Embedding(self.vocab_size, self.d_model)
        self.output_proj = nn.Linear(self.d_model, self.vocab_size, bias=False)
        if config.decoder_shared_emb:
            self.output_proj.weight = self.text_embed.weight

        # Other Configurable parameters
        self.dim_feedforward = config.dim_feedforward if hasattr(config, 'dim_feedforward') else 4 * self.d_model
        self.dropout = config.decoder_dropout if hasattr(config, 'decoder_dropout') else 0.1
        self.activation = config.decoder_activation if hasattr(config, 'decoder_activation') else 'relu'
        self.norm_first = config.decoder_norm_first if hasattr(config, 'decoder_norm_first') else False
        self.bias = config.decoder_bias if hasattr(config, 'decoder_bias') else False
        
        self.num_layers = config.num_decoder_layers if hasattr(config, 'num_decoder_layers') else 6
        self.positional_encoding = SinusoidalPositionalEncoding(self.d_model)

        decoder_layer = nn.TransformerDecoderLayer(
            d_model=self.d_model,
            nhead=self.decoder_nhead,
            dim_feedforward=self.dim_feedforward,
            dropout=self.dropout,
            activation=self.activation,
            batch_first=True,
            norm_first=self.norm_first,
            bias=self.bias,
        )
        self.decoder = nn.TransformerDecoder(decoder_layer, num_layers=self.num_layers)

    def forward(self, tgt, memory, tgt_mask=None, tgt_key_padding_mask=None, memory_key_padding_mask=None):
        """
        Forward pass for the decoder.
        Args:
            tgt: Tensor [B, T], target sequence ids
            memory: Tensor [B, S, D], encoder output embeddings
            tgt_mask: Optional[Tensor], mask for the target sequence
            tgt_key_padding_mask: Optional[Tensor], padding mask for the target sequence
            memory_key_padding_mask: Optional[Tensor], padding mask for the encoder output
            
        Returns:
            Tensor [B, T, D], decoder output embeddings
        """
        tgt_emb = self.text_embed(tgt) + self.positional_encoding(tgt)
        return self.decoder(
            tgt=tgt_emb,
            memory=memory,
            tgt_mask=tgt_mask,
            tgt_key_padding_mask=tgt_key_padding_mask,
            memory_key_padding_mask=memory_key_padding_mask,
        )

class HFDecoderModel(nn.Module):
    """
    Wrapper for HuggingFace's BART decoders.
    This is used for evaluating the audio encoder using pre-trained transformer decoders.
    """
    def __init__(self, config, tokenizer):
        super().__init__()

        # Basic model parameters
        assert config.text_decoder_type is not None, "text_decoder_type must be specified in config"
        model = DECODERS[config.text_decoder_type].from_pretrained(config.text_decoder_type)
        if len(tokenizer) != model.config.vocab_size:
            model.resize_token_embeddings(len(tokenizer))  # Resize embeddings if tokenizer vocab size does not match
        
        self.decoder = model.get_decoder()
        self.text_embed = model.get_input_embeddings()
        self.lm_head = model.lm_head
        self.d_model = self.decoder.config.d_model
        self.vocab_size = self.decoder.config.vocab_size
        # For bart, it is decoder_attention_heads. For T5, it is num_heads
        self.decoder_nhead = self.decoder.config.decoder_attention_heads if hasattr(self.decoder.config, 'decoder_attention_heads') else self.decoder.config.num_heads
        self.output_proj = nn.Identity()  # No additional projection needed

    def _build_4d_attn_mask(
        self,
        tgt_mask: Optional[torch.Tensor],
        tgt_key_padding_mask: Optional[torch.Tensor],
        seq_len: int,
        dtype: torch.dtype,
        device: torch.device,
        batch_size: int,
    ):
        """
        將使用者傳入的:
          - tgt_mask: [T,T] 或 [B,T,T] 布林矩陣，True=禁止位置
          - tgt_key_padding_mask: [B,T] 布林矩陣，True=pad
        組合成 BART 接受的 4D additive mask: [B,1,T,T]，允許位置=0，被遮= -inf
        """
        if tgt_mask is None and tgt_key_padding_mask is None:
            return None  # 使用 BART 預設 (causal + padding)

        finfo_min = torch.finfo(dtype).min

        # 1) 起始：全 0（允許所有注意）
        attn = torch.zeros(batch_size, 1, seq_len, seq_len, dtype=dtype, device=device)

        # 2) 自定義結構遮罩（非自回歸）：True → -inf
        if tgt_mask is not None:
            if tgt_mask.dim() == 2:               # [T,T] → broadcast 到 batch
                mask_ = tgt_mask.unsqueeze(0).expand(batch_size, -1, -1)
            elif tgt_mask.dim() == 3:             # [B*nhead,T,T]
                mask_ = tgt_mask.reshape(batch_size, -1, seq_len, seq_len)
                mask_ = mask_[:, 0, :, :]
            else:
                raise ValueError("tgt_mask must be [T,T] or [B,T,T]")
            mask_ = mask_.to(device=device)
            attn.masked_fill_(mask_.unsqueeze(1), finfo_min)

        # 3) padding 遮罩：對於任一 pad token，不可被 attend，也不對外 attend
        if tgt_key_padding_mask is not None:
            # key 方向
            key_mask = tgt_key_padding_mask[:, None, None, :].to(device=device)   # [B,1,1,T]
            attn.masked_fill_(key_mask, finfo_min)
            # query_mask = tgt_key_padding_mask[:, None, :, None].to(device=device)  # [B,1,T,1]
            # attn.masked_fill_(query_mask, finfo_min)


        return attn

    def forward(
        self,
        tgt: torch.Tensor,                     # [B,T]
        memory: torch.Tensor,                  # [B,S,D]
        tgt_mask: Optional[torch.Tensor] = None,
        tgt_key_padding_mask: Optional[torch.Tensor] = None,      # True=pad
        memory_key_padding_mask: Optional[torch.Tensor] = None,   # True=pad
    ):
        """
        回傳 decoder hidden states (未過 lm_head)；外層會再調用 self.lm_head。
        - 若未提供 tgt_mask：使用 BART 預設 causal（自回歸）
        - 若提供 tgt_mask：改為自定義（可實現非自回歸）
        """
        B, T = tgt.shape
        device = tgt.device
        dtype = memory.dtype

        # decoder padding mask: BART 期望 1 表示可用 / 0 表示 pad
        if tgt_key_padding_mask is not None:
            dec_pad = (~tgt_key_padding_mask)
        else:
            dec_pad = None

        # encoder padding mask
        if memory_key_padding_mask is not None:
            enc_pad = (~memory_key_padding_mask)
        else:
            enc_pad = None

        # custom_4d = None
        # 自定義 4D 注意力遮罩（覆蓋 BART 內部 causal）
        custom_4d = self._build_4d_attn_mask(
            tgt_mask=tgt_mask,
            tgt_key_padding_mask=tgt_key_padding_mask,
            seq_len=T,
            dtype=dtype,
            device=device,
            batch_size=B,
        )
        # import IPython; IPython.embed()  # For debugging, remove in production

        decoder_outputs = self.decoder(
            input_ids=tgt,
            attention_mask=dec_pad if custom_4d is None else custom_4d,  # 若 custom_4d 為 None → 交給 BART 自行生成 causal
            encoder_hidden_states=memory,
            encoder_attention_mask=enc_pad,
            use_cache=False,
        )

        logits = self.lm_head(decoder_outputs.last_hidden_state)  # [B,T,D]
        return logits

class ZipformerForAudioCaptioningModel(ZipformerEncoderModel):
    @classmethod
    def from_pretrained(cls, exp_dir, checkpoint_filename='pretrained.pt'):
        """
        Load model from exp_dir.
        """
        config = AutoConfig.from_pretrained(exp_dir)
        model = cls(config)
        ckpt_path = Path(exp_dir) / checkpoint_filename
        load_model_params(model, ckpt_path)
        return model

    def __init__(self, config):
        super().__init__(config)

        self.text_tokenizer_type = config.text_tokenizer_type
        self._build_tokenizer(self.text_tokenizer_type)
        if config.text_decoder_type is not None:
            logging.info(f"Using text decoder type: {config.text_decoder_type}")
            self.text_decoder = HFDecoderModel(config, self.tokenizer)
            if config.train_decoder_xattn_only:
                logging.info("Training only cross-attention layers in the text decoder")
                for p in self.text_decoder.decoder.parameters():
                    p.requires_grad = False
                for n, p in self.text_decoder.decoder.named_parameters():
                    if n.startswith("layers") and (
                        ".encoder_attn." in n or ".encoder_attn_layer_norm." in n
                    ):
                        p.requires_grad = True
                
                for p in self.text_decoder.text_embed.parameters():
                    p.requires_grad = False
                for p in self.text_decoder.output_proj.parameters():
                    p.requires_grad = False
                for p in self.text_decoder.lm_head.parameters():
                    p.requires_grad = False
        else:
            logging.info("Using default TransformerDecoder")
            self.text_decoder = DecoderModel(config, self.tokenizer)

        self.enc_dec_proj = nn.Linear(max(config.encoder_dim), self.text_decoder.d_model, bias=False) if max(config.encoder_dim) != self.text_decoder.d_model else None
        self.criterion = nn.CrossEntropyLoss(ignore_index=self.id_pad, label_smoothing=config.label_smoothing if hasattr(config, 'label_smoothing') else 0.1)
    
    # --------------------------- Tokenizer initialization ----------------- #
    def _build_tokenizer(self, tokenizer_type: str):
        """
        Initialize the tokenizer based on the specified type.
        Args:
            tokenizer_type: str, type of tokenizer to use
        Returns:
            tokenizer: Tokenizer instance
        """
        if tokenizer_type not in TOKENIZERS:
            raise ValueError(f"Unsupported tokenizer type '{tokenizer_type}'. Supported: {list(TOKENIZERS.keys())}")
        
        self.tokenizer = TOKENIZERS[tokenizer_type].from_pretrained(tokenizer_type)
        ensure_special_tokens(self.tokenizer)

        self.id_bos, self.id_eos, self.id_pad = self.tokenizer.bos_token_id, self.tokenizer.eos_token_id, self.tokenizer.pad_token_id
        self.vocab_size = len(self.tokenizer)

    # --------------------------- Encoder wrapper -------------------------- #
    def encode_audio(self, x, x_lens):
        encoder_output = self.forward_encoder(x, x_lens)
        encoder_out, encoder_out_lens = encoder_output.encoder_out, encoder_output.encoder_out_lens
        padding_mask = make_pad_mask(encoder_out_lens, encoder_out.size(1)).to(encoder_out.device)

        return encoder_out, padding_mask
    
    # --------------------------- Text helper ------------------------------ #
    def _tokenise(self, text: List[str]):
        ids = self.tokenizer(text, truncation=True, max_length=128, add_special_tokens=False).input_ids
        return ids

    def _prepare_tgt(self, text: List[str]):
        """
        Prepare target input for decoder.
        Args:
            text: List[str]
        Returns:
            tgt_in: Tensor [B, T], input to decoder
            tgt_out: Tensor [B, T], target output for loss computation
            tgt_pad_mask: Tensor [B, T], padding mask for decoder input
            tgt_mask: Tensor [T, T], causal mask for decoder input
        """
        PAD, BOS, EOS = self.id_pad, self.id_bos, self.id_eos
        device = next(self.parameters()).device      # always valid

        ids = self._tokenise(text)  # tokenize text to ids

        tgt_in  = [torch.tensor([BOS] + s, dtype=torch.long) for s in ids]  # add BOS token
        tgt_out = [torch.tensor(s + [EOS], dtype=torch.long) for s in ids]  # add EOS token                                      
        tgt_in  = nn.utils.rnn.pad_sequence(tgt_in,  batch_first=True, padding_value=PAD).to(device) 
        tgt_out = nn.utils.rnn.pad_sequence(tgt_out, batch_first=True, padding_value=PAD).to(device)

        tgt_pad_mask = tgt_in.eq(PAD)                   # True → ignore position
        tgt_mask = causal_mask(tgt_in.size(1), device=device)

        return tgt_in, tgt_out, tgt_pad_mask, tgt_mask
    
    # --------------------------- Forward ---------------------------------- #
    def forward(
        self,
        x: torch.Tensor,
        x_lens: torch.Tensor,
        text: Optional[List[str]] = None,
    ):
        """
        Args:
            x: Tensor [B, T, D]
            x_lens: Tensor [B]
            text: List[str] or None
        Returns:
            logits: Tensor [B, T', vocab_size]
            loss: Tensor, if text is provided, otherwise None

        """
        # Encode audio features
        audio_embs, padding_mask = self.encode_audio(x, x_lens)

        if self.enc_dec_proj is not None:
            audio_embs = self.enc_dec_proj(audio_embs)

        # Prepare decoder input
        if self.training:
            assert text is not None, "Text must be provided during training" 
            tgt_in, tgt_out, tgt_pad_mask, tgt_mask = self._prepare_tgt(text)
            decoder_out = self.text_decoder(
                tgt=tgt_in,
                memory=audio_embs,
                tgt_mask=tgt_mask,                      # causal mask
                tgt_key_padding_mask=tgt_pad_mask,      # padding mask for decoder input
                memory_key_padding_mask=padding_mask,   # padding mask for encoder output
            )
            logits = self.text_decoder.output_proj(decoder_out)
            loss = self.criterion(logits.view(-1, self.vocab_size), tgt_out.view(-1))
            return loss, None
        else:
            output_text = self._generate(audio_embs, padding_mask)
            return None, output_text

    @torch.inference_mode()
    def _generate(self, audio_embs, padding_mask, max_len=128):
        """
        Generate captions from audio embeddings.
        Args:
            audio_embs: Tensor [B, T, D], audio embeddings
            padding_mask: Tensor [B, T], padding mask for audio embeddings
            max_length: int, maximum length of generated captions
            is_causal: bool, whether to use causal mask for decoder
        Returns:
            generated_ids: Tensor [B, T'], generated token ids
        """
        B  = audio_embs.size(0)
        tokens = torch.full((B, 1), self.id_bos, dtype=torch.long, device=audio_embs.device)  # [B, 1]
        finished = torch.zeros(B, dtype=torch.bool, device=audio_embs.device)

        for _ in range(max_len):
            tgt_mask = causal_mask(tokens.size(1), device=audio_embs.device)
            dec_out = self.text_decoder(
                tgt=tokens,
                memory=audio_embs,
                tgt_mask=tgt_mask,
                memory_key_padding_mask=padding_mask
            )
            next_tok = self.text_decoder.output_proj(dec_out[:, -1]).argmax(-1)
            tokens = torch.cat([tokens, next_tok[:, None]], dim=1)
            finished |= next_tok.eq(self.id_eos)
            if finished.all():
                break

        # Find first EOS position and pad the rest of the sequence
        for b in range(B):
            eos_pos = (tokens[b] == self.id_eos).nonzero(as_tuple=True)[0]
            if len(eos_pos):
                tokens[b, eos_pos[0]:] = self.id_pad
        
        text_ids = tokens[:, 1:].detach().cpu()  # remove BOS token, return [B, T']
        text = self.tokenizer.batch_decode(text_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)

        return text  # remove BOS token, return [B, T']
    
    def generate(self, input):
        if isinstance(input, tuple) and len(input) == 2:
            x, x_lens = input
        else:
            x, x_lens = self.extract_feature(input)

        audio_embs, padding_mask = self.encode_audio(x, x_lens)
        if self.enc_dec_proj is not None:
            audio_embs = self.enc_dec_proj(audio_embs)

        text = self._generate(audio_embs, padding_mask)
        return text
        
class ZipformerForAudioCaptioningWithMaskingModel(ZipformerForAudioCaptioningModel):
    """
    Zipformer model for audio captioning with masking support.
    """
    def __init__(self, config):
        super().__init__(config)
        assert config.parallel_decoding_prob >= 0.0 and config.parallel_decoding_prob <= 1.0, \
            f"Parallel decoding probability must be in [0.0, 1.0], got {config.parallel_decoding_prob}"
        self.parallel_decoding_prob = config.parallel_decoding_prob
        if self.parallel_decoding_prob == 0.0:
            logging.info("Parallel decoding is disabled (probability set to 0.0). "
                  "This will use causal decoding for all inputs.")
        elif self.parallel_decoding_prob == 1.0:
            logging.info("Parallel decoding is enabled (probability set to 1.0). "
                  "This will use parallel decoding for all inputs.")
            
        self.id_mask = self.tokenizer.mask_token_id

    def _prepare_tgt(self, text: List[str], parallel_decoding_prob: float):
        """
        Prepare target input for decoder.
        Args:
            text: List[str]
        Returns:
            tgt_in: Tensor [B, T], input to decoder
            tgt_out: Tensor [B, T], target output for loss computation
            tgt_pad_mask: Tensor [B, T], padding mask for decoder input
            tgt_mask: Tensor [T, T], causal mask for decoder input
        """
        PAD, BOS, EOS, MASK = self.id_pad, self.id_bos, self.id_eos, self.id_mask
        device = next(self.parameters()).device      # always valid

        ids = self._tokenise(text)  # tokenize text to ids
        lengths = [len(s) for s in ids]
        B = len(ids)

        n_parallel = int(round(B*parallel_decoding_prob))
        perm = torch.randperm(B, device=device)
        parallel_ids = perm[:n_parallel]
        causal_ids = perm[n_parallel:]

        max_causal_len = max((lengths[idx] + 1 for idx in causal_ids), default=0)  # +1 for EOS
        max_parallel_len = max((lengths[idx] for idx in parallel_ids), default=0)
        max_len = max(max_causal_len, max_parallel_len)

        tgt_in = torch.full((B, max_len), PAD, dtype=torch.long, device=device)  # [B, max_len]
        tgt_out = torch.full_like(tgt_in, PAD)  # [B, max_len]
        tgt_pad_mask = torch.zeros_like(tgt_in, dtype=torch.bool)  # [B, max_len]

        if causal_ids.numel():
            tgt_in_causal = [torch.tensor([BOS] + ids[idx], dtype=torch.long) for idx in causal_ids]  # add BOS token
            tgt_out_causal = [torch.tensor(ids[idx] + [EOS], dtype=torch.long) for idx in causal_ids]
            tgt_in_causal = nn.utils.rnn.pad_sequence(tgt_in_causal, batch_first=True, padding_value=PAD).to(device)
            tgt_out_causal = nn.utils.rnn.pad_sequence(tgt_out_causal, batch_first=True, padding_value=PAD).to(device)

            T_causal = tgt_in_causal.size(1)
            tgt_in[causal_ids, :T_causal] = tgt_in_causal
            tgt_out[causal_ids, :T_causal] = tgt_out_causal

        if parallel_ids.numel():
            tgt_out_parallel = [torch.tensor(ids[idx], dtype=torch.long) for idx in parallel_ids]  # no EOS token
            tgt_in_parallel = [torch.full_like(t, MASK) for t in tgt_out_parallel]  # use MASK token
            tgt_out_parallel = nn.utils.rnn.pad_sequence(tgt_out_parallel, batch_first=True, padding_value=PAD).to(device)
            tgt_in_parallel = nn.utils.rnn.pad_sequence(tgt_in_parallel, batch_first=True, padding_value=PAD).to(device)

            T_parallel = tgt_out_parallel.size(1)
            tgt_in[parallel_ids, :T_parallel] = tgt_in_parallel
            tgt_out[parallel_ids, :T_parallel] = tgt_out_parallel

        tgt_pad_mask = tgt_in.eq(PAD)                   # True → ignore position
        
        per_sample_mask = torch.zeros((B, max_len, max_len), dtype=torch.bool, device=device)  # [B, T, T]
        if causal_ids.numel():
            per_sample_mask[causal_ids] = causal_mask(max_len, device=device)  # causal mask for causal_ids

        num_heads = self.text_decoder.decoder_nhead
        tgt_mask = per_sample_mask[:, None, :, :].repeat_interleave(num_heads, dim=1) # [B, num_heads, T, T]
        tgt_mask = tgt_mask.view(B*num_heads, max_len, max_len)  # [B, num_heads, T, T]

        return tgt_in, tgt_out, tgt_pad_mask, tgt_mask

    # --------------------------- Forward ---------------------------------- #
    def forward(
        self,
        x: torch.Tensor,
        x_lens: torch.Tensor,
        text: Optional[List[str]] = None,
    ):
        """
        Args:
            x: Tensor [B, T, D]
            x_lens: Tensor [B]
            text: List[str] or None
        Returns:
            logits: Tensor [B, T', vocab_size]
            loss: Tensor, if text is provided, otherwise None

        """
        # Encode audio features
        audio_embs, padding_mask = self.encode_audio(x, x_lens)

        if self.enc_dec_proj is not None:
            audio_embs = self.enc_dec_proj(audio_embs)

        # Prepare decoder input
        if self.training:
            assert text is not None, "Text must be provided during training" 
            tgt_in, tgt_out, tgt_pad_mask, tgt_mask = self._prepare_tgt(text, self.parallel_decoding_prob)
            decoder_out = self.text_decoder(
                tgt=tgt_in,
                memory=audio_embs,
                tgt_mask=tgt_mask,                      # causal mask
                tgt_key_padding_mask=tgt_pad_mask,      # padding mask for decoder input
                memory_key_padding_mask=padding_mask,   # padding mask for encoder output
            )
            logits = self.text_decoder.output_proj(decoder_out)
            loss = self.criterion(logits.view(-1, self.vocab_size), tgt_out.view(-1))
            return loss, None
        else:
            output_text = self._generate(audio_embs, padding_mask, force_causal=True)
            ce_loss = self._generate(audio_embs, padding_mask, force_causal=False, text=text)
            return ce_loss, output_text

    @torch.inference_mode()
    def _generate(self, audio_embs, padding_mask, force_causal, text=None, max_len=128):
        """
        Generate captions from audio embeddings.
        Args:
            audio_embs: Tensor [B, T, D], audio embeddings
            padding_mask: Tensor [B, T], padding mask for audio embeddings
            max_length: int, maximum length of generated captions
            is_causal: bool, whether to use causal mask for decoder
        Returns:
            generated_ids: Tensor [B, T'], generated token ids
        """
        if force_causal:
            return super()._generate(audio_embs, padding_mask, max_len)
        else:
            if text is None:
                raise ValueError("Text must be provided for non-causal generation")
            tgt_in, tgt_out, tgt_pad_mask, tgt_mask = self._prepare_tgt(text, 1.0) # use parallel decoding for non-causal generation
            decoder_out = self.text_decoder(
                tgt=tgt_in,
                memory=audio_embs,
                tgt_mask=tgt_mask,                      # parallel mask
                tgt_key_padding_mask=tgt_pad_mask,      # padding mask for decoder input
                memory_key_padding_mask=padding_mask,   # padding mask for encoder output
            )
            logits = self.text_decoder.output_proj(decoder_out)
            loss = self.criterion(logits.view(-1, self.vocab_size), tgt_out.view(-1))
            return loss  # return CE loss for non-causal generation
    
    def generate(self, input):
        if isinstance(input, tuple) and len(input) == 2:
            x, x_lens = input
        else:
            x, x_lens = self.extract_feature(input)

        audio_embs, padding_mask = self.encode_audio(x, x_lens)
        if self.enc_dec_proj is not None:
            audio_embs = self.enc_dec_proj(audio_embs)

        text = self._generate(audio_embs, padding_mask, force_causal=True)
        return text