from evaluator.lang_evaluator import *
from evaluator.image_evaluator import *
from evaluator.audio_evaluator import *
from evaluator.video_evaluator import *
from core.router import Router
import argparse


def parse_args():

    parser = argparse.ArgumentParser(
        description="🚀 GRM-Omni Evaluation",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter
    )

    # ========== 📂 输入输出 ==========
    io_group = parser.add_argument_group("Input / Output")
    io_group.add_argument("--output_dir", type=str, required=True,
                          help="Directory to save inference results")
    
    # ========== ⚙️ 推理模式 ==========
    mode_group = parser.add_argument_group("Evaluation Mode")
    mode_group.add_argument(
        "--benchmark",
        type=str,
        choices=[
            "rewardbench", "rewardbench_v2", "vl_rewardbench",
            "multimodal_rewardbench", "genaibench_image",
            "genai_video", "videogen_bench", "audiobench"
        ],
        default="rewardbench",
        help="Select benchmark dataset"
    )
    mode_group.add_argument(
        "--benchmark_dir",
        type=str,
        help="Path to custom benchmark data directory"
    )
    mode_group.add_argument("--suffix", type=str, default="",
                            help="Optional suffix to append to output files")
    
    mode_group.add_argument(
        "--modality",
        type=str,
        choices=["language", "image", "audio", "video", "omni"],
        default="language",
        help=(
            "Modality for evaluation:\n"
            "  language : text-only\n"
            "  image    : image-only\n"
            "  audio    : audio-only\n"
            "  video    : video-only\n"
            "  omni     : omni-modal (multi-modal)"
        )
    )
    mode_group.add_argument(
        "--mode",
        type=str,
        choices=["preference", "criteria_meta_reward", "meta_reward"],
        default="preference",
        help=(
            "Evaluation mode:\n"
            "  preference          : directly generate SFT output\n"
            "  criteria_meta_reward: evaluate with criteria-based reward\n"
            "  meta_reward         : evaluate with meta reward"
        )
    )

    # ========== 🤖 模型路径 ==========
    model_group = parser.add_argument_group("Model Paths")
    model_group.add_argument("--evaluation_model", type=str, required=True,
                             help="Path to evaluation model checkpoint")
    mode_group.add_argument(
        "--benchmark_modality",
        type=str,
        choices=["language", "image", "audio", "video", "omni"],
        default="language",
        help=(
            "Benchmark Modality for evaluation:\n"
            "  language : text-only\n"
            "  image    : image-only\n"
            "  audio    : audio-only\n"
            "  video    : video-only\n"
            "  omni     : omni-modal (multi-modal)"
        )
    )

    # ========== 🚀 推理配置 ==========
    infer_group = parser.add_argument_group("Inference Settings")
    infer_group.add_argument("--batch_size", type=int, default=32,
                             help="Batch size per GPU for inference")
    infer_group.add_argument("--workers", type=int, default=8,
                             help="Number of Ray workers for parallel inference")
    infer_group.add_argument("--max_tokens", type=int, default=8192,
                             help="Maximum tokens to generate per response")

    # ========== 🔧 采样参数 ==========
    sample_group = parser.add_argument_group("Sampling Parameters")
    sample_group.add_argument("--temperature", type=float, default=0.7,
                              help="Sampling temperature (higher = more random)")
    sample_group.add_argument("--top_p", type=float, default=0.8,
                              help="Top-p nucleus sampling threshold")
    sample_group.add_argument("--sampling_n", type=int, default=1,
                              help="Number of responses to sample; if >1, scaling is applied")
    sample_group.add_argument(
        "--scaling",
        type=str,
        choices=["ranking", "voting"],
        default="ranking",
        help="Scaling strategy when multiple samples are generated:\n"
             "  ranking : select best response by reward model\n"
             "  voting  : majority vote among candidates"
    )
    
    args = parser.parse_args()
    return args


if __name__ == '__main__':

    args = parse_args()
    
    if args.benchmark_modality == "language":
        evaluator = LanguageEvaluator(args=args)
    elif args.benchmark_modality == "image":
        evaluator = ImageEvaluator(args=args)
    elif args.benchmark_modality == "video":
        evaluator = VideoEvaluator(args=args)
    elif args.benchmark_modality == "audio":
        evaluator = AudioEvaluator(args=args)
    else:
        raise Exception("Not Support This Benchmark Modality = {}".format(args.benchmark_modality))
    
    router = Router(args=args)
    evaluator.eval(router)
