""" Factory for loading different data sets """

import enum
from typing import Any, Dict, Callable

from torch.utils.data import DataLoader

from text2graph.data.pcqm4m_dataset import PCQM4MDataset
from text2graph.data.webnlg2020_dataset import WebNLG2020Dataset


class DatasetFactory(enum.Enum):
    chem = PCQM4MDataset
    webnlg = WebNLG2020Dataset


def init_dataloader(
    split_name: str,
    dataset_config : Dict[str, Any],
    collate_function: Callable,
    multiprocessing_flag: bool = True,
    shuffle: bool = True
) -> DataLoader:
    """ Initializes a Custom_Dataloader instance and then using it as
        an initializing argument initializes a Dataloader instance. Then
        returns the Dataloader instance
    """
    return DataLoader(
        dataset=DatasetFactory[dataset_config['type']].value(
            parent_path=dataset_config['path'],
            dataset_name=dataset_config['name'],
            split_name=split_name,
            file_prefix=dataset_config['file_prefix'],
            **dataset_config.get('additional_arguments', {})
        ),
        collate_fn=collate_function,
        num_workers=dataset_config['num_workers'] if multiprocessing_flag else 0,
        batch_size=dataset_config['batch_size'],
        shuffle=shuffle
    )
