from steering import SourceTemperingSampler
# from rewards import pairwise_reward


#experiment is running in tremendous-ocelot


import datetime
import torch

beta = 5.0 
batch_size = 1
n_chains = 10 #was 10

# num_residuals = 125 # 8ok3 ran for 7 steps
num_residuals = 127 # 7r5b 
# num_residuals = 160 # 7pzt ran for 13 steps
device = "cuda"
num_spt_steps = 30

print(f'Generating 8ok3_gen_3.cif')

sampler = SourceTemperingSampler(
    beta=beta,
    n_chains=n_chains,
    num_residuals=num_residuals,
    device=device,
)

torch.cuda.synchronize()
start = datetime.datetime.now()

last_chain, ideal_protein = sampler.sample(n_iterations=num_spt_steps)

torch.cuda.synchronize()
end = datetime.datetime.now()

print(f"Total Time to Run SPT: {end - start}")
print(f'Beta Value:{beta}')
print(f'Num Chains: {n_chains}')

ideal_protein.to_CIF(f"reference_proteins/7r5b_gen_2_cp.cif")



