import logging

import numpy as np
import torch.utils.data as data
from PIL import Image
from torchvision.datasets import MNIST, FashionMNIST, EMNIST
from torchvision.transforms import transforms

logging.basicConfig()
logger = logging.getLogger()
logger.setLevel(logging.INFO)

IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', '.webp')




class emnist_truncated(data.Dataset):

    def __init__(self, root, cache_data_set=None,dataidxs=None, train=True, transform=None, target_transform=None, download=False):

        self.root = root
        self.dataidxs = dataidxs
        self.train = train
        self.transform = transform
        self.target_transform = target_transform
        self.download = download
        self.data, self.targets = self.__build_truncated_dataset__(cache_data_set)

    def __build_truncated_dataset__(self,cache_data_set):
        print("download = " + str(self.download))
        if cache_data_set == None:
            emnist_dataobj = EMNIST(
                self.root, train=self.train, download=True, split='letters', transform=self.transform)
        else:
            emnist_dataobj =  cache_data_set

            # print("train member of the class: {}".format(self.train))
            # data = mnist_dataobj.train_data
        data = emnist_dataobj.data
        targets = np.array(emnist_dataobj.targets)


        if self.dataidxs is not None:
            data = data[self.dataidxs].numpy()
            targets = targets[self.dataidxs]



        return data, targets


    def __getitem__(self, index):
        """
        Args:
            index (int): Index

        Returns:
            tuple: (image, target) where target is index of the target class.
        """

        img, target = self.data[index], self.targets[index]
        img = self.transform(img)
        return img, target-1

    def __len__(self):
        return len(self.data)
