import pickle
from config import conf
from runner import Runner
import torch
from utils import check_validity

runner = Runner(conf)

node_temp = 0.5
dist_temp = 0.3
angle_temp = 0.4
torsion_temp = 1.0
min_atoms = 2
max_atoms = 35
focus_th = 0.5
num_gen = 10000

epoch = 99
runner.model.load_state_dict(torch.load('new_dfs/model_{}.pth'.format(epoch)))
mol_dicts = runner.generate(num_gen, temperature=[node_temp, dist_temp, angle_temp, torsion_temp], max_atoms=max_atoms, min_atoms=min_atoms, focus_th=focus_th, add_final=True)
results, _, _ = check_validity(mol_dicts)
print(results)

with open('rand_gen/{}_mols.mol_dict'.format(epoch),'wb') as f:
    pickle.dump(mol_dicts, f)