from typing import Any, Optional, Union, Type
import copy
import math
import torch
import torch.nn as nn
import gin
import torch.nn.functional as F
import argparse
from tqdm import tqdm
from pathlib import Path
import numpy as np
import os
import torch.nn.functional as F
import soundfile as sf
from torchaudio.transforms import Resample
from funasr import AutoModel
from jiwer import wer
import torchaudio
import json
import regex as re
import time
import csv
import librosa as lib
from resemblyzer import VoiceEncoder, preprocess_wav
from torchaudio.transforms import Resample
from models.ecapa_tdnn import ECAPA_TDNN_SMALL
import audeer
import audonnx
import pandas as pd
import librosa
import numpy as np
import numpy.polynomial.polynomial as poly
import onnxruntime as ort
import soundfile as sf
from requests import session
from tqdm import tqdm
from funasr.utils.postprocess_utils import rich_transcription_postprocess
np.set_printoptions(threshold=np.inf)
pd.set_option('display.max_columns', None)
pd.set_option('display.max_rows', 10)
SAMPLING_RATE = 16000
INPUT_LENGTH = 1

def save_audio(out_path, speech, sample_rate):
    home_path = os.path.dirname(out_path)
    if not os.path.exists(home_path):
        os.makedirs(home_path, exist_ok=True)
    sf.write(out_path, speech, sample_rate)

def remove_punctuation_and_whitespace(text):
    return re.sub(r'[\p{P}\s]+', '', text, flags=re.UNICODE)

def load_wav(wav, target_sr):
    speech, sample_rate = torchaudio.load(wav)
    speech = speech.mean(dim=0, keepdim=True)
    if sample_rate != target_sr:
        assert sample_rate > target_sr, 'wav sample rate {} must be greater than {}'.format(sample_rate, target_sr)
        speech = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=target_sr)(speech)
    return speech

def normalize_text(text, remove_punctuation=False):
    text = re.sub(r'[\x00-\x08\x0b\x0c\x0e-\x1f\x7f-\x9f\xa0]', '', text)

    text = text.translate(str.maketrans({
        '，': ',', '。': '.', '；': ';', '：': ':', '？': '?', '！': '!',
        '（': '(', '）': ')', '【': '[', '】': ']', '“': '"', '”': '"', '‘': "'", '’': "'"
    }))

    if remove_punctuation:
        text = text.translate(str.maketrans('', '', string.punctuation))
        text = re.sub(r"[，。！？；：“”‘’（）【】《》、]", "", text)

    return text

def wer_cal(ref, hyp):
    ref = normalize_text(ref.lower(), remove_punctuation=True)
    hyp = normalize_text(hyp.lower(), remove_punctuation=True)
    score = wer(ref, hyp)
    # print(f"ref: {ref}\ntgt:{hyp}\nwer:{score}")
    return score

def cer_cal(ref, hyp):
    ref = remove_punctuation_and_whitespace(ref)
    hyp = remove_punctuation_and_whitespace(hyp)
    ref_chars = list(ref.strip().replace(" ", ""))
    hyp_chars = list(hyp.strip().replace(" ", ""))
    r_len = len(ref_chars)

    dp = [[0] * (len(hyp_chars) + 1) for _ in range(r_len + 1)]

    for i in range(r_len + 1):
        dp[i][0] = i
    for j in range(len(hyp_chars) + 1):
        dp[0][j] = j

    for i in range(1, r_len + 1):
        for j in range(1, len(hyp_chars) + 1):
            if ref_chars[i - 1] == hyp_chars[j - 1]:
                dp[i][j] = dp[i - 1][j - 1]
            else:
                substitute = dp[i - 1][j - 1] + 1
                insert = dp[i][j - 1] + 1
                delete = dp[i - 1][j] + 1
                dp[i][j] = min(substitute, insert, delete)

    return dp[r_len][len(hyp_chars)] / r_len

