n_nodes=1
n_gpus_per_node=4 # number of gpus
# -x SH-IDC1-10-140-37-112
prefix_cmd="srun -p saturn-v --nodes $n_nodes --ntasks-per-node $n_gpus_per_node --cpus-per-task 16 --gres=gpu:$n_gpus_per_node --quotatype=auto --async"
multi_cmd="srun -p video-aigc-1 --nodes 1 --ntasks-per-node 8 --cpus-per-task 16 --gres=gpu:8 --quotatype=auto --async"

gpu4_cmd="srun -p video-aigc-1 --nodes 2 --ntasks-per-node 4 --cpus-per-task 16 --gres=gpu:4 --quotatype=auto --async"
gpu2_cmd="srun -p video-aigc-1 --nodes 1 --ntasks-per-node 2 --cpus-per-task 16 --gres=gpu:2 --quotatype=auto"
test_cmd="srun -p video-aigc-1 --nodes 1 --ntasks-per-node 1 --cpus-per-task 16 --gres=gpu:1 --quotatype=auto"
cpu_cmd="srun -p video-aigc-1 --nodes 1 --ntasks-per-node 1 --cpus-per-task 1 --quotatype=auto"


export LD_LIBRARY_PATH=~/anaconda3/lib:$LD_LIBRARY_PATH
# prefix_cmd_docker="export OMP_NUM_THREADS=1"
# prefix_cmd_vast="export OMP_NUM_THREADS=1"
# export http_proxy=http://192.168.48.17:18000
# export https_proxy=http://192.168.48.17:18000
# export HF_HOME=/mnt/pfs/share/pretrained_model/.cache/huggingface
# export TORCH_HOME=/mnt/pfs/share/pretrained_model/.cache/torch

function train(){
    export PYTHONPATH=$PROJECT_DIR:$PYTHONPATH
    $test_cmd python train.py --config configs/demo/nerf.yaml
}
function train_nerfacc(){
    export PYTHONPATH=$PROJECT_DIR:$PYTHONPATH
    $test_cmd python train.py --config configs/demo/nerf/nerfacc_nerf.yaml
}
function train_leap(){
    export PYTHONPATH=$PROJECT_DIR:$PYTHONPATH
    $test_cmd python train.py --config configs/demo/nerf/leap.yaml
}
function train_volume(){
    export PYTHONPATH=$PROJECT_DIR:$PYTHONPATH
    $test_cmd python train.py --config configs/demo/nerf/volume_nerfaccc.yaml
}
function train_gs_fix(){
    export PYTHONPATH=$PROJECT_DIR:$PYTHONPATH
    $test_cmd python train.py --config configs/demo/3dgs/3dgs_fix.yaml
}
function train_gs_split(){
    export PYTHONPATH=$PROJECT_DIR:$PYTHONPATH
    $test_cmd python train.py --config configs/demo/3dgs/3dgs_split.yaml
}
function train_gs_gen(){
    export PYTHONPATH=$PROJECT_DIR:$PYTHONPATH
    export CUDA_LAUNCH_BLOCKING=1 
    $test_cmd python train.py --config configs/demo/3dgs/3dgs_gen.yaml
}
function train_na_gs(){
    export PYTHONPATH=$PROJECT_DIR:$PYTHONPATH
    $test_cmd python train.py --config configs/demo/3dgs/3dgs.yaml
}
function train_van_gs(){
    export PYTHONPATH=$PROJECT_DIR:$PYTHONPATH
    $test_cmd python train.py --config configs/demo/3dgs/van_3dgs.yaml
}
function train_cifar10(){
    export PYTHONPATH=$PROJECT_DIR:$PYTHONPATH
    $prefix_cmd python train.py --config configs/demo/cifar10.yaml
}
function train_cifar10_x0(){
    export PYTHONPATH=$PROJECT_DIR:$PYTHONPATH
    $prefix_cmd python train.py --config configs/demo/cifar10_x0.yaml
}
function train_cifar10_x0_deep(){
    export PYTHONPATH=$PROJECT_DIR:$PYTHONPATH
    $prefix_cmd python train.py --config configs/demo/cifar10_x0_deep.yaml
}
function test_position_map(){
    export PYTHONPATH=$PROJECT_DIR:$PYTHONPATH
    $test_cmd python test/test_position_map.py
}
function train_position_map(){
    export PYTHONPATH=$PROJECT_DIR:$PYTHONPATH
    $multi_cmd python train.py --config configs/demo/mv/svd_lgm_position_map.yaml
}
function train_moe(){
    export PYTHONPATH=$PROJECT_DIR:$PYTHONPATH
    $gpu4_cmd python train.py --config configs/demo/mv/svd_lgm_moe.yaml
}
function test_cifar10(){
    $test_cmd python train.py --config configs/demo/cifar10.yaml
}
function test_mlp(){
    export PYTHONPATH=$PROJECT_DIR:$PYTHONPATH
    $test_cmd python src/models/nerf/mlp.py
}
function test(){
    export PYTHONPATH=$PROJECT_DIR:$PYTHONPATH
    $test_cmd python test.py
}
function test_svd(){
    export PYTHONPATH=$PROJECT_DIR:$PYTHONPATH
    $test_cmd python test/test_svd.py
}
function test_svd_inference(){
    export PYTHONPATH=$PROJECT_DIR:$PYTHONPATH
    $test_cmd python test/test_svd_inference.py
}
function test_dataset(){
    export PYTHONPATH=$PROJECT_DIR:$PYTHONPATH
    $cpu_cmd python src/data/multiview.py
}
function test_cifar10_dataset(){
    export PYTHONPATH=$PROJECT_DIR:$PYTHONPATH
    $cpu_cmd python src/data/cifar10.py
}
function test_fid(){
    export PYTHONPATH=$PROJECT_DIR:$PYTHONPATH
    $test_cmd python scripts/get_fid.py
}

