import os
import json
import random
import numpy as np
import soundfile as sf
from scipy.signal import convolve, windows
import argparse
from tqdm import tqdm
from concurrent.futures import ProcessPoolExecutor
import logging
import pandas as pd
import librosa
import re
from datasets import load_from_disk

# --- 日志设置 ---
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

# --- 1. 配置区域 (与您提供的文件一致) ---
# BRIR 数据集路径
BRIR_METADATA_PATH = 'data/brir/brir_dataset_metadata_sft.json'
BRIR_DIR = 'data/brir/files'

# 干声音源数据集路径
CLOTHO_BASE_PATH = '/data4/clotho_v2'
FSD50K_BASE_PATH = '/data2/wl/FSD50K/raw/FSD50K'
EMILIA_ZH_PATH = "/data2/wl/Emilia-Dataset-50k/ZH"
EMILIA_EN_PATH = "/data2/wl/Emilia-Dataset-50k/EN"
MUSICCAPS_CLEANED_PATH = "/data2/wl/MusicCaps-cleaned"

# 输出路径
OUTPUT_AUDIO_DIR = '/data2/wl/RL_benchmark/audio'
OUTPUT_METADATA_DIR = '/data2/wl/RL_benchmark/data/metadata'

# 合成参数
TARGET_FS = 44100
TOTAL_SCENES_TO_GENERATE = 10000
SCENE_DURATION = 5

# --- 辅助函数 (保持不变) ---
def get_next_start_index(metadata_dir: str) -> int:
    if not os.path.exists(metadata_dir):
        logging.info(f"元数据目录 {metadata_dir} 未找到。从索引 0 开始。")
        return 0
    pattern = re.compile(r"scene_(\d{6})\.json")
    max_id = -1
    for filename in os.listdir(metadata_dir):
        match = pattern.match(filename)
        if match: max_id = max(max_id, int(match.group(1)))
    next_id = max_id + 1
    logging.info(f"找到最大场景ID: {max_id}. 下一个场景将从 {next_id} 开始。")
    return next_id

def load_audio(path, target_sr):
    try:
        wav, sr = librosa.load(path, sr=target_sr, mono=True)
        return wav
    except Exception as e:
        logging.error(f"加载音频文件 {path} 出错: {e}")
        return None

def load_brir(path, brir_len_samples=8192):
    try:
        brir = np.fromfile(path, dtype=np.float32).reshape((-1, 2))
        if brir.shape[0] < brir_len_samples:
             brir = np.pad(brir, ((0, brir_len_samples - brir.shape[0]), (0, 0)))
        return brir[:brir_len_samples, :]
    except Exception as e:
        logging.error(f"加载BRIR文件 {path} 出错: {e}")
        return None

def load_brir_metadata(path):
    logging.info(f"正在从 {path} 加载 BRIR 元数据...")
    try:
        with open(path, 'r', encoding='utf-8') as f:
            metadata = json.load(f)
        logging.info("BRIR 元数据加载成功。")
        return metadata
    except Exception as e:
        logging.error(f"加载BRIR元数据文件 {path} 时出错: {e}", exc_info=True)
        raise

