import datasets



def get_dataset(name):
    # return a dataset with columns "problem" and "answer"
    if name == "aime-24":
        dataset = datasets.load_dataset("Maxwell-Jia/AIME_2024", split="train")
        dataset = dataset.rename_column("Problem", "problem").rename_column("Answer", "answer")
        dataset = dataset.remove_columns([col for col in dataset.column_names if col not in ["problem", "answer"]])
        return dataset
    elif name == "math-500":
        dataset = datasets.load_dataset("HuggingFaceH4/MATH-500", split="test")
        dataset = dataset.remove_columns([col for col in dataset.column_names if col not in ["problem", "answer"]])
        return dataset


if __name__ == "__main__":
    # data = get_dataset("math-500")
    data = get_dataset("our-numina")
    # data = get_dataset("aime-24")
    breakpoint()