def merge_bpe_tokens(tokens_with_timestamps):
    merged = []
    current_word = ''
    start_time = None

    for token, start, end in tokens_with_timestamps:
        is_new_word = token.startswith('▁')

        clean_token = token.lstrip('▁')

        if is_new_word:
            # Flush the previous word
            if current_word:
                merged.append([current_word, start_time, last_end])
            # Start a new word
            current_word = clean_token
            start_time = start
        else:
            # Continuation of previous word
            current_word += clean_token

        last_end = end

    # Append the last word
    if current_word:
        merged.append([current_word, start_time, last_end])

    return merged


def wer_and_align(refs, inp_path, asr_model, asr_kwargs):
    inp_wav = load_wav(inp_path, target_sr=16000)[0]
    est_info = asr_model.inference(
        data_in=inp_wav,
        language="auto", # "zh", "en", "yue", "ja", "ko", "nospeech"
        use_itn=False,
        ban_emo_unk=False,
        output_timestamp=True,
        **asr_kwargs,
    )[0]
    est_txt = rich_transcription_postprocess(est_info["text"])

    est_words = rich_transcription_postprocess(est_info['text']).split() 
    lang = re.search(r"<\|(.*?)\|>", est_info["text"]).group(1)
    timestamps = merge_bpe_tokens(est_info['timestamp'])

    used_range = [False] * len(est_words)
    results = []
    segments = []
    inp_wav = inp_wav.numpy().flatten()
    for ref in refs:
        if lang == "zh":
            ref_str = ref.replace(" ", "")
            ref_len = len(ref_str)
        else:
            ref_len = len(ref_str.split(" "))
        best_score = float("inf")
        best_start = -1
        best_end = -1

        for start in range(0, len(est_words)):
            for end in range(start + 1, min(len(est_words) + 1, start + ref_len + 10)):
                if any(used_range[start:end]):  
                    continue
                
                if lang == "zh":
                    hyp_str = "".join(est_words[start:end])
                    score = cer_cal(ref_str, hyp_str)
                else:
                    score = wer_cal(ref_str, est_words)

                if score < best_score:
                    best_score = score
                    best_start = start
                    best_end = end

        if best_start != -1:
            for i in range(best_start, best_end):
                used_range[i] = True

            start_time = timestamps[best_start][0] / 1000
            end_time = timestamps[best_end - 1][1] / 1000

            results.append({
                "ref": ref,
                "match": "".join(est_words[best_start:best_end]) if lang=="zh" else " ".join(est_words[best_start:best_end]),
                "start_time": start_time,
                "end_time": end_time,
                "cer_score": best_score
            })

            segments.append(
                inp_wav[int(start_time*16000):int(end_time*16000)]
            )
    
    if lang == "zh":
        cer_score = cer_cal("".join(refs), est_txt)
    else:
        cer_score = wer_cal(" ".join(refs), est_txt)
    return results, segments, cer_score, inp_wav

#########################################################################
def preprocess_audio(audio_array, feature_extractor, max_duration=30.0):
    max_length = int(feature_extractor.sampling_rate * max_duration)
    original_length = len(audio_array)
    valid_frames = original_length // 320
    if len(audio_array) > max_length:
        audio_array = audio_array[:max_length]
    else:
        audio_array = np.pad(audio_array, (0, max_length - len(audio_array)))

    inputs = feature_extractor(
        audio_array,
        sampling_rate=feature_extractor.sampling_rate,
        max_length=max_length,
        truncation=True,
        return_tensors="pt",
    )
    return inputs, valid_frames


def predict_emotion(audio_array, model, feature_extractor, max_duration=30.0):
    inputs, valid_frames = preprocess_audio(audio_array, feature_extractor, max_duration)
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    inputs = {key: value.to(device) for key, value in inputs.items()}

    with torch.no_grad():
        outputs = model(**inputs, output_hidden_states=True)
    last_hidden_state = outputs.hidden_states[-1][:, :valid_frames, :]
    pooled = last_hidden_state.mean(dim=1)
    
    return pooled.flatten()

