from bbbc021datamodule import BBBC021DataModule
from dataset import CellDataModule
from rohbandatamodule import RohbanDataModule


def setup_data(cfg):
    if cfg.dataset == "bbbc021":
        datamodule = BBBC021DataModule(cfg)
    elif cfg.dataset == "rxrx1":
        datamodule = CellDataModule(cfg)
    elif cfg.dataset == "rohban":
        datamodule = RohbanDataModule(cfg)
    else:
        raise ValueError(f"Unsupported dataset: {cfg.dataset}")
    
    train_dataloader = datamodule.get_train_loader()
    return datamodule, train_dataloader
