import jax
import jax.numpy as jnp

def reward_fn(logits):
    """
    Reward function for GAIL.
    
    Args:
        logits: Logits from the discriminator.
    
    Returns:
        Reward: Reward for the discriminator.
    """
    return jax.nn.softplus(logits)