diff freedpp/train/sac.py
@@ -184,8 +184,7 @@ class SAC:
     @torch.no_grad()
     def polyak_averaging(self):
         for p, p_targ in zip(self.critic.parameters(), self.critic_target.parameters()):
-            p_targ.data.mul_(self.polyak)
-            p_targ.data.add_((1 - self.polyak) * p.data)
+            p_targ.data = p.data.detach().clone()
 
     @torch.no_grad()
     def compute_priority(self, data, batched=True):
