import os
import numpy as np
import json
import zipfile
from tqdm import tqdm
from pathlib import Path
import multiprocessing
import random
import re
import traceback
from g2p_en import G2p
from collections import defaultdict, Counter

def contains_english(text):
    return bool(re.search(r'[a-zA-Z]', text))

def contains_chinese(text):
    return bool(re.search(r'[\u4e00-\u9fff]', text))

def find_all_zip_files(folder_path):
    zip_files = []
    for root, dirs, files in os.walk(folder_path):
        for file in files:
            if file.lower().endswith('.zip'):
                zip_files.append(os.path.join(root, file))
    return zip_files

def load_prompt_infos(prompt_json):
    with open(prompt_json, "r", encoding="utf-8") as rf:
        prompt_infos = json.load(rf)
    path2info = {}
    for emo in prompt_infos.keys():
        for spk in prompt_infos[emo].keys():
            for item in prompt_infos[emo][spk]:
                path2info[item["filepath"]] = item
    return path2info, prompt_infos

def map_speed(mark, real_min, real_mean, real_max):
    mark = max(min(mark, 2.0), 0.5)
    if mark <= 1:
        real_speed = real_min + (mark - 0.5) / (1 - 0.5) * (real_mean - real_min)
    else:
        real_speed = real_mean + (mark - 1) / (2 - 1) * (real_max - real_mean)
    return real_speed   

def prompt_speed_analysis(prompt_json, g2p_model):
    # 支持按emotion分类，同时再分中文/英文
    emotion_speed = defaultdict(lambda: {'chinese': [], 'english': []})
    prompt_speed = {}
    for emotion in prompt_json.keys():
        for spk in prompt_json[emotion].keys():
            for item in prompt_json[emotion][spk]:
                speech_token_len = item["speech_token"].count(" ") + 1
                # 根据 text 内容判断中英文
                if contains_chinese(item["text"]):
                    text_len = len(item["text"])
                    if text_len > 0:
                        speed = speech_token_len / text_len
                        emotion_speed[emotion]['chinese'].append(speed)
                else:
                    text_len = len(g2p_model(item["text"]))
                    if text_len > 0:
                        speed = speech_token_len / text_len
                        emotion_speed[emotion]['english'].append(speed)
                if speed > 30:
                    print(item, text_len, speech_token_len)
                prompt_speed[item[r"filepath"]] = speed

    # 分组统计
    emotion_speed_range = {}
    for emo in emotion_speed.keys():
        for lang in ['chinese', 'english']:
            speeds = emotion_speed[emo][lang]
            if speeds:
                arr = np.array(speeds)
                mean = np.mean(arr)
                median = np.median(arr)
                q1 = np.percentile(arr, 5)
                q3 = np.percentile(arr, 95)
                print(f"Emotion: {emo} - {lang}")
                print(f"  Mean   : {mean:.4f}")
                print(f"  Median : {median:.4f}")
                print(f"  Q1 (5%): {q1:.4f}")
                print(f"  Q3 (95%): {q3:.4f}")
                print("-" * 40)
                # 保存
                emotion_speed_range[f"{emo}-{lang}"] = [mean, q1, q3, np.min(arr), np.max(arr)]
    return {"prompt_speed": prompt_speed, "emotion_speed_range": emotion_speed_range}

