"""Useful transformations for datasets."""

import colorsys

import matplotlib
matplotlib.use('agg')  # Switch backend
import seaborn as sns  # pylint: disable=wrong-import-position
import tensorflow as tf  # pylint: disable=wrong-import-position
import tensorflow_datasets as tfds  # pylint: disable=wrong-import-position


class UpdateMNISTColors:
    """Adds color to MNIST symbols."""

    def __init__(self, num_colors=10):
        self._num_colors = num_colors
        palette = sns.color_palette('hls', num_colors)
        palette = [colorsys.hls_to_rgb(*x) for x in palette]
        palette = [(255 * x[0], 255 * x[1], 255 * x[2]) for x in palette]
        palette = [(int(x[0]), int(x[1]), int(x[2])) for x in palette]
        self._palette = tf.constant(palette, dtype=tf.uint8)

    def feature_dict(self):
        return {'label_color': tfds.features.ClassLabel(
            num_classes=self._num_colors)}

    def transformation_fn(self):
        # pylint: disable=missing-function-docstring #Todo: fill me.
        hash_key = tf.random.Generator.from_seed(1234).normal(shape=(28, 28, 1),
                                                              dtype=tf.float64)

        @tf.function
        # pylint: disable=unexpected-keyword-arg, no-value-for-parameter
        def color_fn(datapoint):
            image = datapoint['image']
            mask = tf.equal(image, 0)
            mask = tf.cast(mask, dtype=tf.uint8)
            img = tf.ones_like(mask, dtype=tf.uint8)
            img = 1 - img * mask
            img = tf.tile(img, [1, 1, 3])
            hash_ = hash_key * tf.cast(image, dtype=tf.float64)
            key = tf.reduce_sum(hash_)
            key = 1e8 * key
            key = tf.math.abs(tf.cast(key, dtype=tf.int64))
            key = tf.math.floormod(key, self._num_colors)
            key = tf.cast(key, dtype=tf.int32)
            idx1h = tf.one_hot(key, self._num_colors, dtype=tf.uint8)
            idx1h = tf.reshape(idx1h, (self._num_colors, -1))
            idx1h = tf.tile(idx1h, (1, 3))
            color = self._palette * idx1h
            color = tf.reduce_sum(color, axis=0)
            img = img * color
            label = datapoint['label']
            label = tf.cast(label, dtype=tf.int32)
            return {'image': img, 'label': label, 'label_color': key}

        return color_fn
