# coding=utf-8
# Copyright 2022 The Mixed Fl Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Data utilities for working with Federated EMNIST dataset."""

import tensorflow as tf

DEFAULT_BATCH_SIZE = 32
SHUFFLE_BUFFER = 10000


def preprocess_img_dataset(images_ds,
                           batch_size = DEFAULT_BATCH_SIZE,
                           shuffle = True,
                           num_epochs = 1):
  """Returns a preprocessed dataset.

  The preprocessing converts the raw image from [0.0, 1.0] (where 1.0 is
  background) to [-1.0, 1.0] (where -1.0 is background).

  The preprocessing also converts the raw label from [0, 61] (where [0, 9] are
  labels for numbers, [10, 35] are labels for uppercase letters, and [36, 61]
  are labels for lowercase letters) to [0, 35] (where [0, 9] are labels for
  numbers and [10, 35] are labels for letters). In other words, the labels are
  converted to be letter case agnostic.

  Args:
    images_ds: The raw EMNIST dataset of OrderedDict elements (i.e., in the
    format described at
    https://www.tensorflow.org/federated/api_docs/python/tff/simulation/datasets/emnist/load_data)
    to be processed.
    batch_size: Batch size of output dataset. If None, don't batch.
    shuffle: If True, shuffle the dataset.
    num_epochs: The number of epochs to repeat the raw dataset in the processed
      dataset.

  Returns:
    A preprocessed, batched, and possibly shuffled/repeated dataset of images
    (and possibly labels).
  """
  # Prevent multiple reads across the wire in the case of multiple epochs.
  if num_epochs > 1:
    images_ds = images_ds.cache()

  @tf.function
  def _preprocess(element):
    """Preprocess: invert image if specified, and make label case agnostic."""
    image = tf.expand_dims(element['pixels'], 2)
    image = -2.0 * (image - 0.5)
    # Reduce label set to be [0, 35], where [0, 9] correspond to numbers and
    # [10, 35] correspond to letters. This makes our label set case agnostic.
    label = element['label']
    if label >= 36:
      label -= 26
    return image, label

  images_ds = images_ds.map(_preprocess)

  if shuffle:
    images_ds = images_ds.shuffle(
        buffer_size=SHUFFLE_BUFFER, reshuffle_each_iteration=True, seed=124578)
  if batch_size is not None:
    images_ds = images_ds.batch(batch_size, drop_remainder=False)
  # Shuffle and batch come before repeat so we shuffle and batch within each
  # epoch, but process complete epochs before repeating.
  images_ds = images_ds.repeat(num_epochs)

  # Note: .prefetch is an optimization which will begin preparing later elements
  # while current elements are being processed. It consumes more memory but
  # should save time. See documentation at:
  # https://www.tensorflow.org/api_docs/python/tf/data/Dataset#prefetch
  return images_ds.prefetch(tf.data.experimental.AUTOTUNE)
