import jax
import jax.numpy as jnp

BANDWIDTH = 0.1


def rectangle(x):
    return ((x > -0.5) & (x < 0.5)).astype(x.dtype)


@jax.custom_gradient
def step(x, threshold):
    y = (x > threshold).astype(x.dtype)

    def grad(dy):
        dx = jnp.zeros_like(x)

        dth = -(1.0 / BANDWIDTH) * rectangle((x - threshold) / BANDWIDTH) * dy
        dth = jnp.sum(dth, axis=(0, 1))
        return dx, dth

    return y, grad


@jax.custom_gradient
def jumprelu(x, threshold):
    y = x * (x > threshold)

    def grad(dy):
        dx = (x > threshold) * dy
        dth = -(threshold / BANDWIDTH) * rectangle((x - threshold) / BANDWIDTH) * dy
        dth = jnp.sum(dth, axis=(0, 1))
        return dx, dth

    return y, grad