def emo_sim(refs, hyps, emo2vec, feature_extractor, sr=16000):
    emo2vec_sims = []
    for i in range(len(refs)):
        ref = refs[i]
        hyp = hyps[i]
        e2v_hyp_embs = []
        for hyp_item in slice_audio(hyp):
            generated_emb = predict_emotion(hyp_item, emo2vec, feature_extractor) # 1280
            e2v_hyp_embs.append(generated_emb)
        e2v_ref_embs = []
        for ref_item in slice_audio(ref):
            tgt_emb = predict_emotion(hyp_item, emo2vec, feature_extractor) # 1280
            e2v_ref_embs.append(tgt_emb)
        generated_emb = np.mean(e2v_hyp_embs, axis=0)
        tgt_emb = np.mean(e2v_ref_embs, axis=0)
        simi = float(F.cosine_similarity(torch.FloatTensor([generated_emb]), torch.FloatTensor([tgt_emb])).item())
        emo2vec_sims.append(simi)
    return np.mean(emo2vec_sims)

################################################################
MODEL_LIST = ['ecapa_tdnn', 'hubert_large', 'wav2vec2_xlsr', 'unispeech_sat', "wavlm_base_plus", "wavlm_large"]

def init_model(model_name, checkpoint=r"./pretrained_models/model_temp/speaker/wavlm_large_finetune.pth"):
    if model_name == 'unispeech_sat':
        config_path = 'config/unispeech_sat.th'
        model = ECAPA_TDNN_SMALL(feat_dim=1024, feat_type='unispeech_sat', config_path=config_path)
    elif model_name == 'wavlm_base_plus':
        config_path = None
        model = ECAPA_TDNN_SMALL(feat_dim=768, feat_type='wavlm_base_plus', config_path=config_path)
    elif model_name == 'wavlm_large':
        config_path = None
        model = ECAPA_TDNN_SMALL(feat_dim=1024, feat_type='wavlm_large', config_path=config_path)
    elif model_name == 'hubert_large':
        config_path = None
        model = ECAPA_TDNN_SMALL(feat_dim=1024, feat_type='hubert_large_ll60k', config_path=config_path)
    elif model_name == 'wav2vec2_xlsr':
        config_path = None
        model = ECAPA_TDNN_SMALL(feat_dim=1024, feat_type='wav2vec2_xlsr', config_path=config_path)
    else:
        model = ECAPA_TDNN_SMALL(feat_dim=40, feat_type='fbank')

    if checkpoint is not None:
        state_dict = torch.load(checkpoint, map_location=lambda storage, loc: storage)
        model.load_state_dict(state_dict['model'], strict=False)
    return model

