from dataset.MUTAG_dataset import Mutagenicity
from dataset.BA3_dataset import BA3Motif
from dataset.fluoride_carbonyl_dataset import FluorideCarbonyl
from dataset.mnistsp_dataset import MNIST75sp

def get_dataset(data_root_path, dataset):
    if dataset == 'BA3':
        train_dataset = BA3Motif(data_root_path, mode="training")
        val_dataset = BA3Motif(data_root_path, mode="evaluation")
        test_dataset = BA3Motif(data_root_path, mode="testing")        
        num_cls = 3
        
    elif dataset == 'MUTAG':
        train_dataset = Mutagenicity(data_root_path, target="explainer", mode="training")
        val_dataset = Mutagenicity(data_root_path, target="explainer", mode="evaluation")
        test_dataset = Mutagenicity(data_root_path, target="explainer", mode="testing")        
        num_cls = 2

    elif dataset == 'FC':
        train_dataset = FluorideCarbonyl(data_root_path, mode="training")
        val_dataset = FluorideCarbonyl(data_root_path, mode="evaluation")
        test_dataset = FluorideCarbonyl(data_root_path, mode="testing")        
        num_cls = 2
        
    elif dataset == 'MNIST':
        train_dataset = MNIST75sp(data_root_path, mode="training")
        val_dataset = MNIST75sp(data_root_path, mode="evaluation")
        test_dataset = MNIST75sp(data_root_path, mode="testing")        
        num_cls = 2
        
    return train_dataset, val_dataset, test_dataset, num_cls