
from .tabular import TabularDataset

def load_dataset(dataset_name, data_path, missing_rate, split=None, mechanism='mcar'):
    """Loads the dataset."""

    implemented_datasets = ('adult', 'kdd')
    assert dataset_name in implemented_datasets

    dataset = None

    if dataset_name is 'adult':
        dataset = TabularDataset(root=data_path, dataset_name='Adult', missing_rate=missing_rate, split=split, mechanism=mechanism)
    
    if dataset_name is 'kdd':
        dataset = TabularDataset(root=data_path, dataset_name='KDD', missing_rate=missing_rate, split=split, mechanism=mechanism)

    return dataset

