DATASETS = [
    'Winoground', # Winoground
    'VG_Relation', 'VG_Attribution', 'COCO_Order', 'Flickr30k_Order', # ARO datasets
    'COCO_Retrieval', 'Flickr30k_Retrieval', 'COCO_Retrieval_Val', 'Flickr30k_Retrieval_Val', # Retrieval datasets
    'CrepeAtom', 'CrepeNegate', 'Crepe3Swap', 'Crepe5Swap', # Crepe datasets
    'EqBen_All', 'EqBen_Val', # EqBen datasets
    'VL_CheckList_Attribute_action', 'VL_CheckList_Attribute_color', 'VL_CheckList_Attribute_material', 'VL_CheckList_Attribute_state', 'VL_CheckList_Attribute_size', # VL_CheckList_Attribute datasets
    'VL_CheckList_Relation_action', 'VL_CheckList_Relation_spatial', # VL_CheckList_Relation datasets
    'VL_CheckList_Object_Location_center', 'VL_CheckList_Object_Location_margin', 'VL_CheckList_Object_Location_mid', 'VL_CheckList_Object_Size_large', 'VL_CheckList_Object_Size_medium', 'VL_CheckList_Object_Size_small', # VL_CheckList_Object datasets
    'Laion_Retrieval_subset_100_00', 'Laion_Retrieval_subset_100_01',
    'Laion_Retrieval_subset_500_00', 'Laion_Retrieval_subset_500_01',
    'Laion_Retrieval_subset_1000_00', 'Laion_Retrieval_subset_1000_01',
    'Laion_Retrieval_subset_2000_00', 'Laion_Retrieval_subset_2000_01',
    'Laion_Retrieval_subset_5000_00', 'Laion_Retrieval_subset_5000_01',
    'Laion_Retrieval_subset_1000_00_sum', 'Laion_Retrieval_subset_2000_00_sum', 'Laion_Retrieval_subset_5000_00_sum',
]

from .datasets import Laion

