# Copyright 2020 DeepMind Technologies Limited.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# https://github.com/deepmind/deepmind-research/blob/master/byol/utils/augmentations.py
"""Data preprocessing and augmentation."""

import functools
from typing import Any, Mapping, Text

import jax
import jax.numpy as jnp
import dm_pix


def random_crop_single(image, rng, pad_size):
    """Randomly crop image."""
    padded_im = jnp.pad(
        image, ((pad_size, pad_size), (pad_size, pad_size), (0, 0)), "edge"
    )
    return dm_pix.random_crop(rng, padded_im, image.shape)


def random_crop(images, rng, pad_size):
    """Randomly crop images."""
    crop_fn = functools.partial(random_crop_single, pad_size=pad_size)
    keys = jax.random.split(rng, images.shape[0])
    return jax.vmap(crop_fn)(images, keys)


# typing
JaxBatch = Mapping[Text, jnp.ndarray]
ConfigDict = Mapping[Text, Any]

augment_config = dict(
    view1=dict(
        # Turn off random flip to maintain spatial consistency
        random_flip=False,  # True,  # Random left/right flip
        color_transform=dict(
            apply_prob=1.0,
            # Range of jittering
            brightness=0.4,
            contrast=0.4,
            saturation=0.2,
            hue=0.1,
            # Probability of applying color jittering
            color_jitter_prob=0.8,
            # Probability of converting to grayscale
            to_grayscale_prob=0.2,
            # Shuffle the order of color transforms
            shuffle=True,
        ),
        gaussian_blur=dict(
            apply_prob=1.0,
            # Kernel size ~ image_size / blur_divider
            blur_divider=10.0,
            # Kernel distribution
            sigma_min=0.1,
            sigma_max=2.0,
        ),
        solarize=dict(apply_prob=0.0, threshold=0.5),
    ),
    view2=dict(
        random_flip=False,  # True,
        color_transform=dict(
            apply_prob=1.0,
            brightness=0.4,
            contrast=0.4,
            saturation=0.2,
            hue=0.1,
            color_jitter_prob=0.8,
            to_grayscale_prob=0.2,
            shuffle=True,
        ),
        gaussian_blur=dict(
            apply_prob=0.1, blur_divider=10.0, sigma_min=0.1, sigma_max=2.0
        ),
        solarize=dict(apply_prob=0.2, threshold=0.5),
    ),
)


def postprocess(inputs: JaxBatch, rng: jnp.ndarray):
    """Apply the image augmentations to crops in inputs (view1 and view2)."""

    def _postprocess_image(
        images: jnp.ndarray,
        rng: jnp.ndarray,
        presets: ConfigDict,
    ) -> JaxBatch:
        """Applies augmentations in post-processing.

        Args:
          images: an NHWC tensor (with C=3), with float values in [0, 1].
          rng: a single PRNGKey.
          presets: a dict of presets for the augmentations.

        Returns:
          A batch of augmented images with shape NHWC, with keys view1, view2
          and labels.
        """
        flip_rng, color_rng, blur_rng, solarize_rng = jax.random.split(rng, 4)
        out = images
        if presets["random_flip"]:
            out = random_flip(out, flip_rng)
        if presets["color_transform"]["apply_prob"] > 0:
            out = color_transform(out, color_rng, **presets["color_transform"])
        if presets["gaussian_blur"]["apply_prob"] > 0:
            out = gaussian_blur(out, blur_rng, **presets["gaussian_blur"])
        if presets["solarize"]["apply_prob"] > 0:
            out = solarize(out, solarize_rng, **presets["solarize"])
        out = jnp.clip(out, 0.0, 1.0)
        return jax.lax.stop_gradient(out)

    rng1, rng2 = jax.random.split(rng, num=2)
    view1 = _postprocess_image(inputs["view1"], rng1, augment_config["view1"])
    view2 = _postprocess_image(inputs["view2"], rng2, augment_config["view2"])
    return dict(view1=view1, view2=view2)


def _maybe_apply(apply_fn, inputs, rng, apply_prob):
    should_apply = jax.random.uniform(rng, shape=()) <= apply_prob
    return jax.lax.cond(should_apply, inputs, apply_fn, inputs, lambda x: x)