def slice_audio(audio, dim=0):
    if len(audio) > 8*16000:
        subarrays = np.array_split(audio, len(audio) // (8*16000) + 1, axis=dim)
        return subarrays
    else:
        return [audio]

def cos_simi_resemb(model, generated_wav, tgt_wav, sr=16000):
    with torch.no_grad():
        gen_wav = preprocess_wav(generated_wav, sr)
        tgt_wav = preprocess_wav(tgt_wav, sr)
        gens = []
        for gen_item in slice_audio(gen_wav):
            gen_embed = model.embed_utterance(gen_item)
            gens.append(gen_embed)
        tgts = []
        for tgt_item in slice_audio(tgt_wav):
            tgt_embed = model.embed_utterance(tgt_item)
            tgts.append(tgt_embed)
        gen_embed = np.mean(gens, axis=0)
        tgt_embed = np.mean(tgts, axis=0)
        sim = F.cosine_similarity(torch.tensor([gen_embed]), torch.tensor([tgt_embed])).item()
        return float(sim)

def cos_simi_wavlm_ecapa(model, generated_wav, tgt_wav, sr=16000):
    with torch.no_grad():
        wav1, sr1 = generated_wav, sr
        wav2, sr2 = tgt_wav,sr
        gens = []
        for gen_item in slice_audio(wav1):
            wav1 = torch.from_numpy(gen_item).unsqueeze(0).float()
            if sr1 != 16000:
                resample1 = Resample(orig_freq=sr1, new_freq=16000)
                wav1 = resample1(wav1)
            wav1 = wav1.cuda()
            emb1 = model(wav1)
            gens.append(emb1)
        tgts = []
        for tgt_item in slice_audio(wav2):
            wav2 = torch.from_numpy(tgt_item).unsqueeze(0).float()
            if sr2 != 16000:
                resample2 = Resample(orig_freq=sr2, new_freq=16000)
                wav2 = resample2(wav2)
            wav2 = wav2.cuda()
            emb2 = model(wav2)
            tgts.append(emb2)
        
        emb1 = torch.mean(torch.stack(gens), dim=0)
        emb2 = torch.mean(torch.stack(tgts), dim=0)
        
        sim = F.cosine_similarity(emb1, emb2).detach().cpu().item()
        return float(sim)

def spk_sim(refs, hyps, resemb_model, wavlm_model):
    resem_simis = []
    wavlm_simis = []
    for i in range(len(refs)):
        ref = refs[i]
        hyp = hyps[i]
        resem_simis.append(cos_simi_resemb(resemb_model, hyp, ref))
        wavlm_simis.append(cos_simi_wavlm_ecapa(wavlm_model, hyp, ref))
    return np.mean(resem_simis), np.mean(wavlm_simis)

##############################################################
def _get_conv_layer(
    in_channels: int,
    out_channels: int,
    kernel_size: [int, int] = (3, 3),
    padding: [int, int] = (1, 1),
    activation_fn = nn.ReLU,
    max_pool_size = 3,
    dropout = 0.3,
    bn: bool = False,
):
    """Returns a CBAD layer: Convolution, Batch normalization, Activation, and Dropout."""
    layers = [nn.Conv2d(
        in_channels=in_channels,
        out_channels=out_channels,
        kernel_size=kernel_size,
        padding=padding
    )]
    if bn:
        layers.append(nn.BatchNorm2d(out_channels))
    layers.append(activation_fn())
    if max_pool_size is not None:
        layers.append(nn.MaxPool2d(max_pool_size))
    if dropout is not None:
        layers.append(nn.Dropout(dropout))
    return nn.Sequential(*layers)


@gin.configurable
def _dense_layer(
    in_dim: int,
    out_dim: int,
    use_ln: bool,
    use_activation: bool,
    activation_fn: Any = nn.ReLU,
) -> nn.Sequential:
    """Returns Sequential Dense-OptionalLayerNorm-OptionalActivation layer."""
    layers = [nn.Linear(in_dim, out_dim)]
    if use_ln:
        layers.append(nn.LayerNorm(out_dim))
    if use_activation:
        layers.append(activation_fn())
    return nn.Sequential(*layers)


@gin.configurable
class Encoder(nn.Module):
    
    def __init__(self, bn: bool = True, max_pool_size: int = 3, activation_fn: Any = nn.ReLU):
        super(Encoder, self).__init__()
        self.encoder = nn.Sequential(
            _get_conv_layer(1, 32, bn=bn, max_pool_size=None, activation_fn=activation_fn),
            _get_conv_layer(32, 32, bn=bn, max_pool_size=max_pool_size, activation_fn=activation_fn),
            _get_conv_layer(32, 64, bn=bn, max_pool_size=None, activation_fn=activation_fn),
            _get_conv_layer(64, 64, bn=bn, max_pool_size=None, dropout=None, activation_fn=activation_fn),
        )
        self._flatten = nn.Flatten()

    def forward(self, spec: torch.Tensor) -> torch.Tensor:
        # input speech_spectrum shape (batch, 1, max_seq_len, n_features)
        embeddings = self.encoder(spec) # shape (batch, 64, max_seq_len, n_features)
        embeddings = F.max_pool2d(embeddings, kernel_size=embeddings.size()[2:])
        return self._flatten(embeddings)


@gin.configurable
class Head(nn.Module):

    def __init__(self, use_ln: bool = False, activation_fn: Any = nn.ReLU, in_dim: int = 64):
        super(Head, self).__init__()
        self.head = nn.Sequential(
            _dense_layer(in_dim, 64, use_ln, True, activation_fn),
            _dense_layer(64, 64, use_ln, True, activation_fn),
            _dense_layer(64, 2, False, False),
        )

    def forward(self, embeddings: torch.Tensor) -> torch.Tensor:
        return self.head(embeddings)


@gin.configurable
class DnsmosPro(nn.Module):

    def __init__(self, encoder_cls: Type[nn.Module] = Encoder, head_cls: Type[nn.Module] = Head):
        super(DnsmosPro, self).__init__()
        self._encoder = encoder_cls()
        self._head = head_cls()
        self._softplus = nn.Softplus()

    def encoder(self, speech_spectrum: torch.Tensor) -> torch.Tensor:
        return self._encoder(speech_spectrum)

    def forward(self, speech_spectrum: torch.Tensor) -> torch.Tensor:
        embeddings = self._encoder(speech_spectrum)
        predictions = self._head(embeddings)
        mean_predictions = 2 * predictions[:, 0].unsqueeze(1) + 3
        var_predictions = 4 * self._softplus(predictions[:, 1].unsqueeze(1))
        predictions = torch.cat((mean_predictions, var_predictions), dim=1)
        return predictions


@gin.configurable
class DnsmosEncoder(nn.Module):
    
    def __init__(self):
        super(DnsmosEncoder, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=32, kernel_size=(3,3), padding=(1,1)),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Dropout(0.3),
            nn.Conv2d(in_channels=32, out_channels=32, kernel_size=(3,3), padding=(1,1)),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Dropout(0.3),
            nn.Conv2d(in_channels=32, out_channels=32, kernel_size=(3,3), padding=(1,1)),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Dropout(0.3),
            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=(3,3), padding=(1,1)),
            nn.ReLU(),
        )
        
    def forward(self, speech_spectrum):
        # input speech_spectrum shape (batch, 1, max_seq_len, n_features)
        batch = speech_spectrum.shape[0]
        time = speech_spectrum.shape[2]
        speech_spectrum = self.encoder(speech_spectrum) # shape (batch, 64, max_seq_len, n_features)
        embeddings = F.max_pool2d(speech_spectrum, kernel_size=speech_spectrum.size()[2:])
        embeddings = embeddings.view(batch, -1) # shape (batch, 64)
        return embeddings


