CONDA_PATH=/mnt/xxxxxxxxxxxxxxxxx/miniconda3
. ${CONDA_PATH}/etc/profile.d/conda.sh && conda init && conda activate pytorch
cd /mnt/xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx


model_sizes=("65m" "8m" "2m" "280k")

for model_size in "${model_sizes[@]}"; do
    task_name=dit${model_size}_uspto50k_cano
    args=(
        experiment=retro
        data=uspto_50k
        data.perm_prob_sampler=null
        data.perm_types=null
        +model/sde=dfm
        +model/net=graph_dit/${model_size}
        task_name=${task_name}
        paths.log_dir=\${paths.root_dir}/logs/grid_model_size/dit
        trainer.num_nodes=$NODE_COUNT
    )

    torchrun \
        --nnodes $NODE_COUNT \
        --node_rank $NODE_RANK \
        --nproc_per_node $PROC_PER_NODE \
        --master_addr $MASTER_ADDR \
        --master_port 6000 \
        -m src.train "${args[@]}" $@

done

