diff freed/args.py
@@ -107,6 +107,9 @@ def parse_args():
     parser.add_argument('--beta_start', type=float, default=0.4)
     parser.add_argument('--beta_frames', type=float, default=100000)
 
+    parser.add_argument('--num_epochs_explore', type=int, default=-1)
+    parser.add_argument('--num_epochs_freeze', type=int, default=-1)
+
     return parser.parse_args()
 
 
diff freed/env/environment.py
@@ -1,7 +1,7 @@
 from rdkit import Chem
 from operator import methodcaller, itemgetter
 from functools import partial
-
+import random
 import numpy as np
 
 from freed.env.state import State
@@ -84,3 +84,9 @@ class Environment(object):
         frag_attachment = frag.GetAtomWithIdx(frag_attachments[a3])
         new_mol = connect_mols(mol, frag, mol_attachment, frag_attachment)
         self.state = State(Chem.MolToSmiles(new_mol), self.num_steps + 1, **self.state_args)
+
+    def sample(self):
+        a1 = random.randint(0, len(self.state.attachment_ids) - 1)
+        a2 = random.randint(0, len(self.fragments) - 1)
+        a3 = random.randint(0, len(self.fragments[a2].attachment_ids) - 1)
+        return (a1, a2, a3)

diff freed/train/sac.py
@@ -33,6 +33,7 @@ class SAC:
                 update_num=256, save_freq=5000, train_alpha=True, max_norm=5.,
                 device='cpu', target_entropy=1.0, 
                 model_dir='.', mols_dir='.', beta_start=0.4, beta_frames=100000,
+                num_epochs_explore=-1, num_epochs_freeze=-1
                 **kwargs
     ):
         super().__init__()
@@ -74,6 +75,8 @@ class SAC:
         self.beta_start = beta_start
         self.beta_frames = beta_frames
 
+        self.num_epochs_explore = num_epochs_explore
+        self.num_epochs_freeze = num_epochs_freeze
         set_requires_grad(self.critic_target.parameters(), False)
 
     def critic_loss(self, data):
@@ -232,17 +235,21 @@ class SAC:
         
         return smiles
     
-    def assemble_molecule(self):
+    def assemble_molecule(self, on_policy=True):
         state = self.env.reset()
         done = False
         cnt = 0
         while not done:
-            action = self.actor(construct_batch([state], device=self.device))
-            next_state, reward, terminated, truncated, info = self.env.step(action.index[0])
+            if on_policy:
+                action = self.actor(construct_batch([state], device=self.device))
+                action = action.index[0]
+            else:
+                action = self.env.sample()
+            next_state, reward, terminated, truncated, info = self.env.step(action)
             done = terminated or truncated
             experience = {
                 'state': state, 'next_state': next_state, 'reward': reward, 'terminated': terminated,
-                'truncated': truncated, 'done': done, 'action': action.index[0]
+                'truncated': truncated, 'done': done, 'action': action
             }
             if self.prioritizer:
                 priority = self.compute_priority(experience, batched=False)
@@ -255,10 +262,10 @@ class SAC:
 
     @log_time
     @torch.no_grad()
-    def collect_experience(self):
+    def collect_experience(self, on_policy=True):
         smiles, steps = list(), 0
         while steps < self.steps_per_epoch:
-            smi, n = self.assemble_molecule()
+            smi, n = self.assemble_molecule(on_policy=on_policy)
             smiles.append(remove_attachments(smi))
             steps += n
 
@@ -276,9 +283,12 @@ class SAC:
         buf.update_buffer(buf.reward, ids, reward)
 
     @log_time
-    def train_epoch(self):
-        rewards_info = self.collect_experience()
-        update_info = self.update()
+    def train_epoch(self, update=True, on_policy=True):
+        rewards_info = self.collect_experience(on_policy=on_policy)
+        if update:
+            update_info = self.update()
+        else:
+            update_info = dict()
         return rewards_info, update_info
     
     def save_model(self):
@@ -307,7 +317,9 @@ class SAC:
         path = os.path.join(self.mols_dir, f'train_{suffix}.csv')
         for epoch in range(self.epoch, self.epochs):
             self.epoch = epoch
-            rewards_info, update_info = self.train_epoch()
+            on_policy = epoch >= self.num_epochs_explore
+            update = epoch >= self.num_epochs_freeze
+            rewards_info, update_info = self.train_epoch(on_policy=on_policy, update=update)
             log_info(path, rewards_info, epoch, additional_info=update_info, writer=self.writer)
             if epoch % self.save_freq == 0:
                 self.save_model()