# --- 2. 数据准备函数 (保持不变) ---
def prepare_audio_source_lists():
    logging.info("准备所有类型的音源列表...")
    
    # 通用音频 (general)
    general_sources = []
    try: # Clotho
        df = pd.read_csv(os.path.join(CLOTHO_BASE_PATH, 'clotho_captions_development.csv'))
        clotho_audio_dir = os.path.join(CLOTHO_BASE_PATH, 'development')
        for _, row in df.iterrows():
            fp = os.path.join(clotho_audio_dir, row['file_name'])
            if os.path.exists(fp): general_sources.append({"dataset": "clotho", "path": fp, "captions": [row[f'caption_{i}'] for i in range(1, 6)]})
    except Exception as e: logging.error(f"处理 Clotho 元数据出错: {e}")
    try: # FSD50K
        df_fsd = pd.read_csv(os.path.join(FSD50K_BASE_PATH, 'FSD50K.ground_truth/dev.csv'))
        fsd_labels_map = {str(r['fname']): r['labels'].split(',') for _, r in df_fsd.iterrows()}
        fsd_audio_dir = os.path.join(FSD50K_BASE_PATH, 'FSD50K.dev_audio')
        for fname, labels in fsd_labels_map.items():
            fp = os.path.join(fsd_audio_dir, f"{fname}.wav")
            if os.path.exists(fp): general_sources.append({"dataset": "fsd50k", "path": fp, "captions": labels})
    except Exception as e: logging.error(f"处理 FSD50K 元数据出错: {e}")
    logging.info(f"找到 {len(general_sources)} 个通用音频源 (Clotho, FSD50K)。")

    # 语音 (speech)
    speech_sources = []
    loaded_emilia_datasets = {}
    try: # EN
        ds_en = load_from_disk(EMILIA_EN_PATH)
        speech_sources.extend([{"dataset": "emilia_en", "index": i, "captions": [ds_en[i]['json']['text']]} for i in range(len(ds_en))])
        loaded_emilia_datasets['en'] = ds_en
        logging.info(f"加载了 {len(ds_en)} 条英文语音源。")
    except Exception as e: logging.error(f"加载 Emilia EN 数据集出错: {e}")
    try: # ZH
        ds_zh = load_from_disk(EMILIA_ZH_PATH)
        speech_sources.extend([{"dataset": "emilia_zh", "index": i, "captions": [ds_zh[i]['json']['text']]} for i in range(len(ds_zh))])
        loaded_emilia_datasets['zh'] = ds_zh
        logging.info(f"加载了 {len(ds_zh)} 条中文语音源。")
    except Exception as e: logging.error(f"加载 Emilia ZH 数据集出错: {e}")
    
    # 音乐 (music)
    music_sources = []
    try:
        ds_music = load_from_disk(MUSICCAPS_CLEANED_PATH)
        music_sources.extend([{"dataset": "musiccaps", "path": ds_music[i]['file_path'], "captions": [ds_music[i]['caption']]} for i in range(len(ds_music))])
        logging.info(f"加载了 {len(ds_music)} 条音乐源。")
    except Exception as e: logging.error(f"加载 MusicCaps 清洗数据集出错: {e}")

    return { "general": general_sources, "speech": speech_sources, "music": music_sources, "loaded_emilia": loaded_emilia_datasets }

