# 文件名: generate_benchmark.py (已更新)

import os
import json
import random
import logging
import argparse
import math
from concurrent.futures import ThreadPoolExecutor, as_completed
from tqdm import tqdm
import sys
import time # 引入time模块

# 将项目根目录添加到Python的模块搜索路径中
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

import config_rl as config
import utils

# --- Pre-computation and Helper Functions (无变化) ---
def convert_azimuth_to_direction(azimuth: int) -> str:
    # ... (此函数代码不变)
    if 337.5 <= azimuth <= 360 or 0 <= azimuth < 22.5: return "正前方"
    elif 22.5 <= azimuth < 67.5: return "右前方"
    elif 67.5 <= azimuth < 112.5: return "正右方"
    elif 112.5 <= azimuth < 157.5: return "右后方"
    elif 157.5 <= azimuth < 202.5: return "正后方"
    elif 202.5 <= azimuth < 247.5: return "左后方"
    elif 247.5 <= azimuth < 292.5: return "正左方"
    elif 292.5 <= azimuth < 337.5: return "左前方"
    return "未知方位"

def precompute_for_task(task_type: str, metadata: dict) -> dict:
    # ... (此函数代码不变)
    precomputed_data = {}
    if task_type == "reasoning_spatial_relationship":
        sources = metadata.get("source_events", [])
        if len(sources) < 2: return {}
        sources.sort(key=lambda x: x['distance'])
        precomputed_data['closest_source'] = sources[0]['class']
        precomputed_data['farthest_source'] = sources[-1]['class']
        source_a, source_b = sources[0], sources[1]
        direction_a = convert_azimuth_to_direction(source_a['azimuth'])
        direction_b = convert_azimuth_to_direction(source_b['azimuth'])
        precomputed_data['relative_position_example'] = \
            f"声源 '{source_a['class']}' (位于{direction_a}) 相对于 声源 '{source_b['class']}' (位于{direction_b}) 的位置。"
    return precomputed_data

def check_scene_suitability(task_type: str, metadata: dict) -> bool:
    # ... (此函数代码不变)
    num_sources = metadata.get("source_count", 0)
    if task_type in ["integration_attribute_binding", "reasoning_spatial_relationship", "reasoning_counterfactual"] and num_sources < 2:
        return False
    return True

def load_available_scenes(metadata_dir: str, rtsd_dir: str) -> list[str]:
    # ... (此函数代码不变)
    logging.info("Scanning for available scenes...")
    try:
        metadata_files = {f.split('.')[0] for f in os.listdir(metadata_dir) if f.endswith('.json')}
        rtsd_files = {f.split('.')[0] for f in os.listdir(rtsd_dir) if f.endswith('.txt')}
    except FileNotFoundError as e:
        logging.error(f"Directory not found: {e.filename}. Please check your config_rl.py paths.")
        return []
    available_scenes = list(metadata_files.intersection(rtsd_files))
    logging.info(f"Found {len(available_scenes)} scenes with both metadata and RTSD files.")
    return available_scenes

# --- MODIFICATION START ---
# 旧的 create_task_queue 已被新的 create_resume_task_queue 替代

def analyze_existing_dataset(path: str) -> dict[str, int]:
    """
    分析现有的.jsonl数据集，统计每种任务类型的数量。
    """
    if not os.path.exists(path):
        return {}
    
    logging.info(f"Analyzing existing dataset at: {path}")
    counts = {}
    with open(path, 'r', encoding='utf-8') as f:
        for line in f:
            try:
                data = json.loads(line)
                task_type = data.get('task_type')
                if task_type:
                    counts[task_type] = counts.get(task_type, 0) + 1
            except (json.JSONDecodeError, KeyError):
                logging.warning(f"Skipping malformed line in dataset: {line.strip()}")
                continue
    
    total_existing = sum(counts.values())
    logging.info(f"Found {total_existing} existing questions in the dataset.")
    for task, count in counts.items():
        logging.info(f"  - {task}: {count} questions")
    return counts

