from enum import Enum

from src.dataset.Birdsnap import Birdsnap
from src.dataset.Food101 import Food101
from src.dataset.CIFAR10 import CIFAR10
from src.dataset.NihCxr import NihCxr


class SupportedDataset(Enum):
    CIFAR10_Enum = dict(
        dataloader=CIFAR10,
        image_size=(32, 32),  # Used for model FC layer.
        channels=3,
        training_size=50000,
        labels_count=10
    )

    ImageNet_Enum = dict(
        dataloader='ImageNet',
        image_size=(224, 224),  # Used for model FC layer.
        channels=3,
        training_size=50000,
        labels_count=1000
    )

    Food101_Enum = dict(
        dataloader=Food101,
        labels_count=101,
        image_size=(224, 224),
        channels=3,
    )

    NihCxr_Enum = dict(
        dataloader=NihCxr,
        labels_count=14,
        image_size=(1024, 1024),
        channels=3,
    )

    Birdsnap_Enum = dict(
        dataloader=Birdsnap,
        labels_count=500,
        image_size=(224, 224),
        channels=3,
    )


MAP_DATASET_TO_ENUM = dict(
    CIFAR10=SupportedDataset.CIFAR10_Enum,
    Food101=SupportedDataset.Food101_Enum,
    NihCxr=SupportedDataset.NihCxr_Enum,
    ImageNet=SupportedDataset.ImageNet_Enum,
    Birdsnap=SupportedDataset.Birdsnap_Enum
)


def create_dataset(dataset_args, train_data_args, val_data_args):
    if dataset_args['name'] not in SupportedDataset:
        raise ValueError('Unsupported Dataset')

    if dataset_args.get('split_ratio'):
        return dataset_args['name'].value['dataloader'](train_data_args,
                                                        val_data_args,
                                                        dataset_args,
                                                        split_ratio=dataset_args.get('split_ratio'))
    else:
        return dataset_args['name'].value['dataloader'](train_data_args,
                                                        val_data_args,
                                                        dataset_args)
