import tensorflow as tf
import numpy as np
import glob, os, pickle
from data.dataset_utils import resize, random_resize, flip

class Cifar10:
    def __init__(self, pathData,
            num_layers=1, batch_size=1, merge_train_valid=False):
        self.num_layers = num_layers
        self.batch_size = batch_size
        if not merge_train_valid:
            raise NotImplementedError

        filenamesTrain = glob.glob(os.path.join(pathData, 'data_batch_*'))
        filenamesTrain.sort()
        filenameTest = [os.path.join(pathData, 'test_batch')]

        def load_file(filenames):
            images = []
            labels = []
            for filename in filenames:
                with open(filename, 'rb') as myFile:
                    data = pickle.load(myFile, encoding='bytes')
                image = data[b'data']
                image = image.reshape(-1, 3, 32, 32)
                image = image.swapaxes(1, 3).swapaxes(1, 2)
                images.append(image)
                labels.append(np.array(data[b'labels']))
            images = np.vstack(images)
            labels = np.hstack(labels)

            return images, labels

        self.imagesTrain, self.labelsTrain = load_file(filenamesTrain)
        self.mean = np.mean(self.imagesTrain, axis=(0, 1, 2))
        self.std = np.std(self.imagesTrain, axis=(0, 1, 2))
        self.imagesValid, self.labelsValid = load_file(filenameTest)

        self.numClasses = len(np.unique(self.labelsTrain))
        assert(self.numClasses == 10)
        self.shape = tuple(self.imagesTrain.shape[1:])

    def generate_dataset(self, images, labels, training=False):
        labelsDict = {}
        for index_layer in range(self.num_layers):
            name = 'output_' + str(index_layer + 1) # name is defined by tf.keras modules
            labelsDict[name] = labels
        dataset = tf.data.Dataset.from_tensor_slices((images, labelsDict))

        def normalize(image, label):
            image = (image - self.mean) / self.std
            return image, label

        def augment(image, label):
            image = tf.image.resize_with_crop_or_pad(image, self.shape[0] + 8, self.shape[1] + 8)
            image = tf.image.random_crop(image, self.shape)
            image = tf.image.random_flip_left_right(image)
            return image, label

        if training:
            dataset = dataset.shuffle(len(images))
            dataset = dataset.map(normalize, num_parallel_calls=tf.data.experimental.AUTOTUNE)
            dataset = dataset.map(augment, num_parallel_calls=tf.data.experimental.AUTOTUNE)
            dataset = dataset.batch(self.batch_size)
            dataset = dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
        else:
            dataset = dataset.map(normalize)
            dataset = dataset.batch(self.batch_size)
        return dataset

    def get_train_set(self):
        return self.generate_dataset(self.imagesTrain, self.labelsTrain, training=True)

    def get_valid_set(self):
        return self.generate_dataset(self.imagesValid, self.labelsValid)

    def get_num_classes(self):
        return self.numClasses

    def get_shape(self):
        return self.shape

    def get_len_train(self):
        return len(self.imagesTrain)

    def get_len_valid(self):
        return len(self.imagesValid)

