from typing import Any, Optional, Union, Type
import copy
import math
import torch
torch.backends.cudnn.allow_tf32 = True
torch.backends.cuda.matmul.allow_tf32 = True
import torch.nn as nn
import gin
import torch.nn.functional as F
import argparse
from tqdm import tqdm
from pathlib import Path
import numpy as np
import os
import torch.nn.functional as F
import soundfile as sf
from torchaudio.transforms import Resample
from funasr import AutoModel
from jiwer import wer
import whisperx
import torchaudio
import json
import regex as re
import time
import csv
import librosa as lib
from resemblyzer import VoiceEncoder, preprocess_wav
from torchaudio.transforms import Resample
from models.ecapa_tdnn import ECAPA_TDNN_SMALL
import audeer
import audonnx
import pandas as pd
import librosa
import numpy as np
import numpy.polynomial.polynomial as poly
import onnxruntime as ort
import soundfile as sf
from requests import session
from tqdm import tqdm
import string
np.set_printoptions(threshold=np.inf)
pd.set_option('display.max_columns', None)
pd.set_option('display.max_rows', 10)
SAMPLING_RATE = 16000
INPUT_LENGTH = 1

def wer_cal(ref, hyp):
    ref = normalize_text(ref.lower(), remove_punctuation=True)
    hyp = normalize_text(hyp.lower(), remove_punctuation=True)
    score = wer(ref, hyp)
    # print(f"ref: {ref}\ntgt:{hyp}\nwer:{score}")
    return score

def normalize_text(text, remove_punctuation=False):
    text = re.sub(r'[\x00-\x08\x0b\x0c\x0e-\x1f\x7f-\x9f\xa0]', '', text)

    text = text.translate(str.maketrans({
        '，': ',', '。': '.', '；': ';', '：': ':', '？': '?', '！': '!',
        '（': '(', '）': ')', '【': '[', '】': ']', '“': '"', '”': '"', '‘': "'", '’': "'"
    }))

    if remove_punctuation:
        text = text.translate(str.maketrans('', '', string.punctuation))
        text = re.sub(r"[，。！？；：“”‘’（）【】《》、]", "", text)

    return text

def save_audio(out_path, speech, sample_rate):
    home_path = os.path.dirname(out_path)
    if not os.path.exists(home_path):
        os.makedirs(home_path, exist_ok=True)
    sf.write(out_path, speech, sample_rate)

def remove_punctuation_and_whitespace(text):
    return re.sub(r'[\p{P}\s]+', '', text, flags=re.UNICODE)

def load_wav(wav, target_sr):
    speech, sample_rate = torchaudio.load(wav)
    speech = speech.mean(dim=0, keepdim=True)
    if sample_rate != target_sr:
        assert sample_rate > target_sr, 'wav sample rate {} must be greater than {}'.format(sample_rate, target_sr)
        speech = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=target_sr)(speech)
    return speech

# wer(ref, hyp)
def cer_cal(ref, hyp):
    ref = remove_punctuation_and_whitespace(ref)
    hyp = remove_punctuation_and_whitespace(hyp)
    ref_chars = list(ref.strip().replace(" ", ""))
    hyp_chars = list(hyp.strip().replace(" ", ""))
    r_len = len(ref_chars)

    dp = [[0] * (len(hyp_chars) + 1) for _ in range(r_len + 1)]

    for i in range(r_len + 1):
        dp[i][0] = i
    for j in range(len(hyp_chars) + 1):
        dp[0][j] = j

    for i in range(1, r_len + 1):
        for j in range(1, len(hyp_chars) + 1):
            if ref_chars[i - 1] == hyp_chars[j - 1]:
                dp[i][j] = dp[i - 1][j - 1]
            else:
                substitute = dp[i - 1][j - 1] + 1
                insert = dp[i][j - 1] + 1
                delete = dp[i - 1][j] + 1
                dp[i][j] = min(substitute, insert, delete)

    return dp[r_len][len(hyp_chars)] / r_len

