diff core_motif.py
@@ -134,8 +134,8 @@ class GCNActorCritic(nn.Module):
         ac_space = env.action_space
         self.env = env
         self.pi = SFSPolicy(ob_space, ac_space, env, args)
-        self.q1 = GCNQFunction(ac_space, args)
-        self.q2 = GCNQFunction(ac_space, args, override_seed=True)
+        self.q1 = GCNQFunction(ac_space, args, env)
+        self.q2 = GCNQFunction(ac_space, args, env, override_seed=True)
 
         # PER based model
         if args.active_learning == 'freed_bu':
@@ -164,7 +164,7 @@ class GCNActorCritic(nn.Module):
         return a
 
 class GCNQFunction(nn.Module):
-    def __init__(self, ac_space, args, override_seed=False):
+    def __init__(self, ac_space, args, env, override_seed=False):
         super().__init__()
         if override_seed:
             seed = args.seed + 1
@@ -176,17 +176,82 @@ class GCNQFunction(nn.Module):
         self.emb_size = args.emb_size
         self.max_action2 = len(ATOM_VOCAB)
         self.max_action_stop = 2
+        self.env = env
 
-        self.d = 2 * args.emb_size + len(FRAG_VOCAB) + 80 
+        self.d = 8 * args.emb_size 
         self.out_dim = 1
         
         self.qpred_layer = nn.Sequential(
                             nn.Linear(self.d, int(self.d//2), bias=False),
                             nn.ReLU(inplace=False),
                             nn.Linear(int(self.d//2), self.out_dim, bias=True))
-    
-    def forward(self, graph_emb, ac_first_prob, ac_second_hot, ac_third_prob):
-        emb_state_action = torch.cat([graph_emb, ac_first_prob, ac_second_hot, ac_third_prob], dim=-1).contiguous()
+        self.cand = self.create_candidate_motifs()
+        self.cand_g = dgl.batch([x['g'] for x in self.cand]).to(self.device)
+        self.cand_ob_len = self.cand_g.batch_num_nodes().tolist()
+        self.max_action = 40
+        self.ac3_att_len = torch.LongTensor([len(x['att']) 
+                                for x in self.cand]).to(self.device)
+        self.ac3_att_mask = torch.cat([torch.LongTensor([i]*len(x['att'])) 
+                                for i, x in enumerate(self.cand)], dim=0).to(self.device)
+
+    def create_candidate_motifs(self):
+        motif_gs = [self.env.get_observation_mol(mol) for mol in FRAG_VOCAB_MOL]
+        return motif_gs
+
+    def reconstruct_action(self, mol, cands, ac1_hot, ac2_hot, ac3_hot):
+        g, node_emb, graph_emb = mol
+        g.ndata['node_emb'] = node_emb
+        cand_g, cand_node_emb, cand_graph_emb = cands 
+
+        ob_len = g.batch_num_nodes().tolist()
+        att_mask = g.ndata['att_mask']
+        
+        att_mask_split = torch.split(att_mask, ob_len, dim=0)
+        att_len = [torch.sum(x, dim=0) for x in att_mask_split]
+
+        print('reconstruct att_len', att_len)
+        cand_att_mask = cand_g.ndata['att_mask']
+        cand_att_mask_split = torch.split(cand_att_mask, self.cand_ob_len, dim=0)
+        cand_att_len = [torch.sum(x, dim=0) for x in cand_att_mask_split]
+
+        att_emb = torch.masked_select(node_emb, att_mask.unsqueeze(-1))
+        att_emb = att_emb.view(-1, 2*self.emb_size)
+
+        embs = []
+        for i, node_emb_i in enumerate(torch.split(att_emb, att_len, dim=0)):
+            emb = torch.cat([node_emb_i, node_emb_i.new_zeros(self.max_action - node_emb_i.size(0), node_emb_i.size(1))], dim=0)
+            embs.append(emb)
+
+        embs = torch.stack(embs)
+        ac1 = torch.einsum("bme,bm->be", embs, ac1_hot)
+
+        ac2 = torch.einsum("ne,bn->be", cands[2], ac2_hot)
+
+        cand_att_emb = torch.masked_select(cand_node_emb, cand_att_mask.unsqueeze(-1))
+        cand_att_emb = cand_att_emb.view(-1, 2*self.emb_size)
+
+        print('reconstruct ac2_hot', ac2_hot.shape)
+        ac_second = torch.argmax(ac2_hot, dim=-1)
+        ac3_att_mask = self.ac3_att_mask.repeat(g.batch_size, 1) # bs x (num cands * num att size)
+        ac3_att_mask = torch.where(ac3_att_mask==ac_second.view(-1,1),
+                            1, 0).view(g.batch_size, -1) # (num_cands * num_nodes)
+        ac3_att_mask = ac3_att_mask.bool()
+
+        ac3_cand_emb = torch.masked_select(cand_att_emb.view(1, -1, 2*self.emb_size), 
+                                ac3_att_mask.view(g.batch_size, -1, 1)).view(-1, 2*self.emb_size)#.view(1, -1, 2*self.emb_size)
+        print('reconstruct ac_second', ac_second.shape)
+        att_len_cand = [self.ac3_att_len[i] for i in ac_second]
+        embs = ac3_cand_emb.split(att_len_cand)
+        embs = [torch.cat([emb, emb.new_zeros(self.max_action - emb.size(0), emb.size(1))], dim=0) for emb in embs]
+        embs = torch.stack(embs)
+        ac3 = torch.einsum("bme,bm->be", embs, ac3_hot)
+
+        return ac1, ac2, ac3
+
+    def forward(self, mol, cands, ac_first_prob, ac_second_hot, ac_third_prob):
+        ac1, ac2, ac3 = self.reconstruct_action(mol, cands, ac_first_prob, ac_second_hot, ac_third_prob)
+        print('q', mol[2].shape, ac1.shape, ac2.shape, ac3.shape)
+        emb_state_action = torch.cat([mol[2], ac1, ac2, ac3], dim=-1).contiguous()
         qpred = self.qpred_layer(emb_state_action)
         return qpred
 
@@ -284,14 +349,14 @@ class SFSPolicy(nn.Module):
             ret = y_soft
         return ret
 
-    def forward(self, graph_emb, node_emb, g, cands, deterministic=False):
+    def forward(self, mol, cands, deterministic=False):
         """
         graph_emb : bs x hidden_dim
         node_emb : (bs x num_nodes) x hidden_dim)
         g: batched graph
         att: indexs of attachment points, list of list
         """
-        
+        g, node_emb, graph_emb = mol
         g.ndata['node_emb'] = node_emb
         cand_g, cand_node_emb, cand_graph_emb = cands 
 
@@ -338,11 +403,13 @@ class SFSPolicy(nn.Module):
         if g.batch_size != 1:  
             first_stack = []
             first_ac_stack = []
+            ac_first_hot_stack = []
             for i, node_emb_i in enumerate(torch.split(att_emb, att_len, dim=0)):
                 ac_first_hot_i = self.gumbel_softmax(ac_first_prob[i], tau=self.tau, hard=True, dim=0).transpose(0,1)
                 ac_first_i = torch.argmax(ac_first_hot_i, dim=-1)
                 first_stack.append(torch.matmul(ac_first_hot_i, node_emb_i))
                 first_ac_stack.append(ac_first_i)
+                ac_first_hot_stack.append(ac_first_hot_i)
 
             emb_first = torch.stack(first_stack, dim=0).squeeze(1)
             ac_first = torch.stack(first_ac_stack, dim=0).squeeze(1)
@@ -359,6 +426,14 @@ class SFSPolicy(nn.Module):
                                         max(self.max_action - log_ac_first_prob_i.size(0),0),1)]
                                             , 0).contiguous().view(1,self.max_action)
                                     for i, log_ac_first_prob_i in enumerate(log_ac_first_prob)], dim=0).contiguous()