class Cityscapes:
    # TODO: implement flip and get_shape
    def __init__(self, pathData,
            num_layers=1, batch_size=1, crop=[512, 1024], merge_train_valid=False):

        self.pathData = pathData
        self.num_layers = num_layers
        self.batch_size = batch_size
        self.crop = crop
        self.merge_train_valid = merge_train_valid
        self.numSamplesTest = None
        self.numSamplesValid = None
        self.numSamplesTrain = None
        self.numClasses = 19

        self.ignore_map = tf.convert_to_tensor(np.load('data/cityscapes_ignore.npy'), dtype=tf.int32) # TODO: better use yaml file

    def generate_dataset(self, split='val', training=False):
        def read_images_labels(imagePath):
            splittedPath = tf.strings.split(imagePath, '/')
            labelPath = tf.strings.join([self.pathData, 'gtFine', splittedPath[-3], splittedPath[-2],
                    tf.strings.split(splittedPath[-1], '_leftImg8bit.png')[-2] + '_gtFine_labelIds.png'], separator='/')
            image = tf.cast(tf.image.decode_png(tf.io.read_file(imagePath), channels=3), tf.float32)
            label = tf.cast(tf.image.decode_png(tf.io.read_file(labelPath), channels=1), tf.int32)
            label = tf.gather(self.ignore_map, label)
            return image, label

        def normalize(image, label):
            return image / 127.5 - 1.0, label

        def resize_train(image, label):
            # return random_resize(image, label, self.crop)
            return resize(image, label, self.crop)

        def resize_valid(image, label):
            return resize(image, label, self.crop)

        def augment(image, label):
            return flip(image, label)

        def duplicate_label(image, label):
            labels = {}
            for index_layer in range(self.num_layers):
                name = 'output_' + str(index_layer + 1) # name is defined by tf.keras modules
                labels[name] = label
            return image, labels

        pathImages = self.pathData + '/leftImg8bit/' + split + '/*/*_leftImg8bit.png'
        dataset = tf.data.Dataset.list_files(pathImages)
        numSamples = len(list(dataset))
        if training:
            # dataset = dataset.cache()
            dataset = dataset.shuffle(numSamples)
            dataset = dataset.map(read_images_labels)
            dataset = dataset.map(resize_train, num_parallel_calls=tf.data.experimental.AUTOTUNE)
            dataset = dataset.map(normalize, num_parallel_calls=tf.data.experimental.AUTOTUNE)
            dataset = dataset.map(augment, num_parallel_calls=tf.data.experimental.AUTOTUNE)
            dataset = dataset.map(duplicate_label, num_parallel_calls=tf.data.experimental.AUTOTUNE)
            dataset = dataset.batch(self.batch_size, drop_remainder=True)
            dataset = dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
        else:
            dataset = dataset.map(read_images_labels)
            dataset = dataset.map(resize_valid)
            dataset = dataset.map(normalize)
            dataset = dataset.map(duplicate_label)
            dataset = dataset.batch(self.batch_size)

        return dataset, numSamples

    def get_train_set(self):
        dataset, self.numSamplesTrain = self.generate_dataset('train')
        assert(self.numSamplesTrain == 2975)
        if self.merge_train_valid:
            datasetValid, numSamplesValid  = self.generate_dataset('val')
            dataset = dataset.concatenate(datasetValid)
            self.numSamplesTrain = self.numSamplesTrain + numSamplesValid
            assert(self.numSamplesTrain == 3475)
        return dataset

    def get_valid_set(self):
        if self.merge_train_valid:
            return None
        else:
            dataset, self.numSamplesValid = self.generate_dataset('val')
            assert(self.numSamplesValid == 500)
            return dataset

    def get_test_set(self):
        dataset, self.numSamplesTest = self.generate_dataset('test')
        assert(self.numSamplesTest == 1525)
        return dataset

    def get_num_classes(self):
        return self.numClasses

    def get_shape(self):
        return tuple(self.crop) + (3,)

    def get_len_train(self):
        if self.numSamplesTrain is None:
            raise Exception('run self.get_train_set() first')
        return self.numSamplesTrain

    def get_len_valid(self):
        if self.numSamplesValid is None:
            raise Exception('run self.get_valid_set() first')
        return self.numSamplesValid

    def get_len_test(self):
        if self.numSamplesTest is None:
            raise Exception('run self.get_test_set() first')
        return self.numSamplesTest