def wer_and_align(refs, inp_path, 
                  asr_model, align_model, metadata, batch_size=1, device=None):
    inp_wav = load_wav(inp_path, target_sr=16000)[0].numpy()
    result = asr_model.transcribe(inp_wav, batch_size=batch_size)
    result = whisperx.align(result["segments"], align_model, metadata, inp_wav, device, return_char_alignments=False)["segments"]

    result_words = []
    for seg in result:
        result_words.extend(seg["words"])

    words = [w["word"] for w in result_words]
    used_range = [False] * len(words)
    results = []
    segments = []
    inp_wav_np = inp_wav.flatten()

    for ref in refs:
        ref_words = ref.strip().split()
        ref_len = len(ref_words)

        best_score = float("inf")
        best_start = -1
        best_end = -1
        best_text = ""

        for start in range(len(words)):
            for end in range(start + 1, min(len(words) + 1, start + ref_len + 10)):

                if any(used_range[start:end]):
                    continue

                hyp_words = [result_words[i]["word"] for i in range(start, end)]
                score = wer_cal(" ".join(ref_words), " ".join(hyp_words))
                
                if score < best_score:
                    best_score = score
                    best_start = start
                    best_end = end
                    best_text = " ".join(hyp_words)

        if best_start != -1:
            for i in range(best_start, best_end):
                used_range[i] = True

            start_time = result_words[best_start]["start"]
            end_time = result_words[best_end - 1]["end"]

            results.append({
                "ref": ref,
                "match": best_text,
                "start_time": start_time,
                "end_time": end_time,
                "wer_score": best_score
            })

            segments.append(
                inp_wav_np[int(start_time * 16000): int(end_time * 16000)]
            )

    est_txt = " ".join(words)
    cer_score = wer_cal(" ".join(refs), est_txt)

    return results, segments, cer_score, inp_wav_np

#########################################################################
def emo_sim(refs, hyps, emo2vec, audonnx_model, sr=16000):
    emo2vec_sims = []
    audonnx_sims = []
    for i in range(len(refs)):
        ref = refs[i]
        hyp = hyps[i]
        e2v_hyp_embs = []
        for hyp_item in slice_audio(hyp):
            generated_emb = emo2vec.generate(hyp_item, granularity="utterance", extract_embedding=True, disable_pbar=True)[0]["feats"] # 1024
            e2v_hyp_embs.append(generated_emb)
        e2v_ref_embs = []
        for ref_item in slice_audio(ref):
            tgt_emb = emo2vec.generate(ref_item, granularity="utterance", extract_embedding=True, disable_pbar=True)[0]["feats"] # 1024
            e2v_ref_embs.append(tgt_emb)
        generated_emb = np.mean(e2v_hyp_embs, axis=0)
        tgt_emb = np.mean(e2v_ref_embs, axis=0)
        simi = float(F.cosine_similarity(torch.FloatTensor([generated_emb]), torch.FloatTensor([tgt_emb])).item())
        emo2vec_sims.append(simi)
        
        audo_hyp_embs = []
        for hyp_item in slice_audio(hyp):
            generated_emb = audonnx_model(hyp_item.flatten(), sr)["hidden_states"][0] # 1024
            audo_hyp_embs.append(generated_emb)
        audo_ref_embs = []
        for ref_item in slice_audio(ref):
            tgt_emb = audonnx_model(ref.flatten(), sr)["hidden_states"][0] # 1024
            audo_ref_embs.append(tgt_emb)
        generated_emb = np.mean(audo_hyp_embs, axis=0)
        tgt_emb = np.mean(audo_ref_embs, axis=0)
        simi = float(F.cosine_similarity(torch.FloatTensor([generated_emb]), torch.FloatTensor([tgt_emb])).item())
        audonnx_sims.append(simi)
    return np.mean(emo2vec_sims), np.mean(audonnx_sims)

################################################################
MODEL_LIST = ['ecapa_tdnn', 'hubert_large', 'wav2vec2_xlsr', 'unispeech_sat', "wavlm_base_plus", "wavlm_large"]