@gin.configurable
class DnsmosHead(nn.Module):

    def __init__(self):
        super(DnsmosHead, self).__init__()
        self.head = nn.Sequential(
            nn.Linear(64, 64),
            nn.ReLU(),
            nn.Linear(64, 64),
            nn.ReLU(),
            nn.Linear(64, 1),
        )

    def forward(self, embeddings):
        # input embeddings shape (batch, 64)
        prediction = self.head(embeddings)
        return prediction


@gin.configurable
class DnsmosClassic(nn.Module):
    
    def __init__(self):
        super(DnsmosClassic, self).__init__()
        self.encoder = DnsmosEncoder()
        self.head = DnsmosHead()

    def forward(self, speech_spectrum: torch.Tensor) -> torch.Tensor:
        embeddings = self.encoder(speech_spectrum)
        prediction = self.head(embeddings)
        return prediction


@gin.configurable
class Mosnet(nn.Module):

    def __init__(self):
        super(Mosnet, self).__init__()
        self.mean_conv = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=16, kernel_size=(3,3), padding=(1,1)),
            nn.ReLU(),
            nn.Conv2d(in_channels=16, out_channels=16, kernel_size=(3,3), padding=(1,1)),
            nn.ReLU(),
            nn.Conv2d(in_channels=16, out_channels=16, kernel_size=(3,3), padding=(1,1), stride=(1,3)),
            nn.ReLU(),
            nn.Conv2d(in_channels=16, out_channels=32,  kernel_size=(3,3), padding=(1,1)),
            nn.ReLU(),
            nn.Conv2d(in_channels=32, out_channels=32, kernel_size=(3,3), padding=(1,1)),
            nn.ReLU(),
            nn.Conv2d(in_channels=32, out_channels=32, kernel_size=(3,3), padding=(1,1), stride=(1,3)),
            nn.ReLU(),
            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=(3,3), padding=(1,1)),
            nn.ReLU(),
            nn.Conv2d(in_channels=64, out_channels=64, kernel_size=(3,3), padding=(1,1)),
            nn.ReLU(),
            nn.Conv2d(in_channels=64, out_channels=64, kernel_size=(3,3), padding=(1,1), stride=(1,3)),
            nn.ReLU(),
            nn.Conv2d(in_channels=64, out_channels=128, kernel_size=(3,3), padding=(1,1)),
            nn.ReLU(),
            nn.Conv2d(in_channels=128, out_channels=128, kernel_size=(3,3), padding=(1,1)),
            nn.ReLU(),
            nn.Conv2d(in_channels=128, out_channels=128, kernel_size=(3,3), padding=(1,1), stride=(1,3)),
            nn.ReLU()
        )

        self.mean_rnn = nn.LSTM(
            input_size=512,
            hidden_size=128,
            num_layers=1,
            batch_first=True,
            dropout=0.3,
            bidirectional=True
        )
        
        self.mean_MLP = nn.Sequential(
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(128,1)
        )

    def encoder(self, speech_spectrum: torch.Tensor) -> torch.Tensor:
        batch = speech_spectrum.shape[0]
        time = speech_spectrum.shape[2]
        speech_spectrum = self.mean_conv(speech_spectrum)  # shape (batch, 64, max_seq_len, n_features)
        embeddings = F.max_pool2d(speech_spectrum, kernel_size=speech_spectrum.size()[2:])
        embeddings = embeddings.view(batch, -1)  # shape (batch, 64)
        return embeddings

    def forward(self, speech_spectrum: torch.Tensor) -> torch.Tensor:
        # input speech_spectrum shape (batch, 1, max_seq_len, 257)
        batch = speech_spectrum.shape[0]
        time = speech_spectrum.shape[2]
        speech_spectrum = self.mean_conv(speech_spectrum)  # shape (batch, 128, max_seq_len, 4)
        speech_spectrum = speech_spectrum.view((batch, time, 512))  # shape (batch, max_seq_len, 512)
        speech_spectrum, (h, c) = self.mean_rnn(speech_spectrum)  # shape (batch, max_seq_len, 256)
        mos_mean = self.mean_MLP(speech_spectrum)  # shape (batch, max_seq_len, 1)
        return mos_mean