def create_resume_task_queue(target_total_questions: int, existing_counts: dict) -> list[str]:
    """
    根据目标总量和现有数量，计算还需要生成的任务，并创建队列。
    """
    queue = []
    
    # 1. 计算每个大类的目标总数
    target_p_count = math.ceil(target_total_questions * config.TASK_DISTRIBUTION['perception'])
    target_i_count = math.ceil(target_total_questions * config.TASK_DISTRIBUTION['integration'])
    # 确保总数正确，避免浮点数误差
    target_r_count = target_total_questions - target_p_count - target_i_count

    # 2. 计算每个子任务需要补充的数量
    all_target_counts = {}
    all_target_counts.update({task: target_p_count / len(config.PERCEPTION_TASKS) for task in config.PERCEPTION_TASKS})
    all_target_counts.update({task: target_i_count / len(config.INTEGRATION_TASKS) for task in config.INTEGRATION_TASKS})
    all_target_counts.update({task: target_r_count / len(config.REASONING_TASKS) for task in config.REASONING_TASKS})

    logging.info("Calculating tasks needed to reach target distribution:")
    for task_type in config.ALL_TASKS:
        target_count = math.ceil(all_target_counts.get(task_type, 0))
        current_count = existing_counts.get(task_type, 0)
        needed_count = max(0, target_count - current_count)
        
        if needed_count > 0:
            logging.info(f"  - {task_type}: Need to generate {needed_count} more (current: {current_count}, target: {target_count})")
            queue.extend([task_type] * needed_count)
            
    random.shuffle(queue)
    return queue

# --- MODIFICATION END ---


# --- Main Worker Function (无变化) ---
def generate_single_question(task_type: str, available_scenes: list[str]) -> dict:
    # ... (此函数代码不变)
    max_scene_draws = 10
    for _ in range(max_scene_draws):
        try:
            scene_id = random.choice(available_scenes)
            metadata_path = os.path.join(config.METADATA_DIR, f"{scene_id}.json")
            rtsd_path = os.path.join(config.RTSD_CACHE_DIR, f"{scene_id}.txt")
            with open(metadata_path, 'r', encoding='utf-8') as f: metadata = json.load(f)
            with open(rtsd_path, 'r', encoding='utf-8') as f: rtsd_content = f.read()

            if not check_scene_suitability(task_type, metadata): continue
            
            precomputed_data = precompute_for_task(task_type, metadata)
            prompt_template = utils.load_prompt_template(task_type)
            prompt = prompt_template.render(rtsd=rtsd_content, precomputed_data=precomputed_data)
            llm_response = utils.call_llm(prompt)

            if not llm_response:
                return {"status": "error", "task": task_type, "reason": "LLM returned empty string."}
            if llm_response.strip() == config.UNABLE_TO_GENERATE_KEYWORD:
                logging.warning(f"Task {task_type} is unsuitable for scene {scene_id} (LLM judged). Redrawing scene...")
                continue
            
            parsed_json = utils.parse_llm_json_output(llm_response)
            if not parsed_json:
                return {"status": "error", "task": task_type, "reason": "Failed to parse LLM JSON response."}

            final_data = {"scene_id": scene_id, "task_type": task_type, "question_data": parsed_json, "source_metadata": metadata}
            return {"status": "success", "data": final_data}
        except Exception as e:
            logging.error(f"An unexpected error occurred while processing task {task_type}: {e}", exc_info=True)
            return {"status": "error", "task": task_type, "reason": str(e)}

    return {"status": "error", "task": task_type, "reason": f"Failed to find a suitable scene after {max_scene_draws} attempts."}

# --- Main Execution ---

