import warnings
warnings.filterwarnings('ignore')
import logging
import os
os.environ["OMP_NUM_THREADS"] = "4"
os.environ["OPENBLAS_NUM_THREADS"] = "4"
os.environ["MKL_NUM_THREADS"] = "4"
os.environ["VECLIB_MAXIMUM_THREADS"] = "4"
os.environ["NUMEXPR_NUM_THREADS"] = "4"
import torch
import argparse
from multiprocessing import Pool, cpu_count
from tqdm import tqdm
from fairseq.checkpoint_utils import load_model_ensemble, load_model_ensemble_and_task
from fairseq.utils import import_user_module_via_dir
from funasr import AutoModel
import torchaudio
import json
import zipfile
from copy import copy
import numpy as np
import random
import subprocess
logger = logging.getLogger(__name__)
import re
import time
import matplotlib.pyplot as plt
from collections import defaultdict, Counter
import torch.multiprocessing as mp
from pathlib import Path
import traceback
import base64
from examples.wescon.utils.en_normalization.expend import normalize as en_normalize
from examples.wescon.utils.zh_normalization.text_normlization import TextNormalizer as ZHTextNormalizer
from g2p_en import G2p
import librosa

def energy_calculate(audio):
    audio = librosa.load(audio, sr=16000)[0]
    energy = librosa.feature.rms(y=audio)
    mean_energy = np.mean(energy[~np.isnan(energy)])
    return mean_energy

def extract_existing_ids(zips):
    existing_ids = set()
    for file in zips:
        if os.path.exists(file):
            with zipfile.ZipFile(file, 'r') as zf:
                for info in zf.infolist():
                    if info.filename.endswith('.json'):
                        file_stem = Path(info.filename).stem
                        if file_stem.isdigit():
                            existing_ids.add(int(file_stem))
    return existing_ids

def contains_english(text):
    return bool(re.search(r'[a-zA-Z]', text))

def contains_chinese(text):
    return bool(re.search(r'[\u4e00-\u9fff]', text))

def clean_last_char(s):
    if not s:
        return s
    if re.match(r'[\u4e00-\u9fa5a-zA-Z]$', s[-1]):
        return s
    else:
        return s[:-1]

def load_wav(wav, target_sr):
    speech, sample_rate = torchaudio.load(wav)
    speech = speech.mean(dim=0, keepdim=True)
    if sample_rate != target_sr:
        assert sample_rate > target_sr, 'wav sample rate {} must be greater than {}'.format(sample_rate, target_sr)
        speech = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=target_sr)(speech)
    return speech

