import torch
import torch.nn.functional as F


# def gumbel_softmax(categorical_probs, hard=False, eps=1e-9):
#     logits = categorical_probs.clamp(min=1e-9).log()
#     return F.gumbel_softmax(logits, hard=hard)


# def sample_categorical(categorical_probs, method="hard"):
#     if method == "hard":
#         gumbel_norm = 1e-10 - (torch.rand_like(categorical_probs) + 1e-10).log()
#         return (categorical_probs / gumbel_norm).argmax(dim=-1)
#     else:
#         raise ValueError(f"Method {method} for sampling categorical variables is not valid.")

def gumbel_softmax(categorical_probs, hard=False, eps=1e-9):
    logits = (categorical_probs + eps).log()
    return F.gumbel_softmax(logits, hard=hard)

def sample_categorical(categorical_probs, method="hard"):
    if method == "hard":
        u = torch.rand_like(categorical_probs)
        gumbel_noise = -torch.log(-torch.log(u + 1e-9) + 1e-9)   # 正宗 Gumbel
        logits = torch.log(categorical_probs + 1e-9) + gumbel_noise
        return logits.argmax(dim=-1)
    elif method == "top1":
        return categorical_probs.argmax(dim=-1)
    else:
        raise ValueError(f"Method {method} for sampling categorical variables is not valid.")
