from typing import Optional

import tensorflow as tf


def random_resized_crop(image, scale, ratio, seed):
    assert image.shape.ndims == 3 or image.shape.ndims == 4
    if image.shape.ndims == 3:
        image = tf.expand_dims(image, axis=0)
    batch_size = tf.shape(image)[0]
    # taken from https://keras.io/examples/vision/nnclr/#random-resized-crops
    log_ratio = (tf.math.log(ratio[0]), tf.math.log(ratio[1]))
    height = tf.shape(image)[1]
    width = tf.shape(image)[2]

    random_scales = tf.random.stateless_uniform((batch_size,), seed, scale[0], scale[1])
    random_ratios = tf.exp(
        tf.random.stateless_uniform((batch_size,), seed, log_ratio[0], log_ratio[1])
    )

    new_heights = tf.clip_by_value(tf.sqrt(random_scales / random_ratios), 0, 1)
    new_widths = tf.clip_by_value(tf.sqrt(random_scales * random_ratios), 0, 1)
    height_offsets = tf.random.stateless_uniform(
        (batch_size,), seed, 0, 1 - new_heights
    )
    width_offsets = tf.random.stateless_uniform((batch_size,), seed, 0, 1 - new_widths)

    bounding_boxes = tf.stack(
        [
            height_offsets,
            width_offsets,
            height_offsets + new_heights,
            width_offsets + new_widths,
        ],
        axis=1,
    )

    image = tf.image.crop_and_resize(
        image, bounding_boxes, tf.range(batch_size), (height, width)
    )

    if image.shape[0] == 1:
        return image[0]
    else:
        return image


def random_rot90(image, seed):
    k = tf.random.stateless_uniform((), seed, 0, 4, dtype=tf.int32)
    return tf.image.rot90(image, k=k)


AUGMENT_OPS = {
    "random_resized_crop": random_resized_crop,
    "random_brightness": tf.image.stateless_random_brightness,
    "random_contrast": tf.image.stateless_random_contrast,
    "random_saturation": tf.image.stateless_random_saturation,
    "random_hue": tf.image.stateless_random_hue,
    "random_flip_left_right": tf.image.stateless_random_flip_left_right,
    "random_flip_up_down": tf.image.stateless_random_flip_up_down,
    "random_rot90": random_rot90,
}


def augment_image(
    image: tf.Tensor,
    seed: Optional[tf.Tensor] = None,
    **augment_kwargs,
) -> tf.Tensor:
    """Unified image augmentation function for TensorFlow.

    This function is primarily configured through `augment_kwargs`. There must be one kwarg called "augment_order",
    which is a list of strings specifying the augmentation operations to apply and the order in which to apply them. See
    the `AUGMENT_OPS` dictionary above for a list of available operations.

    For each entry in "augment_order", there may be a corresponding kwarg with the same name. The value of this kwarg
    can be a dictionary of kwargs or a sequence of positional args to pass to the corresponding augmentation operation.
    This additional kwarg is required for all operations that take additional arguments other than the image and random
    seed. For example, the "random_resized_crop" operation requires a "scale" and "ratio" argument that can be specified
    either positionally or by name. "random_flip_left_right", on the other hand, does not take any additional arguments
    and so does not require an additional kwarg to configure it.

    Here is an example config:

    ```
    augment_kwargs = {
        "augment_order": ["random_resized_crop", "random_brightness", "random_contrast", "random_flip_left_right"],
        "random_resized_crop": {
            "scale": [0.8, 1.0],
            "ratio": [3/4, 4/3],
        },
        "random_brightness": [0.1],
        "random_contrast": [0.9, 1.1],
    ```

    Args:
        image: A `Tensor` of shape [height, width, channels] with the image. May be uint8 or float32 with values in [0, 255].
        seed (optional): A `Tensor` of shape [2] with the seed for the random number generator.
        **augment_kwargs: Keyword arguments for the augmentation operations. The order of operations is determined by
            the "augment_order" keyword argument.  Other keyword arguments are passed to the corresponding augmentation
            operation. See above for a list of operations.
    """
    assert isinstance(augment_kwargs, dict)

    if "augment_order" not in augment_kwargs:
        raise ValueError("augment_kwargs must contain an 'augment_order' key.")

    # convert to float at the beginning to avoid each op converting back and
    # forth between uint8 and float32 internally
    orig_dtype = image.dtype
    image = tf.image.convert_image_dtype(image, tf.float32)

    if seed is None:
        seed = tf.random.uniform([2], maxval=tf.dtypes.int32.max, dtype=tf.int32)

    for op in augment_kwargs["augment_order"]:
        seed = tf.random.stateless_uniform(
            [2], seed, maxval=tf.dtypes.int32.max, dtype=tf.int32
        )
        if op in augment_kwargs:
            if hasattr(augment_kwargs[op], "items"):
                image = AUGMENT_OPS[op](image, seed=seed, **augment_kwargs[op])
            else:
                image = AUGMENT_OPS[op](image, seed=seed, *augment_kwargs[op])
        else:
            image = AUGMENT_OPS[op](image, seed=seed)
        # float images are expected to be in [0, 1]
        image = tf.clip_by_value(image, 0, 1)

    # convert back to original dtype and scale
    image = tf.image.convert_image_dtype(image, orig_dtype, saturate=True)

    return image
