# 多进程重构版本入口点
import torch.multiprocessing as mp
import argparse
import os
import json
import pandas as pd
import csv
import time
import librosa as lib
import numpy as np
from tqdm import tqdm
import traceback

def process_entry(task_queue, result_queue, args):
    import torch
    import torch.nn.functional as F
    from torchaudio.transforms import Resample
    from funasr import AutoModel
    from resemblyzer import VoiceEncoder, preprocess_wav
    import soundfile as sf
    import audeer
    import audonnx
    import whisperx
    from models.ecapa_tdnn import ECAPA_TDNN_SMALL
    from eval_inner_emo_en import (
        init_model, wer_and_align, emo_sim, spk_sim,
        dnsmos_pro, save_audio_4_autopcp
    )

    device = "cuda" 
    compute_type = "float16" # change to "int8" if low on GPU mem (may reduce accuracy)
    asr_model = whisperx.load_model("large-v2", device, compute_type=compute_type)
    asr_align, asr_align_metadata = whisperx.load_align_model(language_code="en", device=device)
    resemb_model = VoiceEncoder(device=torch.device("cuda"))
    resemb_model.eval()
    emo2vec_model = AutoModel(model=os.environ.get("emo2vec_home", "./pretrained_models/modelscope/iic/emotion2vec_plus_large/"), disable_update=True)
    dns_model = torch.jit.load(os.path.join(os.environ.get("dns_home", "./pretrained_models/DNSMOSPro/runs/NISQA/"), 'model_best.pt'), map_location=torch.device('cpu'))

    audeer_ok = False
    while not audeer_ok:
        try:
            wavlm_ecapa = init_model("wavlm_large", os.path.join(os.environ.get("wavlm_home", "./pretrained_models/model_temp/speaker/"), r"wavlm_large_finetune.pth")).cuda()
            wavlm_ecapa.eval()
            url = 'https://zenodo.org/record/6221127/files/w2v2-L-robust-12.6bc4a7fd-1.1.0.zip'
            cache_root = audeer.mkdir(os.environ.get("audeer_home", "./pretrained_models/model_temp/")+'/cache')
            model_root = audeer.mkdir(os.environ.get("audeer_home", "./pretrained_models/model_temp/")+'/model')
            archive_path = audeer.download_url(url, cache_root, verbose=True)
            audeer.extract_archive(archive_path, model_root)
            audonnx_model = audonnx.load(model_root)
            audeer_ok = True
        except:
            pass

    while True:
        task = task_queue.get()
        if task is None:
            break
        try:
            key, line_i, lines, generated_path, ref_home, speed_dyn = task

            text_ref = [l["lines"] for l in lines]
            emotion_refs = [lib.load(os.path.join(ref_home, l["prompt_speech"]), sr=16000)[0] for l in lines]
            if speed_dyn:
                librosa_rate = [1/float(l["speed_refined"]) for l in lines]
                emotion_refs = [lib.effects.time_stretch(emotion_refs[i], rate=librosa_rate[i]) for i in range(len(emotion_refs))]
            spk_ref = os.path.join(ref_home, lines[0]["target_speaker"])

        # if os.path.exists(generated_path):
            segment_infos, segments_16k, cer, generated_wav = wer_and_align(text_ref, generated_path, asr_model, asr_align, asr_align_metadata, device=device)
            emo2vec_score, audonnx_score = emo_sim(emotion_refs, segments_16k, emo2vec_model, audonnx_model)
            resemb_score, wavlm_score = spk_sim([lib.load(spk_ref, sr=16000)[0]] * len(emotion_refs), segments_16k, resemb_model, wavlm_ecapa)
            dnsmos_mean, dnsmos_var = dnsmos_pro(dns_model, generated_wav)
        except Exception as e:
            traceback.print_exc()
            cer = 100.0
            emo2vec_score = -1
            audonnx_score = -1
            resemb_score = -1
            wavlm_score = -1
            dnsmos_mean = -1
            dnsmos_var = -1
            segments_16k = [np.zeros(16000) for _ in range(len(lines))]

        tgt_audio_path = save_audio_4_autopcp(segments_16k, os.path.join(args.out_home, "wav_segs", f"{key}"), f"{line_i}")
        if speed_dyn:
            src_audio_path = save_audio_4_autopcp(emotion_refs, os.path.join(args.out_home, "ref_segs_speeddyn", f"{key}"), f"{line_i}")
        else:
            src_audio_path = [os.path.join(ref_home, l["prompt_speech"]) for l in lines]
        flags = [f"{key}_{line_i}" for i in range(len(src_audio_path))]

        result_queue.put({
            "cer": cer,
            "emo2vec_score": emo2vec_score,
            "audonnx_score": audonnx_score,
            "resemb_score": resemb_score,
            "wavlm_score": wavlm_score,
            "dnsmos_mean": dnsmos_mean,
            "dnsmos_var": dnsmos_var * 100.0,
            "generated_path": generated_path,
            "keys": key,
            "line_i": line_i,
            "src_audio": src_audio_path,
            "tgt_audio": tgt_audio_path,
            "flag": flags
        })

