import torch.multiprocessing as mp
import argparse
import os
import json
import pandas as pd
import csv
import time
import librosa as lib
import numpy as np
from tqdm import tqdm
import traceback


def process_entry(task_queue, result_queue, args):
    import torch
    import torch.nn.functional as F
    from torchaudio.transforms import Resample
    from funasr import AutoModel
    from resemblyzer import VoiceEncoder, preprocess_wav
    import soundfile as sf
    import audeer
    import audonnx
    from models.ecapa_tdnn import ECAPA_TDNN_SMALL
    from eval_inner_emo_zh import (
        init_model, wer_and_align, emo_sim, spk_sim,
        dnsmos_pro, save_audio_4_autopcp
    )
    
    # 初始化模型
    asr_model = AutoModel(model="paraformer-zh", disable_update=True)
    resemb_model = VoiceEncoder(device=torch.device("cuda"))
    resemb_model.eval()
    emo2vec_model = AutoModel(model=os.environ.get("emo2vec_home", "./pretrained_models/modelscope/iic/emotion2vec_plus_large/"), disable_update=True)
    dns_model = torch.jit.load(os.path.join(os.environ.get("dns_home", "./pretrained_models/DNSMOSPro/runs/NISQA/"), 'model_best.pt'), map_location=torch.device('cpu'))

    audeer_ok = False
    while not audeer_ok:
        try:
            wavlm_ecapa = init_model("wavlm_large", os.path.join(os.environ.get("wavlm_home", "./pretrained_models/model_temp/speaker/"), r"wavlm_large_finetune.pth")).cuda()
            wavlm_ecapa.eval()
            url = 'https://zenodo.org/record/6221127/files/w2v2-L-robust-12.6bc4a7fd-1.1.0.zip'
            cache_root = audeer.mkdir(os.environ.get("audeer_home", "./pretrained_models/model_temp/")+'/cache')
            model_root = audeer.mkdir(os.environ.get("audeer_home", "./pretrained_models/model_temp/")+'/model')
            archive_path = audeer.download_url(url, cache_root, verbose=True)
            audeer.extract_archive(archive_path, model_root)
            audonnx_model = audonnx.load(model_root)
            audeer_ok = True
        except:
            pass

    while True:
        task = task_queue.get()
        if task is None:
            break
        try:
            key, line_i, lines, generated_path, ref_home, speed_dyn = task

            text_ref = [l["lines"] for l in lines]
            
            emotion_refs = [lib.load(os.path.join(ref_home, l["prompt_speech"]), sr=16000)[0] for l in lines]
            if speed_dyn:
                librosa_rate = [1/float(l["speed_refined"]) for l in lines]
                emotion_refs = [lib.effects.time_stretch(emotion_refs[i], rate=librosa_rate[i]) for i in range(len(emotion_refs))]
            
            spk_ref = os.path.join(ref_home, lines[0]["target_speaker"])

            segment_infos, segments_16k, cer, generated_wav = wer_and_align(text_ref, generated_path, asr_model)
            # print(segment_infos, len(segments_16k))
            emo2vec_score, audonnx_score = emo_sim(emotion_refs, segments_16k, emo2vec_model, audonnx_model)
            resemb_score, wavlm_score = spk_sim([lib.load(spk_ref, sr=16000)[0]] * len(emotion_refs), segments_16k, resemb_model, wavlm_ecapa)
            dnsmos_mean, dnsmos_var = dnsmos_pro(dns_model, generated_wav)
        except Exception as e:
            traceback.print_exc()  
            cer = 100.0
            emo2vec_score = -1
            audonnx_score = -1
            resemb_score = -1
            wavlm_score = -1
            dnsmos_mean = -1
            dnsmos_var = -1
            segments_16k = [np.zeros(16000) for _ in range(len(lines))]

        tgt_audio_path = save_audio_4_autopcp(segments_16k, os.path.join(args.out_home, "wav_segs", f"{key}"), f"{line_i}")
        if speed_dyn:
            src_audio_path = save_audio_4_autopcp(emotion_refs, os.path.join(args.out_home, "ref_segs_speeddyn", f"{key}"), f"{line_i}")
        else:
            src_audio_path = [os.path.join(ref_home, l["prompt_speech"]) for l in lines]
        flags = [f"{key}_{line_i}" for i in range(len(src_audio_path))]

        result_queue.put({
            "cer": cer,
            "emo2vec_score": emo2vec_score,
            "audonnx_score": audonnx_score,
            "resemb_score": resemb_score,
            "wavlm_score": wavlm_score,
            "dnsmos_mean": dnsmos_mean,
            "dnsmos_var": dnsmos_var * 100.0,
            "generated_path": generated_path,
            "keys": key,
            "line_i": line_i,
            "src_audio": src_audio_path,
            "tgt_audio": tgt_audio_path,
            "flag": flags
        })

