# 文件名: generate_sft_data.py

import os
import json
import logging
import argparse
from concurrent.futures import ThreadPoolExecutor, as_completed
from tqdm import tqdm
import sys
import time

# 确保可以导入项目内的模块
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

import config_rl as config
import utils

import re
VALID_SFT_RESPONSE_PATTERN = re.compile(
    r"^\|<think>\|.*?\|</think>\|.*?\|<answer>\|.*?\|</answer>\|$", 
    re.DOTALL
)

def format_options_to_string(options_dict: dict) -> str:
    """将选项字典格式化为 'A: ...\nB: ...' 的字符串。"""
    return "\n".join([f"{key}: {value}" for key, value in options_dict.items()])

def get_processed_ids(output_path: str) -> set:
    """扫描输出文件，返回已处理过的唯一ID集合，用于断点续传。"""
    if not os.path.exists(output_path):
        return set()
    
    processed_ids = set()
    with open(output_path, 'r', encoding='utf-8') as f:
        for line in f:
            try:
                data = json.loads(line)
                # 使用 'metadata' 中的 'scene_id' 和 'task_type' 作为唯一标识符
                if 'metadata' in data and 'scene_id' in data['metadata'] and 'task_type' in data['metadata']:
                    unique_id = f"{data['metadata']['scene_id']}_{data['metadata']['task_type']}"
                    processed_ids.add(unique_id)
            except (json.JSONDecodeError, KeyError):
                continue
    logging.info(f"Found {len(processed_ids)} already processed items in the output file.")
    return processed_ids

def process_single_item(benchmark_item: dict) -> dict:
    """
    处理单个benchmark条目，为其生成SFT数据。
    """
    try:
        # 1. 解析输入
        scene_id = benchmark_item['scene_id']
        task_type = benchmark_item['task_type']
        unique_id = f"{scene_id}_{task_type}"

        question_data = benchmark_item['question_data']
        question = question_data['question']
        options_dict = question_data['options']
        correct_answer_key = question_data['answer']
        correct_answer_text = options_dict[correct_answer_key]

        # 2. 读取缓存的RTSD
        rtsd_path = os.path.join(config.RTSD_CACHE_DIR, f"{scene_id}.txt")
        with open(rtsd_path, 'r', encoding='utf-8') as f:
            rtsd_content = f.read()

        # 3. 构建教师Prompt
        prompt_template = utils.load_prompt_template(config.SFT_TEACHER_PROMPT_NAME)
        options_str = format_options_to_string(options_dict)
        
        prompt = prompt_template.render(
            rtsd=rtsd_content,
            question=question,
            options=options_str,
            correct_answer_key=correct_answer_key,
            correct_answer_text=correct_answer_text
        )

        # 4. 调用LLM API
        llm_response = utils.call_llm(prompt, temperature=0.3) # 使用较低的温度以获得更稳定的格式输出

        # 5. 校验返回结果
        # --- MODIFICATION START: 使用正则表达式进行严格校验 ---
        cleaned_response = llm_response.strip() if llm_response else ""
        if not cleaned_response or not VALID_SFT_RESPONSE_PATTERN.match(cleaned_response):
            logging.warning(f"LLM response for {unique_id} did not match the required regex format. Response: '{cleaned_response[:50]}...'")
            return {"status": "error", "reason": "Invalid format from LLM (Regex mismatch)", "item_id": unique_id}
        # --- MOD-END ---


        # 6. 构建最终的SFT数据
        instruction = f"{question}. Please choose the answer from the following options: {options_str}"
        answer = llm_response.strip()

        final_data = {
            "instruction": instruction,
            "answer": answer,
            "metadata": {"scene_id": scene_id, "task_type": task_type} # 保留元信息以备查验
        }
        return {"status": "success", "data": final_data, "item_id": unique_id}

    except Exception as e:
        item_id = benchmark_item.get('scene_id', 'unknown') + "_" + benchmark_item.get('task_type', 'unknown')
        logging.error(f"An unexpected error occurred while processing item {item_id}: {e}", exc_info=True)
        return {"status": "error", "reason": str(e), "item_id": item_id}


def main():
    parser = argparse.ArgumentParser(description="Generate SFT data with Chain-of-Thought from a benchmark file.")
    parser.add_argument(
        '--limit',
        type=int,
        default=None,
        help="Limit the number of benchmark questions to process (for testing)."
    )
    args = parser.parse_args()

    # 1. 加载所有benchmark数据
    all_benchmark_items = []
    with open(config.BENCHMARK_FILE_PATH, 'r', encoding='utf-8') as f:
        for line in f:
            all_benchmark_items.append(json.loads(line))
    
    if args.limit:
        all_benchmark_items = all_benchmark_items[:args.limit]

    # 2. 断点续传：获取已处理的ID，并创建任务队列
    processed_ids = get_processed_ids(config.SFT_OUTPUT_PATH)
    task_queue = [item for item in all_benchmark_items if f"{item['scene_id']}_{item['task_type']}" not in processed_ids]

    if not task_queue:
        logging.info("All items from the benchmark file have already been processed. Nothing to do.")
        return

    logging.info(f"Total items to process in this run: {len(task_queue)}")
    
    # 3. 使用追加模式打开文件，并开始并行处理
    with open(config.SFT_OUTPUT_PATH, 'a', encoding='utf-8') as outfile:
        successful_generations = 0
        
        while task_queue:
            tasks_to_retry = []
            logging.info(f"--- Starting generation batch of {len(task_queue)} tasks ---")
            
            with ThreadPoolExecutor(max_workers=config.MAX_WORKERS) as executor:
                future_to_item = {executor.submit(process_single_item, item): item for item in task_queue}
                
                progress = tqdm(as_completed(future_to_item), total=len(task_queue), desc="Generating SFT Data")
                for future in progress:
                    result = future.result()
                    
                    if result['status'] == 'success':
                        outfile.write(json.dumps(result['data'], ensure_ascii=False) + '\n')
                        successful_generations += 1
                        progress.set_postfix({"Succeeded": successful_generations, "Failed": len(tasks_to_retry)})
                    else:
                        logging.warning(f"Item {result['item_id']} failed: {result['reason']}")
                        # 找到失败的原始条目并加入重试队列
                        original_item = future_to_item[future]
                        tasks_to_retry.append(original_item)

            task_queue = tasks_to_retry
            if task_queue:
                logging.warning(f"{len(task_queue)} tasks failed and will be retried in the next batch.")
                time.sleep(5)

    logging.info("\n" + "="*50)
    logging.info("🎉 SFT Data Generation Completed! 🎉")
    logging.info(f"Successfully generated {successful_generations} new SFT data pairs in this run.")
    logging.info(f"All data saved to: {config.SFT_OUTPUT_PATH}")
    logging.info("="*50)

if __name__ == "__main__":
    main()