# Add this for deepspeed to compile
CONDA_ENV_PATH=$(conda info --base)/envs/animeshooter
export CPATH=$CONDA_ENV_PATH/targets/x86_64-linux/include:$CPATH

PROJECT_ROOT=$(pwd)
cd $PROJECT_ROOT
export PYTHONPATH=$PYTHONPATH:$PROJECT_ROOT
export OMP_NUM_THREADS=1
export CUDA_VISIBLE_DEVICES=0,1
export NODE_RANK=0

if [ -z "$1" ]; then
    echo "Error: Please provide a video_id"
    echo "Usage: . train.sh <video_id>"
    exit 1
fi

VIDEO_ID=$1

torchrun --nproc_per_node=2 --nnodes=1 --node_rank=$NODE_RANK --master_port=29500 src/cogvideo_lora_train.py \
    --config src/config/cogvideo_lora_config.yaml \
    --video_id $VIDEO_ID