import torch
import torch.nn.functional as F


def gumbel_softmax(categorical_probs, hard=False, eps=1e-9):
    logits = categorical_probs.clamp(min=1e-9).log()
    return F.gumbel_softmax(logits, hard=hard)


def sample_categorical(categorical_probs, method="hard"):
    if method == "hard":
        gumbel_norm = (
            1e-10
            - (torch.rand_like(categorical_probs, dtype=torch.float64) + 1e-10).log()
        )
        return (categorical_probs / gumbel_norm).argmax(dim=-1)
    else:
        raise ValueError(
            f"Method {method} for sampling categorical variables is not valid."
        )
