"""
Data API
"""
from datasets import load_dataset, load_from_disk, concatenate_datasets
from prompts import gsm8k_prompt, asdiv_aug_prompt, math_500_prompt, aime_prompt

def get_dataset(data_name_or_path, tokenizer, prompt_idx):
    """
    Args:
        data_name_or_path: dataset name or path
        tokenizer: tokenizer
        prompt_idx: which query prompt to use
    Returns:
        dataset: dataset
    """

    ### Load dataset ### 
    if "gsm8k" in data_name_or_path.lower():
        try:
            dataset = load_from_disk(data_name_or_path)['test']
        except:
            dataset = load_dataset("openai/gsm8k", "socratic")["test"]
        question_col = "question"
        answer_col = "answer"
    
    elif "asdiv-aug" in data_name_or_path.lower():
        try:
            dataset = load_from_disk(data_name_or_path)['test']
        except:
            dataset = load_dataset("xuyige/ASDiv-Aug")["test"]
        question_col = "question"
        answer_col = "answer"

    elif "math-500" in data_name_or_path.lower():
        try:
            dataset = load_from_disk(data_name_or_path)['test']
        except:
            dataset = load_dataset("HuggingFaceH4/MATH-500")["test"]
        question_col = "problem"
        answer_col = "answer"

    elif "aime_2024" in data_name_or_path.lower():
        try:
            dataset = load_from_disk(data_name_or_path)
        except:
            dataset = load_dataset("Maxwell-Jia/AIME_2024")['train']
        question_col = "Problem"
        answer_col = "Answer"
    
    elif "aime2025" in data_name_or_path.lower():
        try:
            dataset = load_from_disk(data_name_or_path)
        except:
            dataset = concatenate_datasets([
                load_dataset("opencompass/AIME2025", "AIME2025-I")['test'],
                load_dataset("opencompass/AIME2025", "AIME2025-II")['test'],
            ])
        question_col = "question"
        answer_col = "answer"

    else:
        raise ValueError(f"Unsupported dataset: {data_name_or_path}")

    # preprocess dataset
    def preprocess_function(examples):
        '''
        Preprocess dataset

        Args:
            examples: dataset examples

        Returns:
            formatted: formatted dataset
        '''
        prompt = []
        formatted = []
        answers = examples[answer_col]
        questions = examples[question_col]
        for q in questions:
            if "gsm8k" in data_name_or_path.lower():
                messages = gsm8k_prompt(q, prompt_idx)
            elif "asdiv-aug" in data_name_or_path.lower():
                messages = asdiv_aug_prompt(q, prompt_idx)
            elif "math-500" in data_name_or_path.lower():
                messages = math_500_prompt(q, prompt_idx)
            elif "aime" in data_name_or_path.lower():
                messages = aime_prompt(q, prompt_idx)
            else:
                raise ValueError(f"Unsupported dataset: {data_name_or_path}")

            prompt.append(messages)
            formatted.append(tokenizer.apply_chat_template(
                messages, tokenize=False, add_generation_prompt=True
            ))
        if "aime" in data_name_or_path.lower() and "2025" in data_name_or_path.lower():
            answers = [ans.replace('^\circ', '') for ans in answers]
        return {
            "prompt": prompt,
            "formatted": formatted,
            "question": questions,
            "answer": answers,
        }

    dataset = dataset.map(preprocess_function, batched=True, load_from_cache_file=False)
    return dataset
