#!/bin/bash
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7

# source /etc/profile.d/modules.sh
# module load hpcx/2.12
# module load singularitypro/3.11
# module load cuda/11.6/11.6.2
# module load nccl/2.11/2.11.4-1

# # Singularity image file
# SINGULARITY_IMAGE="LSoundEDM/lsoundedm_latest.sif"
# JOB_NAME=$JOB_ID

# GPU_INFO=$(nvidia-smi --query-gpu=gpu_name --format=csv)
# if [[ $GPU_INFO =~ "V100" ]]; then
#     NUM_GPUS_PER_NODE=4
# elif [[ $GPU_INFO =~ "A100" ]]; then
#     NUM_GPUS_PER_NODE=8
# else
#     readonly PROC_ID=$!
#     kill ${PROC_ID}
# fi

# # get number of GPUs
# GPUS_IN_ONE_NODE=$(nvidia-smi --list-gpus | wc -l)
# NUM_GPU=$(expr ${NHOSTS} \* ${GPUS_IN_ONE_NODE})
# echo "NUM_GPU = ${NUM_GPU}"

# MPI options
# MPIOPTS="-np $NUM_GPU -N ${NUM_GPUS_PER_NODE} -x MASTER_ADDR=${HOSTNAME} -hostfile $SGE_JOB_HOSTLIST"
MPIOPTS="-np 8 -x MASTER_ADDR=${HOSTNAME}"
mpirun ${MPIOPTS} \
    singularity exec --bind /data:/data --nv /path/to/container /bin/bash -c "
    ./python_accelerate.sh cm_train_v3.py \
    --train_file "path/to/train.csv" \
    --validation_file "path/to/val.csv" \
    --text_encoder_name "google/flan-t5-large" --tango --unet_model_config "configs/diffusion_model_config.json" \
    --freeze_text_encoder --ctm_unet_model_config "configs/diffusion_model_config.json" \
    --gradient_accumulation_steps 1 --per_device_train_batch_size 6 --num_train_epochs 40 --lr 0.00008 --d_lr 0.00008 \
    --loss_norm 'feature_space' --match_point 'zs' --loss_distance 'l2' --unet_mode 'full' --target_cfg 3.0 --w_min 2.0 --w_max 5.0\
    --num_heun_step 39 --start_scales 40 --end_scales 40 --mixed_precision 'bf16' --cfg_distill False --unform_sampled_cfg_distill True --discriminator_weight 1.0 --tango_data_augment --augment_num 2\
    --discriminator_start_itr 39000 --d_architecture 'CMBDisc' --vqgan_n_layers 1 --gan_target 'z_target' --vqgan_use_spectral_norm False \
    --d_cond_type 'text_encoder' --c_dim 1024 --cmap_dim 128 --mbdisc_ndf 32 --n_bins 64 --increase_ch False --diffusion_training True \
    --text_column caption --audio_column file_name --checkpointing_steps "best" --output_dir "/path/to/output" \
    --with_tracking --seed 5031 \
    "
