from string import Template

prompt_template = Template('''
## Role: AI Research Assistant (Imitation Learning)

## Overall Objective:
Collaborate to discover novel reward functions for Adversarial Imitation Learning (AIL) that improve **training stability** and **final policy performance**. Performance is measured by a performance score (higher is better).

## Background: Adversarial Imitation Learning Setting
You have a policy \pi and expert transitions (s,a) stored in a dataset D_E.
The typical learning loop involves:
1. Sampling transitions (s,a) into a dataset D_{\pi} using the current policy \pi.
2. Training a discriminator D(s,a) to distinguish between expert transitions (D_E) and policy transitions (D_{\pi}) using a standard binary cross-entropy loss:
   `L = -E_{(s,a) ~ D_E} [log(D(s,a))] - E_{(s,a) ~ D_{\pi}} [log(1 - D(s,a))]`
3. The discriminator's output logits, `l(s,a)`, approximate the log-density ratio: `l(s,a) ≈ log(ρ^E(s,a) / ρ^{\pi}(s,a))`.
4. Policy transitions (s,a) in D_{\pi} are assigned rewards based on these logits using a reward function `r(s,a) = f(l(s,a))`. Examples include:
    *   **GAIL:** `r(s,a) = -log(1 - D(s,a)) = softplus(l(s,a))` (Smooth rectifier: near 0 for negative logits, linear for positive).
    *   **AIRL:** `r(s,a) = log D(s,a) - log(1-D(s,a)) = l(s,a)` (Linear everywhere).
    *   **FAIRL:** `r(s,a) = -l(s,a) * exp(l(s,a))` (Rises from 0 to 1/e at l=-1, then drops sharply).
    *   **LOGD:** `r(s,a) = log D(s,a) = -softplus(-l(s,a))` (Linear for negative logits, near 0 for positive).
5. The policy \pi is updated using reinforcement learning (e.g., PPO, SAC) with these calculated rewards.
6. Steps 1-5 are repeated.

## Your Task in This Interaction:
You will be presented with two reward functions, `f_1` and `f_2` (defined based on logits `l`), along with their observed performance. Your goal is to propose a *new* function (not the same as GAIL, AIRL, FAIRL, LOGD), `f_3`, that aims to perform better (higher score).

**Instructions:**

1.  **Analyze `f_1` and `f_2`:**
    *   Consider their mathematical shapes and properties (e.g., monotonicity, bounds, smoothness).
    *   Consider their behavior when the logits are near zero, positive, and negative. What signal do they provide?
    *   Relate these properties to the provided performance data. Why might one function have performed better/worse?

2.  **Design `f_3 = reward_fn(logits)`:**
    *   Based on your analysis, propose a *new* function `f_3`.
    *   **Aim for diversity:** Propose a mix of novel functions and variations on the provided examples.

3.  **Implementation Requirements:**
    *   **Input:** `logits` (a JAX array).
    *   **Output:** `reward` (a JAX array of the same shape).
    *   **Language:** JAX.
    *   **Function Name:** `reward_fn`.
    *   **Clarity:** Ensure the code is clean, well-commented (if necessary), and easily extractable. Include necessary imports (`jax.numpy as jnp`, `jax.nn` etc.).
    *   **Enclose in Code Block:** Use a code block with the language `python`.
    *   **Jittable:** Ensure the function is jittable by JAX.

## Response Format:

```python
import jax.numpy as jnp
# from jax import nn # Uncomment or add other imports if needed

def reward_fn(logits):
    """
    [Brief description of the function's logic/intent]
    """
    # [Your implementation here]
    reward = ...
    return reward
```

## Pair of Reward Functions:
Function 1:
```python
$fn1
```
Score: $score1

Function 2:
```python
$fn2
```
Score: $score2
'''
)
