import os
from sample_eval_pipeline import create_and_run_sample_eval_pipeline, create_only_eval_pipeline

# This code provides an example of how to use the 

job_directory = os.path.join('outputs/sample_eval_array_job')
output_dir = os.path.join(job_directory, 'out')
project = "" # slurm project
partition = "" # slurm partition

n_samples_per_condition = 100
edge_conditional_set = 'test'

hours_per_sample = 1
minutes_per_sample = 0
hours_per_eval = 1
minutes_per_eval = 0

loss_0_repetition = 1

num_parallel = 64
num_arrays = num_parallel // 8
n_conditions = 78
total_cond_eval = 4992

run_id = "" # wandb run id
epochs = [100] # list of epochs to evaluate
sampling_step_counts = [100]
condition_first_random_func = lambda: 0
random_seed = lambda: 0
assert num_parallel % 8 == 0

inpaint_on_one_reactant = False

create_and_run_sample_eval_pipeline(job_directory, output_dir, project, partition, run_id, epochs,
                                        hours_per_sample, minutes_per_sample, hours_per_eval, minutes_per_eval,
                                        num_arrays, n_conditions, total_cond_eval,
                                        n_samples_per_condition, edge_conditional_set, sampling_step_counts, shuffle=False, random_seed=random_seed,
                                        condition_first_random_func=condition_first_random_func, loss_0_repetition=loss_0_repetition,
                                        inpaint_on_one_reactant=inpaint_on_one_reactant)