import os
import pickle

from . import datasets


class DatasetFactory:
    dict_datasets = {
        'MNIST': datasets.MNIST,
        'EMNIST': datasets.EMNIST,
        'KMNIST': datasets.KMNIST,
        'FashionMNIST': datasets.FashionMNIST,
        'OMNIGLOT': datasets.OMNIGLOT,
        'QuickDraw': datasets.QuickDraw,
        'dSprites': datasets.dSprites,
        'HDW': datasets.HDW,
        'Shapes3D': datasets.Shapes3D,
        'KKanji': datasets.KKanji,
        'SVHN': datasets.SVHN,
        'CIFAR10': datasets.CIFAR10,
        'CIFAR100': datasets.CIFAR100,
        'Caltech101': datasets.Caltech101,
        'CelebA': datasets.CelebA,
        'StanfordCars': datasets.StanfordCars,
        'MiniEcoset': datasets.MiniEcoset,
        'ImageNet': datasets.ImageNet,
        'MSCOCO': datasets.MSCOCO,
        'THINGS': datasets.THINGS,
    }

    @classmethod
    def load_dataset(cls, dataset_name, *args, **kwargs):
        if dataset_name in cls.dict_datasets:
            return cls.dict_datasets[dataset_name](*args, **kwargs)
        else:
            raise ValueError(f'Unknown dataset: {dataset_name}. '
                             f'Select from {list(cls.dict_datasets.keys())}')


def load_dataset(dataset_name,
                 path_input=None,
                 path_output=None,
                 size=None):

    if size:
        filename = f'{dataset_name}_{size}.pickle'
    else:
        filename = f'{dataset_name}.pickle'

    if path_output is None:
        path_output = path_input

    path_pickle = os.path.join(path_output, filename)

    if os.path.exists(path_pickle):
        with open(path_pickle, 'rb') as f:
            dataset = pickle.load(f)
    else:
        if size:
            dataset = DatasetFactory.load_dataset(dataset_name,
                                                  path_input,
                                                  size)
        else:
            dataset = DatasetFactory.load_dataset(dataset_name,
                                                  path_input)

    return dataset


def save_dataset(dataset, path_output):
    dataset_name = dataset.__class__.__name__

    if dataset.size is not None:
        filename = f'{dataset_name}_{dataset.size}.pickle'
    else:
        filename = f'{dataset_name}.pickle'

    path_pickle = os.path.join(path_output, filename)

    with open(path_pickle, 'wb') as f:
        pickle.dump(dataset, f, protocol=pickle.HIGHEST_PROTOCOL)


def get_available_datasets():
    dict_datasets = DatasetFactory.dict_datasets
    return list(dict_datasets.keys())
