Source code for cleverhans.compat
"""
Wrapper functions for writing code that is compatible with many versions
of TensorFlow.
"""
import warnings
import tensorflow as tf
# The following 2 imports are not used in this module. They are imported so that users of cleverhans.compat can
# get access to device_lib, app, and flags. A pylint bug makes these imports cause errors when using python3+tf1.8.
# Doing the sanitized import here once makes it possible to do "from cleverhans.compat import flags" throughout the
# library without needing to repeat the pylint boilerplate.
from tensorflow.python.client import device_lib # pylint: disable=no-name-in-module,unused-import
from tensorflow.python.platform import app, flags # pylint: disable=no-name-in-module,unused-import
def _wrap(f):
"""
Wraps a callable `f` in a function that warns that the function is deprecated.
"""
def wrapper(*args, **kwargs):
"""
Issues a deprecation warning and passes through the arguments.
"""
warnings.warn(str(f) + " is deprecated. Switch to calling the equivalent function in tensorflow. "
" This function was originally needed as a compatibility layer for old versions of tensorflow, "
" but support for those versions has now been dropped.")
return f(*args, **kwargs)
return wrapper
reduce_sum = _wrap(tf.reduce_sum)
reduce_max = _wrap(tf.reduce_max)
reduce_min = _wrap(tf.reduce_min)
reduce_mean = _wrap(tf.reduce_mean)
reduce_prod = _wrap(tf.reduce_prod)
reduce_any = _wrap(tf.reduce_any)
def reduce_function(op_func, input_tensor, axis=None, keepdims=None,
name=None, reduction_indices=None):
"""
This function used to be needed to support tf 1.4 and early, but support for tf 1.4 and earlier is now dropped.
:param op_func: expects the function to handle eg: tf.reduce_sum.
:param input_tensor: The tensor to reduce. Should have numeric type.
:param axis: The dimensions to reduce. If None (the default),
reduces all dimensions. Must be in the range
[-rank(input_tensor), rank(input_tensor)).
:param keepdims: If true, retains reduced dimensions with length 1.
:param name: A name for the operation (optional).
:param reduction_indices: The old (deprecated) name for axis.
:return: outputs same value as op_func.
"""
warnings.warn("`reduce_function` is deprecated and may be removed on or after 2019-09-08.")
out = op_func(input_tensor, axis=axis, keepdims=keepdims, name=name, reduction_indices=reduction_indices)
return out
[docs]def softmax_cross_entropy_with_logits(sentinel=None,
labels=None,
logits=None,
dim=-1):
"""
Wrapper around tf.nn.softmax_cross_entropy_with_logits_v2 to handle
deprecated warning
"""
# Make sure that all arguments were passed as named arguments.
if sentinel is not None:
name = "softmax_cross_entropy_with_logits"
raise ValueError("Only call `%s` with "
"named arguments (labels=..., logits=..., ...)"
% name)
if labels is None or logits is None:
raise ValueError("Both labels and logits must be provided.")
try:
f = tf.nn.softmax_cross_entropy_with_logits_v2
except AttributeError:
raise RuntimeError("This version of TensorFlow is no longer supported. See cleverhans/README.md")
labels = tf.stop_gradient(labels)
loss = f(labels=labels, logits=logits, dim=dim)
return loss