diff freed/train/replay_buffer.py
@@ -96,6 +96,33 @@ class ReplayBuffer:
 
         return batch
 
+
+    def get_batch(self, idx, device='cpu', batch_size=32):
+        idxs = np.array(list(range(idx * batch_size, (idx + 1) * batch_size))).astype(np.int32)
+
+        state = construct_batch([self.state[idx] for idx in idxs], device=device)
+        next_state = construct_batch([self.next_state[idx] for idx in idxs], device=device)
+        action = [self.action[idx] for idx in idxs]
+
+        reward = torch.FloatTensor([self.reward[idx] for idx in idxs])[:, None].to(device)
+        terminated = torch.FloatTensor([self.terminated[idx] for idx in idxs])[:, None].to(device)
+        truncated = torch.FloatTensor([self.truncated[idx] for idx in idxs])[:, None].to(device)
+        done = torch.FloatTensor([self.done[idx] for idx in idxs])[:, None].to(device)
+
+        batch = {
+            'ids': idxs,
+            'state': state,
+            'next_state': next_state,
+            'reward': reward,
+            'terminated': terminated,
+            'truncated': truncated,
+            'done': done,
+            'action': action
+        }
+        batch['priority'] = torch.FloatTensor([self.priority[idx] for idx in idxs])[:, None].to(device)
+
+        return batch
+
     def save(self, path):
         self._mkdirs(path)
         base = int(math.log(self.max_size), 10)

diff freed/train/sac.py
@@ -6,6 +6,7 @@ from itertools import chain
 import numpy as np
 import os
 import json
+import pickle
 
 import torch
 import torch.nn as nn
@@ -50,6 +51,8 @@ class SAC:
 
         self.env = env
         self.replay_buffer = replay_buffer
+        with open('replay_buffer.pickle', 'rb') as f:
+            self.replay_buffer = pickle.load(f)
         self.device = device
 
         self.gamma = gamma
@@ -167,6 +170,7 @@ class SAC:
         actor_items['alpha_loss'].backward()
         self.alpha_optimizer.step()
 
+    @log_time
     def _update(self, data):
         prioritizer_items = dict()
         if self.prioritizer:
@@ -209,13 +213,16 @@ class SAC:
     @log_time
     def update(self):
         log_items = defaultdict(list)
-        for _ in range(self.update_num):
-            batch = self.replay_buffer.sample_batch(device=self.device, batch_size=self.batch_size)
+        for i in range(self.update_num):
+            # batch = self.replay_buffer.sample_batch(device=self.device, batch_size=self.batch_size)
+            batch = self.replay_buffer.get_batch(i, device=self.device, batch_size=self.batch_size)
             if self.prioritizer:
                 batch['weight'] = self.compute_batch_weight(batch)
+            else:
+                del batch['priority']
             items = self._update(data=batch)
-            for name, value in items.items():
-                log_items[name].append(value.item() if torch.is_tensor(value) else value)
+            # for name, value in items.items():
+            #     log_items[name].append(value.item() if torch.is_tensor(value) else value)
         return log_items
 
     @log_time
@@ -277,7 +284,7 @@ class SAC:
 
     @log_time
     def train_epoch(self):
-        rewards_info = self.collect_experience()
+        # rewards_info = self.collect_experience()
         update_info = self.update()
         return rewards_info, update_info
     
@@ -308,9 +315,9 @@ class SAC:
         for epoch in range(self.epoch, self.epochs):
             self.epoch = epoch
             rewards_info, update_info = self.train_epoch()
-            log_info(path, rewards_info, epoch, additional_info=update_info, writer=self.writer)
-            if epoch % self.save_freq == 0:
-                self.save_model()
+            # log_info(path, rewards_info, epoch, additional_info=update_info, writer=self.writer)
+            # if epoch % self.save_freq == 0:
+            #     self.save_model()
         
         self.save_model()
         self.epoch += 1
\ No newline at end of file

diff freed/train/utils.py
@@ -1,4 +1,5 @@
 import time
+import os
 from copy import deepcopy
 from operator import attrgetter
 from functools import wraps
@@ -19,6 +20,8 @@ def log_time(method):
         res = method(sac, *args, **kwargs)
         t1 = time.time()
         sac.writer.add_scalar(f'time_{method.__name__}', t1 - t0, sac.epoch)
+        with open(os.path.join(sac.model_dir, 'time_update.txt'), 'a') as f:
+            f.write(f'{t1 - t0}\n')
         return res
 
     return wrapper