def _depthwise_conv2d(inputs, kernel, strides, padding):
    """Computes a depthwise conv2d in Jax.

    Args:
      inputs: an NHWC tensor with N=1.
      kernel: a [H", W", 1, C] tensor.
      strides: a 2d tensor.
      padding: "SAME" or "VALID".

    Returns:
      The depthwise convolution of inputs with kernel, as [H, W, C].
    """
    return jax.lax.conv_general_dilated(
        inputs,
        kernel,
        strides,
        padding,
        feature_group_count=inputs.shape[-1],
        dimension_numbers=("NHWC", "HWIO", "NHWC"),
    )


def _gaussian_blur_single_image(image, kernel_size, padding, sigma):
    """Applies gaussian blur to a single image, given as NHWC with N=1."""
    radius = int(kernel_size / 2)
    kernel_size_ = 2 * radius + 1
    x = jnp.arange(-radius, radius + 1).astype(jnp.float32)
    blur_filter = jnp.exp(-(x**2) / (2.0 * sigma**2))
    blur_filter = blur_filter / jnp.sum(blur_filter)
    blur_v = jnp.reshape(blur_filter, [kernel_size_, 1, 1, 1])
    blur_h = jnp.reshape(blur_filter, [1, kernel_size_, 1, 1])
    num_channels = image.shape[-1]
    blur_h = jnp.tile(blur_h, [1, 1, 1, num_channels])
    blur_v = jnp.tile(blur_v, [1, 1, 1, num_channels])
    expand_batch_dim = len(image.shape) == 3
    if expand_batch_dim:
        image = image[jnp.newaxis, ...]
    blurred = _depthwise_conv2d(image, blur_h, strides=[1, 1], padding=padding)
    blurred = _depthwise_conv2d(blurred, blur_v, strides=[1, 1], padding=padding)
    blurred = jnp.squeeze(blurred, axis=0)
    return blurred


def _random_gaussian_blur(
    image, rng, kernel_size, padding, sigma_min, sigma_max, apply_prob
):
    """Applies a random gaussian blur."""
    apply_rng, transform_rng = jax.random.split(rng)

    def _apply(image):
        (sigma_rng,) = jax.random.split(transform_rng, 1)
        sigma = jax.random.uniform(
            sigma_rng, shape=(), minval=sigma_min, maxval=sigma_max, dtype=jnp.float32
        )
        return _gaussian_blur_single_image(image, kernel_size, padding, sigma)

    return _maybe_apply(_apply, image, apply_rng, apply_prob)


def rgb_to_hsv(r, g, b):
    """Converts R, G, B  values to H, S, V values.

    Reference TF implementation:
    https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/kernels/adjust_saturation_op.cc
    Only input values between 0 and 1 are guaranteed to work properly, but this
    function complies with the TF implementation outside of this range.

    Args:
      r: A tensor representing the red color component as floats.
      g: A tensor representing the green color component as floats.
      b: A tensor representing the blue color component as floats.

    Returns:
      H, S, V values, each as tensors of shape [...] (same as the input without
      the last dimension).
    """
    vv = jnp.maximum(jnp.maximum(r, g), b)
    range_ = vv - jnp.minimum(jnp.minimum(r, g), b)
    sat = jnp.where(vv > 0, range_ / vv, 0.0)
    norm = jnp.where(range_ != 0, 1.0 / (6.0 * range_), 1e9)

    hr = norm * (g - b)
    hg = norm * (b - r) + 2.0 / 6.0
    hb = norm * (r - g) + 4.0 / 6.0

    hue = jnp.where(r == vv, hr, jnp.where(g == vv, hg, hb))
    hue = hue * (range_ > 0)
    hue = hue + (hue < 0)

    return hue, sat, vv


def hsv_to_rgb(h, s, v):
    """Converts H, S, V values to an R, G, B tuple.

    Reference TF implementation:
    https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/kernels/adjust_saturation_op.cc
    Only input values between 0 and 1 are guaranteed to work properly, but this
    function complies with the TF implementation outside of this range.

    Args:
      h: A float tensor of arbitrary shape for the hue (0-1 values).
      s: A float tensor of the same shape for the saturation (0-1 values).
      v: A float tensor of the same shape for the value channel (0-1 values).

    Returns:
      An (r, g, b) tuple, each with the same dimension as the inputs.
    """
    c = s * v
    m = v - c
    dh = (h % 1.0) * 6.0
    fmodu = dh % 2.0
    x = c * (1 - jnp.abs(fmodu - 1))
    hcat = jnp.floor(dh).astype(jnp.int32)
    rr = (
        jnp.where(
            (hcat == 0) | (hcat == 5), c, jnp.where((hcat == 1) | (hcat == 4), x, 0)
        )
        + m
    )
    gg = (
        jnp.where(
            (hcat == 1) | (hcat == 2), c, jnp.where((hcat == 0) | (hcat == 3), x, 0)
        )
        + m
    )
    bb = (
        jnp.where(
            (hcat == 3) | (hcat == 4), c, jnp.where((hcat == 2) | (hcat == 5), x, 0)
        )
        + m
    )
    return rr, gg, bb