# --- 3. 场景合成函数 (已根据新要求重构) ---
def synthesize_scene(scene_id, reverb_type, num_sources, brir_scene_info, audio_data_pools):
    """根据新指令合成一个场景：必须包含至少一个通用音源，整体分布为7:2:1。"""
    try:
        if len(brir_scene_info['sources']) < num_sources:
            logging.warning(f"BRIR 场景 {brir_scene_info['scene_id']} 位置不足 ({len(brir_scene_info['sources'])})，需要 {num_sources}。跳过。")
            return None
        selected_positions = random.sample(brir_scene_info['sources'], num_sources)
        
        scene_len_samples = int(SCENE_DURATION * TARGET_FS)
        final_mix = np.zeros((scene_len_samples, 2))
        source_events_metadata = []

        # ========================[ 核心修改逻辑 ]========================
        # 1. 定义新的分布比例 (通用:语音:音乐 = 7:2:1)
        sound_type_options = ['general', 'speech', 'music']
        sound_type_weights = [0.7, 0.2, 0.1]
        
        # 2. 强制为场景分配至少一个 'general' 类型的音源
        source_types_for_scene = ['general']
        
        # 3. 为场景中剩余的声源名额，按新的比例随机选择类型
        num_remaining_sources = num_sources - 1
        if num_remaining_sources > 0:
            remaining_types = random.choices(
                sound_type_options, 
                weights=sound_type_weights, 
                k=num_remaining_sources
            )
            source_types_for_scene.extend(remaining_types)
            
        # 4. 打乱最终的类型列表，避免 'general' 总在第一个位置处理
        random.shuffle(source_types_for_scene)
        # ========================[ 修改结束 ]========================

        # 使用预先生成好的类型列表来合成场景
        for pos_info, chosen_sound_type in zip(selected_positions, source_types_for_scene):
            source_pool = audio_data_pools.get(chosen_sound_type, [])
            
            if not source_pool:
                logging.warning(f"干声类型 {chosen_sound_type} 的音源池为空，跳过此声源。")
                continue
            
            audio_info = random.choice(source_pool)
            
            # --- 加载干声 (逻辑不变) ---
            source_wav = None
            if 'path' in audio_info: 
                source_wav = load_audio(audio_info['path'], TARGET_FS)
            elif 'index' in audio_info: 
                dataset_name = audio_info['dataset'].split('_')[-1]
                dataset_obj = audio_data_pools['loaded_emilia'][dataset_name]
                sample = dataset_obj[audio_info['index']]
                source_wav = sample['mp3']['array']
                if sample['mp3']['sampling_rate'] != TARGET_FS:
                    source_wav = librosa.resample(y=source_wav, orig_sr=sample['mp3']['sampling_rate'], target_sr=TARGET_FS)
            if source_wav is None: continue

            # --- 空间化处理 (逻辑不变) ---
            brir_path = os.path.join(BRIR_DIR, pos_info['brir_filename'])
            brir = load_brir(brir_path)
            if brir is None: continue
            
            spatialized_wav = np.stack([
                convolve(source_wav, brir[:, 0], mode='full'),
                convolve(source_wav, brir[:, 1], mode='full')
            ], axis=1)

            # --- 混音与元数据 (逻辑不变) ---
            if len(spatialized_wav) < scene_len_samples:
                start = random.randint(0, scene_len_samples - len(spatialized_wav))
                final_mix[start:start + len(spatialized_wav), :] += spatialized_wav
            else:
                final_mix += spatialized_wav[:scene_len_samples, :]

            azimuth = pos_info['relative_azimuth_deg']
            event_meta = { "dataset": audio_info['dataset'], "azimuth": int(azimuth), "elevation": int(pos_info['relative_elevation_deg']), "distance": round(pos_info['relative_distance_m'], 1) }
            if chosen_sound_type == 'speech':
                event_meta['class'] = 'speech'
                event_meta['transcript'] = audio_info['captions'][0]
            else: # general or music
                event_meta['class'] = random.choice(audio_info['captions'])
                event_meta['transcript'] = None
            source_events_metadata.append(event_meta)

        # --- 标准化并保存 ---
        peak = np.max(np.abs(final_mix))
        if peak > 1e-6: final_mix /= peak
        
        output_audio_path = os.path.join(OUTPUT_AUDIO_DIR, f"scene_{scene_id:06d}.wav")
        sf.write(output_audio_path, final_mix, TARGET_FS)

        scene_metadata = {
            "audio_path": os.path.relpath(output_audio_path, os.path.dirname(OUTPUT_METADATA_DIR)),
            "scene_type": "mixed", 
            "reverb_type": reverb_type,
            "source_count": len(source_events_metadata),
            "source_events": source_events_metadata,
            "room_acoustics": brir_scene_info['room']
        }
        
        with open(os.path.join(OUTPUT_METADATA_DIR, f"scene_{scene_id:06d}.json"), 'w', encoding='utf-8') as f:
            json.dump(scene_metadata, f, indent=2, ensure_ascii=False)
        
        return scene_id

    except Exception as e:
        logging.error(f"处理场景 {scene_id} 失败: {e}", exc_info=True)
        return None

