
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import typing as tp
from dataclasses import dataclass, field
from typing import Callable, Dict, List, Optional, Tuple
from typing import Dict, Optional, Union
import torch
from torch import nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence, unpad_sequence
from .utils.common import IGNORE_ID
from .label_smoothing_loss import LabelSmoothingLoss, BalancedLabelSmoothingLoss
from ...utils.common import th_accuracy
from ...utils.attn_mask import *
from .transformer.encoder import ConformerEncoder, TransformerEncoder
from fairseq.dataclass import ChoiceEnum, FairseqDataclass
from fairseq.models import BaseFairseqModel, register_model
from hyperpyyaml import load_hyperpyyaml
import numpy as np
from .cli.frontend import CosyVoiceFrontEnd
from .flow.flow import MaskedDiffWithXvec
from .flow.length_regulator import InterpolateRegulator
from .flow.flow_matching import ConditionalCFM
from .flow.decoder import ConditionalDecoder
from .hifigan.generator import HiFTGenerator
from .hifigan.f0_predictor import ConvRNNF0Predictor
from omegaconf import DictConfig
from .llm.llm import Qwen2Encoder, Qwen2LM, Qwen2ForCausalLM, make_pad_mask
from .llm.qwen2 import Qwen2DecoderLayer, Qwen2RMSNorm, Qwen2PreTrainedModel, Cache, DynamicCache, _prepare_4d_causal_attention_mask_for_sdpa, _prepare_4d_causal_attention_mask, BaseModelOutputWithPast, Qwen2NAREncoder
from .utils.common import fade_in_out
import uuid
from contextlib import nullcontext
import threading
import time
from .utils.common import ras_sampling
from copy import deepcopy

import torch
import torch.nn.functional as F

def check_target_ratio(tp_tokens, end_idx, target_token_list):
    """
    统计 tp_tokens[:end_idx] 中 target token 的占比，超过50%返回 None，否则返回 tp_tokens[:end_idx]
    """
    target_set = set(target_token_list)
    segment = tp_tokens[:end_idx]
    total_count = len(segment)
    if total_count == 0:
        return None  # 空直接返回 None

    # 统计 target token 数量
    target_count = sum(1 for token in segment if token.item() in target_set)

    # 判断占比
    if target_count / total_count > 0.7:
        return None
    else:
        return segment

def biased_cross_attention(query, key, value, bias_vector):
    """
    query: (batch_size, 1, d_model)
    key: (batch_size, seq_len, d_model)
    value: (batch_size, seq_len, d_model)
    bias_vector: (batch_size, seq_len)  # ✅ 每个 batch 一组 bias
    """
    batch_size, seq_len, d_model = key.size()

    # Step 1: compute raw attention scores
    attn_scores = torch.matmul(query, key.transpose(-2, -1)) / d_model ** 0.5  # (batch_size, 1, seq_len)

    # Step 2: softmax to get base attention probs
    attn_probs = F.softmax(attn_scores, dim=-1)  # (batch_size, 1, seq_len)

    # Step 3: apply bias vector
    bias_vector = bias_vector.unsqueeze(1)  # (batch_size, 1, seq_len)
    biased_attn_probs = attn_probs * bias_vector  # element-wise multiply

    # Step 4: normalize (sum to 1)
    biased_attn_probs = biased_attn_probs / (biased_attn_probs.sum(dim=-1, keepdim=True) + 1e-8)

    # Step 5: apply attention to values
    attn_output = torch.matmul(biased_attn_probs, value)  # (batch_size, 1, d_model)

    return attn_output, biased_attn_probs.squeeze(1)  # return: (batch_size, d_model), (batch_size, seq_len)


def beam_search_from_logits(logits, beam_width=3, end_token=None):
    """
    针对 encoder 输出 (1, T, n_class) 的 beam search 解码方法
    logits: torch.Tensor, shape (1, T, n_class)
    beam_width: beam size
    end_token: 可选，遇到end_token提前结束（如果不需要提前结束，传None）

    返回：最佳序列 token 列表
    """
    batch_size, T, n_class = logits.size()
    assert batch_size == 1, "当前只支持 batch_size=1"

    # log_softmax 以避免数值不稳定
    log_probs = F.log_softmax(logits, dim=-1).squeeze(0)  # (T, n_class)

    # beam 每一项格式: (token序列, 累积score)
    beams = [([], 0.0)]

    for t in range(T):
        new_beams = []
        for seq, score in beams:
            # 取当前时间步的 topk
            top_log_probs, top_tokens = torch.topk(log_probs[t], beam_width)
            for i in range(beam_width):
                token = top_tokens[i].item()
                token_score = top_log_probs[i].item()
                new_seq = seq + [token]
                new_score = score + token_score
                new_beams.append((new_seq, new_score))

        # 选出分数最高的 beam_width 个
        beams = sorted(new_beams, key=lambda x: x[1], reverse=True)[:beam_width]

        # 如果设置了end_token且 beam里全结束了，提前结束
        if end_token is not None and all(seq and seq[-1] == end_token for seq, _ in beams):
            break

    # 取最终分数最高的序列
    best_seq, best_score = beams[0]
    return best_seq


def argmax_with_block_penalty(logits, target_token_id, repeat_threshold=20, penalty_scale=0.5):
    """
    logits: (B, T, n_class)
    - target_token_id: 要处理的token
    - repeat_threshold: 连续块超过多少次开始惩罚
    - penalty_scale: 惩罚比例，结合原logit值
    """
    B, T, n_class = logits.size()
    pred_tokens = torch.argmax(logits, dim=-1)  # (B, T)
    penalty_mask = torch.zeros_like(logits)

    for b in range(B):
        count = 0
        start_idx = None
        for t in range(T):
            if pred_tokens[b, t] == target_token_id:
                if start_idx is None:
                    start_idx = t
                count += 1
            else:
                # 当前块结束，判断是否需要惩罚
                if count >= repeat_threshold:
                    for i, t_idx in enumerate(range(start_idx, start_idx + count)):
                        decay = (count - i) / count  # 块内越靠前惩罚越大
                        base_score = logits[b, t_idx, target_token_id].item()
                        dynamic_penalty = penalty_scale * base_score * decay
                        penalty_mask[b, t_idx, target_token_id] = dynamic_penalty
                # reset
                count = 0
                start_idx = None
        # 如果序列最后一块刚好结束在最后
        if count >= repeat_threshold:
            for i, t_idx in enumerate(range(start_idx, start_idx + count)):
                decay = (count - i) / count
                base_score = logits[b, t_idx, target_token_id].item()
                dynamic_penalty = penalty_scale * base_score * decay
                penalty_mask[b, t_idx, target_token_id] = dynamic_penalty

    # 施加惩罚
    adjusted_logits = logits - penalty_mask
    final_pred = torch.argmax(adjusted_logits, dim=-1)
    return final_pred


