import jax, jax.numpy as jnp
def _pairs(key, B, K):
  i = jax.random.randint(key, (B*K,), 0, B); j = jax.random.randint(key, (B*K,), 0, B)
  return i, j + (j==i)
def select_pairs(key, gR, gC, K, eps_r, eps_c, use_hard_mining=True):
  B = gR.shape[0]; i, j = _pairs(key, B, K)
  dR, dC = gR[i]-gR[j], gC[i]-gC[j]
  dom_ij = (dR >= eps_r) & (dC <= -eps_c)
  dom_ji = (dR <= -eps_r) & (dC >= eps_c)
  inc = ((dR>0)&(dC>0)) | ((dR<0)&(dC<0)) | ((jnp.abs(dR)<eps_r)&(jnp.abs(dC)<eps_c))
  if use_hard_mining:
    w = jnp.maximum(jnp.exp(-jnp.abs(jnp.abs(dR)-eps_r)/0.1), jnp.exp(-jnp.abs(jnp.abs(dC)-eps_c)/0.1))
    keep = (w/(w.max()+1e-6))>0.5; i,j,dom_ij,dom_ji,inc = i[keep],j[keep],dom_ij[keep],dom_ji[keep],inc[keep]
  return i, j, dom_ij, dom_ji, inc
