import json
import os
import argparse
from tqdm import tqdm

def filter_json_by_audiolist(input_json, output_json, record_file):
    """
    根据一个包含音频文件名的文本文件，筛选一个JSON文件。
    """
    # 1. 读取有效的音频文件名列表
    print(f"正在从 '{record_file}' 读取有效的音频文件名...")
    try:
        with open(record_file, 'r') as f:
            # 使用集合(set)进行快速查找
            valid_filenames = {line.strip() for line in f if line.strip()}
    except FileNotFoundError:
        print(f"错误: 记录文件 '{record_file}' 未找到。请先运行 process_audios.py。")
        return
        
    print(f"成功加载 {len(valid_filenames)} 个有效文件名。")

    # 2. 读取并筛选JSON文件
    print(f"正在读取并筛选JSON文件: '{input_json}'...")
    try:
        with open(input_json, 'r') as f:
            data = json.load(f)
    except FileNotFoundError:
        print(f"错误: 输入的JSON文件 '{input_json}' 未找到。")
        return
    except json.JSONDecodeError:
        print(f"错误: JSON文件 '{input_json}' 格式无效。")
        return
        
    filtered_data = []
    for item in tqdm(data, desc="筛选JSON条目"):
        # 从 "audio_id" 字段（如 "./test-mini-audios/file.wav"）中提取文件名
        audio_filename = os.path.basename(item.get('audio_id', ''))
        
        # 如果提取出的文件名在我们的有效列表中，则保留该条目
        if audio_filename in valid_filenames:
            filtered_data.append(item)
            
    # 3. 将筛选后的结果写入新的JSON文件
    print(f"正在将筛选后的 {len(filtered_data)} 个条目写入 '{output_json}'...")
    with open(output_json, 'w', encoding='utf-8') as f:
        # indent=4 格式化输出，ensure_ascii=False 保证中文等字符正常显示
        json.dump(filtered_data, f, indent=4, ensure_ascii=False)
        
    print("\n--- 处理完成 ---")
    print(f"原始JSON条目数: {len(data)}")
    print(f"筛选后JSON条目数: {len(filtered_data)}")
    print(f"结果已保存至: {output_json}")


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description="根据文件名列表筛选JSON文件。")
    parser.add_argument('--input_json', type=str, required=True, help='原始JSON文件路径。')
    parser.add_argument('--output_json', type=str, required=True, help='筛选后的JSON文件输出路径。')
    parser.add_argument('--record_file', type=str, default='short_audios.txt', help='记录短音频文件名的文本文件路径。')
    
    args = parser.parse_args()
    
    filter_json_by_audiolist(args.input_json, args.output_json, args.record_file)

"""
python filter_mmau_json.py \
    --input_json ./mmau-test-mini.json \
    --output_json ./mmau-test-mini-short.json
"""