resources:
  accelerators: A100-80GB:8
  disk_size: 1000
  use_spot: true

num_nodes: 1

file_mounts:
  /artifacts:
    name: skypilot-chatbot # Change to your own bucket
    store: gcs
    mode: MOUNT
  /data:
    name: model-weights # Change to your own bucket
    store: gcs
    mode: MOUNT
  # /llamma:
  #   name: llama-ckpts # Change to the bucket that contains the LLaMA weights
  #   store: gcs
  #   mode: MOUNT

workdir: .

setup: |
  # Setup the environment
  conda create -n chatbot python=3.10 -y
  conda activate chatbot

  # Install pytorch
  pip install torch==1.13.1+cu116 torchvision==0.14.1+cu116 torchaudio==0.13.1 --extra-index-url https://download.pytorch.org/whl/cu116

  # Install huggingface with the LLaMA commit
  cd ~
  git clone https://github.com/huggingface/transformers.git
  cd transformers
  git checkout 41a2f3529c6b56866c317031375ffd3e7b8bea01
  pip install .
  cd ~/sky_workdir

  # Install fastchat
  pip install -e .
  pip install flash-attn

  mkdir -p /artifacts/llama-hf/llama-${MODEL_SIZE}B
  if [ ! -f /artifacts/llama-hf/llama-${MODEL_SIZE}B/complete ]; then
    mkdir -p ~/llama-${MODEL_SIZE}b
    gsutil -m rsync -r /llama/${MODEL_SIZE}b/ ~/llama-${MODEL_SIZE}b
    cd ~/transformers
    python src/transformers/models/llama/convert_llama_weights_to_hf.py \
      --input_dir $HOME/llama-${MODEL_SIZE}b \
      --model_size ${MODEL_SIZE}B \
      --output_dir ~/hf-output || exit 1
    mv ~/hf-output/tokenizer/* ~/hf-output/llama-${MODEL_SIZE}b
    gsutil -m rsync -r ~/hf-output/llama-${MODEL_SIZE}b/ /artifacts/llama-hf/llama-${MODEL_SIZE}B
    touch /artifacts/llama-hf/llama-${MODEL_SIZE}B/complete
  else
    mkdir -p ~/hf-output/llama-${MODEL_SIZE}b
    gsutil -m cp -r /artifacts/llama-hf/llama-${MODEL_SIZE}B/* ~/hf-output/llama-${MODEL_SIZE}b
  fi

run: |
  conda activate chatbot
  SEQ_LEN=${SEQ_LEN:-512}
  GC_SCALE=${GC_SCALE:-1}
  DATE=${DATE:-20230303}
  USE_FLASH_ATTN=${USE_FLASH_ATTN:-0}
  if [ $USE_FLASH_ATTN -eq 1 ]; then
    TRAIN_SCRIPT=fastchat/train/train_mem.py
    USE_FLASH_SUFFIX="-flash"
  else
    TRAIN_SCRIPT=fastchat/train/train.py
    USE_FLASH_SUFFIX=""
  fi
  echo "Training with seq_len=${SEQ_LEN} and gc_scale=${GC_SCALE}"
  PER_DEVICE_BATCH_SIZE=$((2048 * $GC_SCALE / $SEQ_LEN))
  NUM_NODES=`echo "$SKYPILOT_NODE_IPS" | wc -l`
  HOST_ADDR=`echo "$SKYPILOT_NODE_IPS" | head -n1`

  # Do the periodic syncing manually, to avoid the degradation of
  # the training for saving checkpoints.
  mkdir -p ~/.checkpoints
  LOCAL_CKPT_PATH=~/.checkpoints
  CKPT_PATH=/artifacts/chatbot/${MODEL_SIZE}b/sharegpt-${DATE}-seq-${SEQ_LEN}${USE_FLASH_SUFFIX}
  last_ckpt=$(ls ${CKPT_PATH} | grep -E '[0-9]+' | sort -t'-' -k1,1 -k2,2n | tail -1)
  mkdir -p ~/.checkpoints/${last_ckpt}
  gsutil -m rsync -r ${CKPT_PATH}/${last_ckpt}/ ~/.checkpoints/${last_ckpt}

  bash scripts/sync_local_checkpoint.sh ${LOCAL_CKPT_PATH} ${CKPT_PATH} > sync.log 2>&1 &
  
  torchrun \
    --nnodes=$NUM_NODES \
    --nproc_per_node=$SKYPILOT_NUM_GPUS_PER_NODE \
    --master_port=12375 \
    --master_addr=$HOST_ADDR \
    --node_rank=${SKYPILOT_NODE_RANK} \
    $TRAIN_SCRIPT \
    --model_name_or_path ~/hf-output/llama-${MODEL_SIZE}b \
    --data_path /data/sharegpt/sharegpt_20230322_clean_lang_split.json \
    --bf16 True \
    --output_dir $LOCAL_CKPT_PATH \
    --num_train_epochs 3 \
    --per_device_train_batch_size $PER_DEVICE_BATCH_SIZE \
    --per_device_eval_batch_size $PER_DEVICE_BATCH_SIZE \
    --gradient_accumulation_steps $((128 * 512 / $SEQ_LEN / $PER_DEVICE_BATCH_SIZE / $NUM_NODES / $SKYPILOT_NUM_GPUS_PER_NODE)) \
    --evaluation_strategy "no" \
    --save_strategy "steps" \
    --save_steps 1200 \
    --save_total_limit 10 \
    --learning_rate 2e-5 \
    --weight_decay 0. \
    --warmup_ratio 0.03 \
    --lr_scheduler_type "cosine" \
    --logging_steps 1 \
    --fsdp "full_shard auto_wrap" \
    --fsdp_transformer_layer_cls_to_wrap 'LlamaDecoderLayer' \
    --tf32 True \
    --model_max_length ${SEQ_LEN} \
    --gradient_checkpointing True \
    --lazy_preprocess True

  # Sync any files not in the checkpoint-* folders
  gsutil -m rsync -r -x 'checkpoint-*' $LOCAL_CKPT_PATH/ $CKPT_PATH/


envs:
  MODEL_SIZE: 13
  SEQ_LEN: 2048
  GC_SCALE: 4
  DATE: 20230322
  USE_FLASH_ATTN: 1
