import jax
import jax.numpy as jnp

def reward_fn(logits):
    """
    Reward function for FAIRL.
    
    Args:
        logits: Logits from the discriminator.
    
    Returns:
        Reward: Reward for the discriminator.
    """
    return -logits*jnp.exp(logits)