import os
import json
import logging
import sys
import argparse

# 将项目根目录添加到Python的模块搜索路径中
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

import config_rl as config
import utils
from jinja2 import Template
from tqdm import tqdm
from concurrent.futures import ThreadPoolExecutor, as_completed

logging.basicConfig(
    level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
)


def generate_and_save_rtsd(metadata_filepath: str):
    """
    为一个元数据文件生成并保存RTSD。

    :param metadata_filepath: 元数据文件的完整路径
    """
    try:
        # 1. 读取元数据
        with open(metadata_filepath, "r", encoding="utf-8") as f:
            metadata = json.load(f)

        # 2. 程序化生成基础描述
        base_template = Template(config.RTSD_BASE_TEMPLATE)
        base_description = base_template.render(metadata)

        # 3. LLM润色
        refinement_template = Template(config.RTSD_REFINEMENT_PROMPT)
        final_prompt = refinement_template.render(base_description=base_description)

        rtsd_content = utils.call_llm(prompt=final_prompt)

        if not rtsd_content:
            logging.warning(
                f"RTSD generation failed for {metadata_filepath}, LLM returned empty string."
            )
            return

        # 4. 保存到缓存文件
        os.makedirs(config.RTSD_CACHE_DIR, exist_ok=True)
        base_filename = os.path.basename(metadata_filepath)
        rtsd_filename = os.path.splitext(base_filename)[0] + ".txt"
        rtsd_filepath = os.path.join(config.RTSD_CACHE_DIR, rtsd_filename)

        with open(rtsd_filepath, "w", encoding="utf-8") as f:
            f.write(rtsd_content)

    except Exception as e:
        logging.error(f"Error processing {metadata_filepath}: {e}", exc_info=True)


def main():
    """
    主函数，并发地为所有元数据文件生成RTSD。
    """
    logging.info("Starting RTSD generation process...")

    try:
        metadata_files = [
            os.path.join(config.METADATA_DIR, f)
            for f in os.listdir(config.METADATA_DIR)
            if f.endswith(".json")
        ]
    except FileNotFoundError:
        logging.error(
            f"Metadata directory not found: {config.METADATA_DIR}. Please create it and add metadata files."
        )
        return

    if not metadata_files:
        logging.warning(f"No metadata .json files found in {config.METADATA_DIR}.")
        return

    # --- 新增逻辑：筛选出需要处理的文件 ---
    files_to_process = []
    for meta_path in metadata_files:
        base_filename = os.path.basename(meta_path)
        rtsd_filename = os.path.splitext(base_filename)[0] + ".txt"
        rtsd_filepath = os.path.join(config.RTSD_CACHE_DIR, rtsd_filename)
        if not os.path.exists(rtsd_filepath):
            files_to_process.append(meta_path)
        
    if not files_to_process:
        logging.info("All metadata files already have corresponding RTSD. Nothing to do.")
        return
        
    logging.info(f"Found {len(metadata_files)} total metadata files. Need to process {len(files_to_process)} of them.")
    # --- 新增逻辑结束 ---

    # 使用ThreadPoolExecutor进行并发处理
    with ThreadPoolExecutor(
        max_workers=512
    ) as executor:  # max_workers可以根据您的API速率限制进行调整
        # --- 修改这里：只提交筛选后的文件列表 ---
        futures = [
            executor.submit(generate_and_save_rtsd, filepath)
            for filepath in files_to_process
        ]

        for future in tqdm(
            as_completed(futures), total=len(files_to_process), desc="Generating New RTSDs"
        ):
            pass

    logging.info("RTSD generation process completed.")
    logging.info(f"Cached RTSD files are saved in: {config.RTSD_CACHE_DIR}")

# --- 2. 新增的测试函数 ---
def run_test():
    """
    运行一个单独的测试用例，展示模板生成和LLM调用的结果。
    """
    logging.info("--- Running in Test Mode ---")

    # 1. 创建一份复杂的假元数据，用于测试模板的所有逻辑分支
    dummy_metadata = {
        "scene_type": "speech", # 测试语音场景
        "source_count": 2,
        "source_events": [
            {
                "class": "speech",
                "dataset": "emilia_zh",
                "transcript": "你好，今天天气怎么样？",
                "azimuth": 45,
                "elevation": 0,
                "distance": 2.5
            },
            {
                "class": "A vibrant pop song with a catchy synth melody and energetic female vocals.",
                "dataset": "musiccaps",
                "transcript": None, # 音乐没有稿件
                "azimuth": 270,
                "elevation": 10,
                "distance": 4.0
            }
        ],
        "room_acoustics": {
            "size_category": "large",
            "reverb_category": "high",
            "dimensions_m": [10.1, 8.2, 4.3],
            "rt60_s": 1.5
        }
    }
    
    # 将假元数据保存到一个临时文件
    # test_dir = "test_output"
    # os.makedirs(test_dir, exist_ok=True)
    # dummy_metadata_path = os.path.join(test_dir, "scene_test_000001.json")
    # with open(dummy_metadata_path, 'w', encoding='utf-8') as f:
    #     json.dump(dummy_metadata, f, indent=2, ensure_ascii=False)
    
    logging.info(f"Dummy metadata generated: {json.dumps(dummy_metadata, indent=2, ensure_ascii=False)}")

    # 2. 程序化生成基础描述
    base_template = Template(config.RTSD_BASE_TEMPLATE)
    base_description = base_template.render(dummy_metadata)
    
    # 3. 打印将要发送给LLM的Prompt，方便调试
    refinement_template = Template(config.RTSD_REFINEMENT_PROMPT)
    final_prompt = refinement_template.render(base_description=base_description)
    
    print("\n" + "="*25 + " Generated Prompt to LLM " + "="*25)
    print(final_prompt)
    print("="*75 + "\n")

    # 4. 调用LLM（会真实消耗额度）
    logging.info("Calling LLM API...")
    rtsd_content = utils.call_llm(prompt=final_prompt)
    
    if not rtsd_content:
        logging.error("Test failed: LLM returned an empty response.")
        return
        
    print("\n" + "="*28 + " LLM Response " + "="*29)
    print(rtsd_content)
    print("="*75 + "\n")

    # 5. 将生成的RTSD保存到测试目录
    # test_rtsd_dir = os.path.join(test_dir, "rtsd_cache")
    # os.makedirs(test_rtsd_dir, exist_ok=True)
    # rtsd_filepath = os.path.join(test_rtsd_dir, "scene_test_000001.txt")
    # with open(rtsd_filepath, 'w', encoding='utf-8') as f:
    #     f.write(rtsd_content)
        
    logging.info(f"Test RTSD file generated:\n{rtsd_content}")
    logging.info("--- Test Completed Successfully ---")

if __name__ == "__main__":
    # 创建命令行参数解析器
    parser = argparse.ArgumentParser(description="Generate RTSD for audio scenes.")
    parser.add_argument(
        '--test',
        action='store_true', # 当出现--test时，其值为True
        help='Run a single test case instead of processing the whole dataset.'
    )
    args = parser.parse_args()

    # 根据参数决定是运行测试还是主程序
    if args.test:
        run_test()
    else:
        main()

