#!/usr/bin/env python3
"""
简单的运行脚本，用于启动reward模型评分
"""

import sys
import os
import argparse
import subprocess
import logging
from reward_scorer import RewardModelScorer

# 为主调度脚本设置日志
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[
        logging.FileHandler(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'scoring.log')),
        logging.StreamHandler()
    ]
)
logger = logging.getLogger(__name__)

def main():
    parser = argparse.ArgumentParser(description="Reward模型评分工具")
    parser.add_argument("--model", type=str, help="指定单个模型进行评分")
    parser.add_argument("--list-models", action="store_true", help="列出所有可用模型")
    
    args = parser.parse_args()
    
    scorer = RewardModelScorer()
    
    if args.list_models:
        print("可用的reward模型:")
        for i, model in enumerate(scorer.available_models, 1):
            print(f"{i}. {model}")
        return
    
    if args.model:
        # 这部分由子进程执行
        if args.model not in scorer.available_models:
            print(f"错误: 模型 '{args.model}' 不存在")
            print("可用模型:", scorer.available_models)
            sys.exit(1)
        
        # scorer内部已经配置了日志，这里不需要额外配置
        qa_data = scorer.load_qa_data()
        qa_pairs = scorer.extract_qa_pairs(qa_data)
        results = scorer.score_with_model(args.model, qa_pairs)
        scorer.save_results(results)
    else:
        # 主进程作为调度器，为每个模型创建一个独立的子进程
        logger.info("开始以独立进程方式执行所有模型的评分任务...")
        script_path = os.path.abspath(__file__)
        
        for i, model_name in enumerate(scorer.available_models, 1):
            
            # 检查结果文件是否已存在
            result_file_path = os.path.join(scorer.output_dir, f"{model_name}.json")
            if os.path.exists(result_file_path):
                logger.info(f"结果文件 '{result_file_path}' 已存在，跳过模型 {model_name}")
                continue
                
            logger.info(f"\n{'='*50}")
            logger.info(f"调度任务: 模型 {i}/{len(scorer.available_models)}: {model_name}")
            logger.info(f"{'='*50}")
            
            try:
                # 使用 sys.executable 确保使用相同的Python解释器
                process = subprocess.run(
                    [sys.executable, script_path, "--model", model_name],
                    check=True,
                    capture_output=True,
                    text=True,
                    encoding='utf-8'
                )
                logger.info(f"模型 {model_name} 子进程成功完成。")
                if process.stdout:
                    logger.info(f"--- 子进程 STDOUT ---\n{process.stdout}\n--- END STDOUT ---")
                if process.stderr:
                    logger.warning(f"--- 子进程 STDERR ---\n{process.stderr}\n--- END STDERR ---")

            except subprocess.CalledProcessError as e:
                logger.error(f"模型 {model_name} 评分失败，子进程返回错误码 {e.returncode}")
                logger.error(f"--- 子进程 STDOUT ---\n{e.stdout}\n--- END STDOUT ---")
                logger.error(f"--- 子进程 STDERR ---\n{e.stderr}\n--- END STDERR ---")

        logger.info(f"\n{'='*50}")
        logger.info("所有模型评分任务调度完成！")
        logger.info(f"{'='*50}")

if __name__ == "__main__":
    main()
