import jax
import jax.numpy as jnp

def reward_fn(logits):
    """
    Reward function for LOGD (heuristic version of GAIL).
    
    Args:
        logits: Logits from the discriminator.
    
    Returns:
        Reward: Reward for the discriminator.
    """
    return -jax.nn.softplus(-logits)