from data.cls.meta_dataset import MetaDatasets
from data.cls.miniimagenet import MiniImagenetDataSet
from data.cls.tieredimagenet import TieredImagenetDataset
from data.dataset_utils import DatasetEnum
from data.multi_dataset_sampler import MultipleDataset
from data.cls.cifar_fs import CifarFS_dataset

def get_dataset(ds_name, split):
    # cls
    if ds_name == DatasetEnum.MINI_IMAGENET.name:
        return MiniImagenetDataSet(split=split)
    elif ds_name == DatasetEnum.CIFAR_FS.name:
        return CifarFS_dataset(split=split)
    elif ds_name == DatasetEnum.TIERED_IMAGENET.name:
        return TieredImagenetDataset(split=split)
    elif ds_name in (DatasetEnum.FUNGI.name, DatasetEnum.AIRCRAFT.name, DatasetEnum.BIRD.name, DatasetEnum.TEXTURE.name):
        return MetaDatasets(ds_name=ds_name, split=split)
    else:
        raise ValueError("unknown dataset: {}, {}".format(ds_name, split))


def get_multi_dataset(ds_name, split="train"):
    multi_ds = []
    if ds_name == DatasetEnum.MetaDataset.name:
        for ds_enum in [DatasetEnum.BIRD, DatasetEnum.TEXTURE, DatasetEnum.AIRCRAFT, DatasetEnum.FUNGI]:
            multi_ds.append((ds_enum.name, get_dataset(ds_name=ds_enum.name, split=split)))
    elif ds_name == DatasetEnum.MIX3A.name:
        for ds_enum in [DatasetEnum.AIRCRAFT, DatasetEnum.BIRD, DatasetEnum.FUNGI]:
            multi_ds.append((ds_enum.name, get_dataset(ds_name=ds_enum.name, split=split)))
    elif ds_name == DatasetEnum.MIX3B.name:
        for ds_enum in [DatasetEnum.MINI_IMAGENET, DatasetEnum.CIFAR_FS, DatasetEnum.BIRD]:
            multi_ds.append((ds_enum.name, get_dataset(ds_name=ds_enum.name, split=split)))
    elif ds_name in [DatasetEnum.BIRD.name, DatasetEnum.TEXTURE.name, DatasetEnum.AIRCRAFT.name,
                          DatasetEnum.FUNGI.name]:
        multi_ds.append((ds_name, get_dataset(ds_name=ds_name, split=split)))
    elif ds_name in [DatasetEnum.MINI_IMAGENET.name, DatasetEnum.CIFAR_FS.name, DatasetEnum.TIERED_IMAGENET.name]:
        multi_ds.append((ds_name, get_dataset(ds_name=ds_name, split=split)))
    multiple_ds = MultipleDataset(multi_ds)
    return multiple_ds

