import warnings
warnings.filterwarnings('ignore')
import logging
import os
os.environ["OMP_NUM_THREADS"] = "4"
os.environ["OPENBLAS_NUM_THREADS"] = "4"
os.environ["MKL_NUM_THREADS"] = "4"
os.environ["VECLIB_MAXIMUM_THREADS"] = "4"
os.environ["NUMEXPR_NUM_THREADS"] = "4"
import torch
import argparse
from multiprocessing import Pool, cpu_count
from tqdm import tqdm
from fairseq.checkpoint_utils import load_model_ensemble, load_model_ensemble_and_task
from fairseq.utils import import_user_module_via_dir
from funasr import AutoModel
import torchaudio
import json
import zipfile
from copy import copy
import numpy as np
import random
import subprocess
logger = logging.getLogger(__name__)
import re
from torch.nn.utils.rnn import pad_sequence, unpad_sequence
import time
import matplotlib.pyplot as plt
from collections import defaultdict, Counter
import torch.multiprocessing as mp
from pathlib import Path
import traceback
import base64
from examples.wescon.utils.attn_mask import *
from examples.wescon.utils.en_normalization.expend import normalize as en_normalize
from examples.wescon.utils.en_normalization.expend import normalize as en_normalize
from examples.wescon.utils.zh_normalization.text_normlization import TextNormalizer as ZHTextNormalizer
from examples.wescon.models.cosyvoice2.cli.frontend import CosyVoiceTextFrontEnd
from g2p_en import G2p
import librosa

def energy_calculate(audio):
    audio = librosa.load(audio, sr=16000)[0]
    energy = librosa.feature.rms(y=audio)
    mean_energy = np.mean(energy[~np.isnan(energy)])
    return mean_energy

def extract_existing_ids(zips):
    existing_ids = set()
    for file in zips:
        if os.path.exists(file):
            with zipfile.ZipFile(file, 'r') as zf:
                for info in zf.infolist():
                    if info.filename.endswith('.json'):
                        file_stem = Path(info.filename).stem
                        if file_stem.isdigit():
                            existing_ids.add(int(file_stem))
    return existing_ids

def contains_english(text):
    return bool(re.search(r'[a-zA-Z]', text))

def contains_chinese(text):
    return bool(re.search(r'[\u4e00-\u9fff]', text))

def clean_last_char(s):
    if not s:
        return s
    if re.match(r'[\u4e00-\u9fa5a-zA-Z]$', s[-1]):
        return s
    else:
        return s[:-1]

def load_wav(wav, target_sr):
    speech, sample_rate = torchaudio.load(wav)
    speech = speech.mean(dim=0, keepdim=True)
    if sample_rate != target_sr:
        assert sample_rate > target_sr, 'wav sample rate {} must be greater than {}'.format(sample_rate, target_sr)
        speech = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=target_sr)(speech)
    return speech

def get_attn_bias(text_side_prompt_lens, 
                text_side_tgt_lens, 
                speech_side_prompt_lens, 
                speech_side_speech_lens, 
                attention_bias=1, atten_bias_func=5, attention_range=1,
                text_side_prompt=None, text_side_tgt=None, speech_side_prompt=None, speech_side_speech=None):
        """
        Args:
            text_side_prompt (_type_): N
            text_side_tgt (_type_): N
            speech_side_prompt (_type_): N
            speech_side_speech (_type_): N
            value (int, optional): _description_. Defaults to 1.
            top (int, optional): _description_. Defaults to 5.
            low (int, optional): _description_. Defaults to 1.
        """
        if text_side_prompt is None:
            text_side_prompt = []
            text_side_tgt = []
            speech_side_prompt = []
            speech_side_speech = []
            for i, (a, b, c, d) in enumerate(zip(
                text_side_prompt_lens, 
                text_side_tgt_lens, 
                speech_side_prompt_lens, 
                speech_side_speech_lens
            )):
                if i == 0:
                    a = a + 1 # sos
                text_side_prompt.append((f"t_0{i}", a+1))
                text_side_tgt.append((f"t_1{i}", b+1)) # +1 for emotion
                if i == 0:
                    c = c + 1 # bos
                speech_side_prompt.append((f"s_0{i}", c))
                speech_side_speech.append((f"s_1{i}", d))

        attn_bias_dict = {}
        for func_key in attention_bias.split("_"):
            attn_bias = torch.tensor(atten_bias_func[func_key](
                text_side_prompt, 
                text_side_tgt, 
                speech_side_prompt, 
                speech_side_speech, 
                top=float(attention_range[1]), 
                low=float(attention_range[0])
            ))
            attn_bias_dict[func_key] = attn_bias
        return attn_bias_dict, \
            text_side_prompt, text_side_tgt, speech_side_prompt, speech_side_speech

