import numpy as np
import tensorflow as tf
#import tensorflow_datasets as tfds
from tensorflow.keras.datasets import mnist
from deel.datasets.util_generator_dataset import simple_generator, otp_generator


'''class MNISTSubDataset(tfds.core.GeneratorBasedBuilder):
  """Short description of my dataset."""

    VERSION = tfds.core.Version('0.1.0')

    def __init__(self, selected_classes=None,gtValues=None, *args, **kwargs):
        """
        """
        super(MNISTSubDataset, self).__init__(*args, **kwargs)
        self.selected_classes = selected_classes
        self.gtValues = gtValues
        self.nb_classes = len(self.selected_classes)

  def _info(self):
    return tfds.core.DatasetInfo(
        builder=self,
        # This is the description that will appear on the datasets page.
        description=("This is a subclass  dataset for MNIST able to equilibrate classes. "),
        # tfds.features.FeatureConnectors
        features=tfds.features.FeaturesDict({
            #"image_description": tfds.features.Text(),
            "image": tfds.features.Image(),
            # Here, labels can be of 5 distinct values.
            "label": tfds.features.ClassLabel(num_classes=self.nb_classes),
        }),
        # If there's a common (input, target) tuple from the features,
        # specify them here. They'll be used if as_supervised=True in
        # builder.as_dataset.
        supervised_keys=("image", "label")
    )

  def _split_generators(self, dl_manager):
    if self.selected_classes is None:
        self.selected_classes = range(10)
    #if self.nb_classes == 2:
    #    self.nb_classes = 1 ## binary
    if self.gtValues is None:
        index_selected_class = {self.selected_classes[i]:i for i in range(len(self.selected_classes))}
    else:
        assert len(self.gtValues)==self.nb_classes
        index_selected_class = {self.selected_classes[i]:self.gtValues[i] for i in range(len(self.selected_classes))}


    print(index_selected_class)

    # the data, shuffled and split between train and test sets
    (X_all, y_all), (X_test, y_test) = mnist.load_data()
    X_all = X_all.reshape((-1, 28, 28, 1))
    X_test = X_test.reshape((-1, 28, 28, 1))

    print("Select only "+str(self.nb_classes)+" classes:"+str(selected_classes))

    print(y_all.shape)
    select_all = [y in self.selected_classes for y in y_all]
    X_all = X_all[select_all]
    y_all = y_all[select_all]
    #print(y_all)
    y_all = [index_selected_class[y] for y in y_all]
    y_all = np.asarray(y_all)
    #y_all = np.reshape(y_all,(-1,1))
    print("X shape "+str(X_all.shape))
    max_train = int(len(X_all)*0.9)

    select_test = [y in self.selected_classes for y in y_test]
    X_test = X_test[select_test]
    y_test = y_test[select_test]
    #print(y_test.shape)
    y_test = [index_selected_class[y] for y in y_test]
    y_test = np.asarray(y_test)

    #y_test = np.reshape(y_test,(-1,1))
    print(y_test.shape)

    #X_test = X_test.reshape(-1, img_rows, img_cols, nb_channel)
    X_all = X_all.astype('float32')
    X_test = X_test.astype('float32')
    

    self.X_train = X_all[:max_train]
    self.X_valid = X_all[max_train:]
    self.Y_train = y_all[:max_train]
    self.Y_valid = y_all[max_train:]
    self.Y_test = y_test
    print(self.X_train.shape[0], 'train samples')
    print(self.X_valid.shape[0], 'valid samples')
    print(self.X_test.shape[0], 'test samples')

  def _generate_examples(self):
    # Yields examples from the dataset
    yield 'key', {}
'''
def mnist_dataset(batch_size,to_categorical, selected_classes=None,gtValues=None):
    
    if selected_classes is None:
        selected_classes = range(10)
    nb_classes = len(selected_classes)
    if nb_classes == 2:
        nb_classes = 1 ## binary
    if gtValues is None:
        index_selected_class = {selected_classes[i]:i for i in range(len(selected_classes))}
    else:
        assert len(gtValues)==nb_classes
        index_selected_class = {selected_classes[i]:gtValues[i] for i in range(len(selected_classes))}


    print(index_selected_class)

    # the data, shuffled and split between train and test sets
    (X_all, y_all), (X_test, y_test) = mnist.load_data()
    X_all = X_all.reshape((-1, 28, 28, 1))
    X_test = X_test.reshape((-1, 28, 28, 1))

    print("Select only "+str(nb_classes)+" classes:"+str(selected_classes))

    select_all = [y in selected_classes for y in y_all]
    X_all = X_all[select_all]
    y_all = y_all[select_all]
    #print(y_all)
    y_all = [index_selected_class[y] for y in y_all]
    y_all = np.asarray(y_all)
    #y_all = np.reshape(y_all,(-1,1))
    max_train = int(len(X_all)*0.9)

    select_test = [y in selected_classes for y in y_test]
    X_test = X_test[select_test]
    y_test = y_test[select_test]
    #print(y_test.shape)
    y_test = [index_selected_class[y] for y in y_test]
    y_test = np.asarray(y_test)

    #y_test = np.reshape(y_test,(-1,1))
    print(y_test.shape)

    #X_test = X_test.reshape(-1, img_rows, img_cols, nb_channel)
    X_all = X_all.astype('float32')
    X_test = X_test.astype('float32')
    X_all /= 255
    X_test /= 255
    X_all = 2*X_all -1  # -1 1 range
    X_test = 2*X_test -1  # -1 1 range

    X_train = X_all[:max_train]
    X_valid = X_all[max_train:]
    Y_train = y_all[:max_train]
    Y_valid = y_all[max_train:]
    Y_test = y_test
    if to_categorical:
        Y_test = tf.keras.utils.to_categorical(Y_test,len(selected_classes))
        Y_train = tf.keras.utils.to_categorical(Y_train,len(selected_classes))
        Y_valid = tf.keras.utils.to_categorical(Y_valid,len(selected_classes))


    print(X_train.shape[0], 'train samples')
    print(X_valid.shape[0], 'valid samples')
    print(X_test.shape[0], 'test samples')

    dtset = {'train' : simple_generator(batch_size,X_train,Y_train) , 'trainSize': X_train.shape[0],
            'valid' : simple_generator(batch_size,X_valid,Y_valid), 'validSize': X_valid.shape[0],
            'test' : simple_generator(batch_size,X_test,Y_test,shuffle=False), 'testSize': X_test.shape[0],
            'batch_size': batch_size }
    return dtset
    
