"""Dataset transformations."""

import gin
import tensorflow as tf


def split_image(image3, tile_size):
    image_shape = tf.shape(image3)
    tile_rows = \
        tf.reshape(image3, [image_shape[0], -1, tile_size[1], image_shape[2]])
    serial_tiles = tf.transpose(tile_rows, [1, 0, 2, 3])
    return tf.reshape(serial_tiles,
                      [-1, tile_size[1], tile_size[0], image_shape[2]])


def unsplit_image(tiles4, image_shape):
    tile_width = tf.shape(tiles4)[1]
    serialized_tiles \
        = tf.reshape(tiles4, [-1, image_shape[0], tile_width, image_shape[2]])
    rowwise_tiles = tf.transpose(serialized_tiles, [1, 0, 2, 3])
    return tf.reshape(rowwise_tiles,
                      [image_shape[0], image_shape[1], image_shape[2]])


@gin.configurable
def scramble_image(image, tile_size, seed=0):
    splitted_image = split_image(image, tile_size)
    splitted_image = tf.random.shuffle(splitted_image, seed=seed)
    image = unsplit_image(splitted_image, tf.shape(image))

    return image


def _feature_list_specs(features_list):
    features_names = list(features_list[0].keys())
    assert len(features_names) == 2, 'Implement first'

    f0_name = features_names[0]
    f0_len = len({feature_spec[f0_name] for feature_spec in features_list})

    f1_name = features_names[1]
    f1_len = len({feature_spec[f1_name] for feature_spec in features_list})

    return f0_name, f0_len, f1_name, f1_len


def _encode(features, features_list):
    f0_name, f0_len, f1_name, _ = _feature_list_specs(features_list)
    return features[f0_name] + f0_len * features[f1_name]


def _decode(features_code, features_list):
    f0_name, f0_len, f1_name, f1_len = _feature_list_specs(features_list)
    feature0 = tf.math.floormod(tf.math.floordiv(features_code, f0_len), f1_len)
    feature1 = tf.math.floormod(features_code, f0_len)
    return {f0_name: feature0, f1_name: feature1}


@gin.configurable
def features_shift(features, features_list, shift):
    new_features_code = _encode(features, features_list) + shift
    return _decode(new_features_code, features_list)


@gin.configurable
def features_scramble(features, features_list, seed=0):
    _, f0_len, _, f1_len = _feature_list_specs(features_list)
    keys = tf.range(f0_len*f1_len)
    vals = tf.random.shuffle(keys, seed=seed)
    table = tf.lookup.StaticHashTable(
        tf.lookup.KeyValueTensorInitializer(keys, vals), -1)
    features_code = _encode(features, features_list)
    new_features_code = table.lookup(features_code)
    return _decode(new_features_code, features_list)
