def ste_project(
    logits: torch.Tensor,
    *,
    mode: str = "SG",  # "SG" | "ST" | "STL"
    tau: float = 1.0,
    dim: int = -1,
) -> torch.Tensor:
    """
    Returns a tensor with hard one-hot in the forward pass and
    a chosen backward surrogate depending on `mode`.

    SG  : y_hard - y_soft.detach() + y_soft    (Gumbel-Softmax)
    ST  : y_hard - y_soft.detach() + y_soft    (Softmax, no noise)
    STL : y_hard - logits.detach() + logits    (identity-Jacobian on logits)
    """
    if mode not in {"SG", "ST", "STL"}:
        raise ValueError(f"Unknown ste mode: {mode}")

    # Forward: hard one-hot via argmax (possibly after noise/softmax prepass)
    if mode in {"SG", "ST"}:
        t = max(tau, 1e-6)
        if mode == "SG":
            # Gumbel noise for exploration
            u = torch.rand_like(logits, dtype=torch.float32).clamp_(1e-20, 1 - 1e-20)
            g = -torch.log(-torch.log(u))
            g = g.to(logits.dtype)
            z = (logits + g) / t
        else:
            z = logits / t

        y_soft = torch.softmax(z, dim=dim)
        idx = y_soft.argmax(dim=dim, keepdim=True)
        y_hard = torch.zeros_like(logits).scatter_(dim, idx, 1.0)
        # STE through softmax
        return (y_hard - y_soft).detach() + y_soft

    else:  # STL
        z = logits
        z_std = (z - z.mean(dim=dim, keepdim=True)) / z.std(
            dim=dim, keepdim=True
        ).clamp_min(1e-3)
        idx = z.argmax(dim=dim, keepdim=True)  # keep forward decision from raw logits
        y_hard = torch.zeros_like(logits).scatter_(dim, idx, 1.0)
        return (y_hard - z_std).detach() + z_std

        # idx = logits.argmax(dim=dim, keepdim=True)
        # y_hard = torch.zeros_like(logits).scatter_(dim, idx, 1.0)
        # # STE with identity Jacobian on logits
        # return (y_hard - logits).detach() + logits