def argmax_with_blocklist_penalty(logits, target_token_list, repeat_threshold=20, penalty_scale=0.5):
    """
    logits: (B, T, n_class)
    - target_token_list: 需要处理的 token 列表
    - repeat_threshold: 连续块超过多少次开始惩罚
    - penalty_scale: 惩罚比例，结合原logit值
    """
    target_set = set(target_token_list)
    B, T, n_class = logits.size()
    pred_tokens = torch.argmax(logits, dim=-1)  # (B, T)
    penalty_mask = torch.zeros_like(logits)

    for b in range(B):
        count = 0
        start_idx = None
        for t in range(T):
            token = pred_tokens[b, t].item()
            if token in target_set:
                if start_idx is None:
                    start_idx = t
                count += 1
            else:
                # 当前块结束，判断是否需要惩罚
                if count >= repeat_threshold:
                    for i, t_idx in enumerate(range(start_idx, start_idx + count)):
                        decay = (count - i) / count  # 块内越靠前惩罚越大
                        token_id = pred_tokens[b, t_idx].item()
                        if token_id in target_set:
                            base_score = logits[b, t_idx, token_id].item()
                            dynamic_penalty = penalty_scale * base_score * decay
                            penalty_mask[b, t_idx, token_id] = dynamic_penalty
                # reset
                count = 0
                start_idx = None
        # 如果最后一块到结尾
        if count >= repeat_threshold:
            for i, t_idx in enumerate(range(start_idx, start_idx + count)):
                decay = (count - i) / count
                token_id = pred_tokens[b, t_idx].item()
                if token_id in target_set:
                    base_score = logits[b, t_idx, token_id].item()
                    dynamic_penalty = penalty_scale * base_score * decay
                    penalty_mask[b, t_idx, token_id] = dynamic_penalty

    # 施加惩罚
    adjusted_logits = logits - penalty_mask
    final_pred = torch.argmax(adjusted_logits, dim=-1)
    return final_pred



class BatchNormConv1d(nn.Module):
    def __init__(self, in_dim, out_dim, kernel_size, stride, padding,
                 activation=None):
        super().__init__()
        self.conv1d = nn.Conv1d(in_dim, out_dim, kernel_size=kernel_size, stride=stride, padding=padding, bias=False)
        self.bn = nn.BatchNorm1d(out_dim)
        self.activation = activation

    def forward(self, x):
        x = self.conv1d(x)
        if self.activation is not None:
            x = self.activation(x)
        return self.bn(x)