def init_model(model_name, checkpoint=r"./pretrained_models/model_temp/speaker/wavlm_large_finetune.pth"):
    if model_name == 'unispeech_sat':
        config_path = 'config/unispeech_sat.th'
        model = ECAPA_TDNN_SMALL(feat_dim=1024, feat_type='unispeech_sat', config_path=config_path)
    elif model_name == 'wavlm_base_plus':
        config_path = None
        model = ECAPA_TDNN_SMALL(feat_dim=768, feat_type='wavlm_base_plus', config_path=config_path)
    elif model_name == 'wavlm_large':
        config_path = None
        model = ECAPA_TDNN_SMALL(feat_dim=1024, feat_type='wavlm_large', config_path=config_path)
    elif model_name == 'hubert_large':
        config_path = None
        model = ECAPA_TDNN_SMALL(feat_dim=1024, feat_type='hubert_large_ll60k', config_path=config_path)
    elif model_name == 'wav2vec2_xlsr':
        config_path = None
        model = ECAPA_TDNN_SMALL(feat_dim=1024, feat_type='wav2vec2_xlsr', config_path=config_path)
    else:
        model = ECAPA_TDNN_SMALL(feat_dim=40, feat_type='fbank')

    if checkpoint is not None:
        state_dict = torch.load(checkpoint, map_location=lambda storage, loc: storage)
        model.load_state_dict(state_dict['model'], strict=False)
    return model

