
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import typing as tp
from dataclasses import dataclass, field
from typing import Callable, Dict, List, Optional, Tuple
from typing import Dict, Optional, Union
import torch
from torch import nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence, unpad_sequence
from .utils.common import IGNORE_ID
from .label_smoothing_loss import LabelSmoothingLoss
from ...utils.common import th_accuracy
from .transformer.encoder import ConformerEncoder, TransformerEncoder
from fairseq.dataclass import ChoiceEnum, FairseqDataclass
from fairseq.models import BaseFairseqModel, register_model
from hyperpyyaml import load_hyperpyyaml
import numpy as np
from .cli.frontend import CosyVoiceFrontEnd
from .flow.flow import MaskedDiffWithXvec
from .flow.length_regulator import InterpolateRegulator
from .flow.flow_matching import ConditionalCFM
from .flow.decoder import ConditionalDecoder
from .hifigan.generator import HiFTGenerator
from .hifigan.f0_predictor import ConvRNNF0Predictor
from omegaconf import DictConfig
from .llm.llm import Qwen2Encoder, Qwen2LM
from .utils.common import fade_in_out
import uuid
from contextlib import nullcontext
import threading
import time
from .utils.common import ras_sampling

@dataclass
class Cosyvoice2Config(FairseqDataclass):
    text_encoder_input_size: int = field(default=12)
    llm_input_size: int = field(default=12)
    llm_output_size: int = field(default=12)
    text_token_size: int = field(default=12)
    speech_token_size: int = field(default=12)
    length_normalized_loss: bool = field(default=False)
    lsm_weight: float = field(default=0.0)
    spk_embed_dim: int = field(default=192)
    
    lora_rank: int = field(default=-1)
    
    partial_train: bool = field(
        default=False
    )
    partial_layers: str = field(
        default=""
    )
    qwen_pretrained_path: str = field(
        default=""
    )