class Predictor(nn.Module):
    def __init__(self, input_dim, hidden_dim, ker, text_class):
        super().__init__()
        self.convs = nn.ModuleList([
            BatchNormConv1d(input_dim, hidden_dim, ker, 1, ker//2),
            BatchNormConv1d(hidden_dim, hidden_dim, ker, 1, ker//2),
        ])
        self.text_fc_out = nn.Linear(hidden_dim, text_class)
        
    def forward(self, x):
        # x: B, T, C
        x = x.transpose(1, 2)
        for conv in self.convs:
            x = conv(x)  # 卷积不处理 mask，先算
        text_out = self.text_fc_out(x.transpose(1, 2))  # [B, T, out_class]
        return text_out

@dataclass
class Cosyvoice2Config(FairseqDataclass):
    text_encoder_input_size: int = field(default=12)
    llm_input_size: int = field(default=12)
    llm_output_size: int = field(default=12)
    text_token_size: int = field(default=12)
    speech_token_size: int = field(default=12)
    length_normalized_loss: bool = field(default=False)
    lsm_weight: float = field(default=0.0)
    spk_embed_dim: int = field(default=192)
    
    tts_freeze: bool = field(default=False)
    
    lora_rank: int = field(default=-1)
    aligner_layer: int = field(default=-1)
    aligner_convdim: int = field(default=512)
    aligner_convker: int = field(default=5)
    
    partial_train: bool = field(
        default=False
    )
    partial_layers: str = field(
        default=""
    )
    qwen_pretrained_path: str = field(
        default=""
    )
    dyn_ab_num_patterns: int = field(default=0)
    dyn_emo_dim: int = field(default=0)


class TPModule(Qwen2PreTrainedModel):
    def __init__(self, config):
        super().__init__(config)
        self.layers = nn.ModuleList(
            [Qwen2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
        )
        self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self._attn_implementation = config._attn_implementation
        self.post_init()
        
    def forward(
        self,
        attention_mask: Optional[torch.Tensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        use_cache: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        position_ids=None
    ):
        use_cache = use_cache if use_cache is not None else self.config.use_cache
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
        batch_size, seq_length, _ = inputs_embeds.shape
        past_key_values_length = 0
        if use_cache:
            use_legacy_cache = not isinstance(past_key_values, Cache)
            if use_legacy_cache:
                past_key_values = DynamicCache.from_legacy_cache(
                    past_key_values
                )
            past_key_values_length = past_key_values.get_usable_length(seq_length)

        if attention_mask is not None and self._attn_implementation == "flash_attention_2" and use_cache:
            is_padding_right = attention_mask[:, -1].sum().item() != batch_size
            if is_padding_right:
                raise ValueError(
                    "You are attempting to perform batched generation with padding_side='right'"
                    " this may lead to unexpected behaviour for Flash Attention version of Qwen2. Make sure to "
                    " call `tokenizer.padding_side  = 'left'` before tokenizing the input. "
                )

        if self._attn_implementation == "flash_attention_2":
            # 2d mask is passed through the layers
            attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
        elif self._attn_implementation == "sdpa" and not output_attentions:
            # output_attentions=True can not be supported when using SDPA, and we fall back on
            # the manual implementation that requires a 4D causal mask in all cases.
            attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
                attention_mask,
                (batch_size, seq_length),
                inputs_embeds,
                past_key_values_length,
            )
        else:
            # 4d mask is passed through the layers
            attention_mask = _prepare_4d_causal_attention_mask(
                attention_mask,
                (batch_size, seq_length),
                inputs_embeds,
                past_key_values_length,
                sliding_window=self.config.sliding_window,
            )

        hidden_states = inputs_embeds
        next_decoder_cache = None

        for decoder_layer in self.layers:
            layer_outputs = decoder_layer(
                hidden_states,
                attention_mask=attention_mask,
                position_ids=position_ids,
                past_key_value=past_key_values,
                output_attentions=output_attentions,
                use_cache=use_cache,
            )
            hidden_states = layer_outputs[0]

            if use_cache:
                next_decoder_cache = layer_outputs[2 if output_attentions else 1]

        hidden_states = self.norm(hidden_states)
        next_cache = None
        if use_cache:
            next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
        return BaseModelOutputWithPast(
            last_hidden_state=hidden_states,
            past_key_values=next_cache,
            hidden_states=(),
            attentions=(),
        )

try:
    @register_model("cosyvoice2_inneremo_causalalign_dynattb", dataclass=Cosyvoice2Config)
    class CosyVoice2Origin(BaseFairseqModel):
        def __init__(self, cfg:Cosyvoice2Config, task):
            super().__init__()
            self.partial_train = cfg.partial_train
            self.cfg = cfg
            if os.environ.get("DYNMASK") is None:
                # 如果环境变量未设置，手动赋值
                os.environ["DYNMASK"] = cfg.dyn_ab_num_patterns
                if os.environ.get("EMODIM") is None:
                    # 如果环境变量未设置，手动赋值
                    os.environ["EMODIM"] = cfg.dyn_emo_dim
            self.lora_rank = cfg.lora_rank
            self.llm_input_size = self.cfg.llm_input_size
            self.speech_token_size = self.cfg.speech_token_size
            # 2. build speech token language model related modules
            self.sos_eos = 0
            self.task_id = 1
            self.fill_token = 2
            self.llm_embedding = torch.nn.Embedding(2, self.cfg.llm_input_size)
            self.inneremo_emo_embedding = torch.nn.Embedding(5, self.cfg.llm_input_size)
            if not os.path.exists(cfg.qwen_pretrained_path):
                cfg.qwen_pretrained_path = os.environ.get("COSYVOICE2HOME") + "/CosyVoice-BlankEN"
            self.llm = Qwen2Encoder(
                cfg.qwen_pretrained_path,
                lora_rank=self.cfg.lora_rank,
            )
            self.llm_decoder = nn.Linear(self.cfg.llm_output_size, self.cfg.speech_token_size + 3)
            self.criterion_ce = LabelSmoothingLoss(
                size=self.cfg.speech_token_size + 3,
                padding_idx=IGNORE_ID,
                smoothing=self.cfg.lsm_weight,
                normalize_length=self.cfg.length_normalized_loss,
            )
            
            # aligner
            # print(task.datasets)
            self.text_pad_idx = 50257
            aligner_config = deepcopy(self.llm.model.preset_config)
            aligner_config._name_or_path = ""
            aligner_config.num_hidden_layers = cfg.aligner_layer
            self.aligner_fusion = nn.Linear(cfg.dyn_emo_dim+self.cfg.llm_output_size, self.cfg.llm_output_size)
            self.aligner = Qwen2ForCausalLM(
                aligner_config,
            )
            self.aligner_decoder = nn.Linear(self.cfg.llm_output_size, 5)
            
            self.aligner_criterion_ce = LabelSmoothingLoss(
                size=5,
                padding_idx=IGNORE_ID,
                smoothing=self.cfg.lsm_weight,
                normalize_length=self.cfg.length_normalized_loss,
            )
            
            # 3. [Optional] build speech token related modules
            self.speech_embedding = torch.nn.Embedding(self.cfg.speech_token_size + 3, self.cfg.llm_input_size)
            self.hift = None
            self.flow = None
            self.sample_rate = 24000
            # infer
            self.mel_cache_len = 20
            self.tts_freeze = self.cfg.tts_freeze
            self.update_num = 0

        def set_num_updates(self, num_updates):
            super().set_num_updates(num_updates)
            self.update_num = num_updates
        
        @classmethod
        def build_model(cls, cfg: Cosyvoice2Config, task):
            """Build a new model instance."""
            model = CosyVoice2Origin(cfg, task)
            if cfg.tts_freeze:
                for name, param in model.named_parameters():
                    if name.startswith("aligner") or name.startswith("inneremo"):
                        param.requires_grad = True
                    else:
                        param.requires_grad = False
            return model
        
        def encode(
                self,
                text: torch.Tensor,
                text_lengths: torch.Tensor,
        ):
            encoder_out, encoder_mask = self.text_encoder(text, text_lengths, decoding_chunk_size=1, num_decoding_left_chunks=-1)
            encoder_out_lens = encoder_mask.squeeze(1).sum(1)
            encoder_out = self.text_encoder_affine_layer(encoder_out)
            return encoder_out, encoder_out_lens

        def pad_unpad_sequence(self, sos_eos_emb, 
                               text_token, text_token_len, 
                               prompt_texts, prompt_text_token_len,
                               task_id_emb, 
                               speech_token, speech_token_len, 
                               prompt_speech_token, prompt_speech_token_len, 
                               emos, iter_num,
                               tgt_emos_prompt, tgt_emos):
            B, _, _, D = text_token.size()
            speech_token = unpad_sequence(speech_token, speech_token_len.cpu(), batch_first=True)
            tgt_emos = unpad_sequence(tgt_emos, speech_token_len.cpu(), batch_first=True)
            prompt_speech_token = unpad_sequence(prompt_speech_token, prompt_speech_token_len.cpu(), batch_first=True)
            tgt_emos_prompt = unpad_sequence(tgt_emos_prompt, prompt_speech_token_len.cpu(), batch_first=True)
            # emos: B,5,896  prompt_texts:B,5,13,896
            lm_input = []
            text_side_lens = []
            emo_input = []
            for i in range(B):
                text_side_len = 0
                codec_side_len = 0
                temp = [sos_eos_emb.squeeze(dim=0)]
                emo_temp = [sos_eos_emb.squeeze(dim=0)]
                for j in range(iter_num[i]):
                    temp.append(torch.cat([
                        emos[i, j:j+1, :],
                        prompt_texts[i, j, :prompt_text_token_len[i, j], :],
                    ], dim=0))
                    text_side_len += prompt_text_token_len[i, j] + 1
                    emo_temp.append(
                        emos[i, j:j+1, :].expand(prompt_text_token_len[i, j] + 1, -1)
                    )
                for j in range(iter_num[i]):
                    temp.append(torch.cat([
                        emos[i, j:j+1, :],
                        text_token[i, j, :text_token_len[i, j], :],
                    ], dim=0))
                    text_side_len += text_token_len[i, j] + 1
                    emo_temp.append(
                        emos[i, j:j+1, :].expand(text_token_len[i, j] + 1, -1)
                    )
                emo_temp.append(task_id_emb.squeeze(dim=0))
                emo_temp.append(tgt_emos_prompt[i])
                emo_temp.append(tgt_emos[i])
                
                temp.append(task_id_emb.squeeze(dim=0))
                temp.append(prompt_speech_token[i])
                temp.append(speech_token[i])
                
                lm_input.append(torch.cat(temp, dim=0))
                emo_input.append(torch.cat(emo_temp, dim=0))
                text_side_lens.append(text_side_len)
            
            lm_input_len = torch.tensor([i.size(0) for i in lm_input], dtype=torch.int32)
            lm_input = pad_sequence(lm_input, batch_first=True, padding_value=IGNORE_ID)
            emo_input = pad_sequence(emo_input, batch_first=True, padding_value=IGNORE_ID)
            return lm_input, lm_input_len, text_side_lens, emo_input

        def pad_unpad_sequence_infer(self, sos_eos_emb, 
                               text_token, text_token_len, 
                               prompt_texts, prompt_text_token_len,
                               task_id_emb, 
                               prompt_speech_token, prompt_speech_token_len, 
                               emos, iter_num,
                               tgt_emos_prompt):
            B, _, _, D = text_token.size()
            lm_input = []
            text_side_lens = []
            emo_input = []
            for i in range(B):
                text_side_len = 0
                codec_side_len = 0
                temp = [sos_eos_emb.squeeze(dim=0)]
                emo_temp = [sos_eos_emb.squeeze(dim=0)]
                for j in range(iter_num[i]):
                    temp.append(torch.cat([
                        emos[i, j:j+1, :],
                        prompt_texts[i, j, :prompt_text_token_len[i, j], :],
                    ], dim=0))
                    text_side_len += prompt_text_token_len[i, j] + 1
                    emo_temp.append(
                        emos[i, j:j+1, :].expand(prompt_text_token_len[i, j] + 1, -1)
                    )
                for j in range(iter_num[i]):
                    temp.append(torch.cat([
                        emos[i, j:j+1, :],
                        text_token[i, j, :text_token_len[i, j], :],
                    ], dim=0))
                    text_side_len += text_token_len[i, j] + 1
                    emo_temp.append(
                        emos[i, j:j+1, :].expand(text_token_len[i, j] + 1, -1)
                    )
                emo_temp.append(task_id_emb.squeeze(dim=0))
                emo_temp.append(tgt_emos_prompt[i])
                
                temp.append(task_id_emb.squeeze(dim=0))
                temp.append(prompt_speech_token[i])
                
                lm_input.append(torch.cat(temp, dim=0))
                emo_input.append(torch.cat(emo_temp, dim=0))
                text_side_lens.append(text_side_len)
            
            lm_input_len = torch.tensor([i.size(0) for i in lm_input], dtype=torch.int32)
            lm_input = pad_sequence(lm_input, batch_first=True, padding_value=IGNORE_ID)
            emo_input = pad_sequence(emo_input, batch_first=True, padding_value=IGNORE_ID)
            return lm_input, lm_input_len, text_side_lens, emo_input

        def tts_forward(self, batch):
            tgt_texts = batch['tgt_texts']
            tgt_text_lens = batch['tgt_text_lens']
            tgt_codecs = batch['tgt_codecs']
            tgt_codecs_lens = batch['tgt_codecs_lens']
            
            prompt_texts = batch['prompt_texts']
            prompt_text_lens = batch['prompt_text_lens']
            prompt_codecs = batch['prompt_codecs']
            prompt_codecs_lens = batch['prompt_codecs_lens']
            
            emos = batch['emos']
            tgt_emos = batch['tgt_emos']
            iter_num = batch['iter_num']
            
            tgt_emos_prompt = batch['tgt_emos_prompt']
            attn_bias = batch['all_bias']

            device = tgt_texts.device
            
            text = self.llm.model.model.embed_tokens(tgt_texts)
            prompt_texts = self.llm.model.model.embed_tokens(prompt_texts)
            
            sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1)
            task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
            
            speech_token = self.speech_embedding(tgt_codecs)
            prompt_speech_token = self.speech_embedding(prompt_codecs)
            emos = self.inneremo_emo_embedding(emos)
            tgt_emos_prompt_emb = self.inneremo_emo_embedding(tgt_emos_prompt)
            tgt_emos_emb = self.inneremo_emo_embedding(tgt_emos)
            
            # 5. unpad and pad
            lm_input, lm_input_len, text_side_lens, emotion_feature = self.pad_unpad_sequence(
                sos_eos_emb, text, tgt_text_lens, 
                prompt_texts, prompt_text_lens,
                task_id_emb, 
                speech_token, tgt_codecs_lens,
                prompt_speech_token, prompt_codecs_lens, 
                emos, iter_num,
                tgt_emos_prompt_emb, tgt_emos_emb
            )
            
            # 1. prepare llm_target
            with torch.no_grad():
                codec_target = [
                    torch.tensor(
                        [IGNORE_ID] * (text_side_lens[i] + 1 + prompt_codecs_lens[i]).cpu().item() + 
                        tgt_codecs[i, :tgt_codecs_lens[i]].tolist() + 
                        [self.speech_token_size]
                        ) for i in range(lm_input.size(0)
                    )
                ]
                
                emo_target = [
                    torch.tensor(
                        [IGNORE_ID] * (text_side_lens[i] + 1 + prompt_codecs_lens[i]).cpu().item() + 
                        tgt_emos[i, :tgt_codecs_lens[i]].tolist() + 
                        [IGNORE_ID]
                        ) for i in range(tgt_emos.size(0)
                    )
                ]
                codec_target = pad_sequence(
                    codec_target, batch_first=True, padding_value=IGNORE_ID
                ).to(device)
                emo_target = pad_sequence(
                    emo_target, batch_first=True, padding_value=IGNORE_ID
                ).to(device)
            lm_output, emotion_feature = self.llm(
                lm_input, lm_input_len.to(device), 
                dyn_attn_bias=attn_bias.to(device).to(lm_input.dtype),
                emotion_feature=emotion_feature.to(device).to(lm_input.dtype)
            )
            logits = self.llm_decoder(lm_output) 
            return lm_output, logits, codec_target, emo_target, lm_input_len, emotion_feature
        
        def forward(
                self,
                batch: dict,
        ) -> Dict[str, Optional[torch.Tensor]]:
            tgt_texts = batch['tgt_texts']
            device = tgt_texts.device
            
            lm_output, logits, codec_target, emo_target, lm_input_len, emotion_feature = self.tts_forward(batch)
                
            # logits: B, T, n_class 
            # lm_target: B, T
            # print(logits.size(), codec_target.size())
            codec_loss = self.criterion_ce(logits, codec_target)
            with torch.no_grad():
                codec_acc = th_accuracy(logits.view(-1, self.speech_token_size + 3), codec_target, ignore_label=IGNORE_ID)
        
            # print(lm_output.size())
            # tp
            T = lm_output.size(1)
            masks = ~make_pad_mask(lm_input_len.to(device), T)
            
            inp = self.aligner_fusion(torch.cat([lm_output,emotion_feature], dim=-1))
            lm_output, _ = self.aligner(
                inputs_embeds=inp, 
                attention_mask=masks,
                return_dict=True,
                output_hidden_states=True,
            )
            emo_logits = self.aligner_decoder(lm_output.hidden_states[-1])
            emo_loss = self.aligner_criterion_ce(emo_logits, emo_target)
            with torch.no_grad():
                emo_acc = th_accuracy(emo_logits.view(-1, 5), emo_target, ignore_label=IGNORE_ID)
            return {
                'codec_loss': codec_loss, 
                'codec_acc': codec_acc,
                'emo_loss': emo_loss,
                "emo_acc": emo_acc,
                "loss": emo_loss + codec_loss,
            }

        def sampling_ids(
                self,
                weighted_scores: torch.Tensor,
                decoded_tokens: List,
                sampling: int,
                ignore_eos: bool = True,
        ):
            while True:
                top_ids = self.sampling(weighted_scores, decoded_tokens, sampling)
                if (not ignore_eos) or (self.speech_token_size not in top_ids):
                    break
            return top_ids
        
        def sampling_tokens(
                self,
                weighted_scores: torch.Tensor,
                decoded_tokens: List,
                sampling: int,
                ignore_eos: bool = True,
        ):
            while True:
                top_ids = self.sampling(weighted_scores, decoded_tokens, sampling)
                if (not ignore_eos) or (self.speech_token_size not in top_ids):
                    break
            return top_ids

        def repetition_penalty(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, penalty=1.2):
            score = torch.gather(scores, 1, input_ids)
            # if score < 0 then repetition penalty has to be multiplied to reduce the token probabilities
            score = torch.where(score < 0, score * penalty, score / penalty)
            scores_processed = scores.scatter(1, input_ids, score)
            return scores_processed
        
        @torch.inference_mode()
        def inference(
                self,
                batch,
                sampling: int = 25,
                max_token_text_ratio: float = 20,
                min_token_text_ratio: float = 2,
                a_h_n_sa_su_text_lens=None,
                a_h_n_sa_su_st_lens=None,
                text_control=(-1, -1),
                last_item=False,
                lang=""
        ):
            tgt_texts = batch['tgt_texts']
            tgt_text_lens = batch['tgt_text_lens']
            
            prompt_texts = batch['prompt_texts']
            prompt_text_lens = batch['prompt_text_lens']
            prompt_codecs = batch['prompt_codecs']
            prompt_codecs_lens = batch['prompt_codecs_lens']
            
            emos = batch['emos']
            tgt_emos = batch['tgt_emos']
            iter_num = batch['iter_num']
            
            tgt_emos_prompt = batch['tgt_emos_prompt']
            attn_bias = batch['all_bias']
            
            text = self.llm.model.model.embed_tokens(tgt_texts)
            prompt_texts = self.llm.model.model.embed_tokens(prompt_texts)
            device = tgt_texts.device

            # 2. encode embedding
            embedding = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device)

            # 3. concat llm_input
            sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1)
            task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
            prompt_speech_token = self.speech_embedding(prompt_codecs)
            emos = self.inneremo_emo_embedding(emos)
            tgt_emos_prompt_emb = self.inneremo_emo_embedding(tgt_emos_prompt)
            
            lm_input, lm_input_len, text_side_lens, emotion_feature = self.pad_unpad_sequence_infer(
                sos_eos_emb, text, tgt_text_lens, 
                prompt_texts, prompt_text_lens,
                task_id_emb, 
                prompt_speech_token, prompt_codecs_lens, 
                emos, iter_num,
                tgt_emos_prompt_emb
            )

            # 4. cal min/max_length
            min_len = 0
            max_len = int(torch.sum(tgt_text_lens).item() * max_token_text_ratio)
            # 5. step by step decode
            out_tokens = []
            speech_emos = []
            cache = None
            cache_align = None
            y_preds = []
            # end = time.perf_counter()
            # print(f"infer内部进入循环前耗时: {end - start:.6f} 秒")
            # start = time.perf_counter()
            
            ###
            text_side_prompt4mask = batch["text_side_prompt"]
            text_side_tgt4mask = batch["text_side_tgt"]
            speech_side_prompt4mask = batch["speech_side_prompt"]
            speech_side_speech4mask = batch["speech_side_speech"]
            tgt_emos4mask = batch["tgt_emos"]
            emo_state = {"cur_idx": -1, "prev": None}
            attention_bias = "0_1_2_3_4_5_6"
            attention_range = [0.1, 5.0]
            atten_bias_func = {
                "0": lower_triangle,
                "1": tgt_st_paired_emo_and_all_tt,
                "2": all_paired_emo,
                "3": all_st_paired_emo,
                "4": tgt_st_paired_emo,
                "5": st_paired_emo_tgt_st_all_tt,
                "6": all_st_paired_emo_and_all_tt,
            }
            ###
            for i in range(max_len):
                temp_mask = torch.tril(torch.ones((1, lm_input.shape[1], lm_input.shape[1]), device=lm_input.device)).to(torch.bool)
                # print(lm_input.size(), emotion_feature.size(), i)
                y_pred, cache, emotion_feature = self.llm.forward_one_step(
                    lm_input,
                    masks=temp_mask,
                    cache=cache,
                    dyn_attn_bias=attn_bias.to(device).to(lm_input.dtype),
                    emotion_feature=emotion_feature.to(device).to(lm_input.dtype),
                    a_h_n_sa_su_text_lens=a_h_n_sa_su_text_lens,
                    a_h_n_sa_su_st_lens=a_h_n_sa_su_st_lens,
                    text_control=text_control
                )
                y_preds.append(y_pred)
                logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
                top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True if i < min_len else False).item()
                if top_ids == self.speech_token_size:
                    break
                if top_ids > self.speech_token_size:
                    continue
                # in stream mode, yield token one by one
                # yield top_ids
                out_tokens.append(top_ids)
                lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)
                # print(y_pred.size(), emotion_feature.size(), i)
                inp = self.aligner_fusion(torch.cat([y_pred, emotion_feature], dim=-1))
                lm_output, cache_align, _ = self.aligner.forward_one_step(
                    inp, 
                    masks=temp_mask,
                    cache=cache_align,
                )
                emo_logits = self.aligner_decoder(lm_output[:, -1])
                emo_flag = torch.argmax(emo_logits, dim=-1)
                speech_emos.append(emo_flag.item())
                speech_side_speech4mask = update_res_by_sequence(speech_side_speech4mask, speech_emos, emo_state)
                emotion_feature = self.inneremo_emo_embedding(emo_flag).reshape(1, 1, -1)
                attn_bias, _, _, _, _ = get_attn_bias(
                    batch["prompt_text_lens4ab"], 
                    batch["text_lens4ab"], 
                    batch["prompt_codec_lens4ab"], 
                    batch["tgt_codec_lens4ab"],
                    attention_bias=attention_bias,
                    atten_bias_func=atten_bias_func,
                    attention_range=attention_range,
                    text_side_prompt=text_side_prompt4mask,
                    text_side_tgt=text_side_tgt4mask,
                    speech_side_prompt=speech_side_prompt4mask,
                    speech_side_speech=speech_side_speech4mask
                )
                all_bias = []
                for key in range(7):
                    key = str(key)
                    all_bias.append(attn_bias[key][-1:, :])
                attn_bias = torch.stack(all_bias, 0).unsqueeze(0)
            print(attn_bias)
            return out_tokens
        
        def unfreeze_partial(self):
            if self.cfg.partial_layers != "":
                train_layers = [int(temp) for temp in self.cfg.partial_layers.split("-")]
                print(train_layers)
                parameters_names = []
                for name, param in self.named_parameters():
                    print(name)
                    train = False
                    for idx in train_layers:
                        if "llm.encoders.%d."%idx in name:
                            train=True
                    if train:
                        parameters_names.append(name)
                    else:
                        param.requires_grad = False
                print("trainable: "+str(parameters_names))

        def unfreeze_lora(self):
            import loralib as lora
            lora.mark_only_lora_as_trainable(self, bias='lora_only')
            count_parameters(self)
            
        def init_infer_modules(self, 
                            device,
                            model_dir="", 
                            instruct=True,
                            fm_model=None,
                            text_frontend=None):
            model_dir = os.environ.get("COSYVOICE2HOME", "./pretrained_models/CosyVoice2-0.5B/")
            with open('{}/cosyvoice.yaml'.format(model_dir), 'r') as f:
                configs = load_hyperpyyaml(f)
            print(configs)
            self.device = device
            self.frontend = CosyVoiceFrontEnd('{}/campplus.onnx'.format(model_dir),
                                            '{}/speech_tokenizer_v2.onnx'.format(model_dir),
                                            '{}/spk2info.pt'.format(model_dir),
                                            instruct,
                                            configs['allowed_special'],
                                            text_frontend=text_frontend,
                                            device=device)
            if fm_model is not None:
                self.flow = fm_model
                self.flow.to(device).eval()
                self.our_fm = True
            else:
                self.flow = configs["flow"]
                count_parameters(self)
                flow_state_dict = {k.replace('generator.', ''): v for k, v in torch.load('{}/flow.pt'.format(model_dir), map_location=self.device).items()}
                self.flow.load_state_dict(flow_state_dict, strict=True)
                self.flow.to(device).eval()
                self.our_fm = False
            
            self.hift = configs["hift"]
            self.hift.load_state_dict(torch.load('{}/hift.pt'.format(model_dir), map_location=device), strict=True)
            self.hift.to(device).eval()
            self.eval()
            self.device = device
            self.model_dir = model_dir
            self.lock = threading.Lock()
            del configs
            
            # 
            self.token_min_hop_len = 2 * self.flow.input_frame_rate
            self.token_max_hop_len = 4 * self.flow.input_frame_rate
            self.token_overlap_len = 20
            # here we fix set flow.decoder.estimator.static_chunk_size = 0 for compatibability
            if self.our_fm:
                self.flow.causal_masked_diff.decoder.estimator.static_chunk_size = 0
                self.flow.causal_masked_diff.decoder.fp16 = False
            else:
                self.flow.decoder.estimator.static_chunk_size = 0
                self.flow.decoder.fp16 = False
            # mel fade in out
            self.mel_overlap_len = int(self.token_overlap_len / self.flow.input_frame_rate * 22050 / 256)
            self.mel_window = np.hamming(2 * self.mel_overlap_len)
            # hift cache
            self.mel_cache_len = 20
            self.source_cache_len = int(self.mel_cache_len * 256)
            # speech fade in out
            self.speech_window = np.hamming(2 * self.source_cache_len)
            # rtf and decoding related
            self.stream_scale_factor = 1
            assert self.stream_scale_factor >= 1, 'stream_scale_factor should be greater than 1, change it according to your actual rtf'
            # self.llm_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext()
            # dict used to store session related variable
            self.tts_speech_token_dict = {}
            self.llm_end_dict = {}
            self.mel_overlap_dict = {}
            self.flow_cache_dict = {}
            self.hift_cache_dict = {}
            self.sampling = ras_sampling
            
        # def inference_st(self, 
        #                 batch,
        #                 **kwargs):
        #     this_tts_speech_token, last_text, last_st, tp_tokens = self.inference(batch)
        #     return this_tts_speech_token, last_text, last_st, tp_tokens
            
        def generate_speech(self,
                            this_tts_speech_token,
                            flow_prompt_speech_token,
                            prompt_speech_feat,
                            flow_embedding,
                            speed=1.0,
                            **kwargs):
            this_tts_speech_token = torch.tensor(this_tts_speech_token).unsqueeze(dim=0)
            this_tts_speech = self.token2wav(token=this_tts_speech_token,
                                            prompt_token=flow_prompt_speech_token,
                                            prompt_feat=prompt_speech_feat,
                                            embedding=flow_embedding,
                                            token_offset=0,
                                            finalize=True,
                                            speed=speed).cpu()
            return this_tts_speech
        
        def inference_whole(self, 
                            text, 
                            flow_embedding, 
                            llm_embedding=torch.zeros(0, 192),
                            prompt_text=torch.zeros(1, 0, dtype=torch.int32),
                            llm_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
                            flow_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
                            prompt_speech_feat=torch.zeros(1, 0, 80), speed=1.0, 
                            prompt_sp = 1,
                            return_speech = False,
                            **kwargs):
            if prompt_sp != 1.0:
                llm_prompt_speech_token = resample_by_stride(llm_prompt_speech_token[0], prompt_sp).unsqueeze(0)
            this_tts_speech_token, (last_text, last_st) = self.inference(text=text.to(self.device),
                            text_len=torch.tensor([text.shape[1]], dtype=torch.int32).to(self.device),
                            prompt_text=prompt_text.to(self.device),
                            prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device),
                            prompt_speech_token=llm_prompt_speech_token.to(self.device),
                            prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device),
                            embedding=llm_embedding.to(self.device))
            if return_speech:
                this_tts_speech_token = torch.tensor(this_tts_speech_token).unsqueeze(dim=0)
                this_tts_speech = self.token2wav(token=this_tts_speech_token,
                                                prompt_token=flow_prompt_speech_token,
                                                prompt_feat=prompt_speech_feat,
                                                embedding=flow_embedding,
                                                token_offset=0,
                                                finalize=True,
                                                speed=speed).cpu()
            else:
                this_tts_speech = None
            return this_tts_speech_token, this_tts_speech
        
        def token2wav(self, token, prompt_token, prompt_feat, embedding, token_offset, finalize=False, speed=1.0):
            # print(token)
            tts_mel, _ = self.flow.inference(token=token.to(self.device),
                                            token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
                                            prompt_token=prompt_token.to(self.device),
                                            prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device),
                                            prompt_feat=prompt_feat.to(self.device),
                                            prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device),
                                            embedding=embedding.to(self.device),
                                            finalize=finalize)
            tts_mel = tts_mel[:, :, token_offset * self.flow.token_mel_ratio:]
            # append hift cache
            hift_cache_source = torch.zeros(1, 1, 0)
            # keep overlap mel and hift cache
            if speed != 1.0:
                tts_mel = F.interpolate(tts_mel, size=int(tts_mel.shape[2] / speed), mode='linear')
            tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source)
            return tts_speech
                
        def inference_preset(self, input_list, speed=1.0, **kwargs):
            this_uuid = str(uuid.uuid1())
            self.tts_speech_token_dict[this_uuid] = []
            for item in input_list:
                text = item["text"]
                prompt_text = item["prompt_text"]
                llm_prompt_speech_token = item["llm_prompt_speech_token"]
                if item["speed"] != 1:
                    llm_prompt_speech_token = resample_by_stride(llm_prompt_speech_token[0], item["speed"]).unsqueeze(0)
                llm_embedding = item["llm_embedding"]
                flow_prompt_speech_token = item["flow_prompt_speech_token"]
                prompt_speech_feat = item["prompt_speech_feat"]
                flow_embedding = item["flow_embedding"]
                with self.lock:
                    self.llm_end_dict[this_uuid] = False
                    self.hift_cache_dict[this_uuid] = None
                p = threading.Thread(target=self.llm_job, args=(text, prompt_text, llm_prompt_speech_token, llm_embedding, this_uuid,))
                p.start()
                p.join()
                print(len(self.tts_speech_token_dict[this_uuid]))
            # self.llm_end_dict[uuid] = True
            this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
            this_tts_speech = self.token2wav(token=this_tts_speech_token,
                                            prompt_token=flow_prompt_speech_token,
                                            prompt_feat=prompt_speech_feat,
                                            embedding=flow_embedding,
                                            uuid=this_uuid,
                                            token_offset=0,
                                            finalize=True,
                                            speed=speed)
            yield {'tts_speech': this_tts_speech.cpu()}
            with self.lock:
                self.tts_speech_token_dict.pop(this_uuid)
                self.llm_end_dict.pop(this_uuid)

