import warnings
warnings.filterwarnings('ignore')
import logging
import os
import torch
import argparse
from multiprocessing import Pool, cpu_count
from tqdm import tqdm
from funasr import AutoModel
import torchaudio
import json
import zipfile
from copy import copy
import numpy as np
import random
import subprocess
logger = logging.getLogger(__name__)
import re
import time
import matplotlib.pyplot as plt
from collections import defaultdict, Counter
import torch.multiprocessing as mp
from pathlib import Path
import traceback
import base64
from examples.wescon.utils.zh_normalization.text_normlization import TextNormalizer as ZHTextNormalizer
from examples.wescon.utils.en_normalization.expend import normalize as en_normalize
from g2p_en import G2p
import librosa

def contains_english(text):
    return bool(re.search(r'[a-zA-Z]', text))

def contains_chinese(text):
    return bool(re.search(r'[\u4e00-\u9fff]', text))

def clean_last_char(s):
    if not s:
        return s
    # 判断最后一个字符是不是汉字或英文字母
    if re.match(r'[\u4e00-\u9fa5a-zA-Z]$', s[-1]):
        return s
    else:
        return s[:-1]


def preprocess_prompt_infos(prompt_infos, flag):
    processed = defaultdict(lambda: {"chinese": defaultdict(list), "english": defaultdict(list)})
    speaker_emo_count = {"chinese": Counter(), "english": Counter()}

    all_prompts = []
    for emo in prompt_infos:
        for spk in prompt_infos[emo]:
            for item in prompt_infos[emo][spk]:
                if flag not in item["filepath"]:
                    continue
                text = item["text"]
                token_count = item["speech_token"].count(" ")
                if 25 * 1.5 <= token_count <= 25 * 15: # 1.5s~15s
                    if contains_chinese(text):
                        lang = "chinese"
                    else:
                        lang = "english"
                    processed[emo][lang][spk].append(item)
                    speaker_emo_count[lang][(emo, spk)] += 1  
                    all_prompts.append(item)
    return processed, all_prompts

def generate2zip(all_keys, all_infos, prompt_infos, 
                 save_home, emotion_speed_range, 
                 args):
    emo_dict = {
        "Sad": "sad",
        "Happy": "happy",
        "Angry": "angry",
        "Surprise": "surprised",
        "Neutral": "neutral",
        "快乐": "happy",
        "高兴": "happy",
        "悲伤": "sad",
        "伤心": "sad",
        "惊喜": "surprised",
        "惊讶": "surprised",
        "生气": "angry",
        "愤怒": "angry",
        "中立": "neutral",
        "自然": "neutral",
        "neutral": "neutral",
        "surprised": "surprised",
        "sad": "sad",
        "happy": "happy",
        "angry": "angry",
        "anger": "angry",
        "surprise": "surprised",
    }
    audio_home = r"./datas/Emotional_Speech_Dataset/"
    g2p_model = G2p()
    processed_prompt_infos, all_prompt_infos = preprocess_prompt_infos(prompt_infos, args.flag)
    zh_norm = ZHTextNormalizer()
    ###
    progress_bar = tqdm(all_keys)
    for key_idx, desc_key in enumerate(progress_bar):
        items = all_infos[desc_key].get("script", [])
        if not (items and len(items) > 1):
            del all_infos[desc_key]
            continue
        for line_i, line_item in enumerate(items):
            try:
                item = line_item["text"]
                emotions = [] 
                for tmp_idx, temp in enumerate(item):
                    emotions.append(emo_dict[temp["emotion"]])
                
                is_chinese = contains_chinese(item[0]["lines"])
                lang = "chinese" if is_chinese else "english"
            
                valid_speakers = set(processed_prompt_infos[emotions[0]][lang].keys())
                for emo in emotions[1:]:
                    valid_speakers = valid_speakers & set(processed_prompt_infos[emo][lang].keys())
                if not valid_speakers:
                    del all_infos[desc_key]
                    continue
                spk_weights = []
                for spk in valid_speakers:
                    emo_len = len(processed_prompt_infos[emo][lang][spk])
                    spk_weights.append(emo_len)  
                    
                prompt_spk = random.choices(list(valid_speakers), weights=spk_weights, k=1)[0]
                
                if args.flag == "train":
                    all_prompts = all_prompt_infos # choose from all prompt
                else:
                    # choose from same speaker's all emotion data
                    all_prompts = []
                    for emo_temp in processed_prompt_infos.keys():
                        all_prompts += processed_prompt_infos[emo_temp][lang][prompt_spk] 
                target_speaker = random.choice(all_prompts)
                
                for idx, lines in enumerate(item):
                    if contains_chinese(item[idx]["lines"]) and contains_english(item[idx]["lines"]) and len(item[idx]["lines"]) > 50:
                        print(f"both contain and too long:{item}")
                        del all_infos[desc_key]
                        continue
                    emo = emo_dict[item[idx]["emotion"]]
                    
                    prompt = random.choice(processed_prompt_infos[emo][lang][prompt_spk])
                    
                    # text pretreatment
                    if lang == "chinese":
                        item[idx]["lines"] = zh_norm.normalize_sentence(item[idx]["lines"])
                    else:
                        item[idx]["lines"] = en_normalize(item[idx]["lines"])
                    
                    # speed
                    tgt_speed = float(item[idx].get("speed", 1.0))
                    cur_speed = map_speed(
                        tgt_speed, 
                        emotion_speed_range[f"{emo}-{lang}"][1],
                        emotion_speed_range[f"neutral-{lang}"][0],
                        emotion_speed_range[f"{emo}-{lang}"][2],
                    )
                    sp_scale = cur_speed
                    item[idx]["sp_scale"] = sp_scale
                    item[idx]["prompt_speech"] = os.path.relpath(prompt["filepath"], audio_home)
                    item[idx]["prompt_speech_sample_num"] = prompt["duration"]
                    item[idx]["prompt_speech_sample_rate"] = prompt["sample_rate"]
                    item[idx]["prompt_text"] = prompt["text"]
                    item[idx]["prompt_text_cs2token"] = prompt["text_token"]
                    item[idx]["prompt_speech_cs2token"] = prompt["speech_token"]
                    item[idx]["target_speaker"] = target_speaker["filepath"]
            except:
                traceback.print_exc()
                del all_infos[desc_key]
                continue

    with open(os.path.join(save_home, f"{args.flag}.json"), "w") as log_wf:
        json.dump(all_infos, log_wf, indent=4, ensure_ascii=False)
                        
                    
