diff freedpp/main.py
@@ -75,7 +75,7 @@ def init_models(args, env, checkpoint=None):
     emb_size = s = args['emb_size']
     N = len(env.fragments)
     mlp_kwargs = {'norm_layer': nn.LayerNorm}
-    critic_args = (4 * s, (2 * s, s, s, 1))
+    critic_args = (s + env.actions_dim[0] + N + env.actions_dim[2], (2 * s, s, s, 1))
     m = N if args['action_mechanism'] == 'pi' else 1
     actor_args = ((s, (s, s, 1)), (s, (s, s, m)), (s, (s, s, 1)))
     actor_kwargs = ({}, {}, {})

diff freedpp/train/nn/action.py
@@ -7,12 +7,13 @@ from freedpp.utils import lmap, lzip
 
 
 class ActionBatch(object):
-    def __init__(self, batch_size, index, onehot, embedding, logits):
+    def __init__(self, batch_size, index, onehot, embedding, logits, critic_input):
         self.batch_size = batch_size
         self.index = index.flatten().tolist()
         self.onehot = onehot
         self.embedding = embedding
         self.logits = logits
+        self.critic_input = critic_input
 
     def entropy(self):
         bs = self.batch_size
@@ -29,6 +30,7 @@ class StepActionBatch(object):
         self.actions = actions
         self._embedding = torch.cat(lmap(attrgetter('embedding'), self.actions), dim=-1)
         self._index = lzip(*map(attrgetter('index'), self.actions))
+        self._critic_input = torch.cat(lmap(attrgetter('critic_input'), self.actions), dim=-1)
 
     def entropy(self):
         return sum(map(methodcaller('entropy'), self.actions))
@@ -40,3 +42,7 @@ class StepActionBatch(object):
     @property
     def index(self):
         return self._index
+
+    @property
+    def critic_input(self):
+        return self._critic_input

diff freedpp/train/nn/actor.py
@@ -101,7 +101,8 @@ class Actor(nn.Module):
         molecule = molecule.readout.repeat_interleave(sections, dim=0)
         logits, mergers = self.molecule_attachment_ranker(molecule, attachments)
         index, onehot, logits, (attachment, merger) = self.sample_and_pad(self.actions_dim[0], sections.tolist(), logits, attachments, mergers)
-        return ActionBatch(batch_size, index, onehot, attachment, logits), merger
+        critic_input = logits.view(batch_size, -1).softmax(dim=1)
+        return ActionBatch(batch_size, index, onehot, attachment, logits, critic_input), merger
 
     def select_fragment_PI(self, condition, *mask_args, **mask_kwargs):
         batch_size = condition.size(0)
@@ -114,7 +115,7 @@ class Actor(nn.Module):
         self.encode_fragments(index)
         fragment = (onehot[:, None, :] @ torch.stack(self.fragments_gcn)[None, :, :]).squeeze(1)
         merger = self.fragment_ranker.merger(condition, fragment)
-        return ActionBatch(batch_size, index, onehot, fragment, logits), merger
+        return ActionBatch(batch_size, index, onehot, fragment, logits, onehot), merger
 
     def select_fragment_SFPS(self, condition, *mask_args, **mask_kwargs):
         batch_size, num_frags = condition.size(0), len(self.fragments)
@@ -130,7 +131,7 @@ class Actor(nn.Module):
         self.encode_fragments(index)
         fragment = (onehot[:, None, :] @ torch.stack(self.fragments_gcn)[None, :, :]).squeeze(1)
         merger = (onehot[:, None, :] @ mergers).squeeze(1)
-        return ActionBatch(batch_size, index, onehot, fragment, logits), merger
+        return ActionBatch(batch_size, index, onehot, fragment, logits, onehot), merger
 
     def select_fragment_attachment(self, condition, fragment_index, *mask_args, **mask_kwargs):
         batch_size = condition.size(0)
@@ -142,7 +143,8 @@ class Actor(nn.Module):
             mask = ~ self.acceptable_sites(*mask_args, **mask_kwargs)
             logits.masked_fill_(mask, float("-inf"))
         index, onehot, logits, (attachment, ) = self.sample_and_pad(self.actions_dim[2], sections.tolist(), logits, attachments)