def load_zip_infos(args):
    emo_dict = {
        "Sad": "sad",
        "Happy": "happy",
        "Angry": "angry",
        "Surprise": "surprised",
        "Neutral": "neutral",
        "快乐": "happy",
        "高兴": "happy",
        "悲伤": "sad",
        "伤心": "sad",
        "惊喜": "surprised",
        "惊讶": "surprised",
        "生气": "angry",
        "愤怒": "angry",
        "中立": "neutral",
        "自然": "neutral",
        "neutral": "neutral",
        "surprised": "surprised",
        "sad": "sad",
        "happy": "happy",
        "angry": "angry",
    }
    zip_path, prompt_path2info, prompt_speed = args
    infos = []
    with zipfile.ZipFile(zip_path, "r") as zip_f:
        info_list = zip_f.infolist()
    tot_hours = 0.0
    for i in tqdm(info_list, desc=f"Processing {os.path.basename(zip_path)}"):
        ok = False
        utt_id = Path(i.filename).stem
        offset = i.header_offset + 30 + len(i.filename)
        file_size = i.file_size
        item_zip_loc = f"{zip_path}:{offset}:{file_size}"

        try:
            with open(zip_path, "rb") as f:
                f.seek(offset)
                byte_data = f.read(file_size)

            if len(byte_data) <= 1:
                print(f"Empty data: {item_zip_loc}")
                continue

            # JSON 解析
            json_data = json.loads(byte_data.decode("utf-8"))

            # 统计长度信息
            codec_lens, text_lens, emos, prompt_scalers, \
                prompt_codec_lens, prompt_text_lens, \
                    prompt_energy, tgt_sp, tgt_energy = [[] for i in range(9)]
                    
            for x in json_data:
                codec_lens.append(str(len(x['st'].split(" "))))
                text_lens.append(str(len(x['tt'].split(" "))))
                prompt_item = prompt_path2info[x['prompt_wav']]
                
                # prompt_text = prompt_item["text"]
                if contains_chinese(x['lines']):
                    lang = "chinese"
                else:
                    lang = "english"
                prompt_sp = prompt_speed["prompt_speed"][x['prompt_wav']]
                
                if "emotion_refined" in x.keys():
                    emo = emo_dict[x["emotion_refined"]]
                else:
                    emo = emo_dict[x["emotion"]]
                if "speed_refined" in x:
                    tgt_speed = float(x["speed_refined"])
                else:
                    tgt_speed = float(x["speed"])
                cur_sp = map_speed(
                    tgt_speed, 
                    prompt_speed["emotion_speed_range"][f"{emo}-{lang}"][1],
                    prompt_speed["emotion_speed_range"][f"neutral-{lang}"][0],
                    prompt_speed["emotion_speed_range"][f"{emo}-{lang}"][2],
                )
                scaler = cur_sp/prompt_sp
                
                prompt_codec_lens.append(str(len(prompt_item["speech_token"].split(" "))))
                prompt_text_lens.append(str(len(prompt_item["text_token"].split(" "))))
                prompt_energy.append(str(prompt_item["energy"]))
                tot_hours += len(x['st'].split(" ")) / 25.0 / 60.0 / 60.0
                tgt_sp.append(str(x["speed"])+f"-{tgt_speed}")
                tgt_energy.append(str(x.get("energy", 1.0)))
                emos.append(str(emo))
                prompt_scalers.append(str(scaler))

            codec_lens_str = "|".join(codec_lens)
            text_lens_str = "|".join(text_lens)
            prompt_codec_lens = "|".join(prompt_codec_lens)
            prompt_text_lens = "|".join(prompt_text_lens)
            prompt_energy = "|".join(prompt_energy)
            tgt_sp = "|".join(tgt_sp)
            tgt_energy = "|".join(tgt_energy)
            emos = "|".join(emos)
            prompt_scalers = "|".join(prompt_scalers)
            ok = True
            
        except Exception as e:
            print(f"Error processing {item_zip_loc}: {e}")
            ok = False
            continue

        if ok:
            infos.append(
                f"{item_zip_loc}\t{codec_lens_str}\t{text_lens_str}\t{prompt_codec_lens}\t{prompt_text_lens}\t{prompt_energy}\t{emos}\t{prompt_scalers}"
            )
            # print(infos[-1])
    return infos, tot_hours
    

if __name__ == "__main__":
    data_home = r"./datas/wesc/supervision"
    prompt_json = r"./datas/grouped_emo_speaker_data.json"
    info_home = r"./datas/wesc/supervision/infos"
    os.makedirs(info_home, exist_ok=True)
    prompt_path2info, prompt_infos = load_prompt_infos(prompt_json)
    all_infos = []
    zips = find_all_zip_files(data_home)
    
    g2p_model = G2p()
    prompt_speed = prompt_speed_analysis(prompt_infos, g2p_model)
    all_infos = []
    with multiprocessing.Pool(processes=os.cpu_count()) as pool:
        results = list(tqdm(pool.imap(load_zip_infos, [(zip_path, prompt_path2info, prompt_speed) for zip_path in zips]), total=len(zips)))

    tot_hours = 0
    for res, tot_hour in results:
        all_infos.extend(res)
        tot_hours += tot_hour

    random.shuffle(all_infos)
    print(f"Total processed files: {len(all_infos)}\t{tot_hours} hours")
    
    with open(os.path.join(info_home, "train.tsv"), "w") as train_f:
        with open(os.path.join(info_home, "dev.tsv"), "w") as dev_f:
            for idx, line in enumerate(all_infos):
                if idx % 50 == 0:
                    print(line.strip(), file=dev_f)
                else:
                    print(line.strip(), file=train_f)
                