import os
os.environ["OMP_NUM_THREADS"] = "4"
os.environ["OPENBLAS_NUM_THREADS"] = "4"
os.environ["MKL_NUM_THREADS"] = "4"
os.environ["VECLIB_MAXIMUM_THREADS"] = "4"
os.environ["NUMEXPR_NUM_THREADS"] = "4"
from examples.wescon.models.cosyvoice2.cli.cosyvoice import CosyVoice2
from examples.wescon.models.cosyvoice2.utils.file_utils import load_wav, save_wav
import string
punctuation_string = string.punctuation
import re
import numpy as np
from tqdm import tqdm
import multiprocessing as mp
import argparse
import json
import torch
import time
import torchaudio
from copy import deepcopy
from collections import defaultdict
from funasr import AutoModel

def convert_to_token_based_encoding(entry, tokens_per_second=25):
    text = entry['text'].split()
    timestamps = entry['timestamp']
    converted = []
    prev_end = None  
    for char, ts in zip(text, timestamps):
        start_ms, end_ms = ts
        duration_ms = end_ms - start_ms
        start_token = round(start_ms / 1000 * tokens_per_second)
        end_token = round(end_ms / 1000 * tokens_per_second)
        
        if prev_end is None and start_ms >= 40:
            silence_start_token = 0
            silence_end_token = round(start_ms / 1000 * tokens_per_second)
            converted.append({
                'char': '<SIL>',
                'start_ms': 0,
                'end_ms': start_ms,
                'start_token': silence_start_token,
                'end_token': silence_end_token
            })
        
        if prev_end is not None:
            silence_duration_ms = start_ms - prev_end
            if silence_duration_ms >= 40:
                silence_start_token = round(prev_end / 1000 * tokens_per_second)
                silence_end_token = round(start_ms / 1000 * tokens_per_second)
                converted.append({
                    'char': '<SIL>',
                    'start_ms': prev_end,
                    'end_ms': start_ms,
                    'start_token': silence_start_token,
                    'end_token': silence_end_token
                })
        
        converted.append({
            'char': char,
            'start_ms': start_ms,
            'end_ms': end_ms,
            'start_token': start_token,
            'end_token': end_token
        })
        prev_end = end_ms
    return converted

def npywrite(destpath, arr):
    if os.path.exists(destpath):
        return
    destpath = os.path.abspath(destpath)
    destdir = os.path.dirname(destpath)
    if not os.path.exists(destdir):
        os.makedirs(destdir)
    np.save(destpath, arr)
    
def jsonwrite(destpath, arr):
    if os.path.exists(destpath):
        return
    destpath = os.path.abspath(destpath)
    destdir = os.path.dirname(destpath)
    if not os.path.exists(destdir):
        os.makedirs(destdir)
    with open(destpath, "w", encoding="utf-8") as f:
        json.dump(arr, f, ensure_ascii=False, indent=4)

def extract_feature(wav_path, front_end, align_model, text):
    # save audio
    audio_16k = None
    with torch.no_grad():
        audio_16k = load_wav(wav_path, target_sr=16000)
        if audio_16k.size(1) > 16000 * 30:
            audio_16k = torch.split(audio_16k, audio_16k.size(1)//(audio_16k.size(1)//(16000 * 20)+1), dim=1)
            temp_tokens = []
            for item in audio_16k:
                if item.size(1) > 16000 * 0.5:
                    speech_token, speech_token_len = front_end._extract_speech_token(item)
                    speech_token = speech_token.detach().cpu().numpy().flatten()
                    temp_tokens.append(speech_token)
                    speech_token = np.concatenate(temp_tokens)
        else:
            speech_token, speech_token_len = front_end._extract_speech_token(audio_16k)
        speech_token = speech_token.detach().cpu().numpy().flatten().astype(str)
        codec_len = len(speech_token)
        speech_token = " ".join(speech_token)
            
        if audio_16k is None:
            audio_16k = load_wav(wav_path, target_sr=16000)
        alignment = align_model.generate(input=(audio_16k[0], text), data_type=("sound", "text"), disable_pbar=True)[0]
        info_json = convert_to_token_based_encoding(alignment)
    return speech_token, codec_len, info_json

def get_line_count(tsv_path):
    with open(tsv_path, "r") as f:
        return sum(1 for _ in f if len(_) > 2)  # 过滤掉空行

def process_task(args):
    total_lines = get_line_count(args.tsv)
    slice_len = total_lines // args.np
    start = slice_len * args.p
    end = total_lines if args.p == args.np - 1 else slice_len * (args.p + 1)
    print(f"{start}, {end}\n")
    infos = []
    with open(args.tsv, "r") as rf:
        for i, line in enumerate(rf):
            if i >= start and i < end:
                if len(line) > 2:
                    file_id, wav_path, text = line.strip().split("\t")
                    infos.append((file_id, wav_path, text))
            elif i >= end:
                break  
    print(f"Process {args.p}: handling lines {start} to {end}")
    return infos

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--tsv', type=str, help="")
    parser.add_argument('--p', type=int, help="")
    parser.add_argument('--np', type=int, help="")
    args = parser.parse_args()
    ##############################################################
    cosyvoice = CosyVoice2('./pretrained_models/CosyVoice2-0.5B', load_jit=True, load_onnx=False, load_trt=False, only_front=True)
    front_end = cosyvoice.frontend
    sample_rate = cosyvoice.sample_rate
    # front_end = None
    local_path_root="./pretrained_models/paraformer/modelscope_models"
    align_model = AutoModel(model=f"{local_path_root}/speech_timestamp_prediction-v1-16k-offline")
    # align_model = None
    ##################################################
    split = os.path.basename(args.tsv).split(".")[0]
    home = f"./datas/1st_stage_alignment/{split}"
    os.makedirs(os.path.join(home, "logs"), exist_ok=True)
    ##################################################
    infos = []
    with open(args.tsv, "r") as rf:
        for line in rf.readlines():
            if len(line) > 2:
                file_id, wav_path, text = line.strip().split("\t")
                infos.append((file_id, wav_path, text))
                
    slice_len = len(infos) // args.np
    if args.p != args.np - 1:
        start = slice_len * args.p
        end = slice_len * (args.p + 1)
    else:
        start = slice_len * args.p
        end = len(infos)
    infos = process_task(args)
    all_infos = {}
    with open(f"{home}/logs/err_{args.p}.log", "w") as wf_e:
        with open(f"{home}/logs/info_{args.p}.json", "w") as wf_info:
            for file_id, wav_path, text in tqdm(infos):
                try:
                    speech_token, codec_len, info_json = extract_feature(wav_path, front_end, align_model, text)
                    all_infos[f"{split}-{file_id}"] = {
                        "t": text,
                        "c": speech_token,  
                        "a": info_json,
                        "cl": codec_len,
                    }
                except Exception as e:
                    print(e)
                    print(str(file_id) + str(e), file=wf_e, flush=True)
            json.dump(all_infos, wf_info, ensure_ascii=False, indent=4)
  