"""
An interface to registry the datasets.

The samples of each dataset within the trlm are made in a consistent format.

Note that at the end of sample `question`, we add the solution flag prompt defined in `tmrl.identifier.py`. This is to prompt the model to add the final
solution within a specific identifier for the simple extraction.
"""

import logging
from trlm.dataset import (
    gsm8k,
    math,
    aime2024,
    aime19832024,
    theoremqa,
    mmmu,
    scienceqa,
    humaneval,
    codealpaca,
)

data_factory = {
    "gsm8k": gsm8k.GSM8KDataset,
    "math": math.MATHDataset,
    "mmmu": mmmu.MMMUDataset,
    "scienceqa": scienceqa.ScienceQADataset,
    "aime2024": aime2024.AIME2024Dataset,
    "aime19832024": aime19832024.AIME19832024Dataset,
    "humaneval": humaneval.HumanEvalDataset,
    "codealpaca": codealpaca.CodeAlpacaDataset,
    "theoremqa": theoremqa.TheoremQADataset,
}


hf_datasets = {
    "gsm8k": "openai/gsm8k",
    "math": "DigitalLearningGmbH/MATH-lighteval",
    "mmmu": "lmms-lab/MMMU",
    "scienceqa": "lmms-lab/ScienceQA",
    "aime2024": "HuggingFaceH4/aime_2024",
    "aime19832024": "di-zhang-fdu/AIME_1983_2024",
    "humaneval": "openai/openai_humaneval",
    "codealpaca": "sahil2801/CodeAlpaca-20k",
    "theoremqa": "TIGER-Lab/TheoremQA",
}


def get(config: dict, split="train"):
    """Get the dataset."""

    data_name = config["data_name"].lower()
    hf_dataname = hf_datasets[data_name]
    logging.info(
        "---> Logging %s data from %s dataset linked to HF %s",
        split,
        data_name,
        hf_dataname,
    )
    dataset = data_factory[data_name](
        split=split, hf_dataname=hf_dataname, config=config
    )
    logging.info("   - Obtained %s samples", len(dataset))

    return dataset
