import torch
from torch.distributions.one_hot_categorical import OneHotCategorical


class GumbelSoftmax(OneHotCategorical):

    def __init__(self, logits, probs=None, temperature=1):
        super(GumbelSoftmax, self).__init__(logits=logits, probs=probs)
        self.eps = 1e-20
        self.temperature = temperature

    def sample_gumbel(self):
        U = self.logits.clone()
        U.uniform_(0, 1)
        return -torch.log(-torch.log(U + self.eps) + self.eps)

    def gumbel_softmax_sample(self):
        """Draw a sample from the Gumbel-Softmax distribution. The returned sample will be a probability distribution
        that sums to 1 across classes"""
        y = self.logits + self.sample_gumbel()
        return torch.softmax(y / self.temperature, dim=-1)