-        return ActionBatch(batch_size, index, onehot, attachment, logits)
+        critic_input = logits.view(batch_size, -1).softmax(dim=1)
+        return ActionBatch(batch_size, index, onehot, attachment, logits, critic_input)
 
     def sample_and_pad(self, size, sections, logits, *options):
         batch_size = len(sections)

diff --git freedpp/train/nn/critic.py
@@ -57,15 +57,8 @@ class Critic(nn.Module):
         for i, attachments in zip(index, get_attachments(fragments).split(sections)):
             self.fragments_attachments[i] = attachments
 
-    def forward(self, state, action, from_index=False):
+    def forward(self, state, action):
         state = self.encoder(state)
-        if from_index:
-            action = lzip(*action)
-            ac1, ac2, ac3 = lmap(list, action)
-            molecule_attachment = self.get_molecule_attachment(state, ac1)
-            fragment = self.get_fragment(ac2)
-            fragment_attachment = self.get_fragment_attachment(ac2, ac3)
-            action = torch.cat([molecule_attachment, fragment, fragment_attachment], dim=1)
         input = torch.cat([state.readout, action], dim=1)
         values = lmap(methodcaller('forward', input), self.nets)
         return values, reduce(torch.minimum, values, torch.tensor(float("+inf")).to(input.device))

diff freedpp/train/replay_buffer.py
@@ -74,7 +74,7 @@ class ReplayBuffer:
 
         state = construct_batch([self.state[idx] for idx in idxs], device=device)
         next_state = construct_batch([self.next_state[idx] for idx in idxs], device=device)
-        action = [self.action[idx] for idx in idxs]
+        action = torch.stack([self.action[idx] for idx in idxs], dim=0).to(device)
 
         reward = torch.FloatTensor([self.reward[idx] for idx in idxs])[:, None].to(device)
         terminated = torch.FloatTensor([self.terminated[idx] for idx in idxs])[:, None].to(device)

diff freedpp/train/sac.py
@@ -81,12 +81,12 @@ class SAC:
         action, state = data['action'], data['state']
         next_state = data['next_state']
         reward, done = data['reward'], data['done']
-        q_values, _ = self.critic(state, action, from_index=True)
+        q_values, _ = self.critic(state, action)
 
         with torch.no_grad():
             action = self.actor(next_state)
             entropy = action.entropy()
-            _, q_target = self.critic_target(next_state, action.index, from_index=True)
+            _, q_target = self.critic_target(next_state, action.critic_input)
             alpha = self.log_alpha.exp().item()
             target = reward + self.gamma * (1 - done) * (q_target + alpha * entropy)
 
@@ -100,7 +100,7 @@ class SAC:
         weight = data['weight'] if data.get('weight') is not None else 1
         state = data['state']
         action = self.actor(state)
-        _, q_value = self.critic(state, action.embedding)
+        _, q_value = self.critic(state, action.critic_input)
         alpha = self.log_alpha.exp().item()
         entropy = action.entropy()
         loss_policy = - q_value
@@ -193,10 +193,10 @@ class SAC:
         state, action = data['state'], data['action']
         done, reward = data['done'], data['reward']
         if not batched:
-            action = [data['action']]
+            action = data['action'][None, :].to(device)
             state = construct_batch([state], device=device)
         predicted = self.prioritizer(state)
-        _, q_value = self.critic(state, action, from_index=True)
+        _, q_value = self.critic(state, action)
         priority = (1 - done) * (q_value - predicted).abs() + done * (reward - predicted).abs()
         return priority.squeeze(dim=1).tolist()
 
@@ -242,7 +242,7 @@ class SAC:
             done = terminated or truncated
             experience = {
                 'state': state, 'next_state': next_state, 'reward': reward, 'terminated': terminated,
-                'truncated': truncated, 'done': done, 'action': action.index[0]
+                'truncated': truncated, 'done': done, 'action': action.critic_input[0].detach().cpu()
             }
             if self.prioritizer:
                 priority = self.compute_priority(experience, batched=False)
