from src.dataset.base_dataset import BaseDataset
from src.dataset.peliqan_dataset import PELIQANDataset


__all__ = [
    "dataset_factory",
    "BaseDataset",
]


def dataset_factory(dataset_name: str, data_path: str) -> BaseDataset:
    if dataset_name == "peliqan":
        return PELIQANDataset(qa_data_path=data_path)
    else:
        raise ValueError(f"Unknown dataset_name: {dataset_name}")
