#!/bin/bash
#SBATCH -o ./log/gan.%j.out
#SBATCH --partition=GPUA800
#SBATCH --job-name=cifar
#SBATCH --ntasks=1
#SBATCH --gres=gpu:2
#SBATCH --qos=normal
#SBATCH --cpus-per-task=2
#SBATCH --time 120:00:00
#SBATCH --mem 64G

#Reproduce SiDA distillation of pretrained EDM models

# Retrieve the dataset name from the first argument
dataset=$1

# Example usage:
# To set specific GPUs and run the script for 'cifar10-uncond':
# export CUDA_VISIBLE_DEVICES=0,1,2,3
# sh run_sid_sida.sh 'cifar10-uncond'

# Tip: Decrease --batch-gpu to reduce memory consumption on limited GPU resources

if [ "$dataset" = 'cifar10-uncond' ]; then
    # Command to execute the SiDA training script with specified parameters
    # Optional: Use the --resume option to load a specific checkpoint, e.g.:
    # --resume 'image_experiment/sid-train-runs/cifar10-uncond/training-state-????.pt'
    # If --resume points to a folder, the script will automatically load the latest checkpoint from that folder. 
    # This is particularly useful for seamless resumption when running the code in a cluster environment.
    # Note: Optional parameters, such as --data_stat, will be computed automatically within the code if not explicitly provided.
    torchrun --standalone --nproc_per_node=2 fsim_train.py \
    --tmax 800 \
    --init_sigma 2.5 \
    --batch 256 \
    --batch-gpu 128 \
    --data './checkpoints/cifar10-32x32.zip'  \
    --outdir './image_experiment/fsim-train-runs/cifar10-uncond-forward-clean' \
    --divergence 'Forward-Clean' \
    --nosubdir 0 \
    --arch ddpmpp \
    --edm_model '/gpfs/share/home/2206192113/cvpr_code/Uni-Instruct/checkpoints/edm-cifar10-32x32-uncond-vp.pkl' \
    --detector_url '/gpfs/share/home/2206192113/cvpr_code/Uni-Instruct/checkpoints/inception-2015-12-05.pt' \
    --tick 10 \
    --snap 50 \
    --dump 200 \
    --lr 1e-5 \
    --glr 1e-5 \
    --fp16 0 \
    --ls 1 \
    --lsg 100 \
    --lsd 1 \
    --lsg_gan 0.01 \
    --duration 300 \
    --data_stat '/gpfs/share/home/2206192113/cvpr_code/Uni-Instruct/checkpoints/cifar10-32x32.npz' \
    --use_gan 1 \
    --metrics fid50k_full,is50k \
    --save_best_and_last 1 

    #--sid_model 'https://huggingface.co/UT-Austin-PML/SiD/resolve/main/imagenet64/alpha1.2/network-snapshot-1.200000-939176.pkl'

elif [ "$dataset" = 'cifar10-cond' ]; then
    torchrun --standalone --nproc_per_node=4 fsim_train.py \
    --cond 1 \
    --tmax 800 \
    --init_sigma 2.5 \
    --batch 256 \
    --batch-gpu 32 \
    --data './datasets/cifar10-32x32.zip'  \
    --outdir './image_experiment/fsim-train-runs/cifar10-cond-f-distill-jensen-shannon' \
    --divergence 'f-distill-Jensen-Shannon' \
    --variance_reduction 1 \
    --nosubdir 0 \
    --arch ddpmpp \
    --edm_model '' \
    --detector_url '' \
    --tick 10 \
    --snap 50 \
    --dump 200 \
    --lr 1e-5 \
    --glr 1e-5 \
    --fp16 0 \
    --ls 1 \
    --lsg 100 \
    --lsd 1 \
    --lsg_gan 0.01 \
    --duration 300 \
    --data_stat '' \
    --use_gan 1 \
    --metrics fid50k_full \
    --save_best_and_last 1 \
    # --resume './image_experiment/fsim-train-runs/cifar10-cond' 
    #--sid_model 'https://huggingface.co/UT-Austin-PML/SiD/resolve/main/cifar10-cond/alpha1.2/network-snapshot-1.200000-713312.pkl'

    
elif [ "$dataset" = 'imagenet64-cond' ]; then
    torchrun --standalone --nproc_per_node=4 fsim_train.py \
    --cond 1 \
    --tmax 800 \
    --init_sigma 2.5 \
    --batch 8192 \
    --batch-gpu 8 \
    --data '' \
    --outdir './image_experiment/fsim-train-runs/imagenet64-cond-chi-square' \
    --divergence 'Chi-Square' \
    --nosubdir 0 \
    --arch adm \
    --edm_model "" \
    --detector_url '' \
    --tick 20 \
    --snap 50 \
    --dump 200 \
    --lr 4e-6 \
    --glr 4e-6 \
    --fp16 1 \
    --ls 1 \
    --lsg 100 \
    --lsd 1 \
    --lsg_gan 0.01 \
    --duration 300 \
    --data_stat '' \
    --use_gan 1 \
    --metrics fid50k_full \
    --save_best_and_last 1 \
    --dropout 0.1 \
    --augment 0 \
    --ema 2 \
    --duration 300 
    #--sid_model 'https://huggingface.co/UT-Austin-PML/SiD/resolve/main/imagenet64/alpha1.2/network-snapshot-1.200000-939176.pkl'

