import jax, jax.numpy as jnp
def dominance_hinge(s_hi, s_lo, margin=0.1): return jnp.maximum(0.0, margin - (s_hi - s_lo))
def tie_band(s1, s2, delta=0.05): return jnp.maximum(0.0, jnp.abs(s1 - s2) - delta)
def slope_ratio_regularizer(apply_fn, params, r, c, feat, rho_min=0.5, rho_max=2.0, eps=1e-6):
  df_dr = jax.grad(lambda rr: apply_fn(params, rr, c, feat).sum())(r)
  df_dc = jax.grad(lambda cc: apply_fn(params, r, cc, feat).sum())(c)
  g = df_dr / (-(df_dc) + eps)
  return jnp.maximum(0.0, rho_min - g) + jnp.maximum(0.0, g - rho_max)