except:
    pass
        
def count_parameters(model):
    # 总参数量
    total_params = sum(p.numel() for p in model.parameters())
    # 名字中包含 "lora" 的参数量
    lora_params = sum(p.numel() for name, p in model.named_parameters() if 'lora' in name)
    one_layer = sum(p.numel() for name, p in model.named_parameters() if ('llm.encoders.0' in name and 'lora' not in name))
    print(f"parameter number: lora-{lora_params} / 1-layer-{one_layer} / total-{total_params}")
        
def resample_by_stride(tensor, scale):
    length = len(tensor)
    new_length = max(1, int(length * scale))  # 计算新的长度
    indices = torch.linspace(0, length - 1, new_length).round().long()  # 均匀采样索引
    return tensor[indices]

# def masked_cross_entropy_loss(logits: torch.Tensor, lm_target: torch.Tensor, tgt_label: int) -> torch.Tensor:
#     """
#     Compute cross-entropy loss only at positions where lm_target == tgt_label.

#     Args:
#         logits (Tensor): Logits of shape (B, T, n_class), raw model outputs.
#         lm_target (Tensor): Target labels of shape (B, T).
#         tgt_label (int): The specific label to compute loss on.

#     Returns:
#         torch.Tensor: Scalar loss, computed only on tgt_label positions.
#     """
#     # 获取 tgt_label 的 mask
#     mask = (lm_target == tgt_label)  # Shape: (B, T)

