# coding=utf-8
# Copyright 2023 The Uncertainty Baselines Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Provides utilities to preprocess images for the Inception networks."""

import tensorflow as tf

IMAGE_SIZE = 224
CROP_PADDING = 32


def distorted_bounding_box_crop(image_bytes,
                                bbox,
                                seed,
                                min_object_covered=0.1,
                                aspect_ratio_range=(0.75, 1.33),
                                area_range=(0.05, 1.0),
                                max_attempts=100):
  """Generates cropped_image using one of the bboxes randomly distorted.

  See `tf.image.sample_distorted_bounding_box` for more documentation.

  Args:
    image_bytes: `Tensor` of binary image data.
    bbox: `Tensor` of bounding boxes arranged `[1, num_boxes, coords]`
        where each coordinate is [0, 1) and the coordinates are arranged
        as `[ymin, xmin, ymax, xmax]`. If num_boxes is 0 then use the whole
        image.
    seed: the random seed to use.
    min_object_covered: An optional `float`. Defaults to `0.1`. The cropped
        area of the image must contain at least this fraction of any bounding
        box supplied.
    aspect_ratio_range: An optional list of `float`s. The cropped area of the
        image must have an aspect ratio = width / height within this range.
    area_range: An optional list of `float`s. The cropped area of the image
        must contain a fraction of the supplied image within in this range.
    max_attempts: An optional `int`. Number of attempts at generating a cropped
        region of the image of the specified constraints. After `max_attempts`
        failures, return the entire image.
  Returns:
    (cropped image `Tensor`, distorted bbox `Tensor`).
  """
  with tf.name_scope('distorted_bounding_box_crop'):
    decoded = image_bytes.dtype != tf.string
    shape = (tf.shape(image_bytes) if decoded
             else tf.image.extract_jpeg_shape(image_bytes))
    sample_distorted_bounding_box = tf.image.stateless_sample_distorted_bounding_box(
        shape,
        bounding_boxes=bbox,
        seed=seed,
        min_object_covered=min_object_covered,
        aspect_ratio_range=aspect_ratio_range,
        area_range=area_range,
        max_attempts=max_attempts,
        use_image_if_no_bounding_boxes=True)
    bbox_begin, bbox_size, _ = sample_distorted_bounding_box

    # Crop the image to the specified bounding box.
    offset_y, offset_x, _ = tf.unstack(bbox_begin)
    target_height, target_width, _ = tf.unstack(bbox_size)
    if decoded:
      image = tf.image.crop_to_bounding_box(
          image_bytes,
          offset_height=offset_y,
          offset_width=offset_x,
          target_height=target_height,
          target_width=target_width)
    else:
      crop_window = tf.stack([offset_y, offset_x, target_height, target_width])
      image = tf.image.decode_and_crop_jpeg(image_bytes,
                                            crop_window,
                                            channels=3)
    return image


def _at_least_x_are_equal(a, b, x):
  """At least `x` of `a` and `b` `Tensors` are equal."""
  match = tf.equal(a, b)
  match = tf.cast(match, tf.int32)
  return tf.greater_equal(tf.reduce_sum(match), x)


def _resize_image(image, image_size, method=None):
  if method is not None:
    return tf.image.resize([image], [image_size, image_size], method)[0]
  return tf.image.resize([image], [image_size, image_size], method='bicubic')[0]


def _decode_and_random_crop(image_bytes, image_size, seed, resize_method=None):
  """Make a random crop of image_size."""
  bbox = tf.constant([0.0, 0.0, 1.0, 1.0], dtype=tf.float32, shape=[1, 1, 4])
  image = distorted_bounding_box_crop(
      image_bytes,
      bbox,
      seed=seed,
      min_object_covered=0.1,
      aspect_ratio_range=(3. / 4, 4. / 3.),
      area_range=(0.08, 1.0),
      max_attempts=10)
  decoded = image_bytes.dtype != tf.string
  original_shape = (tf.shape(image_bytes) if decoded
                    else tf.image.extract_jpeg_shape(image_bytes))
  bad = _at_least_x_are_equal(original_shape, tf.shape(image), 3)

  image = tf.cond(
      bad,
      lambda: _decode_and_center_crop(image_bytes, image_size),
      lambda: _resize_image(image, image_size, resize_method))

  return image


