from timm.data import create_dataset, ImageDataset
from .readers.annotation_reader import AnnotationReader


def create_dataset_v2(
        name,
        root,
        images=None, 
        labels=None,
        split='validation',
        search_split=True,
        class_map=None,
        load_bytes=False,
        is_training=False,
        download=False,
        batch_size=None,
        seed=42,
        repeats=0,
        **kwargs):
    
    if name == "vdd-imagenet12":
        # Create a custom parser
        parser = AnnotationReader(root, images, labels)
        ds = ImageDataset(root, parser=parser, class_map=class_map, load_bytes=load_bytes, **kwargs)
    else:
        ds = create_dataset(
                name=name, root=root, split=split, search_split=search_split, class_map=class_map, 
                load_bytes=load_bytes, is_training=is_training, download=download, batch_size=batch_size, repeats=repeats, 
                **kwargs
            )
        
    return ds