#     # 计算交叉熵损失（不使用 reduction='mean' 以便手动处理）
#     loss = F.cross_entropy(logits.view(-1, logits.size(-1)), lm_target.view(-1), reduction='none')
#     loss = loss.view(lm_target.shape)  # 重新 reshape 回 (B, T)

#     # 仅在 mask 位置计算损失
#     masked_loss = loss * mask  # 仅保留 tgt_label 的损失

#     # 计算加权平均，避免除零
#     return masked_loss.sum() / (mask.sum() + 1e-8)  # 避免 NaN 产生

import torch
import torch.nn.functional as F

def masked_cross_entropy_loss(logits: torch.Tensor, lm_target: torch.Tensor, tgt_label: int) -> torch.Tensor:
    """
    Efficiently compute cross-entropy loss only at positions where lm_target == tgt_label.

    Args:
        logits (Tensor): Logits of shape (B, T, n_class), raw model outputs.
        lm_target (Tensor): Target labels of shape (B, T).
        tgt_label (int): The specific label to compute loss on.

    Returns:
        torch.Tensor: Scalar loss, computed only on tgt_label positions.
    """
    # 获取 tgt_label 的 mask
    mask = (lm_target == tgt_label)  # Shape: (B, T)
    
    # 找到所有 tgt_label 的索引
    indices = mask.nonzero(as_tuple=True)  # 获取 (B_idx, T_idx)

    if len(indices[0]) == 0:
        # 如果没有目标标签，返回 0
        return torch.tensor(0.0, device=logits.device, requires_grad=True)

    # 仅提取 tgt_label 位置的 logits 和 targets
    selected_logits = logits[indices]  # Shape: (N, n_class)
    selected_targets = lm_target[indices]  # Shape: (N,)

    # 计算交叉熵损失
    loss = F.cross_entropy(selected_logits, selected_targets, reduction='mean')

    return loss