@gin.configurable
def stft(
    samples: np.ndarray,
    win_length: int = 320,
    hop_length: int = 160,
    n_fft: int = 320,
    use_log: bool = True,
    use_magnitude: bool = True,
    n_mels: Optional[int] = None,
) -> np.ndarray:
    if use_log and not use_magnitude:
        raise ValueError('Log is only available if the magnitude is to be computed.')
    if n_mels is None:
        spec = librosa.stft(y=samples, win_length=win_length, hop_length=hop_length, n_fft=n_fft)
    else:
        spec = librosa.feature.melspectrogram(
            y=samples, win_length=win_length, hop_length=hop_length, n_fft=n_fft, n_mels=n_mels
        )
    spec = spec.T
    if use_magnitude:
        spec = np.abs(spec)
    if use_log:
        spec = np.clip(spec, 10 ** (-7), 10 ** 7)
        spec = np.log10(spec)
    return spec

def pad_to_fit_window(samples, frame_length=32000, hop_length=16000):
    total_len = len(samples)
    num_frames = (total_len - frame_length) // hop_length + 1
    last_frame_start = num_frames * hop_length
    remaining = total_len - last_frame_start

    if remaining < frame_length:
        pad_len = frame_length - remaining
        # 镜像填充
        pad = samples[-pad_len:][::-1]
        samples = np.concatenate([samples, pad])
    return samples

