CONDA_PATH=/mnt/xxxxxxxxxxxxxxxxx/miniconda3
. ${CONDA_PATH}/etc/profile.d/conda.sh && conda init && conda activate pytorch
cd /mnt/xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx


model_sizes=("8m" "2m" "280k")

for model_size in "${model_sizes[@]}"; do
    task_name=dit${model_size}_usptofull_cano
    args=(
        experiment=retro
        data=uspto_full
        data.perm_prob_sampler=null
        data.perm_types=null
        +model/sde=dfm
        model.sde.train.lambda_train_e=20.0
        +model/net=graph_dit/${model_size}
        task_name=${task_name}
        paths.log_dir=\${paths.root_dir}/logs/uspto_full_sota/dit
        trainer.num_nodes=$NODE_COUNT
        trainer.val_check_interval=512
    )

    torchrun \
        --nnodes $NODE_COUNT \
        --node_rank $NODE_RANK \
        --nproc_per_node $PROC_PER_NODE \
        --master_addr $MASTER_ADDR \
        --master_port 6000 \
        -m src.train "${args[@]}" $@

done

