CONDA_PATH=/mnt/xxxxxxxxxxxxxxxxx/miniconda3
. ${CONDA_PATH}/etc/profile.d/conda.sh && conda init && conda activate pytorch
cd /mnt/xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx



all_steps=(100 50 20 10)

for sampling_steps in "${all_steps[@]}"; do
    task_name=dit2m_uspto50k_rc_all_pred
    args=(
        experiment=retro
        data=uspto_50k
        data.perm_types=data/uspto_50k/test_json/pred_rc_uspto50k.json
        +model/sde=dfm
        model.sde.test_sampling_steps=${sampling_steps}
        +model/net=graph_dit/2m
        ckpt_path=your/ckpt/path/here.ckpt
        task_name=${task_name}_step${sampling_steps}
        paths.log_dir=\${paths.root_dir}/logs/sampling/uspto_50k_sota_pred/grid_sampling_steps
        trainer.num_nodes=$NODE_COUNT
        ~logger
    )
    torchrun \
        --nnodes $NODE_COUNT \
        --node_rank $NODE_RANK \
        --nproc_per_node $PROC_PER_NODE \
        --master_addr $MASTER_ADDR \
        --master_port 6000 \
        -m src.eval "${args[@]}" $@

done