def generate2zip(all_keys, all_infos, 
                 max_p, p_id, save_home, 
                 half, gpu_id, args):
    device = torch.device(f"cuda:0")
    emo_dict = {
        "Sad": "sad",
        "Happy": "happy",
        "Angry": "angry",
        "Surprise": "surprised",
        "Neutral": "neutral",
        "快乐": "happy",
        "高兴": "happy",
        "悲伤": "sad",
        "伤心": "sad",
        "惊喜": "surprised",
        "惊讶": "surprised",
        "生气": "angry",
        "愤怒": "angry",
        "中立": "neutral",
        "自然": "neutral",
        "neutral": "neutral",
        "surprised": "surprised",
        "sad": "sad",
        "happy": "happy",
        "angry": "angry",
    }
    import_user_module_via_dir(r"examples/wescon")
    model, cfg, task = load_model_ensemble_and_task([args.checkpoint_path])
    model = model[0]
    model.eval()
    model = model.to(device).float()
    model.init_infer_modules(device, text_frontend=task.text_frontend)
    if half:
        model = model.half()
    torch.manual_seed(args.seed)
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    ###
    g2p_model = G2p()
    zh_norm = ZHTextNormalizer()
    ###
    slice_len = len(all_keys) // max_p
    start, end = p_id * slice_len, (p_id + 1) * slice_len if p_id != max_p - 1 else len(all_keys)
    zip_path = os.path.join(save_home, f"codec_{p_id}.zip")
    os.makedirs(os.path.join(save_home, "temp_signal"), exist_ok=True)
    skip = 0
    costs = []
    cost_lens = []
    with open(os.path.join(save_home, f"{p_id}.log"), "w") as log_wf:
        with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_STORED) as f:
            with torch.no_grad():
                with torch.cuda.amp.autocast():
                    tot_hours = 0.0
                    progress_bar = tqdm(all_keys[start:end], desc=f"Process-{p_id}")
                    for key_idx, desc_key in enumerate(progress_bar):
                        items = all_infos[desc_key].get("script", [])
                        if not (items and len(items) > 1):
                            continue
                        for line_i, line_item in enumerate(items):
                            try:
                                item = line_item["text"]
                                emotions = [] # emotiondict
                                for tmp_idx, temp in enumerate(item):
                                    emotions.append(emo_dict[temp["emotion_refined"]])
                                
                                is_chinese = contains_chinese(item[0]["lines"])
                                lang = "chinese" if is_chinese else "english"
                                return_speech = key_idx % 1 == 0
                                speech_tokens = []
                                last_text, last_st = None, None
                                used_info = []
                            except:
                                continue
                            spk_ref = load_wav(
                                os.path.join(
                                    args.speech_home, 
                                    item[0]["target_speaker"]
                                    ), 
                                    target_sr=16000
                                )
                            spk_ref = model.frontend.speaker_infos(spk_ref, model.sample_rate)
                            start = time.time()
                            for idx, lines in enumerate(item):
                                try:
                                    if contains_chinese(item[idx]["lines"]) and contains_english(item[idx]["lines"]) and len(item[idx]["lines"]) > 50:
                                        print(f"both contain and too long:{item}")
                                        continue
                                    
                                    # text pretreatment
                                    if idx != len(item) - 1:
                                        item[idx]["lines"] = clean_last_char(item[idx]["lines"])
                                        last_item = False
                                    else:
                                        last_item = True
                                        
                                    if lang == "chinese":
                                        item[idx]["lines"] = zh_norm.normalize_sentence(item[idx]["lines"])
                                    else:
                                        item[idx]["lines"] = en_normalize(item[idx]["lines"])
                                    
                                    model_input = {
                                        "text": model.frontend._extract_text_token(
                                            model.frontend.text_normalize(item[idx]["lines"], split=False)
                                        )[0],
                                        "prompt_text": torch.LongTensor([list(map(int, item[idx]["prompt_text_cs2token"].split(" ")))]),
                                        "llm_prompt_speech_token": torch.LongTensor([list(map(int, item[idx]["prompt_speech_cs2token"].split(" ")))]),
                                    }
                                    
                                    if last_st is not None:
                                        model_input["llm_prompt_speech_token"] = torch.cat(
                                            [model_input["llm_prompt_speech_token"], last_st.unsqueeze(0)], dim=1
                                        )
                                        model_input["prompt_text_cs2token"] = torch.cat(
                                            [model_input["prompt_text_cs2token"], last_text.unsqueeze(0)], dim=1
                                        )
                                    last_text = model_input["text"][0, -1:]
                                    
                                    for key in model_input.keys():
                                        try:
                                            model_input[key] = model_input[key].to(device)
                                        except:
                                            pass
                                    # speed
                                    if lang == "chinese":
                                        tgt_phone_len = len(item[idx]["lines"])
                                        prompt_phone_len = len(item[idx]["prompt_text"])
                                    else:
                                        tgt_phone_len = len(g2p_model(item[idx]["lines"]))
                                        prompt_phone_len = len(g2p_model(item[idx]["prompt_text"]))
                                    prompt_speed = model_input["llm_prompt_speech_token"].size(1) / prompt_phone_len
                                    cur_speed = item[idx]["sp_scale"]
                                    # prompt_text 1xN   llm_prompt_speech_token 1xN
                                    speech_token, _, last_st, tp_tokens = model.inference_st(**model_input, return_speech=return_speech, prompt_sp=cur_speed/prompt_speed, last_item=last_item, lang=lang)
                                    speech_tokens.append(speech_token)
                                    temp = copy(item[idx])
                                    temp["prompt_wav"] = item[idx]["prompt_speech"]
                                    temp["prompt_energy"] = "none"
                                    temp["st/tt"] = cur_speed
                                    temp["st"] = " ".join(speech_token.cpu().numpy().flatten().astype(str))
                                    temp["tt"] = " ".join(model_input["text"].cpu().numpy().flatten().astype(str))
                                    temp["tp"] = " ".join(tp_tokens.cpu().numpy().flatten().astype(str))
                                    tot_hours += len(speech_token) / 25.0 / 60.0 / 60.0
                                    used_info.append(temp)
                                except Exception as e:
                                    print("error:"+str(e)+str(item[idx])) #item
                            cost = time.time() - start
                            costs.append(cost)
                            cost_lens.append(len(torch.cat(speech_tokens)))
                            print(f"{key_idx} {len(costs)} {np.mean(costs)} {np.mean(cost_lens)}", flush=True)
                            if len(used_info) > 0:
                                save_name = f"{desc_key}_{line_i}.json"
                                info_str = json.dumps(used_info, ensure_ascii=False)
                                json_bytes = info_str.encode("utf-8") 
                                f.writestr(save_name, json_bytes)
                                print(f"{desc_key}_{line_i}\t{info_str}", file=log_wf, flush=True)
                                progress_bar.set_description(f"Process-{p_id} Hours: {tot_hours:.4f}")
                                speech = model.generate_speech(torch.cat(speech_tokens), **spk_ref)
                                torchaudio.save(os.path.join(save_home, "temp_signal", f"{desc_key}_{line_i}.wav"), speech, model.sample_rate)
                    
