
from .tabular import TabularDataset
from .image import CelebADataset

def load_dataset(dataset_name, data_path, balanced, fair_c_ratio=None):
    """Loads the dataset."""

    implemented_datasets = ('adult', 'compas', 'celeba', 'titanic', 'sp', 'credit')
    assert dataset_name in implemented_datasets

    dataset = None

    if dataset_name == 'adult':
        dataset = TabularDataset(root=data_path, dataset_name='Adult', balanced=balanced, fair_c_ratio=fair_c_ratio)

    elif dataset_name == 'titanic':
        dataset = TabularDataset(root=data_path, dataset_name='Titanic', balanced=balanced)
    
    elif dataset_name == 'sp':
         dataset = TabularDataset(root=data_path, dataset_name='Student_Performance', balanced=balanced)

    elif dataset_name == 'compas':
        dataset = TabularDataset(root=data_path, dataset_name='Compas', balanced=balanced, fair_c_ratio=fair_c_ratio)
    
    elif dataset_name == 'credit':
        dataset = TabularDataset(root=data_path, dataset_name='Credit', balanced=balanced)
    
    elif dataset_name == 'celeba':
        dataset = CelebADataset(balanced=balanced)



    return dataset

