import jax
from flax import linen as nn
from functools import partial


class QuickGELUActivation(nn.Module):
    """
    Applies GELU approximation that is fast but somewhat inaccurate. See: https://github.com/hendrycks/GELUs
    """

    @nn.compact
    def __call__(self, input: jax.Array) -> jax.Array:
        return input * nn.sigmoid(1.702 * input)


ACT2FN = {
    "quick_gelu": QuickGELUActivation,
    "gelu_pytorch_tanh": partial(jax.nn.gelu, approximate=True),
    "relu": nn.relu,
    "relu6": nn.relu6,
    "sigmoid": nn.sigmoid,
    "silu": nn.silu,
    "tanh": nn.tanh,
}