def _decode_and_center_crop(image_bytes, image_size, resize_method=None):
  """Crops to center of image with padding then scales by image_size.

  Args:
    image_bytes: `Tensor` representing an image binary of arbitrary size, or
      an already decoded image.
    image_size: Image height/width dimension.
    resize_method: Resize method.

  Returns:
    A decoded and cropped image Tensor.
  """
  decoded = image_bytes.dtype != tf.string
  shape = (tf.shape(image_bytes) if decoded
           else tf.image.extract_jpeg_shape(image_bytes))
  image_height = shape[0]
  image_width = shape[1]

  padded_center_crop_size = tf.cast(
      ((image_size / (image_size + CROP_PADDING)) *
       tf.cast(tf.minimum(image_height, image_width), tf.float32)),
      tf.int32)

  offset_height = ((image_height - padded_center_crop_size) + 1) // 2
  offset_width = ((image_width - padded_center_crop_size) + 1) // 2
  crop_window = tf.stack([offset_height, offset_width,
                          padded_center_crop_size, padded_center_crop_size])
  if decoded:
    image = tf.image.crop_to_bounding_box(
        image_bytes,
        offset_height=offset_height,
        offset_width=offset_width,
        target_height=padded_center_crop_size,
        target_width=padded_center_crop_size)
  else:
    image = tf.image.decode_and_crop_jpeg(image_bytes, crop_window, channels=3)

  image = _resize_image(image, image_size, resize_method)

  return image


def preprocess_for_train(image_bytes,
                         use_bfloat16,
                         seed,
                         image_size=IMAGE_SIZE,
                         resize_method=None):
  """Preprocesses the given image for evaluation.

  Args:
    image_bytes: `Tensor` representing an image binary of arbitrary size, or
      an already decoded image.
    use_bfloat16: `bool` for whether to use bfloat16.
    seed: the random seed to use.
    image_size: image size/resolution in Efficientnet.
    resize_method: resize method. If none, use bicubic.

  Returns:
    A preprocessed image `Tensor`.
  """
  seeds = tf.random.experimental.stateless_split(seed, num=2)
  image = _decode_and_random_crop(
      image_bytes, image_size, seed=seeds[0], resize_method=resize_method)
  image = tf.image.stateless_random_flip_left_right(image, seed=seeds[1])
  image = tf.reshape(image, [image_size, image_size, 3])
  image = tf.image.convert_image_dtype(
      image, dtype=tf.bfloat16 if use_bfloat16 else tf.float32)
  return image


def preprocess_for_eval(image_bytes,
                        use_bfloat16,
                        image_size=IMAGE_SIZE,
                        resize_method=None):
  """Preprocesses the given image for evaluation.

  Args:
    image_bytes: `Tensor` representing an image binary of arbitrary size, or
      an already decoded image.
    use_bfloat16: `bool` for whether to use bfloat16.
    image_size: image size.
    resize_method: if None, use bicubic.

  Returns:
    A preprocessed image `Tensor`.
  """
  image = _decode_and_center_crop(image_bytes, image_size, resize_method)
  image = tf.reshape(image, [image_size, image_size, 3])
  image = tf.image.convert_image_dtype(
      image, dtype=tf.bfloat16 if use_bfloat16 else tf.float32)
  return image


def preprocess_image(image_bytes,
                     seed=None,
                     is_training=False,
                     use_bfloat16=False,
                     image_size=IMAGE_SIZE,
                     resize_method=None):
  """Preprocesses the given image.

  Args:
    image_bytes: `Tensor` representing an image binary of arbitrary size, or
      an already decoded image.
    seed: the random seed to use. Must be a shape (2,) array.
    is_training: `bool` for whether the preprocessing is for training.
    use_bfloat16: `bool` for whether to use bfloat16.
    image_size: image size.
    resize_method: if None, use bicubic.

  Returns:
    A preprocessed image `Tensor`.
  """
  if seed is None:
    seed = tf.random.uniform((2,), maxval=int(1e10), dtype=tf.int32)
  if is_training:
    return preprocess_for_train(
        image_bytes,
        use_bfloat16,
        seed=seed,
        image_size=image_size,
        resize_method=resize_method)
  else:
    return preprocess_for_eval(image_bytes, use_bfloat16,
                               image_size, resize_method)