def adjust_brightness(rgb_tuple, delta):
    return jax.tree_map(lambda x: x + delta, rgb_tuple)


def adjust_contrast(image, factor):
    def _adjust_contrast_channel(channel):
        mean = jnp.mean(channel, axis=(-2, -1), keepdims=True)
        return factor * (channel - mean) + mean

    return jax.tree_map(_adjust_contrast_channel, image)


def adjust_saturation(h, s, v, factor):
    return h, jnp.clip(s * factor, 0.0, 1.0), v


def adjust_hue(h, s, v, delta):
    # Note: this method exactly matches TF"s adjust_hue (combined with the hsv/rgb
    # conversions) when running on GPU. When running on CPU, the results will be
    # different if all RGB values for a pixel are outside of the [0, 1] range.
    return (h + delta) % 1.0, s, v


def _random_brightness(rgb_tuple, rng, max_delta):
    delta = jax.random.uniform(rng, shape=(), minval=-max_delta, maxval=max_delta)
    return adjust_brightness(rgb_tuple, delta)


def _random_contrast(rgb_tuple, rng, max_delta):
    factor = jax.random.uniform(
        rng, shape=(), minval=1 - max_delta, maxval=1 + max_delta
    )
    return adjust_contrast(rgb_tuple, factor)


def _random_saturation(rgb_tuple, rng, max_delta):
    h, s, v = rgb_to_hsv(*rgb_tuple)
    factor = jax.random.uniform(
        rng, shape=(), minval=1 - max_delta, maxval=1 + max_delta
    )
    return hsv_to_rgb(*adjust_saturation(h, s, v, factor))


def _random_hue(rgb_tuple, rng, max_delta):
    h, s, v = rgb_to_hsv(*rgb_tuple)
    delta = jax.random.uniform(rng, shape=(), minval=-max_delta, maxval=max_delta)
    return hsv_to_rgb(*adjust_hue(h, s, v, delta))


def _to_grayscale(image):
    rgb_weights = jnp.array([0.2989, 0.5870, 0.1140])
    grayscale = jnp.tensordot(image, rgb_weights, axes=(-1, -1))[..., jnp.newaxis]
    return jnp.tile(grayscale, (1, 1, 3))  # Back to 3 channels.


def _color_transform_single_image(
    image,
    rng,
    brightness,
    contrast,
    saturation,
    hue,
    to_grayscale_prob,
    color_jitter_prob,
    apply_prob,
    shuffle,
):
    """Applies color jittering to a single image."""
    apply_rng, transform_rng = jax.random.split(rng)
    perm_rng, b_rng, c_rng, s_rng, h_rng, cj_rng, gs_rng = jax.random.split(
        transform_rng, 7
    )

    # Whether the transform should be applied at all.
    should_apply = jax.random.uniform(apply_rng, shape=()) <= apply_prob
    # Whether to apply grayscale transform.
    should_apply_gs = jax.random.uniform(gs_rng, shape=()) <= to_grayscale_prob
    # Whether to apply color jittering.
    should_apply_color = jax.random.uniform(cj_rng, shape=()) <= color_jitter_prob

    # Decorator to conditionally apply fn based on an index.
    def _make_cond(fn, idx):
        def identity_fn(x, unused_rng, unused_param):
            return x

        def cond_fn(args, i):
            def clip(args):
                return jax.tree_map(lambda arg: jnp.clip(arg, 0.0, 1.0), args)

            out = jax.lax.cond(
                should_apply & should_apply_color & (i == idx),
                args,
                lambda a: clip(fn(*a)),
                args,
                lambda a: identity_fn(*a),
            )
            return jax.lax.stop_gradient(out)

        return cond_fn

    random_brightness_cond = _make_cond(_random_brightness, idx=0)
    random_contrast_cond = _make_cond(_random_contrast, idx=1)
    random_saturation_cond = _make_cond(_random_saturation, idx=2)
    random_hue_cond = _make_cond(_random_hue, idx=3)

    def _color_jitter(x):
        rgb_tuple = tuple(jax.tree_map(jnp.squeeze, jnp.split(x, 3, axis=-1)))
        if shuffle:
            order = jax.random.permutation(perm_rng, jnp.arange(4, dtype=jnp.int32))
        else:
            order = range(4)
        for idx in order:
            if brightness > 0:
                rgb_tuple = random_brightness_cond((rgb_tuple, b_rng, brightness), idx)
            if contrast > 0:
                rgb_tuple = random_contrast_cond((rgb_tuple, c_rng, contrast), idx)
            if saturation > 0:
                rgb_tuple = random_saturation_cond((rgb_tuple, s_rng, saturation), idx)
            if hue > 0:
                rgb_tuple = random_hue_cond((rgb_tuple, h_rng, hue), idx)
        return jnp.stack(rgb_tuple, axis=-1)

    out_apply = _color_jitter(image)
    out_apply = jax.lax.cond(
        should_apply & should_apply_gs, out_apply, _to_grayscale, out_apply, lambda x: x
    )
    return jnp.clip(out_apply, 0.0, 1.0)