class Camvid:
    def __init__(self, pathTrain, pathValid, pathTest='',
            num_layers=1, batch_size=1, crop=[352, 480], merge_train_valid=False, crop_method='random'):

        self.batch_size = batch_size
        self.num_layers = num_layers
        self.crop = crop
        self.merge_train_valid = merge_train_valid
        self.cropMethod = crop_method

        dataTrain = np.load(pathTrain)
        dataValid = np.load(pathValid)
        dataTest = np.load(pathTest)

        self.imagesTrain = dataTrain['images']
        self.masksTrain = dataTrain['labels']
        self.imagesValid = dataValid['images']
        self.masksValid = dataValid['labels']
        self.imagesTest = dataTest['images']
        self.masksTest = dataTest['labels']
        self.numClasses = np.max(np.unique(self.masksTrain[self.masksTrain != 255])) + 1
        assert(self.numClasses == 11)
        if merge_train_valid:
            self.imagesTrain = np.concatenate((self.imagesTrain, self.imagesValid))
            self.masksTrain = np.concatenate((self.masksTrain, self.masksValid))
            self.imagesValid = []
            self.masksValid = []

        print('shape of dataset')
        if merge_train_valid:
            print('  train+valid:', self.imagesTrain.shape, self.masksTrain.shape)
        else:
            print('  train:', self.imagesTrain.shape, self.masksTrain.shape)
            print('  valid:', self.imagesValid.shape, self.masksValid.shape)
        print('  test:', self.imagesTest.shape, self.masksTest.shape)
        print('  classes:', self.numClasses)

    def generate_dataset(self, images, masks, training=False):
        def normalize(image, labels):
            return tf.cast(image, tf.float32) / 255.0, labels
            # return tf.image.per_image_standardization(image), labels

        def crop(image, labels, method=''):
            if method == 'random':
                height = self.crop[0]
                width = self.crop[1]
                offset_height = tf.random.uniform(shape=[], minval=0, maxval=tf.shape(image)[0] - height + 1, dtype=tf.int32)
                offset_width = tf.random.uniform(shape=[], minval=0, maxval=tf.shape(image)[1] - width + 1, dtype=tf.int32)

                image = tf.image.crop_to_bounding_box(image, offset_height, offset_width, height, width)
                for name, mask in labels.items(): # TODO: only augment once
                    labels[name] = tf.image.crop_to_bounding_box(mask, offset_height, offset_width, height, width)
            elif method == 'resize': # TODO: should fill with 255: remove and add 255
                raise NotImplementedError
                resize = tf.random.uniform(shape=[], minval=0.75, maxval=1.1)
                height = tf.cast(tf.cast(tf.shape(image)[0], tf.float32) * resize, tf.int32)
                width = tf.cast(tf.cast(tf.shape(image)[1], tf.float32) * resize, tf.int32)
                image = tf.image.resize(image, tf.stack([height, width]))
                image = tf.image.resize_with_crop_or_pad(image, self.crop[0], self.crop[1])
                for name, mask in labels.items(): # TODO: only augment once
                    mask  = tf.image.resize(mask, tf.stack([height, width]), method='nearest')
                    labels[name] = tf.image.resize_with_crop_or_pad(mask, self.crop[0], self.crop[1])
            elif method == 'center':
                image = tf.image.resize_with_crop_or_pad(image, self.crop[0], self.crop[1])
                for name, mask in labels.items(): # TODO: only augment once
                    labels[name] = tf.image.resize_with_crop_or_pad(mask, self.crop[0], self.crop[1])
            else:
                raise NotImplementedError

            return image, labels

        def crop_train(image, labels):
            return crop(image, labels, method=self.cropMethod)

        def crop_valid(image, labels):
            return crop(image, labels, method='center')

        def augment(image, labels):
            flip = tf.random.uniform(shape=[]) < 0.5
            image = tf.cond(flip, lambda: tf.image.flip_left_right(image), lambda: image)
            for name, mask in labels.items():
                labels[name] = tf.cond(flip, lambda: tf.image.flip_left_right(mask), lambda: mask)
            return image, labels

        labels = {}
        for index_layer in range(self.num_layers):
            name = 'output_' + str(index_layer + 1) # name is defined by tf.keras modules
            labels[name] = masks.astype(np.int32)
        dataset = tf.data.Dataset.from_tensor_slices((images, labels))
        if training:
            # dataset = dataset.cache()
            dataset = dataset.shuffle(len(images))
            dataset = dataset.map(crop_train, num_parallel_calls=tf.data.experimental.AUTOTUNE)
            dataset = dataset.map(normalize, num_parallel_calls=tf.data.experimental.AUTOTUNE)
            dataset = dataset.map(augment, num_parallel_calls=tf.data.experimental.AUTOTUNE)
            dataset = dataset.batch(self.batch_size, drop_remainder=True)
            dataset = dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
        else:
            dataset = dataset.map(crop_valid)
            dataset = dataset.map(normalize)
            dataset = dataset.batch(self.batch_size)

        return dataset

    def get_train_set(self):
        return self.generate_dataset(self.imagesTrain, self.masksTrain, training=True)

    def get_valid_set(self):
        if self.merge_train_valid:
            return None
        else:
            return self.generate_dataset(self.imagesValid, self.masksValid)

    def get_test_set(self):
        return self.generate_dataset(self.imagesTest, self.masksTest)

    def get_num_classes(self):
        return self.numClasses

    def get_shape(self):
        return tuple(self.crop) + (self.imagesTrain.shape[3],)

    def get_len_train(self):
        return len(self.imagesTrain)

    def get_len_valid(self):
        return len(self.imagesValid)

    def get_len_test(self):
        return len(self.imagesTest)
