
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import typing as tp
from dataclasses import dataclass, field
from typing import Callable, Dict, List, Optional, Tuple
from typing import Dict, Optional, Union
import torch
from torch import nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence, unpad_sequence
from .utils.common import IGNORE_ID
from .label_smoothing_loss import LabelSmoothingLoss
from ...utils.common import th_accuracy
from .transformer.encoder import ConformerEncoder, TransformerEncoder
from fairseq.dataclass import ChoiceEnum, FairseqDataclass
from fairseq.models import BaseFairseqModel, register_model
from hyperpyyaml import load_hyperpyyaml
import numpy as np
from .cli.frontend import CosyVoiceFrontEnd
from .flow.flow_our import CausalMaskedDiffWithXvec
from .flow.length_regulator import InterpolateRegulator
from .flow.flow_matching import ConditionalCFM
from .flow.decoder import ConditionalDecoder
from .hifigan.generator import HiFTGenerator
from .hifigan.f0_predictor import ConvRNNF0Predictor
from omegaconf import DictConfig
from .llm.llm import Qwen2Encoder, Qwen2LM
from .utils.common import fade_in_out
import uuid
from contextlib import nullcontext
import threading
import time
from .utils.common import ras_sampling

@dataclass
class FMConfig(FairseqDataclass):
    qwen_pretrained_path: str = field(
        default=""
    )

try:
    @register_model("cosyvoice2_flow", dataclass=FMConfig)
    class CosyVoice2Flow(BaseFairseqModel):
        def __init__(self, cfg:FMConfig):
            super().__init__()
            self.causal_masked_diff = CausalMaskedDiffWithXvec()
            self.sample_rate = 24000
            # infer
            self.mel_cache_len = 20
            self.input_size = self.causal_masked_diff.input_size
            self.output_size = self.causal_masked_diff.output_size
            self.decoder_conf = self.causal_masked_diff.decoder_conf
            self.mel_feat_conf = self.causal_masked_diff.mel_feat_conf
            self.vocab_size = self.causal_masked_diff.vocab_size
            self.output_type = self.causal_masked_diff.output_type
            self.input_frame_rate = self.causal_masked_diff.input_frame_rate
            self.only_mask_loss = self.causal_masked_diff.only_mask_loss
            self.token_mel_ratio = self.causal_masked_diff.token_mel_ratio
            self.pre_lookahead_len = self.causal_masked_diff.pre_lookahead_len

        @classmethod
        def build_model(cls, cfg: FMConfig, task):
            """Build a new model instance."""
            model = CosyVoice2Flow(cfg)
            return model

        def forward(
                self,
                batch: dict,
        ) -> Dict[str, Optional[torch.Tensor]]:
            device = batch["codecs_d"].device
            inp = {
                "codecs_d": batch["codecs_d"],
                "speech_token_len": batch["speech_token_len"],
                "speech_feat": batch["fbanks"],
                "speech_feat_len": batch["fbank_lengthes"],
                "spk_feats_denoiseds": batch["spk_feats_denoiseds"],
            }
            return self.causal_masked_diff(inp, device)

        @torch.inference_mode()
        def inference(self,
                    token,
                    token_len,
                    prompt_token,
                    prompt_token_len,
                    prompt_feat,
                    prompt_feat_len,
                    embedding,
                    finalize):
            return self.causal_masked_diff.inference(token,
                    token_len,
                    prompt_token,
                    prompt_token_len,
                    prompt_feat,
                    prompt_feat_len,
                    embedding,
                    finalize
                )
except:
    pass
        
def count_parameters(model):
    # 总参数量
    total_params = sum(p.numel() for p in model.parameters())
    # 名字中包含 "lora" 的参数量
    lora_params = sum(p.numel() for name, p in model.named_parameters() if 'lora' in name)
    one_layer = sum(p.numel() for name, p in model.named_parameters() if ('llm.encoders.0' in name and 'lora' not in name))
    print(f"parameter number: lora-{lora_params} / 1-layer-{one_layer} / total-{total_params}")