try:
    @register_model("cosyvoice2_origin", dataclass=Cosyvoice2Config)
    class CosyVoice2Origin(BaseFairseqModel):
        def __init__(self, cfg:Cosyvoice2Config):
            super().__init__()
            self.partial_train = cfg.partial_train
            self.cfg = cfg
            self.lora_rank = cfg.lora_rank
            self.llm_input_size = self.cfg.llm_input_size
            self.speech_token_size = self.cfg.speech_token_size
            # 2. build speech token language model related modules
            self.sos_eos = 0
            self.task_id = 1
            self.fill_token = 2
            self.llm_embedding = torch.nn.Embedding(2, self.cfg.llm_input_size)
            self.llm = Qwen2Encoder(
                cfg.qwen_pretrained_path,
                lora_rank=self.cfg.lora_rank
            )
            print(self.llm.model.preset_config)
            # input()
            self.llm_decoder = nn.Linear(self.cfg.llm_output_size, self.cfg.speech_token_size + 3)
            self.criterion_ce = LabelSmoothingLoss(
                size=self.cfg.speech_token_size + 3,
                padding_idx=IGNORE_ID,
                smoothing=self.cfg.lsm_weight,
                normalize_length=self.cfg.length_normalized_loss,
            )

            # 3. [Optional] build speech token related modules
            self.speech_embedding = torch.nn.Embedding(self.cfg.speech_token_size + 3, self.cfg.llm_input_size)
            self.hift = None
            self.flow = None
            self.sample_rate = 24000
            # infer
            self.mel_cache_len = 20

        @classmethod
        def build_model(cls, cfg: Cosyvoice2Config, task):
            """Build a new model instance."""
            model = CosyVoice2Origin(cfg)
            return model
        
        def encode(
                self,
                text: torch.Tensor,
                text_lengths: torch.Tensor,
        ):
            encoder_out, encoder_mask = self.text_encoder(text, text_lengths, decoding_chunk_size=1, num_decoding_left_chunks=-1)
            encoder_out_lens = encoder_mask.squeeze(1).sum(1)
            encoder_out = self.text_encoder_affine_layer(encoder_out)
            return encoder_out, encoder_out_lens

        def pad_unpad_sequence(self, sos_eos_emb, text_token, text_token_len, task_id_emb, speech_token, speech_token_len):
            text_token = unpad_sequence(text_token, text_token_len.cpu(), batch_first=True)
            speech_token = unpad_sequence(speech_token, speech_token_len.cpu(), batch_first=True)
            lm_input = [torch.concat([sos_eos_emb.squeeze(dim=0), text_token[i], task_id_emb.squeeze(dim=0), speech_token[i]], dim=0) for i in range(len(text_token))]
            lm_input_len = torch.tensor([i.size(0) for i in lm_input], dtype=torch.int32)
            lm_input = pad_sequence(lm_input, batch_first=True, padding_value=IGNORE_ID)
            return lm_input, lm_input_len

        def forward(
                self,
                batch: dict,
        ) -> Dict[str, Optional[torch.Tensor]]:
            """
            Args:
                text_token: (B, N)
                text_token_len: (B,)
                speech_token: (B, N)
                speech_token_len: (B,)
                embedding: (B, D)
            """
            text_token = batch['text_token']
            device = text_token.device
            text_token_len = batch['text_token_len'].to(device)
            speech_token = batch['codecs_d'].to(device)
            speech_token_len = batch['speech_token_len'].to(device)
            embedding = batch['spk_feats_denoiseds'].to(device)
            # persona
            # spk_persona = batch["spk_persona"].to(device)
            # spk_character = batch["spk_character"].to(device)
            # spk_relation_embs = batch["spk_relation_embs"].to(device)
            # spk_relation_spks = batch["spk_relation_spks"].to(device)
            # spk_relation_mask = batch["spk_relation_mask"].to(device)
            # # contextual
            # contextual_text_tokens = batch["contextual_text_tokens"].to(device)
            # contextual_tt_lengths = batch["contextual_tt_lengths"].to(device)
            # contextual_tt_heights = batch["contextual_tt_heights"].to(device)
            # contextual_speech_tokens = batch["contextual_speech_tokens"].to(device)
            # contextual_st_lengths = batch["contextual_st_lengths"].to(device)
            # contextual_st_heights = batch["contextual_st_heights"].to(device)
            # contextual_spk_persona = batch["contextual_spk_persona"].to(device)
            # contextual_spk_character = batch["contextual_spk_character"].to(device)
            # print(text_token.size(), text_token_len.size(), speech_token.size(), speech_token_len.size(), embedding.size())

            # 1. prepare llm_target
            lm_target = [
                torch.tensor(
                    [IGNORE_ID] * (text_token_len[i] + 1) + 
                    speech_token[i, :speech_token_len[i]].tolist() + 
                    [self.speech_token_size]
                    ) for i in range(text_token.size(0)
                )
            ]
            lm_target = pad_sequence(
                lm_target, batch_first=True, padding_value=IGNORE_ID
            ).to(device)
            # 1. encode text_token
            text = self.llm.model.model.embed_tokens(text_token)
            # embedding = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device)
            
            # 3. eos and task_id
            sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1)
            task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)

            # 4. encode speech_token
            speech_token = self.speech_embedding(speech_token)

            # 5. unpad and pad
            lm_input, lm_input_len = self.pad_unpad_sequence(
                sos_eos_emb, text, text_token_len, task_id_emb, speech_token, speech_token_len
            )

            # 6. run lm forward
            lm_output = self.llm(
                lm_input, lm_input_len.to(device)
            )
            logits = self.llm_decoder(lm_output)
            loss = self.criterion_ce(logits, lm_target)
            acc = th_accuracy(logits.view(-1, self.speech_token_size + 3), lm_target, ignore_label=IGNORE_ID)
            return {'loss': loss, 'acc': acc}

        def sampling_ids(
                self,
                weighted_scores: torch.Tensor,
                decoded_tokens: List,
                sampling: int,
                ignore_eos: bool = True,
        ):
            while True:
                top_ids = self.sampling(weighted_scores, decoded_tokens, sampling)
                if (not ignore_eos) or (self.speech_token_size not in top_ids):
                    break
            return top_ids

        def repetition_penalty(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, penalty=1.2):
            score = torch.gather(scores, 1, input_ids)
            # if score < 0 then repetition penalty has to be multiplied to reduce the token probabilities
            score = torch.where(score < 0, score * penalty, score / penalty)
            scores_processed = scores.scatter(1, input_ids, score)
            return scores_processed
        
        @torch.inference_mode()
        def inference(
                self,
                text: torch.Tensor,
                text_len: torch.Tensor,
                prompt_text: torch.Tensor,
                prompt_text_len: torch.Tensor,
                prompt_speech_token: torch.Tensor,
                prompt_speech_token_len: torch.Tensor,
                embedding: torch.Tensor,
                sampling: int = 25,
                max_token_text_ratio: float = 20,
                min_token_text_ratio: float = 2,
                a_h_n_sa_su_text_lens=None,
                a_h_n_sa_su_st_lens=None,
                text_control=None
        ):
            device = text.device
            text = torch.concat([prompt_text, text], dim=1)
            text_len += prompt_text_len
            text = self.llm.model.model.embed_tokens(text)

            # 2. encode embedding
            embedding = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device)

            # 3. concat llm_input
            sos_eos_emb = self.llm_embedding.weight[self.sos_eos].reshape(1, 1, -1)
            task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
            if prompt_speech_token_len != 0:
                prompt_speech_token_emb = self.speech_embedding(prompt_speech_token)
            else:
                prompt_speech_token_emb = torch.zeros(
                    1, 0, self.llm_input_size, dtype=text.dtype
                ).to(device)
            lm_input = torch.concat([sos_eos_emb, embedding, text, task_id_emb, prompt_speech_token_emb], dim=1)

            # 4. cal min/max_length
            min_len = int((text_len - prompt_text_len) * 0)
            max_len = int((text_len - prompt_text_len) * max_token_text_ratio)

            # 5. step by step decode
            out_tokens = []
            cache = None
            for i in range(max_len):
                y_pred, cache = self.llm.forward_one_step(lm_input,
                                                        masks=torch.tril(torch.ones((1, lm_input.shape[1], lm_input.shape[1]), device=lm_input.device)).to(torch.bool),
                                                        cache=cache,
                                                        a_h_n_sa_su_text_lens=a_h_n_sa_su_text_lens,
                                                        a_h_n_sa_su_st_lens=a_h_n_sa_su_st_lens,
                                                        text_control=text_control)
                logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
                top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True if i < min_len else False).item()
                if top_ids == self.speech_token_size:
                    break
                if top_ids > self.speech_token_size:
                    continue
                # in stream mode, yield token one by one
                yield top_ids
                out_tokens.append(top_ids)
                lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)
        
        def unfreeze_partial(self):
            if self.cfg.partial_layers != "":
                train_layers = [int(temp) for temp in self.cfg.partial_layers.split("-")]
                print(train_layers)
                parameters_names = []
                for name, param in self.named_parameters():
                    print(name)
                    train = False
                    for idx in train_layers:
                        if "llm.encoders.%d."%idx in name:
                            train=True
                    if train:
                        parameters_names.append(name)
                    else:
                        param.requires_grad = False
                print("trainable: "+str(parameters_names))

        def unfreeze_lora(self):
            import loralib as lora
            lora.mark_only_lora_as_trainable(self, bias='lora_only')
            count_parameters(self)
            
        def init_infer_modules(self, 
                            device,
                            model_dir="./pretrained_models/CosyVoice2-0.5B/", 
                            instruct=True,
                            fm_model=None):
            with open('{}/cosyvoice.yaml'.format(model_dir), 'r') as f:
                configs = load_hyperpyyaml(f)
            print(configs)
            self.device = device
            self.frontend = CosyVoiceFrontEnd('{}/campplus.onnx'.format(model_dir),
                                            '{}/speech_tokenizer_v2.onnx'.format(model_dir),
                                            '{}/spk2info.pt'.format(model_dir),
                                            instruct,
                                            configs['allowed_special'])
            # configs['flow']["encoder"] = ConformerEncoder(
            #     **configs['flow']["encoder"]
            # )
            # configs['flow']["length_regulator"] = InterpolateRegulator(
            #     **configs['flow']["length_regulator"]
            # )
            # configs['flow']["decoder"]["cfm_params"] = DictConfig(
            #     **configs['flow']["decoder"]["cfm_params"]
            # )
            # configs['flow']["decoder"]["estimator"] = ConditionalDecoder(
            #     **configs['flow']["decoder"]["estimator"]
            # )
            # configs['flow']["decoder"] = ConditionalCFM(
            #     **configs['flow']["decoder"]
            # )
            # self.flow = MaskedDiffWithXvec(
            #     **configs['flow']
            # )
            # configs['hift']["f0_predictor"] = ConvRNNF0Predictor(**configs['hift']["f0_predictor"])
            # self.hift = HiFTGenerator(**configs['hift'])
            if fm_model is not None:
                self.flow = fm_model
                self.flow.to(device).eval()
                self.our_fm = True
            else:
                self.flow = configs["flow"]
                count_parameters(self)
                flow_state_dict = {k.replace('generator.', ''): v for k, v in torch.load('{}/flow.pt'.format(model_dir), map_location=self.device).items()}
                self.flow.load_state_dict(flow_state_dict, strict=True)
                self.flow.to(device).eval()
                self.our_fm = False
            
            self.hift = configs["hift"]
            self.hift.load_state_dict(torch.load('{}/hift.pt'.format(model_dir), map_location=device), strict=True)
            self.hift.to(device).eval()
            self.eval()
            self.device = device
            self.model_dir = model_dir
            self.lock = threading.Lock()
            del configs
            
            # 
            self.token_min_hop_len = 2 * self.flow.input_frame_rate
            self.token_max_hop_len = 4 * self.flow.input_frame_rate
            self.token_overlap_len = 20
            # here we fix set flow.decoder.estimator.static_chunk_size = 0 for compatibability
            if self.our_fm:
                self.flow.causal_masked_diff.decoder.estimator.static_chunk_size = 0
                self.flow.causal_masked_diff.decoder.fp16 = False
            else:
                self.flow.decoder.estimator.static_chunk_size = 0
                self.flow.decoder.fp16 = False
            # mel fade in out
            self.mel_overlap_len = int(self.token_overlap_len / self.flow.input_frame_rate * 22050 / 256)
            self.mel_window = np.hamming(2 * self.mel_overlap_len)
            # hift cache
            self.mel_cache_len = 20
            self.source_cache_len = int(self.mel_cache_len * 256)
            # speech fade in out
            self.speech_window = np.hamming(2 * self.source_cache_len)
            # rtf and decoding related
            self.stream_scale_factor = 1
            assert self.stream_scale_factor >= 1, 'stream_scale_factor should be greater than 1, change it according to your actual rtf'
            self.llm_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext()
            # dict used to store session related variable
            self.tts_speech_token_dict = {}
            self.llm_end_dict = {}
            self.mel_overlap_dict = {}
            self.flow_cache_dict = {}
            self.hift_cache_dict = {}
            self.sampling = ras_sampling
            
        
        
        def llm_job(self, text, 
                    prompt_text, 
                    llm_prompt_speech_token, 
                    llm_embedding, 
                    uuid, 
                    a_h_n_sa_su_text_lens=None,
                    a_h_n_sa_su_st_lens=None,
                    text_control=None):
            
            with self.llm_context:
                for i in self.inference(text=text.to(self.device),
                                            text_len=torch.tensor([text.shape[1]], dtype=torch.int32).to(self.device),
                                            prompt_text=prompt_text.to(self.device),
                                            prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device),
                                            prompt_speech_token=llm_prompt_speech_token.to(self.device),
                                            prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device),
                                            embedding=llm_embedding.to(self.device),
                                            a_h_n_sa_su_text_lens=a_h_n_sa_su_text_lens,
                                            a_h_n_sa_su_st_lens=a_h_n_sa_su_st_lens,
                                            text_control=text_control):
                    self.tts_speech_token_dict[uuid].append(i)
            self.llm_end_dict[uuid] = True
            
        def llm_job_preset(self, text, 
                    prompt_text, 
                    llm_prompt_speech_token, 
                    llm_embedding, 
                    uuid, 
                    a_h_n_sa_su_text_lens=None,
                    a_h_n_sa_su_st_lens=None,
                    text_control=None):
            
            with self.llm_context:
                for i in self.inference(text=text.to(self.device),
                                            text_len=torch.tensor([text.shape[1]], dtype=torch.int32).to(self.device),
                                            prompt_text=prompt_text.to(self.device),
                                            prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device),
                                            prompt_speech_token=llm_prompt_speech_token.to(self.device),
                                            prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device),
                                            embedding=llm_embedding.to(self.device),
                                            a_h_n_sa_su_text_lens=a_h_n_sa_su_text_lens,
                                            a_h_n_sa_su_st_lens=a_h_n_sa_su_st_lens,
                                            text_control=text_control):
                    self.tts_speech_token_dict[uuid].append(i)

        def token2wav(self, token, prompt_token, prompt_feat, embedding, uuid, token_offset, finalize=False, speed=1.0):
            # print(token)
            tts_mel, _ = self.flow.inference(token=token.to(self.device),
                                            token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
                                            prompt_token=prompt_token.to(self.device),
                                            prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device),
                                            prompt_feat=prompt_feat.to(self.device),
                                            prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device),
                                            embedding=embedding.to(self.device),
                                            finalize=finalize)
            tts_mel = tts_mel[:, :, token_offset * self.flow.token_mel_ratio:]
            # append hift cache
            if self.hift_cache_dict[uuid] is not None:
                hift_cache_mel, hift_cache_source = self.hift_cache_dict[uuid]['mel'], self.hift_cache_dict[uuid]['source']
                tts_mel = torch.concat([hift_cache_mel, tts_mel], dim=2)
            else:
                hift_cache_source = torch.zeros(1, 1, 0)
            # keep overlap mel and hift cache
            if finalize is False:
                tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source)
                if self.hift_cache_dict[uuid] is not None:
                    tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
                self.hift_cache_dict[uuid] = {'mel': tts_mel[:, :, -self.mel_cache_len:],
                                            'source': tts_source[:, :, -self.source_cache_len:],
                                            'speech': tts_speech[:, -self.source_cache_len:]}
                tts_speech = tts_speech[:, :-self.source_cache_len]
            else:
                if speed != 1.0:
                    assert self.hift_cache_dict[uuid] is None, 'speed change only support non-stream inference mode'
                    tts_mel = F.interpolate(tts_mel, size=int(tts_mel.shape[2] / speed), mode='linear')
                tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source)
                if self.hift_cache_dict[uuid] is not None:
                    tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
            return tts_speech

        def inference_whole(self, text, flow_embedding, llm_embedding=torch.zeros(0, 192),
                prompt_text=torch.zeros(1, 0, dtype=torch.int32),
                llm_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
                flow_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
                prompt_speech_feat=torch.zeros(1, 0, 80), stream=False, speed=1.0, 
                a_h_n_sa_su_text_lens = None,
                a_h_n_sa_su_st_lens = None,
                text_control = None,
                **kwargs):
            # this_uuid is used to track variables related to this inference thread
            this_uuid = str(uuid.uuid1())
            with self.lock:
                self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = [], False
                self.hift_cache_dict[this_uuid] = None
            p = threading.Thread(target=self.llm_job, args=(text, prompt_text, llm_prompt_speech_token, llm_embedding, this_uuid, a_h_n_sa_su_text_lens, a_h_n_sa_su_st_lens, text_control,))
            p.start()
            if stream is True:
                token_offset = 0
                while True:
                    time.sleep(0.1)
                    if len(self.tts_speech_token_dict[this_uuid]) - token_offset >= self.token_hop_len + self.flow.pre_lookahead_len:
                        this_tts_speech_token = torch.LongTensor(self.tts_speech_token_dict[this_uuid][:token_offset + self.token_hop_len + self.flow.pre_lookahead_len]).unsqueeze(dim=0)
                        this_tts_speech = self.token2wav(token=this_tts_speech_token,
                                                        prompt_token=flow_prompt_speech_token,
                                                        prompt_feat=prompt_speech_feat,
                                                        embedding=flow_embedding,
                                                        uuid=this_uuid,
                                                        token_offset=token_offset,
                                                        finalize=False)
                        token_offset += self.token_hop_len
                        yield {'tts_speech': this_tts_speech.cpu()}
                    if self.llm_end_dict[this_uuid] is True and len(self.tts_speech_token_dict[this_uuid]) - token_offset < self.token_hop_len + self.flow.pre_lookahead_len:
                        break
                p.join()
                # deal with remain tokens, make sure inference remain token len equals token_hop_len when cache_speech is not None
                this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
                this_tts_speech = self.token2wav(token=this_tts_speech_token,
                                                prompt_token=flow_prompt_speech_token,
                                                prompt_feat=prompt_speech_feat,
                                                embedding=flow_embedding,
                                                uuid=this_uuid,
                                                token_offset=token_offset,
                                                finalize=True)
                yield {'tts_speech': this_tts_speech.cpu()}
            else:
                # deal with all tokens
                p.join()
                this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
                # print(this_tts_speech_token)
                # print(flow_prompt_speech_token)
                # print(flow_embedding)
                this_tts_speech = self.token2wav(token=this_tts_speech_token,
                                                prompt_token=flow_prompt_speech_token,
                                                prompt_feat=prompt_speech_feat,
                                                embedding=flow_embedding,
                                                uuid=this_uuid,
                                                token_offset=0,
                                                finalize=True,
                                                speed=speed)
                yield {'tts_speech': this_tts_speech.cpu()}
            with self.lock:
                self.tts_speech_token_dict.pop(this_uuid)
                self.llm_end_dict.pop(this_uuid)
                
        def inference_preset(self, input_list, speed=1.0, **kwargs):
            this_uuid = str(uuid.uuid1())
            self.tts_speech_token_dict[this_uuid] = []
            for item in input_list:
                text = item["text"]
                prompt_text = item["prompt_text"]
                llm_prompt_speech_token = item["llm_prompt_speech_token"]
                if item["speed"] != 1:
                    llm_prompt_speech_token = resample_by_stride(llm_prompt_speech_token[0], item["speed"]).unsqueeze(0)
                llm_embedding = item["llm_embedding"]
                flow_prompt_speech_token = item["flow_prompt_speech_token"]
                prompt_speech_feat = item["prompt_speech_feat"]
                flow_embedding = item["flow_embedding"]
                with self.lock:
                    self.llm_end_dict[this_uuid] = False
                    self.hift_cache_dict[this_uuid] = None
                p = threading.Thread(target=self.llm_job, args=(text, prompt_text, llm_prompt_speech_token, llm_embedding, this_uuid,))
                p.start()
                p.join()
                print(len(self.tts_speech_token_dict[this_uuid]))
            # self.llm_end_dict[uuid] = True
            this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
            this_tts_speech = self.token2wav(token=this_tts_speech_token,
                                            prompt_token=flow_prompt_speech_token,
                                            prompt_feat=prompt_speech_feat,
                                            embedding=flow_embedding,
                                            uuid=this_uuid,
                                            token_offset=0,
                                            finalize=True,
                                            speed=speed)
            yield {'tts_speech': this_tts_speech.cpu()}
            with self.lock:
                self.tts_speech_token_dict.pop(this_uuid)
                self.llm_end_dict.pop(this_uuid)

except:
    pass
        
def count_parameters(model):
    # 总参数量
    total_params = sum(p.numel() for p in model.parameters())
    # 名字中包含 "lora" 的参数量
    lora_params = sum(p.numel() for name, p in model.named_parameters() if 'lora' in name)
    one_layer = sum(p.numel() for name, p in model.named_parameters() if ('llm.encoders.0' in name and 'lora' not in name))
    print(f"parameter number: lora-{lora_params} / 1-layer-{one_layer} / total-{total_params}")
        
def resample_by_stride(tensor, scale):
    length = len(tensor)
    new_length = max(1, int(length * scale))  # 计算新的长度
    indices = torch.linspace(0, length - 1, new_length).round().long()  # 均匀采样索引
    return tensor[indices]