#!/bin/bash

export CUDA_VISIBLE_DEVICES=0,1
# NODE='node-name'
# NAME='train'

# Define common parameters
DATASET="coco"
BATCHSIZE=4
NROWS=1
NUM_EVAL=10
NUMITER=10
# Define path
LOGDIR="./data"
DATADIR="/mnt/data1/jaayeon/data"

echo "Train SD1.5 start"
# srun -w $NODE --job-name=$NAME \
# torchrun --nproc-per-node=2 train_dist.py --model sd1.5 --datadir $DATADIR --logdir $LOGDIR \
#  --dataset $DATASET --batch_size $BATCHSIZE --epochs 2 --num_iter $NUMITER --num_eval $NUM_EVAL --nrows $NROWS \
#  --n_dc_tokens 4 --apply_dc True True False --dweight 0.01 \

echo "Train SDXL start"
# srun -w $NODE --job-name=$NAME \
# torchrun --nproc-per-node=2 train_dist.py --model sdxl --datadir $DATADIR --logdir $LOGDIR \
#  --dataset $DATASET --batch_size $BATCHSIZE --epochs 2 --num_iter $NUMITER --num_eval $NUM_EVAL --nrows $NROWS \
#  --n_dc_tokens 4 --apply_dc True True False --dweight 1 \

 echo "Train SD3 start"
# srun -w $NODE --job-name=$NAME \
torchrun --nproc-per-node=2 train_dist.py --model sd3 --datadir $DATADIR --logdir $LOGDIR \
 --dataset $DATASET --batch_size $BATCHSIZE --epochs 2 --num_iter $NUMITER --num_eval $NUM_EVAL --nrows $NROWS \
 --n_dc_tokens 4 --n_dc_layers 5 --use_dc_t --dweight 0 \

echo "All epochs completed!"