from .classification_dataset import classificationdatasets


def get_dataloader(task, imbalance_type, imbalance_ratio):

    if task == 'agnews':
        train_dataset, test_dataset = classificationdatasets(task=task, imbalance_type=imbalance_type, imbalance_ratio=imbalance_ratio)

    return train_dataset, test_dataset
    