def get_error_num(file_path):
    if not os.path.exists(file_path):
        return 100 # need to compute
    abnormal_values = {
        "cer": 100.0,
        "emo2vec_score": -1,
        "audonnx_score": -1,
        "resemb_score": -1,
        "wavlm_score": -1,
        "dnsmos_mean": -1,
        "dnsmos_var": -1
    }
    
    df = pd.read_csv(file_path)
    
    abnormal_rows_count = 0
    
    for index, row in df.iterrows():
        is_abnormal = False
        for column, value in abnormal_values.items():
            if column in df.columns and row[column] == value:
                is_abnormal = True
                break  
        if is_abnormal:
            abnormal_rows_count += 1
    return abnormal_rows_count

def main():
    parser = argparse.ArgumentParser(description='Parallel Inference')
    parser.add_argument('--tsv', type=str)
    parser.add_argument('--out_home', type=str)
    parser.add_argument('--ref_home', type=str)
    parser.add_argument('--num_workers', type=int, default=4)
    args = parser.parse_args()

    if get_error_num(os.path.join(args.out_home, "metrics_all.csv")) < 50:
        exit()
    
    with open(args.tsv, "r", encoding="utf-8") as rf:
        infos = json.load(rf)
        
    if "nosp" not in args.out_home:
        speed_dyn = True
    else:
        speed_dyn = False

    task_queue = mp.Queue()
    result_queue = mp.Queue()

    processes = [mp.Process(target=process_entry, args=(task_queue, result_queue, args)) for _ in range(args.num_workers)]
    for p in processes:
        p.start()
        time.sleep(10)

    total_tasks = 0
    for key in infos:
        utt_items = infos[key].get("script", [])
        for line_i, line_item in enumerate(utt_items):
            task_queue.put((key, line_i, line_item["text"], os.path.join(args.out_home, "temp_signal", f"{key}_{line_i}.wav"), args.ref_home, speed_dyn))
            total_tasks += 1

    for _ in range(args.num_workers):
        task_queue.put(None)  # Stop signal

    info_dict = {
        "cer": [], "emo2vec_score": [], "audonnx_score": [], "resemb_score": [], "wavlm_score": [],
        "dnsmos_mean": [], "dnsmos_var": [], "generated_path": [], "keys": [], "line_i": []
    }
    src_audio_paths, tgt_audio_paths, flags = [], [], []

    for _ in tqdm(range(total_tasks)):
        res = result_queue.get()
        for k in info_dict:
            info_dict[k].append(res[k])
        src_audio_paths += res["src_audio"]
        tgt_audio_paths += res["tgt_audio"]
        flags += res["flag"]

    df = pd.DataFrame(info_dict)
    df.to_csv(os.path.join(args.out_home, "metrics_1st.csv"), index=False, encoding='utf-8')
    with open(os.path.join(args.out_home, "metrics_1st.log"), "w") as log_file:
        with pd.option_context('display.max_columns', None):
            print(df.describe(include='all').to_string(), file=log_file)

    output_tsv_path = os.path.join(args.out_home, "eval4autopcp.tsv")
    output_flags_path = os.path.join(args.out_home, "eval4autopcp_flag.tsv")
    with open(output_tsv_path, mode="w", newline="", encoding="utf-8") as f1:
        with open(output_flags_path, mode="w", newline="", encoding="utf-8") as f2:
            writer = csv.writer(f1, delimiter="\t")
            writer.writerow(["src_audio", "tgt_audio"]) 
            for src, tgt, flag in zip(src_audio_paths, tgt_audio_paths, flags):
                writer.writerow([src, tgt])
                print(flag, file=f2)

if __name__ == '__main__':
    mp.set_start_method('spawn', force=True)
    main()
