from generator.criteria_meta_reward import *
from generator.rule_filter import *
from generator.meta_reward import *
from generator.exploration import *
from evaluator.pairwise_evaluator import *

from core.router import *
import argparse
import torch
import random
import os
import numpy as np


def parse_args():

    parser = argparse.ArgumentParser(
        description="🚀 GRM-Omni Inference",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter
    )

    # ========== 📂 输入输出 ==========
    # Features: 
    io_group = parser.add_argument_group("📂 Input / Output")
    io_group.add_argument("--input_file", type=str,
                          help="input file")
    io_group.add_argument("--dpo_pool_file", type=str, default= None, help="the data file for dpo data curation.")
    io_group.add_argument("--output_dir", type=str, required=True,
                          help="Directory to store inference results")
    io_group.add_argument("--max_input_size", type=int, default=-1,
                          help="the max input size for inference. set -1 means infinite!")
    # ========== Benchmark ==========
    bench_group = parser.add_argument_group("Evaluation Benchmak")
    bench_group.add_argument(
        "--benchmark",
        type=str,
        choices=[
            "rewardbench", "rewardbench_v2", "rmb", "vl_rewardbench",
            "multimodal_rewardbench", "genai_image",
            "genai_video", "videogen_bench", "audio_bench", "ppe_bench"
        ],
        default="rewardbench",
        help="Select benchmark dataset"
    )
    bench_group.add_argument(
        "--benchmark_dir",
        type=str,
        help="Path to custom benchmark data directory"
    )
    bench_group.add_argument("--suffix", type=str, default="",
                            help="Optional suffix to append to output files")
    bench_group.add_argument(
        "--modality",
        type=str,
        choices=["language", "image", "audio", "video", "omni"],
        default="language",
        help=(
            "Modality for evaluation:\n"
            "  language : text-only\n"
            "  image    : image-only\n"
            "  audio    : audio-only\n"
            "  video    : video-only\n"
            "  omni     : omni-modal (multi-modal)"
        )
    )
    # ========== ⚙️ 推理模式 ==========
    method_group = parser.add_argument_group("⚙️ Inference Mode")
    method_group.add_argument(
        "--method",
        type=str,
        choices=["criteria_meta_reward", "meta_reward", "rule_filter", "pairwise_judge", "pointwise_judge", "exploration", "criteria_exploration"],
        help=(
            "Choose inference method\n"
        )
    )
    method_group.add_argument(
        "--manner",
        type=str,
        choices=["direct", "stepwise", "criteria_n"]
    )
    # ========== 🤖 模型路径 ==========
    model_group = parser.add_argument_group("🤖 Model Paths")
    model_group.add_argument("--ranking_model", type=str,
                             help="Path to the model used for scoring")
    model_group.add_argument(
        "--ranking_model_modality",
        type=str,
        choices=["language", "image", "audio", "video", "omni", "vision"],
        default="language",
        help=(
            "Choose the scoring model modality."
        )
    )
    model_group.add_argument("--refinement_model", type=str,
                             help="Path to the model used for refine response")
    model_group.add_argument(
        "--refinement_model_modality",
        type=str,
        choices=["language", "image", "audio", "video", "omni", "vision"],
        default="language",
        help=(
            "Choose the refinement model modality."
        )
    )
    model_group.add_argument("--inference_model", type=str, required=True,help="Path to the model used for inference")
    model_group.add_argument(
        "--inference_model_modality",
        type=str,
        choices=["language", "image", "audio", "video", "omni", "vision"],
        default="language",
        help=(
            "Choose the inference model modality."
        )
    )

    # ========== 🚀 推理配置 ==========
    infer_group = parser.add_argument_group("🚀 Inference Settings")
    infer_group.add_argument("--workers", type=int, default=8,
                             help="Number of worker nodes for Ray-based parallel inference")
    infer_group.add_argument("--tensor_parallel", type=int, default=1,
                             help="Number of tensor for parallel inference")
    infer_group.add_argument("--seed", type=int, default=0)
    
    # ========== 🔧 采样参数 ==========
    sample_group = parser.add_argument_group("🔧 Sampling Parameters")
    sample_group.add_argument("--batch_size", type=int, default=256)
    sample_group.add_argument("--temperature", type=float, default=0.7,
                              help="Sampling temperature")
    sample_group.add_argument("--top_p", type=float, default=0.8,
                              help="Top-p sampling parameter")
    sample_group.add_argument("--criteria_n", type=int, default=10)
    sample_group.add_argument("--criteria_step", type=int, default=3)
    sample_group.add_argument("--sampling_n", type=int, default=1)

    args = parser.parse_args()
    # 🌟 参数依赖检查：如果选择 criteria/correction，必须提供 score_model
    if args.method in ["refinement"] and not args.ranking_model:
        parser.error("--ranking_model is required when mode is refinement.")

    return args

def set_seed(seed: int=42):

    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    print(f"Random seed set as {seed}")


if __name__ == '__main__':

    args = parse_args()
    set_seed(args.seed)
    if args.method == "criteria_meta_reward":
        generator = CriteriaMetaRewardGenerator(args=args)
    elif args.method == "meta_reward":
        generator = MetaRewardGenerator(args=args)
    elif args.method == "rule_filter":
        generator = RuleFilter(args=args)
    elif args.method == "exploration":
        generator = ExplorationSampler(args=args)
    elif args.method == "criteria_exploration":
        generator = ExplorationSampler(args=args)
    elif args.method == "pairwise_judge":
        generator = PairwiseEvaluator(args=args)
    else:
        raise Exception("Not Support This Mode = {}".format(args.mode))
    
    router = Router(args=args)
    generator.run(router)