def dnsmos_pro(dnsmos_pro_model, samples, frame_length=32000, hop_length=16000):
    samples = pad_to_fit_window(samples, frame_length=frame_length, hop_length=hop_length)
    num_frames = (len(samples) - frame_length) // hop_length + 1
    frames = np.stack([
        samples[i * hop_length : i * hop_length + frame_length]
        for i in range(num_frames)
    ])
    spec_batch = torch.stack([torch.FloatTensor(stft(frame)) for frame in frames])
    spec_batch = spec_batch.unsqueeze(1) 
    prediction = dnsmos_pro_model(spec_batch)
    mos = prediction[:, 0].detach().numpy()
    return " ".join(mos.astype(str)), np.var(mos)

##############################################################
def save_audio_4_autopcp(hyps, save_home, line_i, sample_rate=16000):
    hyp_save_paths = []
    for i in range(len(hyps)):
        hyp = hyps[i]
        hyp_save_path = os.path.join(save_home, f"{line_i}_{i}.wav")
        save_audio(hyp_save_path, hyp.flatten(), sample_rate=sample_rate)
        hyp_save_paths.append(hyp_save_path)
    return hyp_save_paths

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Inference')
    parser.add_argument('--tsv', type=str)
    parser.add_argument('--out_home', type=str)
    parser.add_argument('--ref_home', type=str)
    args = parser.parse_args()
    
    asr_model = AutoModel(model="paraformer-zh", disable_update=True)
    
    resemb_model = VoiceEncoder(device=torch.device("cuda"))
    resemb_model.eval()
    
    wavlm_ecapa = init_model("wavlm_large", os.path.join(os.environ.get("wavlm_home", "./pretrained_models/model_temp/speaker/"), r"wavlm_large_finetune.pth")).cuda()
    wavlm_ecapa.eval()
    
    emo2vec_model = AutoModel(model=os.environ.get("emo2vec_home", "./pretrained_models/modelscope/iic/emotion2vec_plus_large/"), disable_update=True)
    
    dns_model = torch.jit.load(os.path.join(os.environ.get("dns_home", "./pretrained_models/DNSMOSPro/runs/NISQA/"), 'model_best.pt'), map_location=torch.device('cpu'))
    
    url = 'https://zenodo.org/record/6221127/files/w2v2-L-robust-12.6bc4a7fd-1.1.0.zip'
    cache_root = audeer.mkdir(os.environ.get("audeer_home", "./pretrained_models/model_temp/")+'/cache')
    model_root = audeer.mkdir(os.environ.get("audeer_home", "./pretrained_models/model_temp/")+'/model')
    archive_path = audeer.download_url(url, cache_root, verbose=True)
    audeer.extract_archive(archive_path, model_root)
    audonnx_model = audonnx.load(model_root)
    
    with open(args.tsv, "r", encoding="utf-8") as rf:
        infos = json.load(rf)
    
    src_audio_paths = []
    tgt_audio_paths = []
    flags = []
    
    info_dict = {
        "cer": [],
        "emo2vec_score": [],
        "audonnx_score": [],
        "resemb_score": [],
        "wavlm_score": [],
        "dnsmos_mean": [],
        "dnsmos_var": [],
        "generated_path": [],
        "keys": [],
        "line_i": [],
    }
        
    for key_idx, key in enumerate(tqdm(infos.keys())):
        with torch.no_grad():
            utt_items = infos[key].get("script", [])
            for line_i, line_item in enumerate(utt_items):
                try:
                    generated_path = os.path.join(args.out_home, "temp_signal", f"{key}_{line_i}.wav")
                # if os.path.exists(generated_path):
                    start = time.time()
                    lines = line_item["text"]
                    spk_ref = os.path.join(args.ref_home, lines[0]["target_speaker"])
                    emotion_refs = [
                        lib.load(os.path.join(args.ref_home, lines[i]["prompt_speech"]), sr=16000)[0]
                        for i in range(len(lines))
                    ]
                    text_ref = [lines[i]["lines"] for i in range(len(lines))]
                    emotions = [lines[i]["emotion_refined"] for i in range(len(lines))]
                    speeds = [lines[i]["speed_refined"] for i in range(len(lines))]
                    energys = [lines[i]["energy_refined"] for i in range(len(lines))]
                    print(f"READ: {time.time()-start}")
                    
                    start = time.time()
                    segment_infos, segments_16k, cer, generated_wav = wer_and_align(text_ref, generated_path, asr_model)
                    print(f"ASR: {time.time()-start}")
                    start = time.time()
                    emo2vec_score, audonnx_score = emo_sim(
                        emotion_refs, segments_16k, emo2vec_model, audonnx_model
                    )
                    print(f"EMO: {time.time()-start}")
                    start = time.time()
                    resemb_score, wavlm_score = spk_sim(
                        [lib.load(spk_ref, sr=16000)[0]] * len(emotion_refs), 
                        segments_16k, resemb_model, wavlm_ecapa
                    )
                    print(f"SPK: {time.time()-start}")
                    start = time.time()
                    dnsmos_mean, dnsmos_var = dnsmos_pro(dns_model, generated_wav)
                    print(f"DNS: {time.time()-start}")
                except Exception as e:
                    print(e)
                    cer = 100.0
                    emo2vec_score = -1
                    audonnx_score = -1
                    resemb_score = -1
                    wavlm_score = -1
                    dnsmos_mean = -1
                    dnsmos_var = -1
                    segments_16k = [np.zeros(16000) for i in range(len(lines))]
                
                info_dict["cer"].append(cer)
                info_dict["emo2vec_score"].append(emo2vec_score)
                info_dict["audonnx_score"].append(audonnx_score)
                info_dict["resemb_score"].append(resemb_score)
                info_dict["wavlm_score"].append(wavlm_score)
                info_dict["generated_path"].append(generated_path)
                info_dict["dnsmos_mean"].append(dnsmos_mean)
                info_dict["dnsmos_var"].append(dnsmos_var)
                info_dict["keys"].append(key)
                info_dict["line_i"].append(line_i)
                print(f"INFO: {time.time()-start}")
                start = time.time()
                
                tgt_audio_path = save_audio_4_autopcp(segments_16k, os.path.join(
                    args.out_home, "wav_segs", f"{key}"
                ), f"{line_i}")
                src_audio_path = [
                    os.path.join(args.ref_home, lines[i]["prompt_speech"])
                    for i in range(len(lines))
                ]
                tgt_audio_paths += tgt_audio_path
                src_audio_paths += src_audio_path
                flags.append(f"{key}_{line_i}")
                print(f"WRITE: {time.time()-start}")
                start = time.time()
            
    df = pd.DataFrame(info_dict)
    df.to_csv(os.path.join(args.out_home, "metrics_1st.csv"), index=False, encoding='utf-8')
    log_file = open(os.path.join(args.out_home, "metrics_1st.log"), "w")
    with pd.option_context('display.max_columns', None): 
        print(df.describe(include='all').to_string(), file=log_file)
    log_file.close()
    
    output_tsv_path = os.path.join(args.out_home, "eval4autopcp.tsv")
    output_flags_path = os.path.join(args.out_home, "eval4autopcp_flag.tsv")
    with open(output_tsv_path, mode="w", newline="", encoding="utf-8") as f1:
        with open(output_flags_path, mode="w", newline="", encoding="utf-8") as f2:
            writer = csv.writer(f1, delimiter="\t")
            writer.writerow(["src_audio", "tgt_audio"]) 
            for src, tgt, flag in zip(src_audio_paths, tgt_audio_paths, flags):
                writer.writerow([src, tgt])
                print(flag, file=f2)