CONDA_PATH=/mnt/xxxxxxxxxxxxxxxxx/miniconda3
. ${CONDA_PATH}/etc/profile.d/conda.sh && conda init && conda activate pytorch
cd /mnt/xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx


sampling_steps=50

task_name=pred_rc_uspto50k_step${sampling_steps}
args=(
    experiment=retro
    data=uspto_50k
    data.perm_types=data/uspto_50k/test_json/pred_rc_uspto50k.json
    +model/sde=dfm
    model.sde.test_sampling_steps=${sampling_steps}
    +model/net=graph_dit/8m
    ckpt_path=your/ckpt/path/here.ckpt
    task_name=${task_name}
    paths.log_dir=\${paths.root_dir}/logs/sampling/uspto_50k_sota_pred
    trainer.num_nodes=$NODE_COUNT
    ~logger
)
torchrun \
    --nnodes $NODE_COUNT \
    --node_rank $NODE_RANK \
    --nproc_per_node $PROC_PER_NODE \
    --master_addr $MASTER_ADDR \
    --master_port 6000 \
    -m src.eval "${args[@]}" $@





