from .mnist_usps import MNIST_USPS_Dataset
from .mnist_invert import MNIST_INVERT_Dataset
from .har import HAR_Dataset
from .adult import ADULT_Dataset
from .bank import BANK_Dataset
from .credit import CREDIT_Dataset

def load_dataset(dataset_name, data_path):
    """Loads the dataset."""

    implemented_datasets = ('credit', 'adult', 'har', 'bank', 'mnist_usps', 'mnist_invert')
    assert dataset_name in implemented_datasets

    dataset = None
    if dataset_name == 'credit':
        dataset = CREDIT_Dataset(root=data_path)
    
    if dataset_name == 'bank':
        dataset = BANK_Dataset(root=data_path)
    
    if dataset_name == 'adult':
        dataset = ADULT_Dataset(root=data_path)
    
    if dataset_name == 'har':
        dataset = HAR_Dataset(root=data_path)
    
    if dataset_name == 'mnist_usps':
        dataset = MNIST_USPS_Dataset(root=data_path)

    if dataset_name == 'mnist_invert':
        dataset = MNIST_INVERT_Dataset(root=data_path)

    return dataset