+            # TODO: check shapes
+            print('ac_first_hot_i before', [ac_first_hot_i.shape for ac_first_hot_i in ac_first_hot_stack])
+            ac_first_hot = torch.cat([
+                                    torch.cat([ac_first_hot_i.T, ac_first_hot_i.new_zeros(
+                                        max(self.max_action - ac_first_hot_i.size(1),0),1)]
+                                            , 0).contiguous().view(1,self.max_action)
+                                    for i, ac_first_hot_i in enumerate(ac_first_hot_stack)], dim=0).contiguous()
+            print('ac_first_hot_i after', ac_first_hot.shape)
             
         else:            
             ac_first_hot = self.gumbel_softmax(ac_first_prob, tau=self.tau, hard=True, dim=0).transpose(0,1)
@@ -370,6 +445,12 @@ class SFSPolicy(nn.Module):
             log_ac_first_prob = torch.cat([log_ac_first_prob, log_ac_first_prob.new_zeros(
                             max(self.max_action - log_ac_first_prob.size(0),0),1)]
                                 , 0).contiguous().view(1,self.max_action)
+            # TODO: check shapes
+            print('ac_first_hot before', ac_first_hot.shape)
+            ac_first_hot = torch.cat([ac_first_hot.T, ac_first_hot.new_zeros(
+                            max(self.max_action - ac_first_hot.size(1),0),1)]
+                                , 0).contiguous().view(1,self.max_action)
+            print('ac_first_hot after', ac_first_hot.shape)
 
         # =============================== 
         # step 2 : which motif to add - Using Descriptors
