

from datasets import load_dataset, get_dataset_config_names


def get_dataset(eval_benchmark, TEST_OFFSET, SKIP_STEP):

    if eval_benchmark == 'mmluPro': 
        ###### Loading the MMLU-PRO dataset
        mmluPro_path = "/home/peizhengqi/TIGER-Lab/MMLU-Pro"
        mmluPro_configs = get_dataset_config_names(mmluPro_path)
        print(f'MMLUPro_configNames:{mmluPro_configs}')
        _DS = {_config: load_dataset(mmluPro_path, _config) for _config in mmluPro_configs}
        llm_testPromptIds = list(range(TEST_OFFSET, 12000, SKIP_STEP))
        llm_trainPromptIds = list(range(0, 70, 1))

    if eval_benchmark == 'gpqa-main': 
        ###### Loading the GPQA dataset
        gpqa_path = "/home/peizhengqi/Idavidrein/gpqa"
        gpqa_configs = get_dataset_config_names(gpqa_path)
        print(f'GPQA_configNames:{gpqa_configs}')
        _DS = {_config: load_dataset(gpqa_path, _config) for _config in gpqa_configs}
        llm_testPromptIds = list(range(TEST_OFFSET, 448, SKIP_STEP))
        llm_trainPromptIds = list(range(0, 540, 1))

    if eval_benchmark == 'gsm8k': 
        ###### Loading the GSM8K dataset
        gsm8k_path = "/home/peizhengqi/HF_datasets/gsm8k"
        gsm8k_configs = get_dataset_config_names(gsm8k_path)
        print(f'GSM8K_configNames:{gsm8k_configs}')
        _DS = {_config: load_dataset(gsm8k_path, _config) for _config in gsm8k_configs}
        llm_testPromptIds = list(range(TEST_OFFSET, 1300, SKIP_STEP))
        llm_trainPromptIds = list(range(0, 7470, 1))

    if eval_benchmark == 'math-500': 
        ##### Loading the MATH dataset
        math500_path = "/home/peizhengqi/HF_datasets/ankner/math-500"
        math500_configs = get_dataset_config_names(math500_path)
        print(f'MATH500_configNames:{math500_configs}')
        _DS = {_config: load_dataset(math500_path, _config) for _config in math500_configs}
        llm_testPromptIds = list(range(TEST_OFFSET, 500, SKIP_STEP))
        llm_trainPromptIds = list(range(0, 7500, 1))

    return _DS, llm_testPromptIds, llm_trainPromptIds
