import argparse
import multiprocessing
import importlib.util
import os
import sys

if __name__ == '__main__':
    multiprocessing.set_start_method('spawn')
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--max_tokens",
        type=int,
        default=8192
    )
    parser.add_argument(
        "--agg_max_tokens",
        type=int,
        default=32768
    )
    parser.add_argument(
        "--rm_model",
        type=str,
        default="Qwen2.5-Math-PRM-7B",
    )
    parser.add_argument(
        "--model_list",
        type=str,
        default="7_large",
    )
    parser.add_argument(
        "--model",
        type=str,
        default="Qwen2.5-7B-Instruct",
    )
    parser.add_argument(
        "--sc_posi",
        type=str,
        default="agg",
    )
    parser.add_argument(
        "--use_sc",
        action='store_true',
    )
    parser.add_argument(
        "--use_rm",
        action='store_true',
    )
    parser.add_argument(
        "--N",
        type=int,
        default=8
    )
    # prior_x (remain x)
    parser.add_argument(
        "--ref_sample",
        type=str,
        default='all'
    )
    # [raw_moa, rag_moa_x]
    parser.add_argument(
        "--mode",
        type=str,
        default='raw_moa'
    )
    #
    parser.add_argument(
        "--ppl_coef",
        type=float,
        default=0.0
    )
    parser.add_argument(
        "--mor_batch",
        type=int,
        default=1
    )
    parser.add_argument(
        "--ref_token_cut",
        action='store_true',
    )
    parser.add_argument(
        "--ref_clean_think",
        action='store_true',
    )
    parser.add_argument(
        "--question_bank",
        type=str,
        default='8d_32k'
    )
    parser.add_argument(
        "--dataset",
        choices=['MMLU-PRO', 'AIME', 'GPQA', 'IFEval', 'LiveCodeBench', 'MATH', 'MBPP', 'MedMCQA',"human-eval"],
        type=str,
        default='MMLU-PRO'
    )
    parser.add_argument(
        "--exp_suffix",
        type=str,
        default=''
    )
    parser.add_argument(
        "--rm_qb_path",
        type=str,
        default='/fs-computility/mabasic/songzhende/moa_greedy_search/union_7d_rm_acc_bank_8k.json'
    )
    parser.add_argument(
        "--weight_method",
        type=str,
        default='nav_sum',
        help='nav_sum,linear,softem'
    )
    parser.add_argument(
        "--score_method",
        type=str,
        default='similarity',
        help='similarity,average'
    )
    parser.add_argument(
        "--temperature",
        type=float,
        default=100,
        help='Any value can work, suggest 0-2'
    )
    parser.add_argument(
        "--max_process",
        type=int,
        default=8
    )
    parser.add_argument(
        "--rtk",
        type=str,
        default='r',
        help='r/nr'
    )
    parser.add_argument(
        "--rm_model_list",
        nargs='+', type=str,
        default=[ "Qwen2.5-Math-PRM-7B",
            "Qwen2.5-Math-RM-72B",
            "Skywork-27B-Reward-Models",
            "LDL-Reward-Gemma-2-27B",
            "QRM-Gemma-2-27B",
            "Skywork-Reward-V2-Llama-3.1-8B-40M",
            "INF-ORM-Llama3.1-70B",
            ]
    )
    parser.add_argument(
        "--qb_lis",
        nargs='+', type=str,
        default=['MATH','AIME'
            ]
    )

    
    args = parser.parse_args()
    # first import the lib in the dataset
    sys.path = [f'./{args.dataset}'] + sys.path
    # sys.path.append(f'./{args.dataset}')
    dataset_name = args.dataset if args.dataset != 'MMLU-PRO' else 'MMLU'
    dataset_name = args.dataset if args.dataset != 'human-eval' else 'humaneval'
    dataset_functions = importlib.import_module(f"{args.dataset}.run_exp")
    if dataset_functions.data_path is not None:
        # absolute path
        if dataset_functions.data_path.startswith('/'):
            data_path = dataset_functions.data_path
        # relative path
        else:
            data_path = os.path.join(args.dataset, dataset_functions.data_path)
    else:
        data_path = dataset_functions.data_path
    model_List_map = dataset_functions.model_List_map
    task_class = getattr(dataset_functions, f'Moa{dataset_name}')
    model = args.model
    rm_model = args.rm_model
    max_tokens = args.max_tokens
    agg_max_tokens = args.agg_max_tokens
    model_list_str = args.model_list
    use_rm = args.use_rm
    use_sc = args.use_sc
    sc_posi = args.sc_posi
    N = args.N
    ppl_coef = args.ppl_coef
    ref_sample = args.ref_sample
    ref_token_cut = args.ref_token_cut
    ref_clean_think = args.ref_clean_think
    question_bank = args.question_bank
    exp_suffix = args.exp_suffix
    mode = args.mode
    rm_qb_path=args.rm_qb_path
    rm_model_list=args.rm_model_list
    task = task_class(model, model_list_str, model_List_map, data_path, max_tokens, mode=mode, use_sc=use_sc,
                 use_rm=use_rm, rm_model=rm_model, N=N, max_process=args.max_process, sc_posi=sc_posi, ref_sample=ref_sample,
                      ppl_coef=ppl_coef, agg_max_tokens=agg_max_tokens, ref_token_cut=ref_token_cut, ref_clean_think=ref_clean_think,
                      question_bank=question_bank, dataset=args.dataset, exp_suffix=exp_suffix,rm_qb_path=rm_qb_path,rm_model_list=rm_model_list,mor_batch=args.mor_batch,weight_method=args.weight_method,score_method=args.score_method,temperature=args.temperature,rtk=args.rtk,qb_lis=args.qb_lis)
    task.run()