def generate_aligned_log(zip_path, merged_log_path, output_tsv_path):
    zip_item_info = {}
    with zipfile.ZipFile(zip_path, mode="r") as zf:
        for i in tqdm(zf.infolist(), desc="Indexing zip"):
            utt_id = Path(i.filename).stem  # e.g., '0_0'
            offset = i.header_offset + 30 + len(i.filename)
            file_size = i.file_size
            zip_item_info[utt_id] = (offset, file_size)
    
    with open(output_tsv_path, "w", encoding="utf-8") as wf, \
         open(merged_log_path, "r", encoding="utf-8") as rf, \
         open(zip_path, "rb") as zip_f:
        for line in tqdm(rf, desc="Processing merged.log"):
            if not line.strip():
                continue
            utt_id, info = line.strip().split("\t")
            if utt_id not in zip_item_info:
                print(f"[Warning] {utt_id} not found in zip!")
                continue

            offset, file_size = zip_item_info[utt_id]
            item_zip_loc = f"{zip_path}:{offset}:{file_size}"

            zip_f.seek(offset)
            byte_data = zip_f.read(file_size)
            json_data = json.loads(byte_data.decode("utf-8"))
            json_data_temp = json.loads(info)
            assert json_data == json_data_temp, f"{str(json_data)}\n{str(json_data_temp)}"
            codec_lens = [str(len(x['st'].split())) for x in json_data]
            text_lens = [str(len(x['st'].split())) for x in json_data]
            codec_lens = "|".join(codec_lens)
            text_lens = "|".join(text_lens)
            wf.write(f"{utt_id}\t{item_zip_loc}\t{codec_lens}\t{text_lens}\n")

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Inference')
    parser.add_argument('--checkpoint_path', type=str)
    parser.add_argument('--save-home', type=str)
    parser.add_argument('--half', action="store_true")
    parser.add_argument('--max_p', type=int)
    parser.add_argument('--devices', type=str)
    parser.add_argument('--tgt-json', type=str)
    parser.add_argument('--speech-home', type=str)
    parser.add_argument('--seed', type=int, default=100)
    args = parser.parse_args()
    
    with open(args.tgt_json, "r", encoding="utf-8") as rf:
        all_lines = json.load(rf)
        
    # first_test 1/10
    all_keys = list(all_lines.keys())
    tot_len = len(all_keys)
    
    save_home = args.save_home 
    os.makedirs(save_home, exist_ok=True)

    max_p = args.max_p
    gpu_list = [int(x) for x in list(map(int, args.devices.split(";")))]
    gpu_num = len(gpu_list)
    mp.set_start_method('spawn', force=True)
    
    process_list = []
    for p_id in range(max_p):
        gpu_id = gpu_list[p_id % gpu_num]
        print(f"run {p_id}, GPU {gpu_id}")
        p = mp.Process(
            target=generate2zip,
            args=(all_keys, all_lines, max_p, p_id, save_home, args.half, gpu_id, args)
        )
        p.start()
        process_list.append(p)

    for p in process_list:
        p.join()
        
    merged_zip_path = os.path.join(save_home, "codec_all.zip")
    with zipfile.ZipFile(merged_zip_path, 'w', zipfile.ZIP_STORED) as merged_zip:
        for p_id in range(max_p):
            part_zip_path = os.path.join(save_home, f"codec_{p_id}.zip")
            with zipfile.ZipFile(part_zip_path, 'r') as part_zip:
                for file_info in part_zip.infolist():
                    merged_zip.writestr(file_info, part_zip.read(file_info.filename))
            os.remove(part_zip_path) 
        if os.path.exists(os.path.join(save_home, "codec_processed.zip")):
            with zipfile.ZipFile(os.path.join(save_home, "codec_processed.zip"), 'r') as part_zip:
                for file_info in part_zip.infolist():
                    merged_zip.writestr(file_info, part_zip.read(file_info.filename))

    merged_log_path = os.path.join(save_home, "merged.log")
    with open(merged_log_path, "w", encoding="utf-8") as merged_log:
        for p_id in range(max_p):
            log_path = os.path.join(save_home, f"{p_id}.log")
            with open(log_path, "r", encoding="utf-8") as lf:
                merged_log.write(lf.read())
            os.remove(log_path)  
        log_path = os.path.join(save_home, f"processed.log")
        if os.path.exists(log_path):
            with open(log_path, "r", encoding="utf-8") as lf:
                merged_log.write(lf.read())
            os.remove(log_path) 
    
    out_tsv = os.path.join(save_home, "info.tsv")
    generate_aligned_log(merged_zip_path, merged_log_path, out_tsv)