import json
import random
import os
import shutil
import argparse
import logging
from tqdm import tqdm
# --- MODIFICATION START: 引入并行处理库 ---
from concurrent.futures import ProcessPoolExecutor
# --- MODIFICATION END ---

import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import config_rl as config

logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

# --- MODIFICATION START: 将单行处理逻辑封装为独立的worker函数 ---
# 这个函数必须在顶层定义，以便子进程可以访问
def shuffle_single_line(line: str) -> str:
    """
    Worker function to process a single line from the JSONL file.
    It shuffles options and updates the answer key.
    """
    try:
        data = json.loads(line)
        
        if 'question_data' not in data or 'options' not in data['question_data'] or 'answer' not in data['question_data']:
            logging.warning("跳过格式不符的行...")
            return line.strip()

        question_data = data['question_data']
        options = question_data['options']
        original_answer_key = question_data['answer']
        
        correct_answer_text = options[original_answer_key]
        
        option_values = list(options.values())
        random.shuffle(option_values)
        
        new_options = {}
        new_answer_key = None
        option_keys = ['A', 'B', 'C', 'D']
        
        for i, key in enumerate(option_keys):
            new_options[key] = option_values[i]
            if option_values[i] == correct_answer_text:
                new_answer_key = key
        
        data['question_data']['options'] = new_options
        data['question_data']['answer'] = new_answer_key
        
        return json.dumps(data, ensure_ascii=False)

    except (json.JSONDecodeError, KeyError) as e:
        logging.warning(f"处理行时出错: {e}. 保留原始行。")
        return line.strip()
# --- MODIFICATION END ---


def shuffle_options_in_file(filepath: str, max_workers: int):
    """
    Shuffles the options in a .jsonl file in parallel, backs up the original,
    and updates the answer key accordingly.
    """
    if not os.path.exists(filepath):
        logging.error(f"输入文件未找到: {filepath}")
        return

    backup_path = filepath + '.bak'
    try:
        shutil.copy(filepath, backup_path)
        logging.info(f"原始文件已成功备份到: {backup_path}")
    except Exception as e:
        logging.error(f"创建备份文件失败: {e}")
        return

    try:
        # 1. 一次性将所有行读入内存
        logging.info("Reading all lines from the source file into memory...")
        with open(filepath, 'r', encoding='utf-8') as f:
            lines_to_process = f.readlines()

        if not lines_to_process:
            logging.warning("文件为空，无需处理。")
            return

        processed_lines = []
        
        # --- MODIFICATION START: 使用ProcessPoolExecutor进行并行处理 ---
        logging.info(f"Starting parallel shuffling with up to {max_workers} worker processes...")
        with ProcessPoolExecutor(max_workers=max_workers) as executor:
            # executor.map 会保持原始顺序
            # 使用tqdm显示进度
            results_iterator = executor.map(shuffle_single_line, lines_to_process)
            processed_lines = list(tqdm(results_iterator, total=len(lines_to_process), desc="Shuffling Options"))
        # --- MODIFICATION END ---
        
        # 3. 将处理后的内容写回原始文件
        logging.info("Writing shuffled data back to the file...")
        with open(filepath, 'w', encoding='utf-8') as f:
            for line in processed_lines:
                f.write(line + '\n')
        
        logging.info("🎉 选项随机化处理完成！")
        logging.info(f"文件 {filepath} 已被更新。")

    except Exception as e:
        logging.error(f"处理过程中发生严重错误: {e}")
        logging.info(f"您可以从备份文件 {backup_path} 中恢复原始数据。")


if __name__ == "__main__":
    # --- MODIFICATION START: 增加命令行参数以控制并行worker数量 ---
    parser = argparse.ArgumentParser(
        description="Shuffle the multiple-choice options in the benchmark JSONL file in parallel."
    )
    parser.add_argument(
        '--filepath',
        type=str,
        default=config.BENCHMARK_FILE_PATH,
        help="Path to the benchmark.jsonl file to process."
    )
    parser.add_argument(
        '-w', '--workers',
        type=int,
        default=os.cpu_count(), # 默认使用所有可用的CPU核心
        help="Number of worker processes to use for parallel execution."
    )
    args = parser.parse_args()
    # --- MODIFICATION END ---
    
    shuffle_options_in_file(args.filepath, args.workers)