
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Implementation of image dataset ops."""

from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import stateless_random_ops

# go/tf-wildcard-import
# pylint: disable=wildcard-import
from tensorflow.python.ops.gen_random_ops import *
# pylint: enable=wildcard-import

from tensorflow.python.util import dispatch
from tensorflow.python.util.tf_export import tf_export
import functools

def _AssertAtLeast3DImage(image):
  """Assert that we are working with a properly shaped image.

  Performs the check statically if possible (i.e. if the shape
  is statically known). Otherwise adds a control dependency
  to an assert op that checks the dynamic shape.

  Args:
    image: >= 3-D Tensor of size [*, height, width, depth]

  Raises:
    ValueError: if image.shape is not a [>= 3] vector.

  Returns:
    If the shape of `image` could be verified statically, `image` is
    returned unchanged, otherwise there will be a control dependency
    added that asserts the correct dynamic shape.
  """
  return control_flow_ops.with_dependencies(
      _CheckAtLeast3DImage(image, require_static=False), image)


def _CheckAtLeast3DImage(image, require_static=True):
  """Assert that we are working with a properly shaped image.

  Args:
    image: >= 3-D Tensor of size [*, height, width, depth]
    require_static: If `True`, requires that all dimensions of `image` are known
      and non-zero.

  Raises:
    ValueError: if image.shape is not a [>= 3] vector.

  Returns:
    An empty list, if `image` has fully defined dimensions. Otherwise, a list
    containing an assert op is returned.
  """
  try:
    if image.get_shape().ndims is None:
      image_shape = image.get_shape().with_rank(3)
    else:
      image_shape = image.get_shape().with_rank_at_least(3)
  except ValueError:
    raise ValueError("'image' (shape %s) must be at least three-dimensional." %
                     image.shape)
  if require_static and not image_shape.is_fully_defined():
    raise ValueError('\'image\' must be fully defined.')
  if any(x == 0 for x in image_shape[-3:]):
    raise ValueError('inner 3 dims of \'image.shape\' must be > 0: %s' %
                     image_shape)
  if not image_shape[-3:].is_fully_defined():
    return [
        check_ops.assert_positive(
            array_ops.shape(image)[-3:],
            ["inner 3 dims of 'image.shape' "
             'must be > 0.']),
        check_ops.assert_greater_equal(
            array_ops.rank(image),
            3,
            message="'image' must be at least three-dimensional.")
    ]
  else:
    return []

def fix_image_flip_shape(image, result):
  """Set the shape to 3 dimensional if we don't know anything else.

  Args:
    image: original image size
    result: flipped or transformed image

  Returns:
    An image whose shape is at least (None, None, None).
  """

  image_shape = image.get_shape()
  if image_shape == tensor_shape.unknown_shape():
    result.set_shape([None, None, None])
  else:
    result.set_shape(image_shape)
  return result

@tf_export('image.stateless_random_flip_left_right', v1=[])
@dispatch.add_dispatch_support
def stateless_random_flip_left_right(image, seed):
  """Randomly flip an image horizontally (left to right) deterministically.

  Guarantees the same results given the same `seed` independent of how many
  times the function is called, and independent of global seed settings (e.g.
  `tf.random.set_seed`).

  Example usage:

  >>> image = np.array([[[1], [2]], [[3], [4]]])
  >>> seed = (2, 3)
  >>> tf.image.stateless_random_flip_left_right(image, seed).numpy().tolist()
  [[[2], [1]], [[4], [3]]]

  Args:
    image: 4-D Tensor of shape `[batch, height, width, channels]` or 3-D Tensor
      of shape `[height, width, channels]`.
    seed: A shape [2] Tensor, the seed to the random number generator. Must have
      dtype `int32` or `int64`. (When using XLA, only `int32` is allowed.)

  Returns:
    A tensor of the same type and shape as `image`.
  """
  random_func = functools.partial(
      stateless_random_ops.stateless_random_uniform, seed=seed)
  return _random_flip(
      image, 1, random_func, 'stateless_random_flip_left_right')


@tf_export('image.stateless_random_flip_up_down', v1=[])
@dispatch.add_dispatch_support
def stateless_random_flip_up_down(image, seed):
  """Randomly flip an image vertically (upside down) deterministically.

  Guarantees the same results given the same `seed` independent of how many
  times the function is called, and independent of global seed settings (e.g.
  `tf.random.set_seed`).

  Example usage:

  >>> image = np.array([[[1], [2]], [[3], [4]]])
  >>> seed = (2, 3)
  >>> tf.image.stateless_random_flip_up_down(image, seed).numpy().tolist()
  [[[3], [4]], [[1], [2]]]

  Args:
    image: 4-D Tensor of shape `[batch, height, width, channels]` or 3-D Tensor
      of shape `[height, width, channels]`.
    seed: A shape [2] Tensor, the seed to the random number generator. Must have
      dtype `int32` or `int64`. (When using XLA, only `int32` is allowed.)

  Returns:
    A tensor of the same type and shape as `image`.
  """
  random_func = functools.partial(
      stateless_random_ops.stateless_random_uniform, seed=seed)
  return _random_flip(
      image, 0, random_func, 'stateless_random_flip_up_down')


def _random_flip(image, flip_index, random_func, scope_name):
  """Randomly (50% chance) flip an image along axis `flip_index`.

  Args:
    image: 4-D Tensor of shape `[batch, height, width, channels]` or 3-D Tensor
      of shape `[height, width, channels]`.
    flip_index: Dimension along which to flip the image.
      Vertical is 0, Horizontal is 1.
    random_func: partial function for calling either stateful or stateless
      random ops with `seed` parameter specified.
    scope_name: Name of the scope in which the ops are added.

  Returns:
    A tensor of the same type and shape as `image`.

  Raises:
    ValueError: if the shape of `image` not supported.
  """
  with ops.name_scope(None, scope_name, [image]) as scope:
    image = ops.convert_to_tensor(image, name='image')
    image = _AssertAtLeast3DImage(image)
    shape = image.get_shape()

    def f_rank3():
      uniform_random = random_func(shape=[], minval=0, maxval=1.0)
      mirror_cond = math_ops.less(uniform_random, .5)
      result = control_flow_ops.cond(
          mirror_cond,
          lambda: array_ops.reverse(image, [flip_index]),
          lambda: image,
          name=scope)
      return fix_image_flip_shape(image, result)

    def f_rank4():
      batch_size = array_ops.shape(image)[0]
      uniform_random = random_func(shape=[batch_size], minval=0, maxval=1.0)
      flips = math_ops.round(
          array_ops.reshape(uniform_random, [batch_size, 1, 1, 1]))
      flips = math_ops.cast(flips, image.dtype)
      flipped_input = array_ops.reverse(image, [flip_index + 1])
      return flips * flipped_input + (1 - flips) * image

    if shape.ndims is None:
      rank = array_ops.rank(image)
      return control_flow_ops.cond(math_ops.equal(rank, 3), f_rank3, f_rank4)
    if shape.ndims == 3:
      return f_rank3()
    elif shape.ndims == 4:
      return f_rank4()
    else:
      raise ValueError(
          '\'image\' (shape %s) must have either 3 or 4 dimensions.' % shape)




@tf_export("image.stateless_random_crop", v1=[])
@dispatch.add_dispatch_support
def stateless_random_crop(value, size, seed, name=None):
  """Randomly crops a tensor to a given size in a deterministic manner.

  Slices a shape `size` portion out of `value` at a uniformly chosen offset.
  Requires `value.shape >= size`.

  If a dimension should not be cropped, pass the full size of that dimension.
  For example, RGB images can be cropped with
  `size = [crop_height, crop_width, 3]`.

  Guarantees the same results given the same `seed` independent of how many
  times the function is called, and independent of global seed settings (e.g.
  `tf.random.set_seed`).

  Usage Example:

  >>> image = [[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]]
  >>> seed = (1, 2)
  >>> tf.image.stateless_random_crop(value=image, size=(1, 2, 3), seed=seed)
  <tf.Tensor: shape=(1, 2, 3), dtype=int32, numpy=
  array([[[1, 2, 3],
          [4, 5, 6]]], dtype=int32)>

  Args:
    value: Input tensor to crop.
    size: 1-D tensor with size the rank of `value`.
    seed: A shape [2] Tensor, the seed to the random number generator. Must have
      dtype `int32` or `int64`. (When using XLA, only `int32` is allowed.)
    name: A name for this operation (optional).

  Returns:
    A cropped tensor of the same rank as `value` and shape `size`.
  """
  with ops.name_scope(name, "random_crop", [value, size]) as name:
    value = ops.convert_to_tensor(value, name="value")
    size = ops.convert_to_tensor(size, dtype=dtypes.int32, name="size")
    shape = array_ops.shape(value)
    check = control_flow_ops.Assert(
        math_ops.reduce_all(shape >= size),
        ["Need value.shape >= size, got ", shape, size],
        summarize=1000)
    shape = control_flow_ops.with_dependencies([check], shape)
    limit = shape - size + 1
    offset = stateless_random_ops.stateless_random_uniform(
        array_ops.shape(shape),
        dtype=size.dtype,
        maxval=size.dtype.max,
        seed=seed) % limit
    return array_ops.slice(value, offset, size, name=name)


