# coding=utf-8
# Copyright 2022 The Mixed Fl Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Data utilities for working with Federated CelebA dataset."""

import tensorflow as tf

DEFAULT_BATCH_SIZE = 32
SHUFFLE_BUFFER = 10000


def preprocess_img_dataset(images_ds,
                           label_attribute,
                           batch_size = DEFAULT_BATCH_SIZE,
                           shuffle = True,
                           num_epochs = 1):
  """Returns a preprocessed dataset, making the designated attribute the label.

  The preprocessing converts the raw image from [0, 255] to [0.0, 1.0]. It also
  makes the attribute designated by `label_attribute` the binary label, and
  removes all other attribute information

  Args:
    images_ds: Raw CelebA dataset of OrderedDict elements (i.e., in the format
    described at
    https://www.tensorflow.org/federated/api_docs/python/tff/simulation/datasets/celeba/load_data)
    to be processed.
    label_attribute: The attribute to use as the label. Must be one of the
    attribute keys in the example's OrderedDict.
    batch_size: Batch size of output dataset. If None, don't batch.
    shuffle: If True, shuffle the dataset.
    num_epochs: The number of epochs to repeat the raw dataset in the processed
      dataset.

  Returns:
    A preprocessed, batched, and possibly shuffled/repeated dataset of images
      and labels.

  Raises:
    ValueError: If the `label_attribute` provided is not a valid attribute key
      in the CelebA dataset.
  """
  if label_attribute not in list(images_ds.element_spec.keys()):
    raise ValueError(
        'The `label_attribute` specified (%s) is not a valid attribute key in '
        'the CelebA dataset. ' % label_attribute)

  # Prevent multiple reads across the wire in the case of multiple epochs.
  if num_epochs > 1:
    images_ds = images_ds.cache()

  @tf.function
  def _preprocess(example):
    image = tf.cast(example['image'], tf.float32) / 255.0
    label = tf.expand_dims(
        tf.cast(example[label_attribute], tf.float32), axis=-1)
    return image, label

  images_ds = images_ds.map(_preprocess)

  if shuffle:
    images_ds = images_ds.shuffle(
        buffer_size=SHUFFLE_BUFFER, reshuffle_each_iteration=True, seed=124578)
  if batch_size is not None:
    images_ds = images_ds.batch(batch_size, drop_remainder=False)
  # Shuffle and batch come before repeat so we shuffle and batch within each
  # epoch, but process complete epochs before repeating.
  images_ds = images_ds.repeat(num_epochs)

  # Note: .prefetch is an optimization which will begin preparing later elements
  # while current elements are being processed. It consumes more memory but
  # should save time. See documentation at:
  # https://www.tensorflow.org/api_docs/python/tf/data/Dataset#prefetch
  return images_ds.prefetch(tf.data.AUTOTUNE)