@@ -430,13 +511,15 @@ class SFSPolicy(nn.Module):
         if g.batch_size != 1:
             third_stack = []
             third_ac_stack = []
+            ac_third_hot_stack = []
             for i, node_emb_i in enumerate(torch.split(emb_cat_ac3, ac3_att_len, dim=0)):
                 ac_third_hot_i = self.gumbel_softmax(ac_third_prob[i], tau=self.tau, hard=True, dim=-1)
                 ac_third_i = torch.argmax(ac_third_hot_i, dim=-1)
                 third_stack.append(torch.matmul(ac_third_hot_i, node_emb_i))
                 third_ac_stack.append(ac_third_i)
+                ac_third_hot_stack.append(ac_third_hot_i)
 
-                del ac_third_hot_i
+                # del ac_third_hot_i
             emb_third = torch.stack(third_stack, dim=0).squeeze(1)
             ac_third = torch.stack(third_ac_stack, dim=0)
             ac_third_prob = torch.cat([
@@ -450,6 +533,14 @@ class SFSPolicy(nn.Module):
                                         self.max_action - log_ac_third_prob_i.size(0))]
                                             , 0).contiguous().view(1,self.max_action)
                                     for i, log_ac_third_prob_i in enumerate(log_ac_third_prob)], dim=0).contiguous()
+            # TODO: check shapes
+            print('ac_third_hot_i before', [ac_third_hot_i.shape for ac_third_hot_i in ac_third_hot_stack])
+            ac_third_hot = torch.cat([
+                                    torch.cat([ac_third_hot_i, ac_third_hot_i.new_zeros(
+                                        self.max_action - ac_third_hot_i.size(0))]
+                                            , 0).contiguous().view(1,self.max_action)
+                                    for i, ac_third_hot_i in enumerate(ac_third_hot_stack)], dim=0).contiguous()
+            print('ac_third_hot_i after', ac_third_hot.shape)
 
         else:
             ac_third_hot = self.gumbel_softmax(ac_third_prob, tau=self.tau, hard=True, dim=-1)
@@ -462,6 +553,12 @@ class SFSPolicy(nn.Module):
             log_ac_third_prob = torch.cat([log_ac_third_prob, log_ac_third_prob.new_zeros(
                                         1, self.max_action - log_ac_third_prob.size(1))]
                                 , -1).contiguous()
+            # TODO: check shapes
+            print('ac_third_hot before', ac_third_hot.shape)
+            ac_third_hot = torch.cat([ac_third_hot, ac_third_hot.new_zeros(
+                                        1, self.max_action - ac_third_hot.size(1))]
+                                , -1).contiguous()
+            print('ac_third_hot after', ac_third_hot.shape)
 
         # ==== concat everything ====
 
@@ -469,107 +566,16 @@ class SFSPolicy(nn.Module):
         log_ac_prob = torch.cat([log_ac_first_prob, 
                             log_ac_second_prob, log_ac_third_prob], dim=1).contiguous()
         ac = torch.stack([ac_first, ac_second, ac_third], dim=1)
-
-        return ac, (ac_prob, log_ac_prob), (ac_first_prob, ac_second_hot, ac_third_prob)
+        print('forward', ac_first_hot.shape, ac_second_hot.shape, ac_third_hot.shape)
+        return ac, (ac_prob, log_ac_prob), (ac_first_hot, ac_second_hot, ac_third_hot)
     