def load_info_from_json(json_items, lang, 
                        g2p_model, zh_norm, frontend, 
                        scaler_switch=True, max_tgt=500, device=None):
    text_pad_idx = 50257
    emo_dict = {
        "sad":0, 
        "happy":1, 
        "angry":2, 
        "surprised":3, 
        "neutral":4
    }
    attention_bias = "0_1_2_3_4_5_6"
    attention_range = [0.1, 5.0]
    atten_bias_func = {
        "0": lower_triangle,
        "1": tgt_st_paired_emo_and_all_tt,
        "2": all_paired_emo,
        "3": all_st_paired_emo,
        "4": tgt_st_paired_emo,
        "5": st_paired_emo_tgt_st_all_tt,
        "6": all_st_paired_emo_and_all_tt,
    }
    prompt_codec = []
    prompt_text = []
    emos = []
    prompt_text_lens = []
    text_lens = []
    prompt_codec_lens = []
    tgt_codec_lens = []
    tgt_text = []
    tgt_emos = []
    tgt_emos_prompt = []
    for json_item in json_items:
        line = json_item["lines"]
        # text pretreatment
        line = re.sub(r'^[^\u4e00-\u9fa5A-Za-z]+', '', line.replace("啊", "").replace("哦", "").strip())
        if lang == "chinese":
            line = zh_norm.normalize_sentence(line)
        else:
            line = en_normalize(line)
        text_token = frontend._extract_text_token(
            frontend.text_normalize(line, split=False)
        )[0].flatten()
        
        tgt_text.append(text_token)
        text_lens.append(tgt_text[-1].size(0))
        
        if lang == "chinese":
            prompt_phone_len = len(json_item["prompt_text"])
        else:
            prompt_phone_len = len(g2p_model(json_item["prompt_text"]))
        
        emo = json_item["emotion_refined"]
        
        temp_prompt_st = torch.tensor(
            np.array(json_item['prompt_speech_cs2token'].split(" "), dtype=int)
        )
        prompt_speed = temp_prompt_st.size(0) / prompt_phone_len
        if scaler_switch:
            scaler = float(json_item["sp_scale"])/prompt_speed
        else:
            scaler = 1.0
        prompt_codec.append(resample_by_stride(temp_prompt_st, scaler).to(torch.long))
        prompt_codec_lens.append(prompt_codec[-1].size(0))
        
        prompt_text.append(torch.tensor(
            np.array(json_item['prompt_text_cs2token'].split(" "), dtype=int)
        ).to(torch.long))
        prompt_text_lens.append(prompt_text[-1].size(0))
        
        emos.append(torch.tensor(
            [emo_dict[emo]],
            dtype=torch.long
        ).to(torch.long))
        tgt_codec_lens.append(0)
        
        tgt_emos_prompt.append(torch.tensor(
            [emo_dict[emo]]*prompt_codec[-1].size(0),
            dtype=torch.long
        ).to(torch.long))
        
        tgt_emos.append(emo_dict[emo])
    
    prompt_codec = torch.cat(prompt_codec, dim=0)
    prompt_text = pad_sequence(prompt_text, batch_first=True, padding_value=text_pad_idx)
    tgt_text = pad_sequence(tgt_text, batch_first=True, padding_value=text_pad_idx)
    # tgt_emos = torch.cat(tgt_emos, dim=0)
    tgt_emos_prompt = torch.cat(tgt_emos_prompt, dim=0)
    
    emos = torch.cat(emos, dim=0)

    attn_bias, text_side_prompt, text_side_tgt, speech_side_prompt, speech_side_speech = get_attn_bias(
        prompt_text_lens, 
        text_lens, 
        prompt_codec_lens, 
        tgt_codec_lens,
        attention_bias=attention_bias,
        atten_bias_func=atten_bias_func,
        attention_range=attention_range
    )
    all_bias = []
    for key in range(7):
        key = str(key)
        all_bias.append(attn_bias[key])
    all_bias = torch.stack(all_bias, 0)
    
    model_input = {
        "prompt_codecs": prompt_codec.unsqueeze(0),
        "tgt_texts": tgt_text.unsqueeze(0),
        "prompt_texts": prompt_text.unsqueeze(0),
        "tgt_emos_prompt": tgt_emos_prompt.unsqueeze(0),
        "tgt_text_lens": torch.tensor([text_lens], dtype=torch.long),
        "prompt_text_lens": torch.tensor([prompt_text_lens], dtype=torch.long),
        "iter_num": torch.tensor([prompt_text.size(0)], dtype=torch.long),
        "prompt_codecs_lens": torch.tensor([prompt_codec_lens], dtype=torch.long),
        "emos": emos.unsqueeze(0),
        "prompt_each_codec_lens": torch.tensor([prompt_codec_lens], dtype=torch.long),
        "all_bias": all_bias.unsqueeze(0)
    }
    if device is not None:
        for key in model_input.keys():
            if model_input[key] is not None:
                model_input[key] = model_input[key].to(device)
    model_input["text_side_prompt"] = text_side_prompt
    model_input["text_side_tgt"] = text_side_tgt
    model_input["speech_side_prompt"] = speech_side_prompt
    model_input["speech_side_speech"] = speech_side_speech
    model_input["prompt_text_lens4ab"] = prompt_text_lens
    model_input["text_lens4ab"] = text_lens
    model_input["prompt_codec_lens4ab"] = prompt_codec_lens
    model_input["tgt_codec_lens4ab"] = tgt_codec_lens
    model_input["tgt_emos"] = tgt_emos
    return model_input

