"""
Returns a dataset and DataLoader for the requested dataset.
"""

from datasets.wilds_dataset import get_wilds_dataloader
from datasets.domainnet_dataset import get_domainnet_dataloader
from torchvision import transforms

def get_dataloader(
    dataset_name: str = None,
    trial_num: int = None,
    path_to_dataset_dir: str = None,
    path_to_annotations_dir: str = None,
    domains: list[str] = None,
    classes: list[str] = None,
    transform: transforms.Compose = None,
    split_type: str = "train",
    batch_size: int=32,
    num_workers: int=4,
    pin_memory: bool=True,
    is_training: bool=True,
    create_val_from_train: bool=False
):
    """Return a dataset and DataLoader for the requested dataset.

    If dataset_name == "domainnet", uses the DomainNet dataloader; otherwise,
    uses the WILDS dataloader.

    Parameters
    ----------
    dataset_name : str
        Name of the dataset to load. Use "domainnet" or a WILDS dataset name
        ("iwildcam", "camelyon17", "fmow").
    trial_num : int
        Integer seed for deterministic worker initialization.
    path_to_dataset_dir : str
        Root directory where dataset files are stored.
    path_to_annotations_dir : str
        Directory containing DomainNet annotation files (required for DomainNet).
    domains : list[str]
        Subset of domains to include (DomainNet) or domain filter (WILDS).
    classes : list[str]
        Ordered class names (DomainNet only).
    transform : torchvision.transforms.Compose
        Transform applied to each loaded sample.
    split_type : str
        Data split to load, such as "train", "val", or "test".
    batch_size : int
        Mini-batch size for the DataLoader.
    num_workers : int
        Number of subprocesses to use for data loading.
    pin_memory : bool
        If True and using CUDA, copies Tensors into pinned memory before return.
    is_training : bool
        Whether the dataloader is for training or evaluation (shuffle=True or False).
    create_val_from_train : bool
        If True, create a validation split from the training data when supported.

    Returns
    -------
    dataset, loader
        The constructed dataset and its DataLoader.
    """

    if dataset_name == "domainnet":
        return get_domainnet_dataloader(
            trial_num=trial_num,
            path_to_annotations_dir=path_to_annotations_dir,
            path_to_dataset_dir=path_to_dataset_dir,
            domains=domains,
            classes=classes,
            transform=transform,
            split_type=split_type,
            batch_size=batch_size,
            num_workers=num_workers,
            pin_memory=pin_memory,
            is_training=is_training,
            create_val_from_train=create_val_from_train
        )
    else:
        return get_wilds_dataloader(  
            trial_num=trial_num,
            dataset_name=dataset_name,
            path_to_dataset_dir=path_to_dataset_dir,
            split_type=split_type,
            domains=domains,
            transform=transform,
            batch_size=batch_size,
            num_workers=num_workers,
            pin_memory=pin_memory,
            is_training=is_training,
            create_val_from_train=create_val_from_train
        )