# coding=utf-8
# Copyright 2021 The Meta-Dataset Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Lint as: python2, python3
"""Module responsible for decoding image/feature examples."""
import gin.tf
import tensorflow.compat.v1 as tf


def read_single_example(example_string):
    """Parses the record string."""
    return tf.parse_single_example(
        example_string,
        features={
            'image': tf.FixedLenFeature([], dtype=tf.string),
            'label': tf.FixedLenFeature([], tf.int64)
        })


def read_example_and_parse_image(example_string):
    """Reads the string and decodes the image."""
    parsed_example = read_single_example(example_string)
    image_decoded = tf.image.decode_image(parsed_example['image'], channels=3)
    image_decoded.set_shape([None, None, 3])
    parsed_example['image'] = image_decoded
    return parsed_example


@gin.configurable
class ImageDecoder(object):
    """Image decoder."""
    out_type = tf.float32

    def __init__(self, image_size=None, data_augmentation=None):
        """Class constructor.
        Args:
          image_size: int, desired image size. The extracted image will be resized
            to `[image_size, image_size]`.
          data_augmentation: A DataAugmentation object with parameters for
            perturbing the images.
        """
        self.image_size = image_size
        self.data_augmentation = data_augmentation

    def __call__(self, example_string):
        """Processes a single example string.
        Extracts and processes the image, and ignores the label. We assume that the
        image has three channels.
        Args:
          example_string: str, an Example protocol buffer.
        Returns:
          image_rescaled: the image, resized to `image_size x image_size` and
          rescaled to [-1, 1]. Note that Gaussian data augmentation may cause values
          to go beyond this range.
        """
        return self.decode_with_label(example_string)[0]

    def decode_with_label(self, example_string):
        """Processes a single example string.
        Extracts and processes the image, and ignores the label. We assume that the
        image has three channels.
        Args:
          example_string: str, an Example protocol buffer.
        Returns:
          image_rescaled: the image, resized to `image_size x image_size` and
            rescaled to [-1, 1]. Note that Gaussian data augmentation may cause
            values to go beyond this range.
          label: tf.int
        """
        ex_decoded = read_example_and_parse_image(example_string)
        image_decoded = ex_decoded['image']
        image_resized = tf.image.resize_images(
            image_decoded, [self.image_size, self.image_size],
            method=tf.image.ResizeMethod.BILINEAR,
            align_corners=True)
        image_resized = tf.cast(image_resized, tf.float32)
        image = 2 * (image_resized / 255.0 - 0.5)  # Rescale to [-1, 1].

        if self.data_augmentation is not None:
            if self.data_augmentation.enable_gaussian_noise:
                image = image + tf.random_normal(
                    tf.shape(image)) * self.data_augmentation.gaussian_noise_std

            if self.data_augmentation.enable_jitter:
                j = self.data_augmentation.jitter_amount
                paddings = tf.constant([[j, j], [j, j], [0, 0]])
                image = tf.pad(image, paddings, 'REFLECT')
                image = tf.image.random_crop(image,
                                             [self.image_size, self.image_size, 3])
        return image, tf.cast(ex_decoded['label'], dtype=tf.int32)


@gin.configurable
class FeatureDecoder(object):
    """Feature decoder."""
    out_type = tf.float32

    def __init__(self, feat_len):
        """Class constructor.
        Args:
          feat_len: The expected length of the feature vectors.
        """

        self.feat_len = feat_len

    def __call__(self, example_string):
        """Processes a single example string.
        Extracts and processes the feature, and ignores the label.
        Args:
          example_string: str, an Example protocol buffer.
        Returns:
          feat: The feature tensor.
        """
        feat = tf.parse_single_example(
            example_string,
            features={
                'image/embedding':
                    tf.FixedLenFeature([self.feat_len], dtype=tf.float32),
                'image/class/label':
                    tf.FixedLenFeature([], tf.int64)
            })['image/embedding']

        return feat


@gin.configurable
class StringDecoder(object):
    """Simple decoder that reads the image without decoding."""
    out_type = tf.string

    def __init__(self):
        """Class constructor."""

    def __call__(self, example_string):
        """Processes a single example string.
        Extracts the image as string, and ignores the label.
        Args:
          example_string: str, an Example protocol buffer.
        Returns:
          img_string: tf.Tensor of type tf.string.
        """
        img_string = read_single_example(example_string)['image']
        return img_string
