# Lint as: python3
# Copyright 2018 The SPL Authors.
#
# All rights reserved.
#
# This is the code for reproducing results of the paper.

"""Helper code which creates augmentations."""

import functools
from absl import logging
import tensorflow.compat.v2 as tf


def flip_augmentation(image):
  """Augmentation which does random left-right flip of the image."""
  image = tf.image.random_flip_left_right(image)
  return {'image': image}


def noop_augmentation(image):
  """No-op augmentation."""
  return {'image': image}


def _get_augmenter_type_and_args(**kwargs):
  """Extracts augmenter type and args from **kwargs dict."""
  augment_type = kwargs['type'].lower()
  augment_args = {}
  for k, v in kwargs.items():
    if k.startswith(augment_type + '_'):
      augment_args[k[len(augment_type)+1:]] = v
  logging.info('Using augmentation %s with parameters %s',
               augment_type, augment_args)
  return augment_type, augment_args


def create_augmenter(**kwargs):
  """Creates augmenter for supervised task based on hyperparameters dict.

  Args:
    **kwargs: dictionary augment_type and augmenter arguments.

  Returns:
    augmenter_state: class representing augmenter state or None for stateless
      augmnenter
    sup_augmenter: callable which performs augmentation of the data
  """
  augment_type, _ = _get_augmenter_type_and_args(**kwargs)
  if not augment_type or (augment_type == 'none') or (augment_type == 'noop'):
    return None, noop_augmentation
  elif augment_type == 'horizontal_flip':
    return None, flip_augmentation
  else:
    raise ValueError('Invalid augmentation type {0}'.format(augment_type))