def get_dataset(dataset_name, image_preprocess=None, image_perturb_fn=None, download=True, *args, **kwargs):
    """
    Helper function that returns a dataset object with an evaluation function. 
    dataset_name: Name of the dataset.
    image_preprocess: Preprocessing function for images.
    text_perturb_fn: A function that takes in a string and returns a string. This is for ARO's perturbation experiments.
    image_perturb_fn: A function that takes in a PIL image and returns a PIL image. This is for ARO's perturbation experiments.
    download: Whether to allow downloading images if they are not found.
    """
    if dataset_name == "Winoground":
        from .datasets import get_winoground
        return get_winoground(image_preprocess=image_preprocess, *args, **kwargs)
    elif dataset_name == "VG_Relation": 
        from .datasets import get_visual_genome_relation
        return get_visual_genome_relation(image_preprocess=image_preprocess, download=download, *args, **kwargs)
    elif dataset_name == "VG_Attribution":
        from .datasets import get_visual_genome_attribution
        return get_visual_genome_attribution(image_preprocess=image_preprocess, download=download, *args, **kwargs)
    elif dataset_name == "COCO_Order":
        from .datasets import get_coco_order
        return get_coco_order(image_preprocess=image_preprocess, download=download, *args, **kwargs)
    elif dataset_name == "Flickr30k_Order":
        from .datasets import get_flickr30k_order
        return get_flickr30k_order(image_preprocess=image_preprocess, download=download, *args, **kwargs)
    elif dataset_name == "COCO_Retrieval":
        from .retrieval import get_coco_retrieval
        return get_coco_retrieval(image_preprocess=image_preprocess, image_perturb_fn=image_perturb_fn, download=download, split='test', *args, **kwargs)
    elif dataset_name == "Laion_Retrieval_subset_100_00":
        from .retrieval import get_laion_retrieval
        return get_laion_retrieval(image_preprocess=image_preprocess, image_perturb_fn=image_perturb_fn, download=download, split='subset_100_00', *args, **kwargs)
    elif dataset_name == "Laion_Retrieval_subset_100_01":
        from .retrieval import get_laion_retrieval
        return get_laion_retrieval(image_preprocess=image_preprocess, image_perturb_fn=image_perturb_fn, download=download, split='subset_100_01', *args, **kwargs)
    elif dataset_name == "Laion_Retrieval_subset_500_00":
        from .retrieval import get_laion_retrieval
        return get_laion_retrieval(image_preprocess=image_preprocess, image_perturb_fn=image_perturb_fn, download=download, split='subset_500_00', *args, **kwargs)
    elif dataset_name == "Laion_Retrieval_subset_500_01":
        from .retrieval import get_laion_retrieval
        return get_laion_retrieval(image_preprocess=image_preprocess, image_perturb_fn=image_perturb_fn, download=download, split='subset_500_01', *args, **kwargs)
    elif dataset_name == "Laion_Retrieval_subset_1000_00" or dataset_name == 'Laion_Retrieval_subset_1000_00_sum':
        from .retrieval import get_laion_retrieval
        return get_laion_retrieval(image_preprocess=image_preprocess, image_perturb_fn=image_perturb_fn, download=download, split='subset_1000_00', *args, **kwargs)
    elif dataset_name == "Laion_Retrieval_subset_1000_01":
        from .retrieval import get_laion_retrieval
        return get_laion_retrieval(image_preprocess=image_preprocess, image_perturb_fn=image_perturb_fn, download=download, split='subset_1000_01', *args, **kwargs)
    elif dataset_name == "Laion_Retrieval_subset_2000_00" or dataset_name == 'Laion_Retrieval_subset_2000_00_sum':
        from .retrieval import get_laion_retrieval
        return get_laion_retrieval(image_preprocess=image_preprocess, image_perturb_fn=image_perturb_fn, download=download, split='subset_2000_00', *args, **kwargs)
    elif dataset_name == "Laion_Retrieval_subset_2000_01":
        from .retrieval import get_laion_retrieval
        return get_laion_retrieval(image_preprocess=image_preprocess, image_perturb_fn=image_perturb_fn, download=download, split='subset_2000_01', *args, **kwargs)
    elif dataset_name == "Laion_Retrieval_subset_5000_00" or dataset_name == 'Laion_Retrieval_subset_5000_00_sum':
        from .retrieval import get_laion_retrieval
        return get_laion_retrieval(image_preprocess=image_preprocess, image_perturb_fn=image_perturb_fn, download=download, split='subset_5000_00', *args, **kwargs)
    elif dataset_name == "Laion_Retrieval_subset_5000_01" or dataset_name == 'Laion_Retrieval_subset_5000_01_sum':
        from .retrieval import get_laion_retrieval
        return get_laion_retrieval(image_preprocess=image_preprocess, image_perturb_fn=image_perturb_fn, download=download, split='subset_5000_01', *args, **kwargs)
    elif dataset_name == "Flickr30k_Retrieval":
        from .retrieval import get_flickr30k_retrieval
        return get_flickr30k_retrieval(image_preprocess=image_preprocess, image_perturb_fn=image_perturb_fn, download=download, split='test', *args, **kwargs)
    elif dataset_name == "COCO_Retrieval_Val":
        from .retrieval import get_coco_retrieval
        return get_coco_retrieval(image_preprocess=image_preprocess, image_perturb_fn=image_perturb_fn, download=download, split='val', *args, **kwargs)
    elif dataset_name == "Flickr30k_Retrieval_Val":
        from .retrieval import get_flickr30k_retrieval
        return get_flickr30k_retrieval(image_preprocess=image_preprocess, image_perturb_fn=image_perturb_fn, download=download, split='val', *args, **kwargs)
    elif dataset_name == "CrepeAtom":
        from .datasets import get_crepe_productivity
        return get_crepe_productivity(image_preprocess=image_preprocess, download=download, hard_neg_type='atom', *args, **kwargs)
    elif dataset_name == "CrepeNegate":
        from .datasets import get_crepe_productivity
        return get_crepe_productivity(image_preprocess=image_preprocess, download=download, hard_neg_type='negate', *args, **kwargs)
    elif dataset_name == "Crepe3Swap":
        from .datasets import get_crepe_productivity
        return get_crepe_productivity(image_preprocess=image_preprocess, download=download, hard_neg_type='swap_3', *args, **kwargs)
    elif dataset_name == "Crepe5Swap":
        from .datasets import get_crepe_productivity
        return get_crepe_productivity(image_preprocess=image_preprocess, download=download, hard_neg_type='swap_5', *args, **kwargs)
    elif dataset_name == "EqBen_All":
        from .datasets import get_eqben_all
        return get_eqben_all(image_preprocess=image_preprocess, download=download, *args, **kwargs)
    elif dataset_name == "EqBen_Val":
        from .datasets import get_eqben_val
        return get_eqben_val(image_preprocess=image_preprocess, download=download, *args, **kwargs)
    elif dataset_name == 'VL_CheckList_Attribute_action':
        from .datasets import get_vl_checklist
        return get_vl_checklist(image_preprocess, download=download, split='Attribute', subsplit='action', *args, **kwargs)
    elif dataset_name == 'VL_CheckList_Attribute_color':
        from .datasets import get_vl_checklist
        return get_vl_checklist(image_preprocess, download=download, split='Attribute', subsplit='color', *args, **kwargs)
    elif dataset_name == 'VL_CheckList_Attribute_material':
        from .datasets import get_vl_checklist
        return get_vl_checklist(image_preprocess, download=download, split='Attribute', subsplit='material', *args, **kwargs)
    elif dataset_name == 'VL_CheckList_Attribute_size':
        from .datasets import get_vl_checklist
        return get_vl_checklist(image_preprocess, download=download, split='Attribute', subsplit='size', *args, **kwargs)
    elif dataset_name == 'VL_CheckList_Attribute_state':
        from .datasets import get_vl_checklist
        return get_vl_checklist(image_preprocess, download=download, split='Attribute', subsplit='state', *args, **kwargs)
    elif dataset_name == 'VL_CheckList_Object_Location_center':
        from .datasets import get_vl_checklist
        return get_vl_checklist(image_preprocess, download=download, split='Object', subsplit='Location', subsubsplit='center', *args, **kwargs)
    elif dataset_name == 'VL_CheckList_Object_Location_margin':
        from .datasets import get_vl_checklist
        return get_vl_checklist(image_preprocess, download=download, split='Object', subsplit='Location', subsubsplit='margin', *args, **kwargs)
    elif dataset_name == 'VL_CheckList_Object_Location_mid':
        from .datasets import get_vl_checklist
        return get_vl_checklist(image_preprocess, download=download, split='Object', subsplit='Location', subsubsplit='mid', *args, **kwargs)
    elif dataset_name == 'VL_CheckList_Object_Size_large':
        from .datasets import get_vl_checklist
        return get_vl_checklist(image_preprocess, download=download, split='Object', subsplit='Size', subsubsplit='large', *args, **kwargs)
    elif dataset_name == 'VL_CheckList_Object_Size_medium':
        from .datasets import get_vl_checklist
        return get_vl_checklist(image_preprocess, download=download, split='Object', subsplit='Size', subsubsplit='medium', *args, **kwargs)
    elif dataset_name == 'VL_CheckList_Object_Size_small':
        from .datasets import get_vl_checklist
        return get_vl_checklist(image_preprocess, download=download, split='Object', subsplit='Size', subsubsplit='small', *args, **kwargs)
    elif dataset_name == 'VL_CheckList_Relation_spatial':
        from .datasets import get_vl_checklist
        return get_vl_checklist(image_preprocess, download=download, split='Relation', subsplit='spatial', *args, **kwargs)
    elif dataset_name == 'VL_CheckList_Relation_action':
        from .datasets import get_vl_checklist
        return get_vl_checklist(image_preprocess, download=download, split='Relation', subsplit='action', *args, **kwargs)
    else:
        raise ValueError(f"Unknown dataset {dataset_name}")
