	# •	提供统一的 run_inference_on_dataset，可对任意数据集进行推理评测；
	# •	提供 aggregate_inference_results，可针对多个 epoch（checkpoint）下的 generated_predictions.jsonl 做批量测试。
 
 # scripts/python_version/inference_runner.py

import json
import logging
import os
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, List, Optional

import yaml
from config_utils import read_yaml_config, write_yaml_config
from dataset_utils import sort_checkpoint_folders
from tqdm import tqdm

logger = logging.getLogger(__name__)

def make_llmfac_inference_eachepoch_file(
    train_file: str,
    inference_file: str,
    folder_name: str,
    inference_type: str = "lora",
    llama_factory_root: Optional[str] = None,
    model_name: str = "qwen2_vl-7b",
    eval_datasets: List[str] = [],
    eval_batch_size: int = 32,
    use_vllm: bool = True,
    batch_size: int = 32768,
    hack_vllm_inference: bool = False,
    only_eval_last_ckp: bool = False,
) -> str:
    """
    为每个训练checkpoint生成推理脚本
    
    Args:
        train_file: 训练配置文件路径
        inference_file: 推理配置文件路径
        folder_name: 输出文件夹名称
        inference_type: 推理类型，默认为"lora"
        llama_factory_root: LLaMA Factory根目录，可选
    """
    try:
        # 基础路径设置
        current_date = datetime.now().strftime('%m%d')
        root_dir = Path(llama_factory_root or '/home/liyu/workspaces/llama-factory')
        base_output_dir = root_dir / f'saves/{model_name}/{inference_type}/sft/predict'

        # 读取配置文件
        train_config = read_yaml_config(train_file)
        # 记录日志
        logger.info(f"Read train config from {train_file}")
        
        if eval_datasets:
            inference_config = read_yaml_config(inference_file)
            inference_config['eval_dataset'] = ",".join(eval_datasets)
            inference_config['per_device_eval_batch_size'] = eval_batch_size
        else:
            inference_config = read_yaml_config(inference_file)

        # 训练输出目录
        train_output_dir = root_dir / train_config['output_dir']
        
        # 获取checkpoint目录
        
        checkpoint_dirs = sorted([
            d for d in os.listdir(train_output_dir) 
            if d.startswith("checkpoint-") and Path(train_output_dir / d).is_dir()
        ])
        
        checkpoint_dirs = sort_checkpoint_folders(checkpoint_dirs)
        
        # compute count of checkpoints
        if only_eval_last_ckp:
            checkpoint_dirs = checkpoint_dirs[-1:]

        # 验证checkpoint数量
        # if len(checkpoint_dirs) != train_config['num_train_epochs']:
        #     raise ValueError(
        #         f"Expected {train_config['num_train_epochs']} checkpoints, "
        #         f"but found {len(checkpoint_dirs)}"
        #     )

        commands = []
        output_folder = base_output_dir / folder_name
        
        os.makedirs(output_folder, exist_ok=True)
        checkpoint_count = len(checkpoint_dirs)
        
        for idx, checkpoint_dir in enumerate(checkpoint_dirs):
            # 创建输出目录
            (output_folder / checkpoint_dir).mkdir(parents=True, exist_ok=True)
            (output_folder / "inference_yamls").mkdir(parents=True, exist_ok=True)

            # 更新推理配置
            current_inference = inference_config.copy()
            if inference_type == "lora":
                current_inference['adapter_name_or_path'] = str(train_output_dir / checkpoint_dir)
            else:
                current_inference['model_name_or_path'] = str(train_output_dir / checkpoint_dir)
                current_inference.pop('adapter_name_or_path', None)

            
            current_inference.update({
                'output_dir': str(output_folder / checkpoint_dir),
                'finetuning_type': inference_type
            })
            datasets_str = "+".join(eval_datasets)
            # 写入推理配置文件
            yaml_path = output_folder / "inference_yamls" / f"{checkpoint_dir}.yaml"
            write_yaml_config(current_inference, yaml_path)

            # 生成推理命令
            log_file = output_folder / checkpoint_dir / f"{current_date}_inference.log"
            redirect = '>' if not commands else '>'
            command_prefix = 'DISABLE_VERSION_CHECK=1'
            vllm_script = 'scripts/vllm_batch_infer.py'
            vllm_command = f"python {vllm_script} --model_name_or_path {current_inference['model_name_or_path']}  --dataset {','.join(eval_datasets)} --save_name {os.path.join(current_inference['output_dir'],'generated_predictions.jsonl')} --template qwen2_vl --is_sampled True --batch_size {batch_size}"
            if use_vllm:
                commands.extend([
                # f"{command_prefix} nohup llamafactory-cli train {yaml_path} {redirect} {log_file} 2>&1 &",
                f"{command_prefix} nohup {vllm_command} {redirect} {log_file} 2>&1 &",
                "wait $!",
                f"echo Epoch {idx} of {checkpoint_count} complete",
                    f"echo {checkpoint_dir} to {output_folder / 'completed_checkpoints.txt'}",
                    f"echo '======================================'"
                    f'echo "datasets: {datasets_str}"'
                ])
            else:
                commands.extend([
                f"{command_prefix} nohup llamafactory-cli train {yaml_path} {redirect} {log_file} 2>&1 &",
                "wait $!",
                f"echo Epoch {idx} of {checkpoint_count} complete",
                    f"echo {checkpoint_dir} to {output_folder / 'completed_checkpoints.txt'}",
                    f"echo '======================================'"
                    f'echo "datasets: {datasets_str}"'
                ])

        # 写入shell脚本
        script_path = output_folder / "inference_eachepoch.sh"
        script_path.write_text('\n'.join(commands), encoding='utf-8')

        # 打印使用说明
        print(f"Execute script with: bash {script_path}")
        # print(f"\nWatch logs at: {log_file}")

        return str(script_path)

    except Exception as e:
        logger.error(f"Failed to generate inference scripts: {e}")
        raise