def resample_by_stride(tensor, scale):
    length = len(tensor)
    new_length = max(1, int(length * float(scale)))  # 计算新的长度
    indices = torch.linspace(0, length - 1, new_length).round().long()  # 均匀采样索引
    return tensor[indices]


def make_encoder_attention_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor:
    """
    Generate attention mask for torch.nn.functional.scaled_dot_product_attention.
    Masked positions are -inf, valid positions are 0.0 (float type).

    Args:
        lengths (torch.Tensor): Tensor of shape [B], each element is the valid length of the sequence.
        max_len (int): Optional maximum sequence length. If 0, use max(lengths).

    Returns:
        torch.Tensor: Attention mask of shape [B, 1, T, T].
    """
    NEG_INF = -65500
    batch_size = lengths.size(0)
    max_len = max_len if max_len > 0 else lengths.max().item()

    # Step 1: Create padding mask [B, T], True means PAD
    seq_range = torch.arange(0, max_len, device=lengths.device)
    seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len)
    seq_length_expand = lengths.unsqueeze(-1)
    pad_mask = seq_range_expand >= seq_length_expand  # [B, T], True is PAD

    # Step 2: Expand to attention mask shape [B, 1, T, T]
    attn_mask = pad_mask.unsqueeze(1).unsqueeze(2).expand(-1, 1, max_len, -1)  # mask key

    # Step 3: Convert to float mask: PAD -> -inf, valid -> 0.0
    attn_mask = attn_mask.masked_fill(attn_mask, NEG_INF).masked_fill(~attn_mask, 0.0)

    return ~attn_mask  # [B, 1, T, T]


