import argparse
import glob
import os
import random
import soundfile as sf
from tqdm import tqdm
import pandas as pd
import numpy as np
import ffmpeg
import re
from collections import defaultdict
import json
import os
os.environ["OMP_NUM_THREADS"] = "4"
os.environ["OPENBLAS_NUM_THREADS"] = "4"
os.environ["MKL_NUM_THREADS"] = "4"
os.environ["VECLIB_MAXIMUM_THREADS"] = "4"
os.environ["NUMEXPR_NUM_THREADS"] = "4"
from examples.wescon.models.cosyvoice2.cli.cosyvoice import CosyVoice2
from examples.wescon.models.cosyvoice2.utils.file_utils import load_wav
import string
punctuation_string = string.punctuation
import re
import numpy as np
from tqdm import tqdm
import argparse
import json
import torch
import time
import torchaudio
import torch.multiprocessing as mp
from copy import deepcopy

def contains_chinese(s):
    return re.search(r'[\u4e00-\u9fff]+', s) is not None

def convert2wav(path):
    save_path = path.split(".")[0] + ".wav"
    if os.path.exists(save_path):
        return save_path
    stream = ffmpeg.input(path)
    stream = ffmpeg.output(stream, save_path, ar=16000, ac=1)
    ffmpeg.run(stream)
    return save_path

def find_spk(text):
    pattern = r'\bsubject ([0-9]|[1-4][0-9]|50)\b'
    matches = re.findall(pattern, text)
    return matches

from pathlib import Path
def get_all_wavs(root, suffix):
    files = []
    for p in Path(root).iterdir():
        if str(p).endswith(".%s"%suffix):
            files.append(str(p))
        for s in p.rglob("*.%s"%suffix):
            files.append(str(s))
    return list(set(files))

def extract_features(tsv, max_p, p_id, save_dir):
    cosyvoice = CosyVoice2('./pretrained_models/CosyVoice2-0.5B', load_jit=True, load_onnx=False, load_trt=False, only_front=True)
    front_end = cosyvoice.frontend
    sample_rate = cosyvoice.sample_rate
    total_infos = []
    
    with open(tsv, "r", encoding="utf-8") as file:
        for line in file:
            parts = line.strip().split("\t")
            if len(parts) < 6:
                continue
            total_infos.append(parts)
    
    total_infos = total_infos[:]
    slice_len = len(total_infos) // max_p
    start, end = p_id * slice_len, (p_id + 1) * slice_len if p_id != max_p - 1 else len(total_infos)
    data = defaultdict(lambda: defaultdict(list))
    
    with torch.no_grad():
        for filepath, sample_rate, duration, speaker, emotion, text in tqdm(total_infos[start:end], desc=f"Process-{p_id}"):
            try:
                text = front_end.text_normalize(text.strip(), split=False)
                text_token, token_len = front_end._extract_text_token(text)
                text_token = " ".join(text_token.detach().cpu().numpy().flatten().astype("str"))
                
                denoised_audio_16k = load_wav(filepath, target_sr=16000)
                
                if denoised_audio_16k.size(1) > 16000 * 30:
                    denoised_audio_16k = torch.split(denoised_audio_16k, denoised_audio_16k.size(1)//(denoised_audio_16k.size(1)//(16000 * 20)+1), dim=1)
                    temp_tokens = []
                    for item in denoised_audio_16k:
                        if item.size(1) > 16000 * 0.5:
                            speech_token, speech_token_len = front_end._extract_speech_token(item)
                            speech_token = speech_token.detach().cpu().numpy().flatten()
                            temp_tokens.append(speech_token)
                    speech_token = np.concatenate(temp_tokens).astype("str")
                else:
                    speech_token, speech_token_len = front_end._extract_speech_token(denoised_audio_16k)
                    speech_token = speech_token.detach().cpu().numpy().flatten().astype("str")
                            
                entry = {
                    "filepath": filepath,
                    "sample_rate": int(sample_rate),
                    "duration": int(duration),
                    "text": text,
                    "text_token": text_token,
                    "speech_token": " ".join(speech_token)
                }
                
                data[emotion][speaker].append(entry)
            except Exception as e:
                print(e)
    
    output_path = os.path.join(save_dir, f"features_{p_id}.json")
    with open(output_path, "w", encoding="utf-8") as f:
        json.dump(data, f, ensure_ascii=False, indent=4)
            
            
if __name__ == "__main__":
    emo_dict = {
        "Sad": "sad",
        "Happy": "happy",
        "Angry": "angry",
        "Surprise": "surprised",
        "Neutral": "neutral",
        "快乐": "happy",
        "伤心": "sad",
        "惊喜": "surprised",
        "生气": "angry",
        "中立": "neutral",
    }
    
    root = "./datas/Emotional_Speech_Dataset"
    trans_dict = {}
    
    for path in get_all_wavs(root, "txt"):
        with open(path, "r") as rf:
            if path.find("ReadMe") != -1:
                continue
            lines = [line.strip() for line in rf.readlines() if line.strip() != ""]
            print(lines)
            for line in lines:
                print(line)
                if line[11] != "\t": # 0013_000431 I never had a whooping cough why.   Angry
                    line = line[:11] + "\t" + line[12:]
                name, trans, emo = line.split("\t")
                trans_dict[name] = (trans, emo)
    print(trans_dict)
    
    with open(os.path.join(root, "all_info.tsv"), "w") as out_file:
        for path in tqdm(get_all_wavs(root, "wav")):
            name = os.path.basename(path).split(".")[0]
            if name not in trans_dict.keys():
                print(name)
                continue
            trans, emo = trans_dict[name]
            audio, sr = sf.read(path)
            if len(audio.shape) > 1:
                audio = audio[0]
            if emo not in emo_dict.keys():
                print(emo)
                continue
            emo = emo_dict[emo]
            spk = os.path.basename(path).split("_")[0]
            if len(audio) < 16000*2:
                continue
            print(
                "{}\t{}\t{}\t{}\t{}\t{}".format(path, sr, len(audio), spk, emo, trans), file=out_file
            )
    
    input_file_path = os.path.join(root, "all_info.tsv")
    save_dir = "./datas/"
    # 读取原始文件内容
    os.makedirs(save_dir, exist_ok=True)
    
    max_p = 16
    processes = []
    
    mp.set_start_method("spawn", force=True)
    for p_id in range(max_p):
        p = mp.Process(target=extract_features, args=(input_file_path, max_p, p_id, save_dir))
        p.start()
        processes.append(p)
    
    for p in processes:
        p.join()
    
    final_data = defaultdict(lambda: defaultdict(list))
    for file in Path(save_dir).glob("features_*.json"):
        with open(file, "r", encoding="utf-8") as f:
            data = json.load(f)
            for emo, spk_data in data.items():
                for spk, entries in spk_data.items():
                    final_data[emo][spk].extend(entries)
        os.remove(file)
    
    output_json_path = os.path.join(save_dir, "grouped_emo_speaker_data.json")
    with open(output_json_path, "w", encoding="utf-8") as f:
        json.dump(final_data, f, ensure_ascii=False, indent=4)
    
    print(f"✅ {output_json_path}")