def generate2zip(all_keys, all_infos, 
                 max_p, p_id, save_home, 
                 half, gpu_id, args):
    # os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id)
    device = torch.device(f"cuda:0")
    emo_dict = {
        "Sad": "sad",
        "Happy": "happy",
        "Angry": "angry",
        "Surprise": "surprised",
        "Neutral": "neutral",
        "快乐": "happy",
        "高兴": "happy",
        "悲伤": "sad",
        "伤心": "sad",
        "惊喜": "surprised",
        "惊讶": "surprised",
        "生气": "angry",
        "愤怒": "angry",
        "中立": "neutral",
        "自然": "neutral",
        "neutral": "neutral",
        "surprised": "surprised",
        "sad": "sad",
        "happy": "happy",
        "angry": "angry",
    }
    import_user_module_via_dir(r"examples/wescon")
    model, cfg, task = load_model_ensemble_and_task([args.checkpoint_path])
    model = model[0]
    model.eval()
    model = model.to(device).float()
    model.init_infer_modules(device, text_frontend=CosyVoiceTextFrontEnd())
    if half:
        model = model.half()
    ###
    g2p_model = G2p()
    zh_norm = ZHTextNormalizer()
    ###
    slice_len = len(all_keys) // max_p
    start, end = p_id * slice_len, (p_id + 1) * slice_len if p_id != max_p - 1 else len(all_keys)
    zip_path = os.path.join(save_home, f"codec_{p_id}.zip")
    costs = []
    cost_lens = []
    os.makedirs(os.path.join(save_home, "temp_signal"), exist_ok=True)
    with open(os.path.join(save_home, f"{p_id}.log"), "w") as log_wf:
        with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_STORED) as f:
            with torch.no_grad():
                tot_hours = 0.0
                progress_bar = tqdm(all_keys[start:end], desc=f"Process-{p_id}")
                for key_idx, desc_key in enumerate(progress_bar):
                    items = all_infos[desc_key].get("script", [])
                    if not (items and len(items) > 1):
                        continue
                    for line_i, line_item in enumerate(items):
                        try:
                            item = line_item["text"]
                            emotions = [] # emotiondict
                            for tmp_idx, temp in enumerate(item):
                                emotions.append(emo_dict[temp["emotion_refined"]])
                            
                            is_chinese = contains_chinese(item[0]["lines"])
                            lang = "chinese" if is_chinese else "english"
                        except:
                            continue
                        spk_ref = load_wav(
                            os.path.join(
                                args.speech_home, 
                                item[0]["target_speaker"]
                                ), 
                                target_sr=16000
                            )
                        spk_ref = model.frontend.speaker_infos(spk_ref, model.sample_rate)
                        
                        model_input = load_info_from_json(item, lang, g2p_model, zh_norm, model.frontend, scaler_switch=True, device=device)
                        start = time.time()
                        speech_token = model.inference(model_input)
                        print(len(speech_token))
                        speech = model.generate_speech(speech_token, **spk_ref)
                        cost = time.time() - start
                        costs.append(cost)
                        cost_lens.append(len(speech_token))
                        print(f"{key_idx} {len(costs)} {np.mean(costs)} {np.mean(cost_lens)}", flush=True)
                        tot_hours += len(speech_token) / 25.0 / 60.0 / 60.0
                        torchaudio.save(os.path.join(save_home, "temp_signal", f"{desc_key}_{line_i}.wav"), speech, model.sample_rate)
                        progress_bar.set_description(f"Process-{p_id} Hours: {tot_hours:.4f}")
                        print(os.path.join(save_home, "temp_signal", f"{desc_key}_{line_i}.wav"))
                        
                