def slice_audio(audio, dim=0):
    if len(audio) > 8*16000:
        subarrays = np.array_split(audio, len(audio) // (8*16000) + 1, axis=dim)
        # print(len(subarrays))
        return subarrays
    else:
        return [audio]

def cos_simi_resemb(model, generated_wav, tgt_wav, sr=16000):
    with torch.no_grad():
        gen_wav = preprocess_wav(generated_wav, sr)
        tgt_wav = preprocess_wav(tgt_wav, sr)
        gens = []
        for gen_item in slice_audio(gen_wav):
            gen_embed = model.embed_utterance(gen_item)
            gens.append(gen_embed)
        tgts = []
        for tgt_item in slice_audio(tgt_wav):
            tgt_embed = model.embed_utterance(tgt_item)
            tgts.append(tgt_embed)
        gen_embed = np.mean(gens, axis=0)
        tgt_embed = np.mean(tgts, axis=0)
        sim = F.cosine_similarity(torch.tensor([gen_embed]), torch.tensor([tgt_embed])).item()
        return float(sim)

def cos_simi_wavlm_ecapa(model, generated_wav, tgt_wav, sr=16000):
    with torch.no_grad():
        wav1, sr1 = generated_wav, sr
        wav2, sr2 = tgt_wav,sr
        gens = []
        for gen_item in slice_audio(wav1):
            wav1 = torch.from_numpy(gen_item).unsqueeze(0).float()
            if sr1 != 16000:
                resample1 = Resample(orig_freq=sr1, new_freq=16000)
                wav1 = resample1(wav1)
            wav1 = wav1.cuda()
            emb1 = model(wav1)
            gens.append(emb1)
        tgts = []
        for tgt_item in slice_audio(wav2):
            wav2 = torch.from_numpy(tgt_item).unsqueeze(0).float()
            if sr2 != 16000:
                resample2 = Resample(orig_freq=sr2, new_freq=16000)
                wav2 = resample2(wav2)
            wav2 = wav2.cuda()
            emb2 = model(wav2)
            tgts.append(emb2)
        
        emb1 = torch.mean(torch.stack(gens), dim=0)
        emb2 = torch.mean(torch.stack(tgts), dim=0)
        
        sim = F.cosine_similarity(emb1, emb2).detach().cpu().item()
        return float(sim)

def spk_sim(refs, hyps, resemb_model, wavlm_model):
    resem_simis = []
    wavlm_simis = []
    for i in range(len(refs)):
        ref = refs[i]
        hyp = hyps[i]
        resem_simis.append(cos_simi_resemb(resemb_model, hyp, ref))
        wavlm_simis.append(cos_simi_wavlm_ecapa(wavlm_model, hyp, ref))
    return np.mean(resem_simis), np.mean(wavlm_simis)

##############################################################
def _get_conv_layer(
    in_channels: int,
    out_channels: int,
    kernel_size: [int, int] = (3, 3),
    padding: [int, int] = (1, 1),
    activation_fn = nn.ReLU,
    max_pool_size = 3,
    dropout = 0.3,
    bn: bool = False,
):
    """Returns a CBAD layer: Convolution, Batch normalization, Activation, and Dropout."""
    layers = [nn.Conv2d(
        in_channels=in_channels,
        out_channels=out_channels,
        kernel_size=kernel_size,
        padding=padding
    )]
    if bn:
        layers.append(nn.BatchNorm2d(out_channels))
    layers.append(activation_fn())
    if max_pool_size is not None:
        layers.append(nn.MaxPool2d(max_pool_size))
    if dropout is not None:
        layers.append(nn.Dropout(dropout))
    return nn.Sequential(*layers)


@gin.configurable
def _dense_layer(
    in_dim: int,
    out_dim: int,
    use_ln: bool,
    use_activation: bool,
    activation_fn: Any = nn.ReLU,
) -> nn.Sequential:
    """Returns Sequential Dense-OptionalLayerNorm-OptionalActivation layer."""
    layers = [nn.Linear(in_dim, out_dim)]
    if use_ln:
        layers.append(nn.LayerNorm(out_dim))
    if use_activation:
        layers.append(activation_fn())
    return nn.Sequential(*layers)


@gin.configurable
class Encoder(nn.Module):
    
    def __init__(self, bn: bool = True, max_pool_size: int = 3, activation_fn: Any = nn.ReLU):
        super(Encoder, self).__init__()
        self.encoder = nn.Sequential(
            _get_conv_layer(1, 32, bn=bn, max_pool_size=None, activation_fn=activation_fn),
            _get_conv_layer(32, 32, bn=bn, max_pool_size=max_pool_size, activation_fn=activation_fn),
            _get_conv_layer(32, 64, bn=bn, max_pool_size=None, activation_fn=activation_fn),
            _get_conv_layer(64, 64, bn=bn, max_pool_size=None, dropout=None, activation_fn=activation_fn),
        )
        self._flatten = nn.Flatten()

    def forward(self, spec: torch.Tensor) -> torch.Tensor:
        # input speech_spectrum shape (batch, 1, max_seq_len, n_features)
        embeddings = self.encoder(spec) # shape (batch, 64, max_seq_len, n_features)
        embeddings = F.max_pool2d(embeddings, kernel_size=embeddings.size()[2:])
        return self._flatten(embeddings)


@gin.configurable
class Head(nn.Module):

    def __init__(self, use_ln: bool = False, activation_fn: Any = nn.ReLU, in_dim: int = 64):
        super(Head, self).__init__()
        self.head = nn.Sequential(
            _dense_layer(in_dim, 64, use_ln, True, activation_fn),
            _dense_layer(64, 64, use_ln, True, activation_fn),
            _dense_layer(64, 2, False, False),
        )

    def forward(self, embeddings: torch.Tensor) -> torch.Tensor:
        return self.head(embeddings)


@gin.configurable
class DnsmosPro(nn.Module):

    def __init__(self, encoder_cls: Type[nn.Module] = Encoder, head_cls: Type[nn.Module] = Head):
        super(DnsmosPro, self).__init__()
        self._encoder = encoder_cls()
        self._head = head_cls()
        self._softplus = nn.Softplus()

    def encoder(self, speech_spectrum: torch.Tensor) -> torch.Tensor:
        return self._encoder(speech_spectrum)

    def forward(self, speech_spectrum: torch.Tensor) -> torch.Tensor:
        embeddings = self._encoder(speech_spectrum)
        predictions = self._head(embeddings)
        mean_predictions = 2 * predictions[:, 0].unsqueeze(1) + 3
        var_predictions = 4 * self._softplus(predictions[:, 1].unsqueeze(1))
        predictions = torch.cat((mean_predictions, var_predictions), dim=1)
        return predictions


@gin.configurable
class DnsmosEncoder(nn.Module):
    
    def __init__(self):
        super(DnsmosEncoder, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=32, kernel_size=(3,3), padding=(1,1)),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Dropout(0.3),
            nn.Conv2d(in_channels=32, out_channels=32, kernel_size=(3,3), padding=(1,1)),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Dropout(0.3),
            nn.Conv2d(in_channels=32, out_channels=32, kernel_size=(3,3), padding=(1,1)),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Dropout(0.3),
            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=(3,3), padding=(1,1)),
            nn.ReLU(),
        )
        
    def forward(self, speech_spectrum):
        # input speech_spectrum shape (batch, 1, max_seq_len, n_features)
        batch = speech_spectrum.shape[0]
        time = speech_spectrum.shape[2]
        speech_spectrum = self.encoder(speech_spectrum) # shape (batch, 64, max_seq_len, n_features)
        embeddings = F.max_pool2d(speech_spectrum, kernel_size=speech_spectrum.size()[2:])
        embeddings = embeddings.view(batch, -1) # shape (batch, 64)
        return embeddings


