import logging
import numpy as np
import torchvision.datasets


CLASS_NOT_IN_STL10 = 'frog'


logger = logging.getLogger(__name__)


class CIFAR10(torchvision.datasets.CIFAR10):
    def __init__(self, root, train, match_stl10_classes=False, transform=None,
                 target_transform=None, download=False, shift_indices=False):
        '''
        Parameters
        ----------
        root : str
            Location of folder where dataset is or should be stored.
        train : bool
            Whether this is the train or test split.
        match_stl10_classes : bool, default False
            Whether to eliminate the class that doesn't exist in STL10.
        transform : Transform, optional
            Optional transform to apply to images.
        target_transform : Transform, optional
            Optional transform to apply to labels.
        download : bool, default False
            Whether to download the dataset if it is missing.
        shift_indices : bool, default False
            If match_stl10_classes is True, then make the labels contiguous from 0 to 8.
        '''

        super().__init__(
            root=root, train=train, transform=transform,
            target_transform=target_transform,
            download=download
        )
        self.targets = np.array(self.targets)
        if match_stl10_classes:
            target_not_in_stl10 = self.class_to_idx[CLASS_NOT_IN_STL10]
            examples_in_stl10_mask = self.targets != target_not_in_stl10
            self.data = self.data[examples_in_stl10_mask]
            self.targets = self.targets[examples_in_stl10_mask]
            self.classes.remove(CLASS_NOT_IN_STL10)
            del self.class_to_idx[CLASS_NOT_IN_STL10]
            if shift_indices:
                self.targets[self.targets > target_not_in_stl10] -= 1
        elif shift_indices:
            logger.warning('shift_indices=True was ignored because match_stl10_classes was False')

        self.labels = self.targets


def get_dataset_cifar_stl_style_pretrain(dataset, data_dir, transform, train=True, download=True):
    return CIFAR10(root=data_dir, transform=transform, train=train, match_stl10_classes=True, download=download, shift_indices=True)