def get_error_num(file_path):
    if not os.path.exists(file_path):
        return 100 # need to compute
    abnormal_values = {
        "cer": 100.0,
        "emo2vec_score": -1,
        "audonnx_score": -1,
        "resemb_score": -1,
        "wavlm_score": -1,
        "dnsmos_mean": -1,
        "dnsmos_var": -1
    }
    
    # 读取CSV文件
    df = pd.read_csv(file_path)
    
    # 初始化异常行计数器
    abnormal_rows_count = 0
    
    # 遍历每一行
    for index, row in df.iterrows():
        is_abnormal = False
        for column, value in abnormal_values.items():
            if column in df.columns and row[column] == value:
                is_abnormal = True
                break 
        if is_abnormal:
            abnormal_rows_count += 1
    return abnormal_rows_count

def main():
    parser = argparse.ArgumentParser(description='Parallel Inference')
    parser.add_argument('--tsv', type=str)
    parser.add_argument('--out_home', type=str)
    parser.add_argument('--ref_home', type=str)
    parser.add_argument('--num_workers', type=int, default=4)
    args = parser.parse_args()
    
    if get_error_num(os.path.join(args.out_home, "metrics_all.csv")) < 50:
        exit()
        
    with open(args.tsv, "r", encoding="utf-8") as rf:
        infos = json.load(rf)
        
    if "nosp" not in args.out_home:
        speed_dyn = True
    else:
        speed_dyn = False

    task_queue = mp.Queue()
    result_queue = mp.Queue()

    processes = [mp.Process(target=process_entry, args=(task_queue, result_queue, args)) for _ in range(args.num_workers)]
    for p in processes:
        p.start()
        time.sleep(10)

    total_tasks = 0
    for key in infos:
        utt_items = infos[key].get("script", [])
        for line_i, line_item in enumerate(utt_items):
            task_queue.put((key, line_i, line_item["text"], os.path.join(args.out_home, "temp_signal", f"{key}_{line_i}.wav"), args.ref_home, speed_dyn))
            total_tasks += 1
    
    for _ in range(args.num_workers):
        task_queue.put(None)  # Stop signal

    info_dict = {
        "cer": [], "emo2vec_score": [], "audonnx_score": [], "resemb_score": [], "wavlm_score": [],
        "dnsmos_mean": [], "dnsmos_var": [], "generated_path": [], "keys": [], "line_i": []
    }
    src_audio_paths, tgt_audio_paths, flags = [], [], []

    for _ in tqdm(range(total_tasks)):
        res = result_queue.get()
        for k in info_dict:
            info_dict[k].append(res[k])
        src_audio_paths += res["src_audio"]
        tgt_audio_paths += res["tgt_audio"]
        flags += res["flag"]

    df = pd.DataFrame(info_dict)
    df.to_csv(os.path.join(args.out_home, "metrics_1st.csv"), index=False, encoding='utf-8')
    with open(os.path.join(args.out_home, "metrics_1st.log"), "w") as log_file:
        with pd.option_context('display.max_columns', None):
            print(df.describe(include='all').to_string(), file=log_file)

    output_tsv_path = os.path.join(args.out_home, "eval4autopcp.tsv")
    output_flags_path = os.path.join(args.out_home, "eval4autopcp_flag.tsv")
    with open(output_tsv_path, mode="w", newline="", encoding="utf-8") as f1:
        with open(output_flags_path, mode="w", newline="", encoding="utf-8") as f2:
            writer = csv.writer(f1, delimiter="\t")
            writer.writerow(["src_audio", "tgt_audio"])
            for src, tgt, flag in zip(src_audio_paths, tgt_audio_paths, flags):
                writer.writerow([src, tgt])
                print(flag, file=f2)

if __name__ == '__main__':
    mp.set_start_method('spawn', force=True)
    main()
