import jax.numpy as jnp


def ppo_loss(net, params, traj, adv_norm, targets, clip_eps, ent_coef, vf_coef):
    pi, v = net.apply(params, traj.obs)
    logp = pi.log_prob(traj.action)
    ratio = jnp.exp(logp - traj.log_prob)
    pg1 = ratio * adv_norm
    pg2 = jnp.clip(ratio, 1.0 - clip_eps, 1.0 + clip_eps) * adv_norm
    pg_loss = -jnp.mean(jnp.minimum(pg1, pg2))

    v_clip = traj.value + (v - traj.value).clip(-clip_eps, clip_eps)
    v_loss = 0.5 * jnp.mean(jnp.maximum((v - targets) ** 2, (v_clip - targets) ** 2))

    ent = jnp.mean(pi.entropy())
    total = pg_loss + vf_coef * v_loss - ent_coef * ent

    approx_kl = jnp.mean(traj.log_prob - logp)
    clipfrac = jnp.mean((jnp.abs(ratio - 1.0) > clip_eps).astype(jnp.float32))

    aux = {
        "train/policy_loss": pg_loss,
        "train/value_loss": v_loss,
        "train/entropy": ent,
        "train/approx_kl": approx_kl,
        "train/clipfrac": clipfrac,
    }
    return total, aux