from __future__ import annotations
import torch


def _sample_gumbel(shape, device, eps: float = 1e-9):
    u = torch.rand(shape, device=device)
    return -torch.log(-torch.log(u + eps) + eps)


def gumbel_topk_binary(logits: torch.Tensor, k: int, tau: float = 1.0) -> torch.Tensor:
    assert logits.dim() == 1
    g = _sample_gumbel(logits.size(0), device=logits.device)
    perturbed = logits + g
    topk = torch.topk(perturbed, k=k, dim=0).indices
    y_hard = torch.zeros_like(logits)
    y_hard[topk] = 1.0
    y_soft = torch.softmax(logits / max(tau, 1e-8), dim=0)
    return y_hard + y_soft - y_soft.detach()