def extract_with_target_limit(tensor, target=151663, max_sil=10, type_num=1, max_len=30):
    """
    向前回溯 type_num 种非target token，确保返回结果不带前置 target
    """
    # Step 1: 定位最后的非target
    mask = (tensor != target)
    non_target_indices = torch.nonzero(mask, as_tuple=False).squeeze(-1)
    if non_target_indices.numel() == 0:
        return None  # 全是target

    last_non_target_idx = non_target_indices[-1].item()

    # Step 2: 回溯收集type_num种non-target token
    unique_tokens = set()
    non_target_pos = []
    idx = last_non_target_idx
    while idx >= 0:
        token = tensor[idx].item()
        if token != target:
            if token not in unique_tokens:
                unique_tokens.add(token)
            non_target_pos.append(idx)
            if len(unique_tokens) == type_num:
                # 确认从当前idx开始到最后non-target为止
                # 找到idx的连续片段起点（如果前面是相同的token一起纳入）
                token_type = tensor[idx].item()
                while idx > 0 and tensor[idx - 1].item() == token_type:
                    idx -= 1
                    non_target_pos.append(idx)
                break
        idx -= 1

    if not non_target_pos:
        return None  # 没找到足够的非target

    start_idx = min(non_target_pos)

    # Step 3: 后面拼接target
    target_count = 0
    end_idx = last_non_target_idx + 1
    while end_idx < len(tensor) and tensor[end_idx].item() == target and target_count < max_sil:
        end_idx += 1
        target_count += 1

    if end_idx - start_idx > max_len:
        start_idx = end_idx - max_len
    return tensor[start_idx:end_idx], start_idx, end_idx