-    def sample(self, ac, graph_emb, node_emb, g, cands):
-        g.ndata['node_emb'] = node_emb
-        cand_g, cand_node_emb, cand_graph_emb = cands 
-
-        # Only acquire node embeddings with attatchment points
-        ob_len = g.batch_num_nodes().tolist()
-        att_mask = g.ndata['att_mask'] # used to select att embs from node embs
-        att_len = torch.sum(att_mask, dim=-1) # used to torch.split for att embs
-
-        cand_att_mask = cand_g.ndata['att_mask']
-        cand_att_mask_split = torch.split(cand_att_mask, self.cand_ob_len, dim=0)
-        cand_att_len = [torch.sum(x, dim=0) for x in cand_att_mask_split]
-
-        # =============================== 
-        # step 1 : where to add
-        # =============================== 
-        # select only nodes with attachment points
-        att_emb = torch.masked_select(node_emb, att_mask.unsqueeze(-1))
-        att_emb = att_emb.view(-1, 2*self.emb_size)
-        graph_expand = graph_emb.repeat(att_len, 1)
-        
-        att_emb = self.action1_layers[0](att_emb, graph_expand) + self.action1_layers[1](att_emb) \
-                    + self.action1_layers[2](graph_expand)
-        logits_first = self.action1_layers[3](att_emb).transpose(1,0)
-            
-        ac_first_prob = torch.softmax(logits_first, dim=-1) + 1e-8
-        
-        log_ac_first_prob = ac_first_prob.log()
-        ac_first_prob = torch.cat([ac_first_prob, ac_first_prob.new_zeros(1,
-                        max(self.max_action - ac_first_prob.size(1),0))]
-                            , 1).contiguous()
-        
-        log_ac_first_prob = torch.cat([log_ac_first_prob, log_ac_first_prob.new_zeros(1,
-                        max(self.max_action - log_ac_first_prob.size(1),0))]
-                            , 1).contiguous()
-        emb_first = att_emb[ac[0]].unsqueeze(0)
-        
-        # =============================== 
-        # step 2 : which motif to add     
-        # ===============================   
-        emb_first_expand = emb_first.repeat(1, self.motif_type_num, 1)
-        cand_expand = self.cand_desc.unsqueeze(0).repeat(g.batch_size, 1, 1)     
-        
-        emb_cat = self.action2_layers[0](cand_expand, emb_first_expand) + \
-                    self.action2_layers[1](cand_expand) + self.action2_layers[2](emb_first_expand)
-
-        logit_second = self.action2_layers[3](emb_cat).squeeze(-1)
-        ac_second_prob = F.softmax(logit_second, dim=-1) + 1e-8
-        log_ac_second_prob = ac_second_prob.log()
-        
-        ac_second_hot = self.gumbel_softmax(ac_second_prob, tau=self.tau, hard=True, g_ratio=1e-3)                                    
-        emb_second = torch.matmul(ac_second_hot, cand_graph_emb)
-        ac_second = torch.argmax(ac_second_hot, dim=-1)
-
-        # ===============================  
-        # step 3 : where to add on motif
-        # ===============================
-        # Select att points from candidates
-        
-        cand_att_emb = torch.masked_select(cand_node_emb, cand_att_mask.unsqueeze(-1))
-        cand_att_emb = cand_att_emb.view(-1, 2*self.emb_size)
-
-        ac3_att_mask = self.ac3_att_mask.repeat(g.batch_size, 1) # bs x (num cands * num att size)
-        # torch where currently does not support cpu ops    
-        
-        ac3_att_mask = torch.where(ac3_att_mask==ac[1], 
-                            1, 0).view(g.batch_size, -1) # (num_cands * num_nodes)
-        ac3_att_mask = ac3_att_mask.bool()
-
-        ac3_cand_emb = torch.masked_select(cand_att_emb.view(1, -1, 2*self.emb_size), 
-                                ac3_att_mask.view(g.batch_size, -1, 1)).view(-1, 2*self.emb_size)
-        
-        ac3_att_len = self.ac3_att_len[ac[1]]
-        emb_second_expand = emb_second.repeat(ac3_att_len,1)
-        emb_cat_ac3 = self.action3_layers[0](emb_second_expand, ac3_cand_emb) + self.action3_layers[1](emb_second_expand) \
-                  + self.action3_layers[2](ac3_cand_emb)
-
-        logits_third = self.action3_layers[3](emb_cat_ac3)
-        logits_third = logits_third.transpose(1,0)
-        ac_third_prob = torch.softmax(logits_third, dim=-1) + 1e-8
-        log_ac_third_prob = ac_third_prob.log()
-
-        # gumbel softmax sampling and zero-padding
-        emb_third = emb_cat_ac3[ac[2]].unsqueeze(0)
-        ac_third_prob = torch.cat([ac_third_prob, ac_third_prob.new_zeros(
-                                        1, self.max_action - ac_third_prob.size(1))] 
-                                , -1).contiguous()
-        log_ac_third_prob = torch.cat([log_ac_third_prob, log_ac_third_prob.new_zeros(
-                                        1, self.max_action - log_ac_third_prob.size(1))]
-                                , -1).contiguous()
-
-        # ==== concat everything ====
-        ac_prob = torch.cat([ac_first_prob, ac_second_prob, ac_third_prob], dim=1).contiguous()
-        log_ac_prob = torch.cat([log_ac_first_prob, 
-                            log_ac_second_prob, log_ac_third_prob], dim=1).contiguous()
-
-        return (ac_prob, log_ac_prob), (ac_first_prob, ac_second_hot, ac_third_prob)
-        
+    def sample(self, ac):
+        ac1, ac2, ac3 = ac
+        ac1 = F.one_hot(torch.tensor(ac1), num_classes=self.max_action)[None, :].to(self.device).float()
+        ac2 = F.one_hot(torch.tensor(ac2), num_classes=self.cand_desc.size(0))[None, :].to(self.device).float()
+        ac3 = F.one_hot(torch.tensor(ac3), num_classes=self.max_action)[None, :].to(self.device).float()
+        print('sample', ac1.shape, ac2.shape, ac3.shape)
+        return ac1, ac2, ac3
 
 class GCNEmbed(nn.Module):
     def __init__(self, args):

