from src.data.config import DatasetCfg, DataLoaderCfg


def get_datamodule(
    dataset_cfg: DatasetCfg,
    train_loader_cfg: DataLoaderCfg,
    val_loader_cfg: DataLoaderCfg,
    test_loader_cfg: DataLoaderCfg,
):
    if dataset_cfg.name == "scannet":
        from src.data.datamodules.scannet_datamodule import ScanNetDataModule

        return ScanNetDataModule(
            train_loader_cfg=train_loader_cfg,
            val_loader_cfg=val_loader_cfg,
            test_loader_cfg=test_loader_cfg,
            dataset_cfg=dataset_cfg,
        )
    if dataset_cfg.name == "scanrefer":
        from src.data.datamodules.scanrefer_datamodule import ScanReferDataModule

        return ScanReferDataModule(
            train_loader_cfg=train_loader_cfg,
            val_loader_cfg=val_loader_cfg,
            test_loader_cfg=test_loader_cfg,
            dataset_cfg=dataset_cfg,
        )
    else:
        raise NotImplementedError(
            f"Dataset {dataset_cfg.name} not implemented. Please implement it in src/data/datamodules."
        )
