import os
from timm.data import ImageDataset
from timm.data.dataset_factory import _search_split

from cheem.timm_custom.data.readers import AutoSplitReaderImageFolder
from .datasets import ClassIncrementalDataset

def get_imagenet_data(root, search_split=True, dev_percent=0., offset_task_labels=False, **kwargs):

    if search_split and os.path.isdir(root):
        # look for split specific sub-folder in root
        train_root = _search_split(root, "train")
        val_root = _search_split(root, "validation")

    datasets = dict()
    do_val_split = dev_percent > 0.
    if do_val_split:
        train_reader = AutoSplitReaderImageFolder(root=train_root, split="train", dev_percent=dev_percent)
        val_reader = AutoSplitReaderImageFolder(root=train_root, samples=train_reader.val_samples, split="val", dev_percent=dev_percent, class_to_idx=train_reader.class_to_idx)
        # Create the datasets
        train_ds = ImageDataset(root, reader=train_reader, **kwargs)
        val_ds = ImageDataset(root, reader=val_reader, **kwargs)
        datasets["train"] = train_ds
        datasets["val"] = val_ds
    else:
        train_ds = ImageDataset(train_root, split="train", **kwargs)
        datasets["train"] = train_ds

    test_dataset = ImageDataset(val_root, split="validation", **kwargs)
    datasets["test"] = test_dataset

    return {split: ClassIncrementalDataset({"imagenet": ds}, offset_task_labels=offset_task_labels) for split, ds in datasets.items()}
