diff sac_motif_freed_pe.py
@@ -474,7 +474,26 @@ class sac:
             q1_pi_targ = self.ac_targ.q1(o2_g_emb, ac2_first, ac2_second, ac2_third)
             q2_pi_targ = self.ac_targ.q2(o2_g_emb, ac2_first, ac2_second, ac2_third)
             q_pi_targ = torch.min(q1_pi_targ, q2_pi_targ).squeeze()
-            backup = r + self.gamma * (1 - d) * q_pi_targ
+
+            ac_prob_sp = torch.split(a2_prob, self.action_dims, dim=1)
+            log_ac_prob_sp = torch.split(log_a2_prob, self.action_dims, dim=1)
+            
+            alpha = min(self.log_alpha.exp().item(), self.alpha_max)
+            alpha = max(alpha, self.alpha_min)
+            
+            ac_prob_comb = torch.einsum('by, bz->byz', ac_prob_sp[1], ac_prob_sp[2]).reshape(self.batch_size, -1) # (bs , 73 x 40)
+            ac_prob_comb = torch.einsum('bx, bz->bxz', ac_prob_sp[0], ac_prob_comb).reshape(self.batch_size, -1) # (bs , 40 x 73 x 40)
+            # order by (a1, b1, c1) (a1, b1, c2)! Be advised!
+            
+            log_ac_prob_comb = log_ac_prob_sp[0].reshape(self.batch_size, self.action_dims[0], 1, 1).repeat(
+                                        1, 1, self.action_dims[1], self.action_dims[2]).reshape(self.batch_size, -1)\
+                                + log_ac_prob_sp[1].reshape(self.batch_size, 1, self.action_dims[1], 1).repeat(
+                                        1, self.action_dims[0], 1, self.action_dims[2]).reshape(self.batch_size, -1)\
+                                + log_ac_prob_sp[2].reshape(self.batch_size, 1, 1, self.action_dims[2]).repeat(
+                                        1, self.action_dims[0], self.action_dims[1], 1).reshape(self.batch_size, -1)
+            entropy = -(ac_prob_comb * log_ac_prob_comb).sum(dim=1)
+
+            backup = r + self.gamma * (1 - d) * (q_pi_targ + alpha * entropy)
 
         # MSE loss against Bellman backup
 