elif [ "$dataset" = 'cifar10-uncond-sid' ]; then
    # Command to execute the SiDA training script with specified parameters
    # Optional: Use the --resume option to load a specific checkpoint, e.g.:
    # --resume 'image_experiment/sid-train-runs/cifar10-uncond/training-state-????.pt'
    # If --resume points to a folder, the script will automatically load the latest checkpoint from that folder. 
    # This is particularly useful for seamless resumption when running the code in a cluster environment.
    # Note: Optional parameters, such as --data_stat, will be computed automatically within the code if not explicitly provided.
    torchrun --standalone --nproc_per_node=2 fsim_train.py \
    --tmax 800 \
    --init_sigma 2.5 \
    --batch 256 \
    --batch-gpu 128 \
    --data './checkpoints/cifar10-32x32.zip'  \
    --outdir './image_experiment/fsim-train-runs/cifar10-uncond-forward-kl' \
    --resume "" \
    --divergence 'Forward-KL' \
    --nosubdir 0 \
    --arch ddpmpp \
    --edm_model './checkpoints/edm-cifar10-32x32-uncond-vp.pkl' \
    --detector_url './checkpoints/inception-2015-12-05.pt' \
    --tick 10 \
    --snap 50 \
    --dump 200 \
    --lr 1e-5 \
    --glr 1e-5 \
    --fp16 0 \
    --ls 1 \
    --lsg 100 \
    --lsd 1 \
    --lsg_gan 0.01 \
    --duration 100 \
    --data_stat './checkpoints/cifar10-32x32.npz' \
    --use_gan 1 \
    --metrics fid50k_full,is50k \
    --save_best_and_last 1 \
    --sid_model ''

elif [ "$dataset" = 'cifar10-cond-sid' ]; then
    torchrun --standalone --nproc_per_node=2 fsim_train.py \
    --cond 1 \
    --tmax 800 \
    --init_sigma 2.5 \
    --batch 256 \
    --batch-gpu 64 \
    --data './checkpoints/cifar10-32x32.zip'  \
    --outdir './image_experiment/fsim-train-runs/cifar10-cond-forward-kl' \
    --divergence 'Forward-KL' \
    --resume "" \
    --variance_reduction 0 \
    --nosubdir 0 \
    --arch ddpmpp \
    --edm_model './checkpoints/edm-cifar10-32x32-cond-vp.pkl' \
    --detector_url './checkpoints/inception-2015-12-05.pt' \
    --tick 10 \
    --snap 50 \
    --dump 200 \
    --lr 1e-5 \
    --glr 1e-5 \
    --fp16 0 \
    --ls 1 \
    --lsg 100 \
    --lsd 1 \
    --lsg_gan 0.01 \
    --duration 300 \
    --data_stat './checkpoints/cifar10-32x32.npz' \
    --use_gan 1 \
    --metrics fid50k_full \
    --save_best_and_last 1 \
    --sid_model 
    # --resume './image_experiment/fsim-train-runs/cifar10-cond' 

    
elif [ "$dataset" = 'imagenet64-cond-sid' ]; then
    torchrun --standalone --nproc_per_node=4 fsim_train.py \
    --cond 1 \
    --tmax 800 \
    --init_sigma 2.5 \
    --batch 8192 \
    --batch-gpu 8 \
    --data '' \
    --outdir './image_experiment/fsim-train-runs/imagenet64-cond-forward-kl' \
    --divergence 'Forward-KL' \
    --nosubdir 0 \
    --arch adm \
    --edm_model "" \
    --detector_url '' \
    --tick 20 \
    --snap 50 \
    --dump 200 \
    --lr 4e-6 \
    --glr 4e-6 \
    --fp16 1 \
    --ls 1 \
    --lsg 100 \
    --lsd 1 \
    --lsg_gan 0.01 \
    --duration 300 \
    --data_stat '' \
    --use_gan 1 \
    --metrics fid50k_full \
    --save_best_and_last 1 \
    --dropout 0.1 \
    --augment 0 \
    --ema 2 \
    --duration 300 \
    --sid_model 'https://huggingface.co/UT-Austin-PML/SiD/resolve/main/imagenet64/alpha1.2/network-snapshot-1.200000-939176.pkl'
    
else
    echo "Invalid dataset specified"
    exit 1
fi