def extract_with_targetset_limit(tensor, target_list=[151663, 151664], max_sil=10, type_num=1, max_len=30):
    """
    向前回溯 type_num 种非 target_list 内 token，确保返回结果不带前置 target
    target_list 支持多个 target token
    """
    target_set = set(target_list)

    # Step 1: 定位最后的非 target
    mask = ~torch.isin(tensor, torch.tensor(target_list, device=tensor.device))
    non_target_indices = torch.nonzero(mask, as_tuple=False).squeeze(-1)
    if non_target_indices.numel() == 0:
        return None  # 全是 target

    last_non_target_idx = non_target_indices[-1].item()

    # Step 2: 回溯收集 type_num 种 non-target token
    unique_tokens = set()
    non_target_pos = []
    idx = last_non_target_idx
    while idx >= 0:
        token = tensor[idx].item()
        if token not in target_set:
            if token not in unique_tokens:
                unique_tokens.add(token)
            non_target_pos.append(idx)
            if len(unique_tokens) == type_num:
                # 将连续相同 token 全部纳入
                token_type = token
                while idx > 0 and tensor[idx - 1].item() == token_type:
                    idx -= 1
                    non_target_pos.append(idx)
                break
        idx -= 1

    if not non_target_pos:
        return None  # 没找到足够的非 target

    start_idx = min(non_target_pos)

    # Step 3: 后面拼接 target（允许拼接 max_sil 个 target）
    target_count = 0
    end_idx = last_non_target_idx + 1
    while end_idx < len(tensor) and tensor[end_idx].item() in target_set and target_count < max_sil:
        end_idx += 1
        target_count += 1

    # Step 4: 限制 max_len
    if end_idx - start_idx > max_len:
        start_idx = end_idx - max_len

    return tensor[start_idx:end_idx], start_idx, end_idx


def get_attn_bias(text_side_prompt_lens, 
                text_side_tgt_lens, 
                speech_side_prompt_lens, 
                speech_side_speech_lens, 
                attention_bias=1, atten_bias_func=5, attention_range=1,
                text_side_prompt=None, text_side_tgt=None, speech_side_prompt=None, speech_side_speech=None):
        """
        Args:
            text_side_prompt (_type_): N
            text_side_tgt (_type_): N
            speech_side_prompt (_type_): N
            speech_side_speech (_type_): N
            value (int, optional): _description_. Defaults to 1.
            top (int, optional): _description_. Defaults to 5.
            low (int, optional): _description_. Defaults to 1.
        """
        if text_side_prompt is None:
            text_side_prompt = []
            text_side_tgt = []
            speech_side_prompt = []
            speech_side_speech = []
            for i, (a, b, c, d) in enumerate(zip(
                text_side_prompt_lens, 
                text_side_tgt_lens, 
                speech_side_prompt_lens, 
                speech_side_speech_lens
            )):
                if i == 0:
                    a = a + 1 # sos
                text_side_prompt.append((f"t_0{i}", a+1))
                text_side_tgt.append((f"t_1{i}", b+1)) # +1 for emotion
                if i == 0:
                    c = c + 1 # bos
                speech_side_prompt.append((f"s_0{i}", c))
                speech_side_speech.append((f"s_1{i}", d))

        attn_bias_dict = {}
        for func_key in attention_bias.split("_"):
            attn_bias = torch.tensor(atten_bias_func[func_key](
                text_side_prompt, 
                text_side_tgt, 
                speech_side_prompt, 
                speech_side_speech, 
                top=float(attention_range[1]), 
                low=float(attention_range[0])
            ))
            attn_bias_dict[func_key] = attn_bias
        return attn_bias_dict, \
            text_side_prompt, text_side_tgt, speech_side_prompt, speech_side_speech
            
def update_res_by_sequence(
    prev_res,
    item,
    state: dict
):
    """
    在传入的 prev_res 基础上，返回新的更新版本，保证无副作用。
    
    参数:
        prev_res: 上一步结果（例如 [('a', 1), ('b', 2), ...]）
        item: 当前处理的元素
        state: {'cur_idx': int, 'prev': Any}
    
    返回:
        新的 res
    """
    new_res = [list(pair) for pair in prev_res]  # 从上一次结果构造新列表
    cur_idx = state.get("cur_idx", -1)
    prev = state.get("prev", None)

    if item == prev:
        new_res[cur_idx][1] += 1
    else:
        cur_idx += 1
        if cur_idx >= len(new_res):
            raise IndexError(f"cur_idx={cur_idx} 超出 res 长度={len(new_res)}")
        new_res[cur_idx][1] += 1

    state["cur_idx"] = cur_idx
    state["prev"] = item

    snapshot = [(x[0], x[1]) for x in new_res]
    return snapshot