from typing import Dict, Type

from .base_dataset import BaseDataset
from .humaneval import HumanEvalDataset
from .gsm8k import Gsm8kDataset
from .mbpp import MBPPDataset
from .mmlu import MMLUDataset
from .arc_c import ARCCDataset
from .hellaswag import HellaSwagDataset
from .math import MATHDataset
from .arc_e import ARCEDataset
from .gpqa import GPQADataset


DATASET_REGISTRY: Dict[str, Type[BaseDataset]] = {
    "humaneval": HumanEvalDataset,
    "gsm8k": Gsm8kDataset,
    "mbpp": MBPPDataset,
    "mmlu": MMLUDataset,
    "arc_c": ARCCDataset,
    "arc_e": ARCEDataset,
    "hellaswag": HellaSwagDataset,
    "math": MATHDataset,
    "gpqa": GPQADataset,
}


def get_dataset(name: str) -> Type[BaseDataset]:
    name = name.lower()
    if name not in DATASET_REGISTRY:
        raise KeyError(f"Unknown dataset: {name}")
    return DATASET_REGISTRY[name]