def main():
    parser = argparse.ArgumentParser(description="Generate or resume a benchmark dataset for spatial audio understanding.")
    # --- MODIFICATION: 更新了help文本 ---
    parser.add_argument(
        '--num_questions',
        type=int,
        default=config.DEFAULT_TOTAL_QUESTIONS,
        help=f"The target total number of questions in the dataset after this run. Default: {config.DEFAULT_TOTAL_QUESTIONS}"
    )
    args = parser.parse_args()

    os.makedirs(config.OUTPUT_DIR, exist_ok=True)

    available_scenes = load_available_scenes(config.METADATA_DIR, config.RTSD_CACHE_DIR)
    if not available_scenes:
        logging.error("No available scenes to process. Exiting.")
        return

    # --- MODIFICATION START: 断点续传逻辑 ---
    existing_counts = analyze_existing_dataset(config.FINAL_BENCHMARK_PATH)
    task_queue = create_resume_task_queue(args.num_questions, existing_counts)
    
    if not task_queue:
        logging.info("Dataset already meets or exceeds the target counts for all task types. Nothing to do.")
        return
    
    total_existing_questions = sum(existing_counts.values())
    # --- MODIFICATION END ---
    
    logging.info(f"Starting benchmark generation. Need to generate {len(task_queue)} new questions to reach the target of {args.num_questions}.")
    
    successful_generations_this_run = 0
    
    # --- MODIFICATION: 文件以 'a' (追加) 模式打开 ---
    with open(config.FINAL_BENCHMARK_PATH, 'a', encoding='utf-8') as outfile:
        while task_queue:
            logging.info(f"--- Starting generation batch of {len(task_queue)} tasks ---")
            tasks_to_retry = []

            with ThreadPoolExecutor(max_workers=config.MAX_WORKERS) as executor:
                future_to_task = {executor.submit(generate_single_question, task, available_scenes): task for task in task_queue}
                
                # --- MODIFICATION: 更新了tqdm的描述 ---
                progress_desc = f"Generating (Total: {total_existing_questions + successful_generations_this_run}/{args.num_questions})"
                progress = tqdm(as_completed(future_to_task), total=len(task_queue), desc=progress_desc)
                
                for future in progress:
                    result = future.result()
                    
                    if result['status'] == 'success':
                        outfile.write(json.dumps(result['data'], ensure_ascii=False) + '\n')
                        successful_generations_this_run += 1
                        
                        # --- MODIFICATION: 更新进度条描述以反映实时总数 ---
                        new_total = total_existing_questions + successful_generations_this_run
                        progress.set_description(f"Generating (Total: {new_total}/{args.num_questions})")
                        progress.set_postfix({"Succeeded_this_run": successful_generations_this_run, "Failed": len(tasks_to_retry)})
                    else:
                        logging.warning(f"Task {result['task']} failed: {result['reason']}")
                        tasks_to_retry.append(result['task'])

            task_queue = tasks_to_retry
            if task_queue:
                logging.warning(f"{len(task_queue)} tasks failed and will be retried in the next batch.")
                time.sleep(5)

    logging.info("\n" + "="*50)
    logging.info("🎉 Benchmark Generation Completed! 🎉")
    logging.info(f"Successfully generated {successful_generations_this_run} new questions in this run.")
    logging.info("\n--- Final Dataset Distribution ---")
    # 再次调用分析函数以获取最准确的最终统计数据
    final_counts = analyze_existing_dataset(config.FINAL_BENCHMARK_PATH)
    if final_counts:
        # 为了输出顺序一致，按任务类型字母排序
        for task_type in sorted(final_counts.keys()):
            logging.info(f"  - {task_type}: {final_counts[task_type]} questions")
        logging.info("------------------------------------")
        logging.info(f"  Total Verified Questions in File: {sum(final_counts.values())}")
    else:
        logging.info("No questions found in the final dataset file.")
    logging.info(f"Dataset now contains a total of {total_existing_questions + successful_generations_this_run} questions.")
    logging.info(f"Dataset saved to: {config.FINAL_BENCHMARK_PATH}")
    logging.info("="*50)

if __name__ == "__main__":
    main()