import jax
import numpy as np


@jax.jit
def separable_power_law(x, y, sigma=0.5):
    return ((1 + x) ** (-1 / sigma)) * ((1 + y) ** (-1 / sigma)) * (1 - 1/sigma)**2


@jax.jit
def inseparable_power_law(x, y, sigma=0.5):
    return np.power((x + y + 1), -1/sigma - 1) * (1/sigma - 1) / sigma