def generate_aligned_log(zip_path, merged_log_path, output_tsv_path):
    zip_item_info = {}
    with zipfile.ZipFile(zip_path, mode="r") as zf:
        for i in tqdm(zf.infolist(), desc="Indexing zip"):
            utt_id = Path(i.filename).stem  # e.g., '0_0'
            offset = i.header_offset + 30 + len(i.filename)
            file_size = i.file_size
            zip_item_info[utt_id] = (offset, file_size)
    
    with open(output_tsv_path, "w", encoding="utf-8") as wf, \
         open(merged_log_path, "r", encoding="utf-8") as rf, \
         open(zip_path, "rb") as zip_f:
        for line in tqdm(rf, desc="Processing merged.log"):
            if not line.strip():
                continue
            utt_id, info = line.strip().split("\t")
            if utt_id not in zip_item_info:
                print(f"[Warning] {utt_id} not found in zip!")
                continue

            offset, file_size = zip_item_info[utt_id]
            item_zip_loc = f"{zip_path}:{offset}:{file_size}"

            zip_f.seek(offset)
            byte_data = zip_f.read(file_size)
            json_data = json.loads(byte_data.decode("utf-8"))
            json_data_temp = json.loads(info)
            assert json_data == json_data_temp, f"{str(json_data)}\n{str(json_data_temp)}"
            codec_lens = [str(len(x['st'].split())) for x in json_data]
            text_lens = [str(len(x['st'].split())) for x in json_data]
            codec_lens = "|".join(codec_lens)
            text_lens = "|".join(text_lens)
            wf.write(f"{utt_id}\t{item_zip_loc}\t{codec_lens}\t{text_lens}\n")

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Inference')
    parser.add_argument('--checkpoint_path', type=str)
    parser.add_argument('--save-home', type=str)
    parser.add_argument('--half', action="store_true")
    parser.add_argument('--max_p', type=int)
    parser.add_argument('--devices', type=str)
    parser.add_argument('--seed', type=int)
    parser.add_argument('--tgt-json', type=str)
    parser.add_argument('--speech-home', type=str)
    args = parser.parse_args()
    
    with open(args.tgt_json, "r", encoding="utf-8") as rf:
        all_lines = json.load(rf)
        
    # first_test 1/10
    all_keys = list(all_lines.keys())
    tot_len = len(all_keys)
    
    save_home = args.save_home 
    os.makedirs(save_home, exist_ok=True)

    max_p = args.max_p
    gpu_list = [int(x) for x in list(map(int, args.devices.split(";")))]
    gpu_num = len(gpu_list)
    mp.set_start_method('spawn', force=True)
    
    process_list = []
    for p_id in range(max_p):
        gpu_id = gpu_list[p_id % gpu_num]
        print(f" {p_id} GPU {gpu_id}")
        p = mp.Process(
            target=generate2zip,
            args=(all_keys, all_lines, max_p, p_id, save_home, args.half, gpu_id, args)
        )
        p.start()
        process_list.append(p)

    for p in process_list:
        p.join()
        
    merged_zip_path = os.path.join(save_home, "codec_all.zip")
    with zipfile.ZipFile(merged_zip_path, 'w', zipfile.ZIP_STORED) as merged_zip:
        for p_id in range(max_p):
            part_zip_path = os.path.join(save_home, f"codec_{p_id}.zip")
            with zipfile.ZipFile(part_zip_path, 'r') as part_zip:
                for file_info in part_zip.infolist():
                    merged_zip.writestr(file_info, part_zip.read(file_info.filename))
            os.remove(part_zip_path)  
        if os.path.exists(os.path.join(save_home, "codec_processed.zip")):
            with zipfile.ZipFile(os.path.join(save_home, "codec_processed.zip"), 'r') as part_zip:
                for file_info in part_zip.infolist():
                    merged_zip.writestr(file_info, part_zip.read(file_info.filename))

    merged_log_path = os.path.join(save_home, "merged.log")
    with open(merged_log_path, "w", encoding="utf-8") as merged_log:
        for p_id in range(max_p):
            log_path = os.path.join(save_home, f"{p_id}.log")
            with open(log_path, "r", encoding="utf-8") as lf:
                merged_log.write(lf.read())
            os.remove(log_path)  
        log_path = os.path.join(save_home, f"processed.log")
        if os.path.exists(log_path):
            with open(log_path, "r", encoding="utf-8") as lf:
                merged_log.write(lf.read())
            os.remove(log_path) 
    
    out_tsv = os.path.join(save_home, "info.tsv")
    generate_aligned_log(merged_zip_path, merged_log_path, out_tsv)
    
    