def map_speed(mark, real_min, real_mean, real_max):
    mark = max(min(mark, 2.0), 0.5)
    if mark <= 1:
        real_speed = real_min + (mark - 0.5) / (1 - 0.5) * (real_mean - real_min)
    else:
        real_speed = real_mean + (mark - 1) / (2 - 1) * (real_max - real_mean)
    return real_speed   

def prompt_speed_analysis(prompt_json, g2p_model):
    emo_dict = {
        "Sad": "sad",
        "Happy": "happy",
        "Angry": "angry",
        "Surprise": "surprised",
        "Neutral": "neutral",
        "快乐": "happy",
        "高兴": "happy",
        "悲伤": "sad",
        "伤心": "sad",
        "惊喜": "surprised",
        "惊讶": "surprised",
        "生气": "angry",
        "愤怒": "angry",
        "中立": "neutral",
        "自然": "neutral",
        "neutral": "neutral",
        "surprised": "surprised",
        "sad": "sad",
        "happy": "happy",
        "angry": "angry",
        "anger": "angry",
        "surprise": "surprised",
    }
    emotion_speed = defaultdict(lambda: {'chinese': [], 'english': []})
    prompt_speed = {}
    for emotion in prompt_json.keys():
        for spk in prompt_json[emotion].keys():
            for item in prompt_json[emotion][spk]:
                emotion = emo_dict[emotion]
                speech_token_len = item["speech_token"].count(" ") + 1
                if contains_chinese(item["text"]):
                    text_len = len(item["text"])
                    if text_len > 0:
                        speed = speech_token_len / text_len
                        emotion_speed[emotion]['chinese'].append(speed)
                else:
                    text_len = len(g2p_model(item["text"]))
                    if text_len > 0:
                        speed = speech_token_len / text_len
                        emotion_speed[emotion]['english'].append(speed)
                if speed > 30:
                    print(item, text_len, speech_token_len)
                prompt_speed[item[r"filepath"]] = speed

    emotion_speed_range = {}
    for emo in emotion_speed.keys():
        for lang in ['chinese', 'english']:
            speeds = emotion_speed[emo][lang]
            if speeds:
                arr = np.array(speeds)
                mean = np.mean(arr)
                median = np.median(arr)
                q1 = np.percentile(arr, 5)
                q3 = np.percentile(arr, 95)
                emotion_speed_range[f"{emo}-{lang}"] = [mean, q1, q3, np.min(arr), np.max(arr)]
    return emotion_speed_range


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Inference')
    parser.add_argument('--save-home', type=str)
    parser.add_argument('--prompt-json', type=str)
    parser.add_argument('--tgt-json', type=str)
    parser.add_argument('--flag', type=str, default="train")
    args = parser.parse_args()
    
    with open(args.prompt_json, "r", encoding="utf-8") as rf:
        prompt_infos = json.load(rf)
    
    with open(args.tgt_json, "r", encoding="utf-8") as rf:
        all_lines = json.load(rf)
    all_lines = dict(list(all_lines.items())[:200])
        
    # first_test 1/10
    all_keys = list(all_lines.keys())
    tot_len = len(all_keys)
    
    save_home = args.save_home 
    os.makedirs(save_home, exist_ok=True)
    g2p_model = G2p()
    emotion_speed_range = prompt_speed_analysis(prompt_infos, g2p_model)
    del g2p_model
    
    with open(args.prompt_json, "w", encoding="utf-8") as f:
        json.dump(prompt_infos, f, indent=4, ensure_ascii=False)
    print("begin2")
    generate2zip(all_keys, all_lines, prompt_infos, save_home, emotion_speed_range, args)