import json
import multiprocessing as mp
from tqdm import tqdm
from examples.wescon.models.cosyvoice2.cli.frontend import CosyVoiceTextFrontEnd
import random

def load_json_files(file_paths):
    merged_data = {}
    for file_path in file_paths:
        try:
            with open(file_path, 'r', encoding='utf-8') as f:
                data = json.load(f)
                merged_data.update(data)
        except Exception as e:
            print(f"Error loading {file_path}: {e}")
    return merged_data

def process_json_chunk(chunk):
    text_frontend = CosyVoiceTextFrontEnd()
    for key in tqdm(chunk.keys(), desc=f"Processing chunk"):
        try:
            chunk[key]["t"] = text_frontend.text_normalize(chunk[key]["t"])
        except:
            print(chunk[key])
            continue
        del chunk[key]["cl"]
        for align_idx in range(len(chunk[key]["a"])):
            try:
                del chunk[key]["a"][align_idx]["start_ms"]
                del chunk[key]["a"][align_idx]["end_ms"]
            except KeyError:
                del chunk[key]["a"][align_idx]["start_s"]
                del chunk[key]["a"][align_idx]["end_s"]
    return chunk

if __name__ == "__main__":
    file_paths = [
        "./datas/1st_stage_alignment/aishell/logs/info.json",
        "./datas/1st_stage_alignment/LibriSpeech_100/logs/info.json",
    ]
    output_path = "./datas/1st_stage_alignment/for_train/aishell_ls100_normed.json"

    print("Loading JSON files...")
    merged_data = load_json_files(file_paths)
    print(f"Total keys before processing: {len(merged_data.keys())}")

    shuffled_items = list(merged_data.items())
    random.shuffle(shuffled_items)
    del merged_data  

    num_processes = min(mp.cpu_count(), 40) 
    chunk_size = len(shuffled_items) // num_processes
    chunks = [dict(shuffled_items[i * chunk_size:(i + 1) * chunk_size]) for i in range(num_processes)]
    del shuffled_items  
    
    with mp.Pool(processes=num_processes) as pool:
        processed_chunks = pool.map(process_json_chunk, chunks)
    del chunks  

    final_merged_data = {}
    for chunk in processed_chunks:
        final_merged_data.update(chunk)
    del processed_chunks  

    print(f"Total keys after processing: {len(final_merged_data.keys())}")

    with open(output_path, 'w', encoding='utf-8') as out_f:
        json.dump(final_merged_data, out_f, ensure_ascii=False, indent=4)
    del final_merged_data  

    print(f"Processed JSON saved to {output_path}")