from typing import Callable, Tuple
import tensorflow as tf


@tf.custom_gradient
def apply_gradients_gelu(x: tf.Variable) -> Tuple[tf.Variable, Callable]:
    """
    This function appleis the identity to x but the gradient of gelu
    to the gradients.
    source: https://www.wolframalpha.com/input?i=d%28f%28x+*+0.5+*+%281+%2B+erf%28x%2Fsqrt%282%29%29%29%29%29+%2Fdx
    """

    def grad(dX):
        dGeLU = (
            0.5 * tf.math.erf(x / tf.math.sqrt(2.0))
            + 0.398942 * tf.math.exp(-(x**2.0)) * x
            + 0.5
        )
        return dX * dGeLU

    return x, grad


@tf.custom_gradient
def apply_gradients_softmax(x: tf.Variable) -> Tuple[tf.Variable, Callable]:
    def grad(dX):
        return dX

    return x, grad