function train_svd_lrm(){
    export PYTHONPATH=$PROJECT_DIR:$PYTHONPATH
    export LD_LIBRARY_PATH=~/anaconda3/lib:$LD_LIBRARY_PATH
    $test_cmd python train.py --config configs/demo/mv/svd_lrm.yaml
}
function train_svd_lgm(){
    export PYTHONPATH=$PROJECT_DIR:$PYTHONPATH
    export LD_LIBRARY_PATH=~/anaconda3/lib:$LD_LIBRARY_PATH
    $gpu4_cmd python train.py --config configs/demo/mv/svd_lgm.yaml
}
function train_svd_lgm_test(){
    export PYTHONPATH=$PROJECT_DIR:$PYTHONPATH
    export LD_LIBRARY_PATH=~/anaconda3/lib:$LD_LIBRARY_PATH
    $test_cmd python train.py --config configs/demo/mv/svd_lgm.yaml
}
function train_lgm(){
    export PYTHONPATH=$PROJECT_DIR:$PYTHONPATH
    export LD_LIBRARY_PATH=~/anaconda3/lib:$LD_LIBRARY_PATH
    CUDA_LAUNCH_BLOCKING=1 $test_cmd python train.py --config configs/demo/mv/lgm.yaml
}
function train_only_lgm(){
    export PYTHONPATH=$PROJECT_DIR:$PYTHONPATH
    export LD_LIBRARY_PATH=~/anaconda3/lib:$LD_LIBRARY_PATH
    $multi_cmd python train.py --config configs/demo/mv/svd_only_lgm.yaml
}
function train_svd(){
    export PYTHONPATH=$PROJECT_DIR:$PYTHONPATH
    export LD_LIBRARY_PATH=~/anaconda3/lib:$LD_LIBRARY_PATH
    $gpu4_cmd python train.py --config configs/demo/mv/svd.yaml
}
function test_encoder(){
    export PYTHONPATH=$PROJECT_DIR:$PYTHONPATH
    export LD_LIBRARY_PATH=~/anaconda3/lib:$LD_LIBRARY_PATH
    $test_cmd python src/models/network/encoder.py
}
function test_lifting(){
    export PYTHONPATH=$PROJECT_DIR:$PYTHONPATH
    export LD_LIBRARY_PATH=~/anaconda3/lib:$LD_LIBRARY_PATH
    $test_cmd python src/models/network/lifting.py
}
function test_dtu(){
    export PYTHONPATH=$PROJECT_DIR:$PYTHONPATH
    export LD_LIBRARY_PATH=~/anaconda3/lib:$LD_LIBRARY_PATH
    $test_cmd python src/data/dtu.py
}
function test_sdxl(){
    export PYTHONPATH=$PROJECT_DIR:$PYTHONPATH
    export LD_LIBRARY_PATH=~/anaconda3/lib:$LD_LIBRARY_PATH
    $test_cmd python test/test_sdxl.py
}
function train_sdxl(){
    export PYTHONPATH=$PROJECT_DIR:$PYTHONPATH
    export LD_LIBRARY_PATH=~/anaconda3/lib:$LD_LIBRARY_PATH
    $test_cmd python train.py --config configs/demo/sd/sdxl.yaml
}
function train_sd(){
    export PYTHONPATH=$PROJECT_DIR:$PYTHONPATH
    export LD_LIBRARY_PATH=~/anaconda3/lib:$LD_LIBRARY_PATH
    $test_cmd python train.py --config configs/demo/sd/sd.yaml
}
function train_lora(){
    export MODEL_NAME="stabilityai/stable-diffusion-2-1-base"
    export TRAIN_DIR="data/sdxl/dog"
    accelerate launch --mixed_precision="fp16" train_text_to_image_lora.py \
    --pretrained_model_name_or_path=$MODEL_NAME \
    --dataset_name=$DATASET_NAME --caption_column="text" \
    --resolution=512 --random_flip \
    --train_batch_size=1 \
    --num_train_epochs=100 --checkpointing_steps=5000 \
    --learning_rate=1e-05 --lr_scheduler="constant" --lr_warmup_steps=0 \
    --seed=42 \
    --output_dir="outputs" \
    --validation_prompt="cute dragon creature" --report_to="tensorboard" 
}
function test_split(){
    export PYTHONPATH=$PROJECT_DIR:$PYTHONPATH
    export LD_LIBRARY_PATH=~/anaconda3/lib:$LD_LIBRARY_PATH
    $test_cmd python test/test_split.py
}
function test_gobjaverse(){
    export PYTHONPATH=$PROJECT_DIR:$PYTHONPATH
    export LD_LIBRARY_PATH=~/anaconda3/lib:$LD_LIBRARY_PATH
    $cpu_cmd python src/data/gobjaverse.py
}
function test_svd_train(){
    export PYTHONPATH=$PROJECT_DIR:$PYTHONPATH
    export LD_LIBRARY_PATH=~/anaconda3/lib:$LD_LIBRARY_PATH
    $test_cmd accelerate launch test/test_svd_train.py \
    --pretrained_model_name_or_path=stabilityai/stable-video-diffusion-img2vid \
    --per_gpu_batch_size=1 --gradient_accumulation_steps=1 \
    --max_train_steps=50000 \
    --width=512 \
    --height=512 \
    --num_frames=14 \
    --checkpointing_steps=10000 --checkpoints_total_limit=1 \
    --learning_rate=1e-5 --lr_warmup_steps=0 \
    --seed=123 \
    --mixed_precision="fp16" \
    --validation_steps=200 
}
function test_lgm(){
    export PYTHONPATH=$PROJECT_DIR:$PYTHONPATH
    export LD_LIBRARY_PATH=~/anaconda3/lib:$LD_LIBRARY_PATH
    $test_cmd python test/test_lgm.py
}

function test_adaptor(){
    export PYTHONPATH=$PROJECT_DIR:$PYTHONPATH
    export LD_LIBRARY_PATH=~/anaconda3/lib:$LD_LIBRARY_PATH
    $test_cmd python src/models/unet/adaptor.py
}

$1