import torch


def sample_gumbel(shape, rng, eps=1e-20):
    U = torch.rand(shape, generator=rng)
    return -torch.log(-torch.log(U + eps) + eps)


def gumbel_softmax_sample(logits, temperature, rng):
    gumbel_noise = sample_gumbel(logits.size(), rng).type_as(logits)
    y = logits + gumbel_noise
    return torch.softmax(y / temperature, dim=-1)


def gumbel_softmax(logits, temperature, rng):
    """
    input: [*, n_class]
    return: [*, n_class] an one-hot vector
    """
    y = gumbel_softmax_sample(logits, temperature, rng)
    shape = y.size()
    _, ind = y.max(dim=-1)
    y_hard = torch.zeros_like(y).view(-1, shape[-1]).type_as(logits)
    y_hard.scatter_(1, ind.view(-1, 1), 1)
    y_hard = y_hard.view(*shape)
    return (y_hard - y).detach() + y
