dataset=$1
exp_dir="ipfm_experiments"
exp_name="run_ipfm"

if [ "$dataset" = 'cifar10-uncond-D128-alpha10-nfe2' ]; then
    torchrun --standalone --nproc_per_node=2 ipfm_train.py \
    --aug_dim "128" \
    --alpha 1.0 \
    --loss_ipfm_or_sid "ipfm" \
    --tmax 800 \
    --init_sigma 2.5 \
    --n_generator_steps 2 \
    --batch 256 \
    --batch-gpu 128 \
    --outdir "${exp_dir}/${exp_name}/${dataset}" \
    --data 'datasets/cifar10-32x32.zip' \
    --arch ncsnpp \
    --edm_model 'downloads/pfgmpp_ckpts/cifar10_ncsnpp_D_128.pkl' \
    --metrics fid50k_full \
    --emas_for_eval G_ema999 \
    --tick 25 \
    --snap 25 \
    --dump 25 \
    --lr 1e-5 \
    --glr 1e-5 \
    --fp16 0 \
    --ls 1 \
    --lsg 100 \
    --duration 30 \
    --data_stat 'https://nvlabs-fi-cdn.nvidia.com/edm/fid-refs/cifar10-32x32.npz'
else
    echo "Invalid dataset specified: ${datset}"
    exit 1
fi
