# finetune_ds.sh
#!/bin/bash

export CUDA_DEVICE_MAX_CONNECTIONS=8
export TOKENIZERS_PARALLELISM=false
export TORCH_NCCL_ASYNC_ERROR_HANDLING=1
export PYTHONUNBUFFERED=1

PER_DEVICE_BATCH_SIZE=2
GRADIENT_ACCUMULATION_STEPS=8

MODEL_BASE_PATH=YOUR_BASE_MODEL_PATH
DATA_PATH=YOUR_DATA_PATH

OUT_DIR=OUTPUT_PATH
LOG_FILE=LOG_PATH

GPUS_PER_NODE=$(python -c 'import torch; print(torch.cuda.device_count())')
NNODES=${NNODES:-1}
NODE_RANK=${NODE_RANK:-0}
MASTER_ADDR=${MASTER_ADDR:-127.0.0.1}
MASTER_PORT=${MASTER_PORT:-6001}

TOTAL_GPUS=$((GPUS_PER_NODE * NNODES))
GLOBAL_BATCH_SIZE=$((PER_DEVICE_BATCH_SIZE * GRADIENT_ACCUMULATION_STEPS * TOTAL_GPUS))

USE_ZERO=${USE_ZERO:-0}
DS_FLAG=()
DS_CONFIG_PATH="finetune_codes/ds_config_zero3.json"
if [[ "$USE_ZERO" == "1" ]]; then
  DS_FLAG=(--deepspeed "$DS_CONFIG_PATH")
fi

function usage() {
  echo "Usage: $(basename "$0") [-m MODEL_PATH] [-d DATA_PATH] [-h]"
}
while [[ "$1" != "" ]]; do
  case $1 in
    -m|--model_path) shift; MODEL_BASE_PATH=$1 ;;
    -d|--data)       shift; DATA_PATH=$1 ;;
    -h|--help)       usage; exit 0 ;;
    *) echo "Error: Unknown argument $1"; usage; exit 1 ;;
  esac
  shift
done

[ -f "$DATA_PATH" ] || { echo "Error: DATA_PATH file does not exist: $DATA_PATH"; exit 1; }
[ -d "$MODEL_BASE_PATH" ] || { echo "Error: MODEL_BASE_PATH does not exist: $MODEL_BASE_PATH"; exit 1; }
SPLIT_BASE_DIR="${OUT_DIR}/ckpts"
mkdir -p "$OUT_DIR" "$SPLIT_BASE_DIR" "$(dirname "$LOG_FILE")"

echo -e "\n===================== $(date '+%Y-%m-%d %H:%M:%S') =====================" | tee -a "$LOG_FILE"
exec > >(tee -a "$LOG_FILE") 2>&1

echo "Model Path: $MODEL_BASE_PATH"
echo "Data Path:  $DATA_PATH"
echo "DeepSpeed ZeRO-3: $([[ "$USE_ZERO" == "1" ]] && echo 'ON' || echo 'OFF')"
[[ "$USE_ZERO" == "1" ]] && echo "DeepSpeed Config: $DS_CONFIG_PATH"
echo "Total GPUs: $TOTAL_GPUS | Global BS: $GLOBAL_BATCH_SIZE"

DISTRIBUTED_ARGS="--nproc_per_node $GPUS_PER_NODE --nnodes $NNODES --node_rank $NODE_RANK --master_addr $MASTER_ADDR --master_port $MASTER_PORT"

torchrun $DISTRIBUTED_ARGS finetune.py \
  --model_name_or_path "$MODEL_BASE_PATH" \
  --data_path "$DATA_PATH" \
  --eval_ratio 0.0 \
  --bf16 True \
  --fp16 False \
  --output_dir "$OUT_DIR" \
  --num_train_epochs 5 \
  --per_device_train_batch_size "$PER_DEVICE_BATCH_SIZE" \
  --per_device_eval_batch_size 1 \
  --gradient_accumulation_steps "$GRADIENT_ACCUMULATION_STEPS" \
  --save_strategy "no" \
  --learning_rate 1e-3 \
  --weight_decay 0.1 \
  --adam_beta2 0.95 \
  --warmup_ratio 0.01 \
  --lr_scheduler_type "cosine" \
  --logging_steps 1 \
  --report_to "none" \
  --model_max_length 512 \
  --gradient_checkpointing False \
  --lazy_preprocess True \
  --dataloader_num_workers 4 \
  --export_split_base_dir "$SPLIT_BASE_DIR" \
  --export_split_every_n_epochs 1 \
  --export_split_keep_last_k 5 \
  --remove_unused_columns False \
  --ddp_find_unused_parameters False
  "${DS_FLAG[@]}"

EXIT_CODE=$?
if [ $EXIT_CODE -eq 0 ]; then
  echo "[SUCCESS] Training finished. Split ckpts -> $SPLIT_BASE_DIR"
else
  echo "[FAILED] Exit code $EXIT_CODE."
fi
exit $EXIT_CODE
