from .imagenet import ImageNetDataset, ImageNetOntology, ImageNetAnnotations
from .broden import BrodenDataset, BrodenOntology, BrodenAnnotations
import os


def get_dataset(path, style='imagenet'):
    """
    Loads a given dataset.

    Parameters
    ----------
    path: str
        The main directory of
        the dataset to load
    style: str
        What kind of dataset
        to expect

    Returns
    -------
    dataset: torch.dtata.utils.Dataset
        PyTorch dataset to retrieve
        the images and their annotations

    Raises
    ------
    NotImplementedError
        If the dataset style has not
        been recognized
    """
    if style == 'imagenet':
        data_directory = os.path.dirname(path)
        ontology = ImageNetOntology(data_directory)
        images = ImageNetDataset(path)
        annotations = ImageNetAnnotations(path)
        return images, ontology, annotations
    elif style == 'broden':
        data_directory = os.path.dirname(path)
        images = BrodenDataset(path)
        ontology = BrodenOntology(data_directory)
        annotations = BrodenAnnotations(path)
        return images, ontology, annotations
    elif style == 'broden_vanilla':
        data_directory = os.path.dirname(path)
        images = BrodenDataset(path)
        ontology = BrodenOntology(data_directory, vanilla_nd=True)
        annotations = BrodenAnnotations(path)
        return images, ontology, annotations
    else:
        raise NotImplementedError
