import tensorflow as tf


@tf.custom_gradient
def one_hot_straight_through(x):
    depth = x.shape[-1]
    y = tf.one_hot(
        tf.argmax(
            x,
            axis=-1
        ),
        depth=depth
    )

    def grad(upstream):
        return upstream

    return y, grad