def _random_flip_single_image(image, rng):
    _, flip_rng = jax.random.split(rng)
    should_flip_lr = jax.random.uniform(flip_rng, shape=()) <= 0.5
    image = jax.lax.cond(should_flip_lr, image, jnp.fliplr, image, lambda x: x)
    return image


def random_flip(images, rng):
    rngs = jax.random.split(rng, images.shape[0])
    return jax.vmap(_random_flip_single_image)(images, rngs)


def color_transform(
    images,
    rng,
    brightness=0.8,
    contrast=0.8,
    saturation=0.8,
    hue=0.2,
    color_jitter_prob=0.8,
    to_grayscale_prob=0.2,
    apply_prob=1.0,
    shuffle=True,
):
    """Applies color jittering and/or grayscaling to a batch of images.

    Args:
      images: an NHWC tensor, with C=3.
      rng: a single PRNGKey.
      brightness: the range of jitter on brightness.
      contrast: the range of jitter on contrast.
      saturation: the range of jitter on saturation.
      hue: the range of jitter on hue.
      color_jitter_prob: the probability of applying color jittering.
      to_grayscale_prob: the probability of converting the image to grayscale.
      apply_prob: the probability of applying the transform to a batch element.
      shuffle: whether to apply the transforms in a random order.

    Returns:
      A NHWC tensor of the transformed images.
    """
    rngs = jax.random.split(rng, images.shape[0])
    jitter_fn = functools.partial(
        _color_transform_single_image,
        brightness=brightness,
        contrast=contrast,
        saturation=saturation,
        hue=hue,
        color_jitter_prob=color_jitter_prob,
        to_grayscale_prob=to_grayscale_prob,
        apply_prob=apply_prob,
        shuffle=shuffle,
    )
    return jax.vmap(jitter_fn)(images, rngs)


def gaussian_blur(
    images, rng, blur_divider=10.0, sigma_min=0.1, sigma_max=2.0, apply_prob=1.0
):
    """Applies gaussian blur to a batch of images.

    Args:
      images: an NHWC tensor, with C=3.
      rng: a single PRNGKey.
      blur_divider: the blurring kernel will have size H / blur_divider.
      sigma_min: the minimum value for sigma in the blurring kernel.
      sigma_max: the maximum value for sigma in the blurring kernel.
      apply_prob: the probability of applying the transform to a batch element.

    Returns:
      A NHWC tensor of the blurred images.
    """
    rngs = jax.random.split(rng, images.shape[0])
    kernel_size = images.shape[1] / blur_divider
    blur_fn = functools.partial(
        _random_gaussian_blur,
        kernel_size=kernel_size,
        padding="SAME",
        sigma_min=sigma_min,
        sigma_max=sigma_max,
        apply_prob=apply_prob,
    )
    return jax.vmap(blur_fn)(images, rngs)


def _solarize_single_image(image, rng, threshold, apply_prob):
    def _apply(image):
        return jnp.where(image < threshold, image, 1.0 - image)

    return _maybe_apply(_apply, image, rng, apply_prob)


def solarize(images, rng, threshold=0.5, apply_prob=1.0):
    """Applies solarization.

    Args:
      images: an NHWC tensor (with C=3).
      rng: a single PRNGKey.
      threshold: the solarization threshold.
      apply_prob: the probability of applying the transform to a batch element.

    Returns:
      A NHWC tensor of the transformed images.
    """
    rngs = jax.random.split(rng, images.shape[0])
    solarize_fn = functools.partial(
        _solarize_single_image, threshold=threshold, apply_prob=apply_prob
    )
    return jax.vmap(solarize_fn)(images, rngs)