@gin.configurable
class DnsmosHead(nn.Module):

    def __init__(self):
        super(DnsmosHead, self).__init__()
        self.head = nn.Sequential(
            nn.Linear(64, 64),
            nn.ReLU(),
            nn.Linear(64, 64),
            nn.ReLU(),
            nn.Linear(64, 1),
        )

    def forward(self, embeddings):
        # input embeddings shape (batch, 64)
        prediction = self.head(embeddings)
        return prediction


@gin.configurable
class DnsmosClassic(nn.Module):
    
    def __init__(self):
        super(DnsmosClassic, self).__init__()
        self.encoder = DnsmosEncoder()
        self.head = DnsmosHead()

    def forward(self, speech_spectrum: torch.Tensor) -> torch.Tensor:
        embeddings = self.encoder(speech_spectrum)
        prediction = self.head(embeddings)
        return prediction


@gin.configurable
class Mosnet(nn.Module):

    def __init__(self):
        super(Mosnet, self).__init__()
        self.mean_conv = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=16, kernel_size=(3,3), padding=(1,1)),
            nn.ReLU(),
            nn.Conv2d(in_channels=16, out_channels=16, kernel_size=(3,3), padding=(1,1)),
            nn.ReLU(),
            nn.Conv2d(in_channels=16, out_channels=16, kernel_size=(3,3), padding=(1,1), stride=(1,3)),
            nn.ReLU(),
            nn.Conv2d(in_channels=16, out_channels=32,  kernel_size=(3,3), padding=(1,1)),
            nn.ReLU(),
            nn.Conv2d(in_channels=32, out_channels=32, kernel_size=(3,3), padding=(1,1)),
            nn.ReLU(),
            nn.Conv2d(in_channels=32, out_channels=32, kernel_size=(3,3), padding=(1,1), stride=(1,3)),
            nn.ReLU(),
            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=(3,3), padding=(1,1)),
            nn.ReLU(),
            nn.Conv2d(in_channels=64, out_channels=64, kernel_size=(3,3), padding=(1,1)),
            nn.ReLU(),
            nn.Conv2d(in_channels=64, out_channels=64, kernel_size=(3,3), padding=(1,1), stride=(1,3)),
            nn.ReLU(),
            nn.Conv2d(in_channels=64, out_channels=128, kernel_size=(3,3), padding=(1,1)),
            nn.ReLU(),
            nn.Conv2d(in_channels=128, out_channels=128, kernel_size=(3,3), padding=(1,1)),
            nn.ReLU(),
            nn.Conv2d(in_channels=128, out_channels=128, kernel_size=(3,3), padding=(1,1), stride=(1,3)),
            nn.ReLU()
        )

        self.mean_rnn = nn.LSTM(
            input_size=512,
            hidden_size=128,
            num_layers=1,
            batch_first=True,
            dropout=0.3,
            bidirectional=True
        )
        
        self.mean_MLP = nn.Sequential(
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(128,1)
        )

    def encoder(self, speech_spectrum: torch.Tensor) -> torch.Tensor:
        batch = speech_spectrum.shape[0]
        time = speech_spectrum.shape[2]
        speech_spectrum = self.mean_conv(speech_spectrum)  # shape (batch, 64, max_seq_len, n_features)
        embeddings = F.max_pool2d(speech_spectrum, kernel_size=speech_spectrum.size()[2:])
        embeddings = embeddings.view(batch, -1)  # shape (batch, 64)
        return embeddings

    def forward(self, speech_spectrum: torch.Tensor) -> torch.Tensor:
        # input speech_spectrum shape (batch, 1, max_seq_len, 257)
        batch = speech_spectrum.shape[0]
        time = speech_spectrum.shape[2]
        speech_spectrum = self.mean_conv(speech_spectrum)  # shape (batch, 128, max_seq_len, 4)
        speech_spectrum = speech_spectrum.view((batch, time, 512))  # shape (batch, max_seq_len, 512)
        speech_spectrum, (h, c) = self.mean_rnn(speech_spectrum)  # shape (batch, max_seq_len, 256)
        mos_mean = self.mean_MLP(speech_spectrum)  # shape (batch, max_seq_len, 1)
        return mos_mean


