from .polymer import PolymerRegDataset
from .networkx import GenericGraphFromNetworkx
# from .ogbg import PygGraphPropPredDataset
from ogb.graphproppred import PygGraphPropPredDataset
from .unlabeled import UnlabelGraphDataset, UnlabelPPI

def get_dataset(args, load_path, load_unlabeled_name="None"):
    if load_unlabeled_name=='None':
        if args.dataset.startswith('plym'):
            return PolymerRegDataset(args.dataset, load_path)
        elif args.dataset.startswith('ogbg'):
            return PygGraphPropPredDataset(args.dataset, load_path)
        elif args.dataset.startswith('nx'):
            return GenericGraphFromNetworkx(args.dataset, load_path)
    elif 'PPI' in load_unlabeled_name:
        return UnlabelPPI(load_unlabeled_name, load_path)
    elif 'QM9' in load_unlabeled_name:
        return UnlabelGraphDataset(load_unlabeled_name, load_path)
    else:
        raise ValueError('Unlabeled dataset {} not supported'.format(load_unlabeled_name))
