diff core_motif.py
@@ -266,12 +266,11 @@ class SFSPolicy(nn.Module):
         return motif_gs
 
 
-    def gumbel_softmax(self, logits: torch.Tensor, tau: float = 1, hard: bool = False, eps: float = 1e-10, dim: int = -1, \
-                    g_ratio: float = 1e-3) -> torch.Tensor:
+    def gumbel_softmax(self, logits: torch.Tensor, tau: float = 1, hard: bool = False, eps: float = 1e-10, dim: int = -1) -> torch.Tensor:
         gumbels = (
             -torch.empty_like(logits, memory_format=torch.legacy_contiguous_format).exponential_().log()
         )  # ~Gumbel(0,1)
-        gumbels = (logits + gumbels * g_ratio) / tau  # ~Gumbel(logits,tau)
+        gumbels = (logits + gumbels) / tau  # ~Gumbel(logits,tau)
         y_soft = gumbels.softmax(dim)
         
         if hard:
@@ -339,7 +338,7 @@ class SFSPolicy(nn.Module):
             first_stack = []
             first_ac_stack = []
             for i, node_emb_i in enumerate(torch.split(att_emb, att_len, dim=0)):
-                ac_first_hot_i = self.gumbel_softmax(ac_first_prob[i], tau=self.tau, hard=True, dim=0).transpose(0,1)
+                ac_first_hot_i = self.gumbel_softmax(log_ac_first_prob[i], tau=self.tau, hard=True, dim=0).transpose(0,1)
                 ac_first_i = torch.argmax(ac_first_hot_i, dim=-1)
                 first_stack.append(torch.matmul(ac_first_hot_i, node_emb_i))
                 first_ac_stack.append(ac_first_i)
@@ -361,7 +360,7 @@ class SFSPolicy(nn.Module):
                                     for i, log_ac_first_prob_i in enumerate(log_ac_first_prob)], dim=0).contiguous()
             
         else:            
-            ac_first_hot = self.gumbel_softmax(ac_first_prob, tau=self.tau, hard=True, dim=0).transpose(0,1)
+            ac_first_hot = self.gumbel_softmax(log_ac_first_prob, tau=self.tau, hard=True, dim=0).transpose(0,1)
             ac_first = torch.argmax(ac_first_hot, dim=-1)
             emb_first = torch.matmul(ac_first_hot, att_emb)
             ac_first_prob = torch.cat([ac_first_prob, ac_first_prob.new_zeros(
@@ -384,12 +383,12 @@ class SFSPolicy(nn.Module):
         ac_second_prob = F.softmax(logit_second, dim=-1) + 1e-8
         log_ac_second_prob = ac_second_prob.log()
         
-        ac_second_hot = self.gumbel_softmax(ac_second_prob, tau=self.tau, hard=True, g_ratio=1e-3)                                    
+        ac_second_hot = self.gumbel_softmax(log_ac_second_prob, tau=self.tau, hard=True)                                    
         emb_second = torch.matmul(ac_second_hot, cand_graph_emb)
         ac_second = torch.argmax(ac_second_hot, dim=-1)
 
         # Print gumbel otuput
-        ac_second_gumbel = self.gumbel_softmax(ac_second_prob, tau=self.tau, hard=False, g_ratio=1e-3)                                    
+        ac_second_gumbel = self.gumbel_softmax(log_ac_second_prob, tau=self.tau, hard=False)                                    
         
         # ===============================  
         # step 4 : where to add on motif
@@ -431,7 +430,7 @@ class SFSPolicy(nn.Module):
             third_stack = []
             third_ac_stack = []
             for i, node_emb_i in enumerate(torch.split(emb_cat_ac3, ac3_att_len, dim=0)):
-                ac_third_hot_i = self.gumbel_softmax(ac_third_prob[i], tau=self.tau, hard=True, dim=-1)
+                ac_third_hot_i = self.gumbel_softmax(log_ac_third_prob[i], tau=self.tau, hard=True, dim=-1)
                 ac_third_i = torch.argmax(ac_third_hot_i, dim=-1)
                 third_stack.append(torch.matmul(ac_third_hot_i, node_emb_i))
                 third_ac_stack.append(ac_third_i)
@@ -452,7 +451,7 @@ class SFSPolicy(nn.Module):
                                     for i, log_ac_third_prob_i in enumerate(log_ac_third_prob)], dim=0).contiguous()
 
         else:
-            ac_third_hot = self.gumbel_softmax(ac_third_prob, tau=self.tau, hard=True, dim=-1)
+            ac_third_hot = self.gumbel_softmax(log_ac_third_prob, tau=self.tau, hard=True, dim=-1)
             ac_third = torch.argmax(ac_third_hot, dim=-1)
             emb_third = torch.matmul(ac_third_hot, emb_cat_ac3)
             
@@ -522,7 +521,7 @@ class SFSPolicy(nn.Module):
         ac_second_prob = F.softmax(logit_second, dim=-1) + 1e-8
         log_ac_second_prob = ac_second_prob.log()
         
-        ac_second_hot = self.gumbel_softmax(ac_second_prob, tau=self.tau, hard=True, g_ratio=1e-3)                                    
+        ac_second_hot = self.gumbel_softmax(log_ac_second_prob, tau=self.tau, hard=True)                                    
         emb_second = torch.matmul(ac_second_hot, cand_graph_emb)
         ac_second = torch.argmax(ac_second_hot, dim=-1)
 