# --- 4. 主函数 (已重构) ---
def main(args):
    """主调度函数，按SFT阶段新需求创建并分发任务。"""
    os.makedirs(OUTPUT_AUDIO_DIR, exist_ok=True)
    os.makedirs(OUTPUT_METADATA_DIR, exist_ok=True)
    
    brir_metadata = load_brir_metadata(BRIR_METADATA_PATH)
    audio_data_pools = prepare_audio_source_lists()
    
    start_index = get_next_start_index(OUTPUT_METADATA_DIR)
    total_scenes_to_gen = args.num_scenes
    
    # --- 步骤一：计算混响和声源数分布 ---
    num_anechoic = round(total_scenes_to_gen * 0.2)
    num_normal_reverb = total_scenes_to_gen - num_anechoic
    
    num_single_source = round(total_scenes_to_gen * 0.6)
    num_dual_source = round(total_scenes_to_gen * 0.3)
    num_triple_source = total_scenes_to_gen - num_single_source - num_dual_source
    
    logging.info(f"计划生成 {total_scenes_to_gen} 个混合场景。分布如下：")
    logging.info(f"  - 混响: {num_normal_reverb} (Normal), {num_anechoic} (Anechoic)")
    logging.info(f"  - 声源数: {num_single_source} (1), {num_dual_source} (2), {num_triple_source} (3)")
    logging.info(f"  - 新规则: 每个场景保证至少1个通用音源, 整体干声类型按 7:2:1 (通用:语音:音乐) 的比例生成。")
    logging.info(f"输出至目录: {OUTPUT_AUDIO_DIR}, 元数据目录: {OUTPUT_METADATA_DIR}")

    # --- 步骤二：创建并打乱任务属性列表 ---
    reverb_types = ['normal'] * num_normal_reverb + ['anechoic'] * num_anechoic
    source_counts = [1] * num_single_source + [2] * num_dual_source + [3] * num_triple_source
    
    random.shuffle(reverb_types)
    random.shuffle(source_counts)

    # --- 步骤三：预筛选BRIR池 (逻辑不变) ---
    brir_scenes = brir_metadata['scenes']
    normal_brir_pool_large_high = [s for s in brir_scenes if s['room']['reverb_type'] == 'normal' and s['room']['size_category'] == 'large']
    normal_brir_pool_small_medium = [s for s in brir_scenes if s['room']['reverb_type'] == 'normal' and s['room']['size_category'] in ['small', 'medium']]
    anechoic_brir_pool = [s for s in brir_scenes if s['room']['reverb_type'] == 'anechoic']

    # --- 步骤四：打包最终任务列表 (逻辑不变) ---
    tasks = []
    for i in range(total_scenes_to_gen):
        scene_id = start_index + i
        reverb_type = reverb_types[i]
        num_sources = source_counts[i]
        
        brir_scene_info = None
        if reverb_type == 'anechoic':
            if anechoic_brir_pool: brir_scene_info = random.choice(anechoic_brir_pool)
        else: # normal
            if random.random() < 0.2:
                 if normal_brir_pool_large_high: brir_scene_info = random.choice(normal_brir_pool_large_high)
            if not brir_scene_info:
                if normal_brir_pool_small_medium: brir_scene_info = random.choice(normal_brir_pool_small_medium)

        if not brir_scene_info:
            logging.warning(f"无法为任务 {i} 找到合适的BRIR场景，将从全部场景中随机选择。")
            brir_scene_info = random.choice(brir_scenes) # Fallback

        tasks.append((scene_id, reverb_type, num_sources, brir_scene_info, audio_data_pools))
        
    logging.info(f"已创建 {len(tasks)} 个任务, 场景ID范围: {start_index} -> {start_index + len(tasks) - 1}")
    
    # --- 并行执行 ---
    with ProcessPoolExecutor(max_workers=args.workers) as executor:
        results = list(tqdm(executor.map(synthesize_scene_wrapper, tasks), total=len(tasks), desc="合成SFT音频场景"))

    successful_scenes = [res for res in results if res is not None]
    logging.info(f"全部任务处理完毕。成功生成 {len(successful_scenes)} / {len(tasks)} 个场景。")

def synthesize_scene_wrapper(args_tuple):
    """辅助函数，用于解包传递给子进程的参数。"""
    scene_id, reverb_type, num_sources, brir_scene_info, audio_data_pools = args_tuple
    return synthesize_scene(scene_id, reverb_type, num_sources, brir_scene_info, audio_data_pools)

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description="根据SFT阶段的精确分布合成空间音频场景。")
    parser.add_argument('--num_scenes', type=int, default=TOTAL_SCENES_TO_GENERATE, help='要生成的场景总数。')
    parser.add_argument('-w', '--workers', type=int, default=os.cpu_count(), help='使用的并行进程数。')
    args = parser.parse_args()
    main(args)