diff freed/main.py
@@ -77,7 +77,7 @@
     mlp_kwargs = {'norm_layer': nn.LayerNorm}
     critic_args = (4 * s, (2 * s, s, s, 1))
     m = N if args['action_mechanism'] == 'pi' else 1
-    actor_args = ((s // 2, (s // 4, 1)), (s // 2, (s // 2, s  // 2, m)), (s // 2, (s // 4, 1)))
+    actor_args = ((s, (s, s, 1)), (s, (s, s, m)), (s, (s, s, 1)))
     actor_kwargs = ({}, {}, {})
 
     set_seed(args['seed'])

diff freed/train/nn/actor.pyactor.py
         if mechanism == 'pi':
             self.select_fragment = self.select_fragment_PI
-            merger_args = ((d, d, d // 2), (d, d, d // 2), (d // 2, d, d // 2))
+            merger_args = ((d, d, d), (d, d, d), (d, d, d))
         elif mechanism == 'sfps':
             self.select_fragment = self.select_fragment_SFPS
             self.fragments_ecfp = torch.FloatTensor(lmap(partial(ecfp, n=ecfp_size), map(attrgetter('smile'), fragments)))
-            merger_args = ((d, d, d // 2), (d // 2, ecfp_size, d // 2), (d // 2, d, d // 2))
+            merger_args = ((d, d, d), (d, ecfp_size, d), (d, d, d))
         else:
             raise ValueError(f"Unknown mechanism '{mechanism}'")
 
@@ -146,16 +146,13 @@
 
     def sample_and_pad(self, size, sections, logits, *options):
         batch_size = len(sections)
-        # options = torch.stack(options, dim=2)
-        # options = self.pad(options.split(sections), size).view(batch_size, size, -1, options.size(2))
-        options = [self.pad(opt.split(sections), size).view(batch_size, size, -1) for opt in options]
+        options = torch.stack(options, dim=2)
+        options = self.pad(options.split(sections), size).view(batch_size, size, -1, options.size(2))
         logits = self.pad(logits.split(sections), size, value=float("-inf")).view(batch_size, size)
         onehot = F.gumbel_softmax(logits, tau=self.tau, hard=True, dim=1)
         index = torch.argmax(onehot, dim=1, keepdim=True)
-        # options = onehot[None, :, None, :] @ options.permute(3, 0, 1, 2)
-        options = [onehot[:, None, :] @ opt for opt in options]
-        # options = [opt for opt in options.squeeze(2)]
-        options = [opt.squeeze(1) for opt in options]
+        options = onehot[None, :, None, :] @ options.permute(3, 0, 1, 2)
+        options = [opt for opt in options.squeeze(2)]
         return index, onehot, logits, options
 
     def pad(self, input, size, value=0):
