diff freedpp/train/nn/actor.py
@@ -109,7 +109,7 @@ class Actor(nn.Module):
         if self.fragmentation == 'brics':
             mask = ~ self.acceptable_fragments(*mask_args, **mask_kwargs)
             logits.masked_fill_(mask, float("-inf"))
-        onehot = F.gumbel_softmax(logits, tau=self.tau, hard=True, dim=1)
+        onehot = self.gumbel_softmax(logits, tau=self.tau, hard=True, dim=1)
         index = torch.argmax(onehot, dim=1)
         self.encode_fragments(index)
         fragment = (onehot[:, None, :] @ torch.stack(self.fragments_gcn)[None, :, :]).squeeze(1)
@@ -125,7 +125,7 @@ class Actor(nn.Module):
         if self.fragmentation == 'brics':
             mask = ~ self.acceptable_fragments(*mask_args, **mask_kwargs)
             logits.masked_fill_(mask, float("-inf"))
-        onehot = F.gumbel_softmax(logits, tau=self.tau, hard=True, dim=1)
+        onehot = self.gumbel_softmax(logits, tau=self.tau, hard=True, dim=1)
         index = torch.argmax(onehot, dim=1)
         self.encode_fragments(index)
         fragment = (onehot[:, None, :] @ torch.stack(self.fragments_gcn)[None, :, :]).squeeze(1)
@@ -149,7 +149,7 @@ class Actor(nn.Module):
         options = torch.stack(options, dim=2)
         options = self.pad(options.split(sections), size).view(batch_size, size, -1, options.size(2))
         logits = self.pad(logits.split(sections), size, value=float("-inf")).view(batch_size, size)
-        onehot = F.gumbel_softmax(logits, tau=self.tau, hard=True, dim=1)
+        onehot = self.gumbel_softmax(logits, tau=self.tau, hard=True, dim=1)
         index = torch.argmax(onehot, dim=1, keepdim=True)
         options = onehot[None, :, None, :] @ options.permute(3, 0, 1, 2)
         options = [opt for opt in options.squeeze(2)]
@@ -189,3 +189,23 @@ class Actor(nn.Module):
         self.fragments_attachments = dict()
         d, N, device = self.emb_size, len(self.fragments), self.fragments_gcn[0].device
         self.fragments_gcn = [torch.zeros(d, device=device) for _ in range(N)]
+
+    def gumbel_softmax(self, logits: torch.Tensor, tau: float = 1, hard: bool = False, eps: float = 1e-10, dim: int = -1, \
+                    g_ratio: float = 1e-3) -> torch.Tensor:
+        gumbels = (
+            -torch.empty_like(logits, memory_format=torch.legacy_contiguous_format).exponential_().log()
+        )
+        tau = tau / g_ratio
+        logits = logits.softmax(dim=dim)
+        gumbels = (logits + gumbels * g_ratio) / tau
+        y_soft = gumbels.softmax(dim)
+        
+        if hard:
+            # Straight through.
+            index = y_soft.max(dim, keepdim=True)[1]
+            y_hard = torch.zeros_like(logits, memory_format=torch.legacy_contiguous_format).scatter_(dim, index, 1.0)
+            ret = y_hard - y_soft.detach() + y_soft
+        else:
+            # Reparametrization trick.
+            ret = y_soft
+        return ret