@gin.configurable
def stft(
    samples: np.ndarray,
    win_length: int = 320,
    hop_length: int = 160,
    n_fft: int = 320,
    use_log: bool = True,
    use_magnitude: bool = True,
    n_mels: Optional[int] = None,
) -> np.ndarray:
    if use_log and not use_magnitude:
        raise ValueError('Log is only available if the magnitude is to be computed.')
    if n_mels is None:
        spec = librosa.stft(y=samples, win_length=win_length, hop_length=hop_length, n_fft=n_fft)
    else:
        spec = librosa.feature.melspectrogram(
            y=samples, win_length=win_length, hop_length=hop_length, n_fft=n_fft, n_mels=n_mels
        )
    spec = spec.T
    if use_magnitude:
        spec = np.abs(spec)
    if use_log:
        spec = np.clip(spec, 10 ** (-7), 10 ** 7)
        spec = np.log10(spec)
    return spec

def pad_to_fit_window(samples, frame_length=32000, hop_length=16000):
    total_len = len(samples)
    num_frames = (total_len - frame_length) // hop_length + 1
    last_frame_start = num_frames * hop_length
    remaining = total_len - last_frame_start

    if remaining < frame_length:
        pad_len = frame_length - remaining
        pad = samples[-pad_len:][::-1]
        samples = np.concatenate([samples, pad])
    return samples

def dnsmos_pro(dnsmos_pro_model, samples, frame_length=32000, hop_length=16000):
    samples = pad_to_fit_window(samples, frame_length=frame_length, hop_length=hop_length)
    num_frames = (len(samples) - frame_length) // hop_length + 1
    frames = np.stack([
        samples[i * hop_length : i * hop_length + frame_length]
        for i in range(num_frames)
    ])
    spec_batch = torch.stack([torch.FloatTensor(stft(frame)) for frame in frames])
    spec_batch = spec_batch.unsqueeze(1)  
    prediction = dnsmos_pro_model(spec_batch)
    mos = prediction[:, 0].detach().numpy()
    return " ".join(mos.astype(str)), np.var(mos)

