import jax

def mish(x):
    return x * jax.nn.tanh(jax.nn.softplus(x))


