# finetune_lora.sh
#!/bin/bash

export CUDA_DEVICE_MAX_CONNECTIONS=8
export TOKENIZERS_PARALLELISM=false
export TORCH_NCCL_ASYNC_ERROR_HANDLING=1
export PYTHONUNBUFFERED=1
export KIMIA_EXPORT_DEBUG=1

PER_DEVICE_BATCH_SIZE=2
GRADIENT_ACCUMULATION_STEPS=16

MODEL_BASE_PATH=YOUR_BASE_MODEL_PATH
DATA_PATH=YOUR_DATA_PATH

OUT_DIR=OUTPUT_PATH
LOG_FILE=LOG_PATH

SPLIT_BASE_DIR="${OUT_DIR}/ckpts"
MODULES_TO_SAVE=${MODULES_TO_SAVE:-"model.vq_adaptor,model.ced_processor"}

GPUS_PER_NODE=$(python -c 'import torch; print(torch.cuda.device_count())')
NNODES=${NNODES:-1}
NODE_RANK=${NODE_RANK:-0}
MASTER_ADDR=${MASTER_ADDR:-127.0.0.1}
MASTER_PORT=${MASTER_PORT:-6002}

TOTAL_GPUS=$((GPUS_PER_NODE * NNODES))
GLOBAL_BATCH_SIZE=$((PER_DEVICE_BATCH_SIZE * GRADIENT_ACCUMULATION_STEPS * TOTAL_GPUS))

USE_ZERO=${USE_ZERO:-0}
DS_FLAG=()
DS_CONFIG_PATH="finetune_codes/ds_config_zero3.json"
if [[ "$USE_ZERO" == "1" ]]; then
  DS_FLAG=(--deepspeed "$DS_CONFIG_PATH")
fi

R=64
ALPHA=64
DROPOUT=${DROPOUT:-0.05}
TARGET=${TARGET:-"q_proj,k_proj,v_proj,o_proj"}
INCLUDE_MLP=${INCLUDE_MLP:-false}

shopt -s nocasematch
if [[ "$INCLUDE_MLP" == "1" || "$INCLUDE_MLP" == "true" || "$INCLUDE_MLP" == "yes" ]]; then
  INCLUDE_MLP_ARG=True
else
  INCLUDE_MLP_ARG=False
fi
shopt -u nocasematch


function usage() {
  echo "Usage: $(basename "$0") [-m MODEL_PATH] [-d DATA_PATH] [-h]"
}
while [[ "$1" != "" ]]; do
  case $1 in
    -m|--model_path) shift; MODEL_BASE_PATH=$1 ;;
    -d|--data)       shift; DATA_PATH=$1 ;;
    -h|--help)       usage; exit 0 ;;
    *) echo "Error: Unknown argument $1"; usage; exit 1 ;;
  esac
  shift
done

[ -f "$DATA_PATH" ] || { echo "Error: DATA_PATH does not exist: $DATA_PATH"; exit 1; }
[ -d "$MODEL_BASE_PATH" ] || { echo "Error: MODEL_BASE_PATH does not exist: $MODEL_BASE_PATH"; exit 1; }
mkdir -p "$OUT_DIR" "$SPLIT_BASE_DIR" "$(dirname "$LOG_FILE")"

echo -e "\n===================== $(date '+%Y-%m-%d %H:%M:%S') =====================" | tee -a "$LOG_FILE"
exec > >(tee -a "$LOG_FILE") 2>&1

echo "Model: $MODEL_BASE_PATH"
echo "Data : $DATA_PATH"
echo "LoRA r/alpha/dropout: $R / $ALPHA / $DROPOUT"
echo "Targets: $TARGET | Include MLP: $INCLUDE_MLP_ARG"
echo "Modules to save: $MODULES_TO_SAVE"
echo "ZeRO-3: $([[ "$USE_ZERO" == "1" ]] && echo 'ON' || echo 'OFF')"

DISTRIBUTED_ARGS="--nproc_per_node $GPUS_PER_NODE --nnodes $NNODES --node_rank $NODE_RANK --master_addr $MASTER_ADDR --master_port $MASTER_PORT"

torchrun $DISTRIBUTED_ARGS finetune_lora.py \
  --model_name_or_path "$MODEL_BASE_PATH" \
  --data_path "$DATA_PATH" \
  --eval_ratio 0.05 \
  --bf16 False \
  --fp16 False \
  --output_dir "$OUT_DIR" \
  --num_train_epochs 2 \
  --per_device_train_batch_size "$PER_DEVICE_BATCH_SIZE" \
  --per_device_eval_batch_size 1 \
  --gradient_accumulation_steps "$GRADIENT_ACCUMULATION_STEPS" \
  --save_strategy "no" \
  --learning_rate 5e-5 \
  --weight_decay 0.1 \
  --adam_beta2 0.95 \
  --warmup_ratio 0.01 \
  --lr_scheduler_type "cosine" \
  --logging_steps 1 \
  --report_to "none" \
  --model_max_length 1024 \
  --gradient_checkpointing False \
  --lazy_preprocess True \
  --dataloader_num_workers 4 \
  --export_split_base_dir "$SPLIT_BASE_DIR" \
  --export_split_every_n_epochs 1 \
  --export_split_keep_last_k 5 \
  --lora_r "$R" \
  --lora_alpha "$ALPHA" \
  --lora_dropout "$DROPOUT" \
  --target_modules "$TARGET" \
  --include_mlp "$INCLUDE_MLP_ARG" \
  --modules_to_save "$MODULES_TO_SAVE" \
  --remove_unused_columns False \
  "${DS_FLAG[@]}"

EXIT_CODE=$?
if [ $EXIT_CODE -eq 0 ]; then
  echo "[SUCCESS] LoRA training finished. Split ckpts -> $SPLIT_BASE_DIR"
else
  echo "[FAILED] Exit code $EXIT_CODE."
fi
exit $EXIT_CODE