##############################################################
def save_audio_4_autopcp(hyps, save_home, line_i, sample_rate=16000):
    hyp_save_paths = []
    for i in range(len(hyps)):
        hyp = hyps[i]
        hyp_save_path = os.path.join(save_home, f"{line_i}_{i}.wav")
        save_audio(hyp_save_path, hyp.flatten(), sample_rate=sample_rate)
        hyp_save_paths.append(hyp_save_path)
    return hyp_save_paths

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Inference')
    parser.add_argument('--tsv', type=str)
    parser.add_argument('--out_home', type=str)
    parser.add_argument('--ref_home', type=str)
    args = parser.parse_args()
    
    device = "cuda" 
    compute_type = "float16" # change to "int8" if low on GPU mem (may reduce accuracy)
    asr_model = whisperx.load_model("large-v2", device, compute_type=compute_type)
    asr_align, asr_align_metadata = whisperx.load_align_model(language_code="en", device=device)
    
    resemb_model = VoiceEncoder(device=torch.device("cuda"))
    resemb_model.eval()
    
    wavlm_ecapa = init_model("wavlm_large", os.path.join(os.environ.get("wavlm_home", "./pretrained_models/model_temp/speaker/"), r"wavlm_large_finetune.pth")).cuda()
    wavlm_ecapa.eval()
    
    emo2vec_model = AutoModel(model=os.environ.get("emo2vec_home", "./pretrained_models/modelscope/iic/emotion2vec_plus_large/"), disable_update=True)
    
    dns_model = torch.jit.load(os.path.join(os.environ.get("dns_home", "./pretrained_models/DNSMOSPro/runs/NISQA/"), 'model_best.pt'), map_location=torch.device('cpu'))
    
    url = 'https://zenodo.org/record/6221127/files/w2v2-L-robust-12.6bc4a7fd-1.1.0.zip'
    cache_root = audeer.mkdir(os.environ.get("audeer_home", "./pretrained_models/model_temp/")+'/cache')
    model_root = audeer.mkdir(os.environ.get("audeer_home", "./pretrained_models/model_temp/")+'/model')
    archive_path = audeer.download_url(url, cache_root, verbose=True)
    audeer.extract_archive(archive_path, model_root)
    audonnx_model = audonnx.load(model_root)
    
    with open(args.tsv, "r", encoding="utf-8") as rf:
        infos = json.load(rf)
    
    src_audio_paths = []
    tgt_audio_paths = []
    
    info_dict = {
        "cer": [],
        "emo2vec_score": [],
        "audonnx_score": [],
        "resemb_score": [],
        "wavlm_score": [],
        "dnsmos_mean": [],
        "dnsmos_var": [],
        "generated_path": [],
        "keys": [],
        "line_i": [],
    }
        
    for key_idx, key in enumerate(tqdm(infos.keys())):
        with torch.no_grad():
            utt_items = infos[key].get("script", [])
            for line_i, line_item in enumerate(utt_items):
                generated_path = os.path.join(args.out_home, "temp_signal", f"{key}_{line_i}.wav")
                if os.path.exists(generated_path):
                    start = time.time()
                    lines = line_item["text"]
                    spk_ref = os.path.join(args.ref_home, lines[0]["target_speaker"])
                    emotion_refs = [
                        lib.load(os.path.join(args.ref_home, lines[i]["prompt_speech"]), sr=16000)[0]
                        for i in range(len(lines))
                    ]
                    text_ref = [lines[i]["lines"] for i in range(len(lines))]
                    emotions = [lines[i]["emotion_refined"] for i in range(len(lines))]
                    speeds = [lines[i]["speed_refined"] for i in range(len(lines))]
                    energys = [lines[i]["energy_refined"] for i in range(len(lines))]
                    print(f"READ: {time.time()-start}")
                    
                    start = time.time()
                    segment_infos, segments_16k, cer, generated_wav = wer_and_align(text_ref, generated_path, asr_model, asr_align, asr_align_metadata, device=device)
                    print(f"ASR: {time.time()-start}")
                    start = time.time()
                    emo2vec_score, audonnx_score = emo_sim(
                        emotion_refs, segments_16k, emo2vec_model, audonnx_model
                    )
                    print(f"EMO: {time.time()-start}")
                    start = time.time()
                    resemb_score, wavlm_score = spk_sim(
                        [lib.load(spk_ref, sr=16000)[0]] * len(emotion_refs), 
                        segments_16k, resemb_model, wavlm_ecapa
                    )
                    print(f"SPK: {time.time()-start}")
                    start = time.time()
                    dnsmos_mean, dnsmos_var = dnsmos_pro(dns_model, generated_wav)
                    print(f"DNS: {time.time()-start}")
                else:
                    cer = 100.0
                    emo2vec_score = -1
                    audonnx_score = -1
                    resemb_score = -1
                    wavlm_score = -1
                    dnsmos_mean = -1
                    dnsmos_var = -1
                    segments_16k = [np.zeros(16000) for i in range(len(lines))]
                
                info_dict["cer"].append(cer)
                info_dict["emo2vec_score"].append(emo2vec_score)
                info_dict["audonnx_score"].append(audonnx_score)
                info_dict["resemb_score"].append(resemb_score)
                info_dict["wavlm_score"].append(wavlm_score)
                info_dict["generated_path"].append(generated_path)
                info_dict["dnsmos_mean"].append(dnsmos_mean)
                info_dict["dnsmos_var"].append(dnsmos_var)
                info_dict["keys"].append(key)
                info_dict["line_i"].append(line_i)
                print(f"INFO: {time.time()-start}")
                start = time.time()
                
                tgt_audio_path = save_audio_4_autopcp(segments_16k, os.path.join(
                    args.out_home, "wav_segs", f"{key}"
                ), f"{line_i}")
                src_audio_path = [
                    os.path.join(args.ref_home, lines[i]["prompt_speech"])
                    for i in range(len(lines))
                ]
                tgt_audio_paths += tgt_audio_path
                src_audio_paths += src_audio_path
                print(f"WRITE: {time.time()-start}")
                start = time.time()
            
    df = pd.DataFrame(info_dict)
    df.to_csv(os.path.join(args.out_home, "metrics_1st.csv"), index=False, encoding='utf-8')
    log_file = open(os.path.join(args.out_home, "metrics_1st.log"), "w")
    with pd.option_context('display.max_columns', None): 
        print(df.describe(include='all').to_string(), file=log_file)
    log_file.close()
    
    output_tsv_path = os.path.join(args.out_home, "eval4autopcp.tsv")
    with open(output_tsv_path, mode="w", newline="", encoding="utf-8") as f:
        writer = csv.writer(f, delimiter="\t")
        writer.writerow(["src_audio", "tgt_audio"])
        for src, tgt in zip(src_audio_paths, tgt_audio_paths):
            writer.writerow([src, tgt])