diff sac_motif_freed_pe.py
@@ -46,12 +46,8 @@ class ReplayBuffer:
     def __init__(self, obs_dim, act_dim, size, load=False, checkpoint=None):
         self.obs_buf = [] # o
         self.obs2_buf = [] # o2
-        self.act_buf = np.zeros((size, 3), dtype=np.int32) # ac
         self.rew_buf = np.zeros(size, dtype=np.float32) # r
         self.done_buf = np.zeros(size, dtype=np.float32) # d
-        
-        self.ac_prob_buf = []
-        self.log_ac_prob_buf = []
 
         self.ac_first_buf = []
         self.ac_second_buf = []
@@ -74,16 +70,13 @@ class ReplayBuffer:
         if load:
             self.load(checkpoint)
 
-    def store(self, obs, act, rew, next_obs, done, ac_prob, log_ac_prob, \
+    def store(self, obs, rew, next_obs, done, \
                 ac_first_prob, ac_second_hot, ac_third_prob, \
                 o_embeds, sampling_score):
         if self.size == self.max_size:
             self.obs_buf.pop(0)
             self.obs2_buf.pop(0)
             
-            self.ac_prob_buf.pop(0)
-            self.log_ac_prob_buf.pop(0)
-            
             self.ac_first_buf.pop(0)
             self.ac_second_buf.pop(0)
             self.ac_third_buf.pop(0)
@@ -93,15 +86,11 @@ class ReplayBuffer:
         self.obs_buf.append(obs)
         self.obs2_buf.append(next_obs)
         
-        self.ac_prob_buf.append(ac_prob)
-        self.log_ac_prob_buf.append(log_ac_prob)
-        
         self.ac_first_buf.append(ac_first_prob)
         self.ac_second_buf.append(ac_second_hot)
         self.ac_third_buf.append(ac_third_prob)
         self.o_embeds_buf.append(o_embeds)
         
-        self.act_buf[self.ptr] = act
         self.rew_buf[self.ptr] = rew
         self.done_buf[self.ptr] = done
         self.sampling_buf[self.ptr] = sampling_score
@@ -132,16 +121,12 @@ class ReplayBuffer:
             self.sampling_buf[done_location_np] = intr_rew
             self.done_location = []
 
-        self.act_buf = np.delete(self.act_buf, zero_ptrs, axis=0)
         self.rew_buf = np.delete(self.rew_buf, zero_ptrs)
         self.done_buf = np.delete(self.done_buf, zero_ptrs)
         self.sampling_buf = np.delete(self.sampling_buf, zero_ptrs)
 
         delete_multiple_element(self.obs_buf, zero_ptrs.tolist())
         delete_multiple_element(self.obs2_buf, zero_ptrs.tolist())
-
-        delete_multiple_element(self.ac_prob_buf, zero_ptrs.tolist())
-        delete_multiple_element(self.log_ac_prob_buf, zero_ptrs.tolist())
         
         delete_multiple_element(self.ac_first_buf, zero_ptrs.tolist())
         delete_multiple_element(self.ac_second_buf, zero_ptrs.tolist())
@@ -177,26 +162,19 @@ class ReplayBuffer:
 
         obs_batch = [self.obs_buf[idx] for idx in idxs]
         obs2_batch = [self.obs2_buf[idx] for idx in idxs]
-
-        ac_prob_batch = [self.ac_prob_buf[idx] for idx in idxs]
-        log_ac_prob_batch = [self.log_ac_prob_buf[idx] for idx in idxs]
         
-        ac_first_batch = torch.stack([self.ac_first_buf[idx] for idx in idxs]).squeeze(1)
-        ac_second_batch = torch.stack([self.ac_second_buf[idx] for idx in idxs]).squeeze(1)
-        ac_third_batch = torch.stack([self.ac_third_buf[idx] for idx in idxs]).squeeze(1)
-        o_g_emb_batch = torch.stack([self.o_embeds_buf[idx][2] for idx in idxs]).squeeze(1)
+        ac_first_batch = torch.stack([self.ac_first_buf[idx] for idx in idxs]).squeeze(1).to(device)
+        ac_second_batch = torch.stack([self.ac_second_buf[idx] for idx in idxs]).squeeze(1).to(device)
+        ac_third_batch = torch.stack([self.ac_third_buf[idx] for idx in idxs]).squeeze(1).to(device)
+        o_g_emb_batch = torch.stack([self.o_embeds_buf[idx][2] for idx in idxs]).squeeze(1).to(device)
 
-        act_batch = torch.as_tensor(self.act_buf[idxs], dtype=torch.float32).unsqueeze(-1).to(device)
         rew_batch = torch.as_tensor(self.rew_buf[idxs], dtype=torch.float32).to(device)
         done_batch = torch.as_tensor(self.done_buf[idxs], dtype=torch.float32).to(device)
 
         batch = dict(obs=obs_batch,
                      obs2=obs2_batch,
-                     act=act_batch,
                      rew=rew_batch,
                      done=done_batch,
-                     ac_prob=ac_prob_batch,
-                     log_ac_prob=log_ac_prob_batch,
                      
                      ac_first=ac_first_batch,
                      ac_second=ac_second_batch,
@@ -237,12 +215,9 @@ class ReplayBuffer:
                 self._save_json(items, os.path.join(path, buf_name), key)
         np.save(os.path.join(path, 'rew.npy'), self.rew_buf)
         np.save(os.path.join(path, 'done.npy'), self.done_buf)
-        np.save(os.path.join(path, 'ac.npy'), self.act_buf)
         np.save(os.path.join(path, 'weights.npy'), self.sampling_buf)
         self._save_json({'ptr': self.ptr, 'size': self.size}, path, 'size_and_ptr')
         self._save_json(self.done_location, path, 'done_location')
-        torch.save({'ac_prob': torch.cat(self.ac_prob_buf)}, os.path.join(path, 'ac_prob.pth'))
-        torch.save({'log_ac_prob': torch.cat(self.log_ac_prob_buf)}, os.path.join(path, 'log_ac_prob.pth'))
         torch.save({'ac_first': torch.cat(self.ac_first_buf)}, os.path.join(path, 'ac_first.pth'))
         torch.save({'ac_second': torch.cat(self.ac_second_buf)}, os.path.join(path, 'ac_second.pth'))
         torch.save({'ac_third': torch.cat(self.ac_third_buf)}, os.path.join(path, 'ac_third.pth'))
@@ -268,13 +243,10 @@ class ReplayBuffer:
             setattr(self, buf_name, [dict(zip(k, v)) for v in zip(*vs)])
         self.rew_buf = np.load(os.path.join(path, 'rew.npy'))
         self.done_buf = np.load(os.path.join(path, 'done.npy'))
-        self.act_buf = np.load(os.path.join(path, 'ac.npy'))
         self.sampling_buf = np.load(os.path.join(path, 'weights.npy'))
         for k, v in self._load_json(path, 'size_and_ptr').items():
             setattr(self, k, v)
         self.done_location = self._load_json(path, 'done_location')
-        self.ac_prob_buf = self._load_pth(path, 'ac_prob', self.size)
-        self.log_ac_prob_buf = self._load_pth(path, 'log_ac_prob', self.size)
         self.ac_first_buf = self._load_pth(path, 'ac_first', self.size)
         self.ac_second_buf = self._load_pth(path, 'ac_second', self.size)
         self.ac_third_buf = self._load_pth(path, 'ac_third', self.size)
@@ -458,21 +430,21 @@ class sac:
         self.ac.q2.train()
         o = data['obs']
 
-        _, _, o_g_emb = self.ac.embed(o)
-        q1 = self.ac.q1(o_g_emb, ac_first, ac_second, ac_third).squeeze()
-        q2 = self.ac.q2(o_g_emb.detach(), ac_first, ac_second, ac_third).squeeze()
+        o_emb = self.ac.embed(o)
+        cands = self.ac.embed(self.ac.pi.cand)
+        q1 = self.ac.q1(o_emb, cands, ac_first, ac_second, ac_third).squeeze()
+        q2 = self.ac.q2(o_emb, cands, ac_first, ac_second, ac_third).squeeze()
 
         # Target actions come from *current* policy
         o2 = data['obs2']
         r, d = data['rew'], data['done']
 
         with torch.no_grad():
-            o2_g, o2_n_emb, o2_g_emb = self.ac.embed(o2)
-            cands = self.ac.embed(self.ac.pi.cand)
-            a2, (a2_prob, log_a2_prob), (ac2_first, ac2_second, ac2_third) = self.ac.pi(o2_g_emb, o2_n_emb, o2_g, cands)
+            o2_emb = self.ac.embed(o2)
+            a2, (a2_prob, log_a2_prob), (ac2_first, ac2_second, ac2_third) = self.ac.pi(o2_emb, cands)
             # Target Q-values
-            q1_pi_targ = self.ac_targ.q1(o2_g_emb, ac2_first, ac2_second, ac2_third)
-            q2_pi_targ = self.ac_targ.q2(o2_g_emb, ac2_first, ac2_second, ac2_third)
+            q1_pi_targ = self.ac_targ.q1(o2_emb, cands, ac2_first, ac2_second, ac2_third)
+            q2_pi_targ = self.ac_targ.q2(o2_emb, cands, ac2_first, ac2_second, ac2_third)
             q_pi_targ = torch.min(q1_pi_targ, q2_pi_targ).squeeze()
             backup = r + self.gamma * (1 - d) * q_pi_targ
 
@@ -491,14 +463,13 @@ class sac:
 
         with torch.no_grad():
             o_embeds = self.ac.embed(data['obs'])   
-            o_g, o_n_emb, o_g_emb = o_embeds
             cands = self.ac.embed(self.ac.pi.cand)
 
         _, (ac_prob, log_ac_prob), (ac_first, ac_second, ac_third) = \
-            self.ac.pi(o_g_emb, o_n_emb, o_g, cands)
+            self.ac.pi(o_embeds, cands)
 
-        q1_pi = self.ac.q1(o_g_emb, ac_first, ac_second, ac_third)
-        q2_pi = self.ac.q2(o_g_emb, ac_first, ac_second, ac_third)
+        q1_pi = self.ac.q1(o_embeds, cands, ac_first, ac_second, ac_third)
+        q2_pi = self.ac.q2(o_embeds, cands, ac_first, ac_second, ac_third)
         q_pi = torch.min(q1_pi, q2_pi)
 
         ac_prob_sp = torch.split(ac_prob, self.action_dims, dim=1)
@@ -644,16 +615,15 @@ class sac:
             with torch.no_grad():
                 cands = self.ac.embed(self.ac.pi.cand)
                 o_embeds = self.ac.embed([o])
-                o_g, o_n_emb, o_g_emb = o_embeds
 
                 if t > self.start_steps:
                     ac, (ac_prob, log_ac_prob), (ac_first, ac_second, ac_third) = \
-                    self.ac.pi(o_g_emb, o_n_emb, o_g, cands)
+                    self.ac.pi(o_embeds, cands)
                     print(ac, ' pi')
                 else:
-                    ac = self.env.sample_motif()[np.newaxis]
-                    (ac_prob, log_ac_prob), (ac_first, ac_second, ac_third) = \
-                    self.ac.pi.sample(ac[0], o_g_emb, o_n_emb, o_g, cands)
+                    ac = self.env.sample_motif()
+                    ac_first, ac_second, ac_third = self.ac.pi.sample(ac)
+                    ac = [ac]
                     print(ac, 'sample')
 
             # Step the env
@@ -661,7 +631,7 @@ class sac:
 
             if d and self.active_learning is not None:
                 ob_list.append(o)
-                o_embed_list.append(o_g_emb)
+            #     o_embed_list.append(o_g_emb)
 
             r_d = info['stop']
 
@@ -672,17 +642,18 @@ class sac:
             if any(o2['att']):
                 # # Acquire sampling scores
                 with torch.no_grad():
-                    q_pred = min(self.ac.q1(o_g_emb, ac_first, ac_second, ac_third),\
-                                self.ac.q2(o_g_emb, ac_first, ac_second, ac_third))
+                    q_pred = min(self.ac.q1(o_embeds, cands, ac_first, ac_second, ac_third),\
+                                self.ac.q2(o_embeds, cands, ac_first, ac_second, ac_third))
                     intr_rew = self.compute_intr_rew([o], q_pred)
                 
+                ac_first, ac_second, ac_third = ac_first.detach().cpu(), ac_second.detach().cpu(), ac_third.detach().cpu()
                 if type(ac) == np.ndarray:
-                    self.replay_buffer.store(o, ac, r, o2, r_d, 
-                                            ac_prob, log_ac_prob, ac_first, ac_second, ac_third,
+                    self.replay_buffer.store(o, r, o2, r_d, 
+                                            ac_first, ac_second, ac_third,
                                             o_embeds, intr_rew)
                 else:    
-                    self.replay_buffer.store(o, ac.detach().cpu().numpy(), r, o2, r_d, 
-                                            ac_prob, log_ac_prob, ac_first, ac_second, ac_third,
+                    self.replay_buffer.store(o, r, o2, r_d, 
+                                            ac_first, ac_second, ac_third,
                                             o_embeds, intr_rew)
 
             # Super critical, easy to overlook step: make sure to update 
@@ -763,10 +734,11 @@ class sac:
                     
                     # update uncertainty loss 
                     with torch.no_grad():
-                        _, _, o_g_pred = self.ac.embed(batch['obs'])
+                        o_g_emb = self.ac.embed(batch['obs'])
+                        cands = self.ac.embed(self.ac.pi.cand)
                         q_pred = torch.min(torch.stack(
-                                    [self.ac.q1(o_g_pred, batch['ac_first'], batch['ac_second'], batch['ac_third']),
-                                    self.ac.q2(o_g_pred, batch['ac_first'], batch['ac_second'], batch['ac_third'])], dim=0)
+                                    [self.ac.q1(o_g_emb, cands, batch['ac_first'], batch['ac_second'], batch['ac_third']),
+                                    self.ac.q2(o_g_emb, cands, batch['ac_first'], batch['ac_second'], batch['ac_third'])], dim=0)
                                     , dim=0)[0].squeeze()
                         priorities = self.compute_intr_rew(batch['obs'], q_pred) * \
                                         (-batch['done'].float().cpu().numpy()+1) + \
@@ -815,9 +787,8 @@ class sac:
             o, d = self.env.reset(), False
             while not d:
                 o_embeds = self.ac.embed([o])
-                o_g, o_n_emb, o_g_emb = o_embeds
                 ac, (ac_prob, log_ac_prob), (ac_first, ac_second, ac_third) = \
-                    self.ac.pi(o_g_emb, o_n_emb, o_g, cands)
+                    self.ac.pi(o_embeds, cands)
                 o2, r, d, info = self.env.step(ac)
                 o = o2
 
