#!/bin/bash

# =========================================================
# Automatically adjusts gradient accumulation to maintain Global Batch Size = 64
#
# Usage: 
#   bash scripts/run_sft.sh <MODEL_PATH> <EXP_NAME> [N_GPUS] [MASTER_PORT] [DATASET_PATH] [NUM_EPOCHS]
#
# Arguments:
#   1. MODEL_PATH    : Path to base model or Step-0 model
#   2. EXP_NAME      : Name of the experiment (for output dir)
#   3. N_GPUS        : (Optional) Number of GPUs. Default: 8
#   4. MASTER_PORT   : (Optional) Distributed port. Default: 29500
#   5. DATASET_PATH  : (Optional) Path to jsonl dataset. 
#   6. NUM_EPOCHS    : (Optional) Number of training epochs. Default: 3
# =========================================================

MODEL_PATH=$1
EXP_NAME=$2
N_GPUS=${3:-8}                 
MASTER_PORT=${4:-29500}              
DATASET_PATH=${5:-"./data/processed/final_train.jsonl"} 
NUM_EPOCHS=${6:-3}          

if [ -z "$MODEL_PATH" ] || [ -z "$EXP_NAME" ]; then
    echo "❌ Usage: bash scripts/run_sft.sh <MODEL_PATH> <EXP_NAME> [N_GPUS] [MASTER_PORT] [DATASET_PATH] [NUM_EPOCHS]"
    exit 1
fi

TARGET_GLOBAL_BATCH_SIZE=64
PER_DEVICE_BATCH_SIZE=1

if [ $N_GPUS -gt $TARGET_GLOBAL_BATCH_SIZE ]; then
    echo "❌ Error: N_GPUS ($N_GPUS) cannot be larger than Target Batch Size ($TARGET_GLOBAL_BATCH_SIZE)"
    exit 1
fi

GRAD_ACCUM=$((TARGET_GLOBAL_BATCH_SIZE / N_GPUS))

OUTPUT_DIR="./experiments/${EXP_NAME}"
LOG_DIR="${OUTPUT_DIR}/runs"
mkdir -p "$LOG_DIR"

echo "======================================================="
echo "🚀 Starting Smart SFT Training"
echo "   Model:        $MODEL_PATH"
echo "   Dataset:      $DATASET_PATH"
echo "   Output:       $OUTPUT_DIR"
echo "   GPUs:         $N_GPUS"
echo "   Epochs:       $NUM_EPOCHS"
echo "   Strategy:     Global BS=$TARGET_GLOBAL_BATCH_SIZE | Grad Accum=$GRAD_ACCUM"
echo "======================================================="

python -m torch.distributed.run --nproc_per_node $N_GPUS --master_port $MASTER_PORT \
    -m swift.cli.sft \
    --train_type 'full' \
    --torch_dtype 'bfloat16' \
    --model "$MODEL_PATH" \
    --model_type 'qwen2' \
    --template 'qwen2_5' \
    --dataset "$DATASET_PATH" \
    --max_length 32768 \
    --learning_rate 2e-05 \
    --num_train_epochs $NUM_EPOCHS \
    --per_device_train_batch_size $PER_DEVICE_BATCH_SIZE \
    --gradient_accumulation_steps $GRAD_ACCUM \
    --eval_steps 500 \
    --save_steps 500 \
    --attn_impl 'flash_attn' \
    --adam_beta1 0.9 \
    --adam_beta2 0.999 \
    --adam_epsilon 1e-8 \
    --report_to 'tensorboard' \
    --agent_template 'hermes' \
    --lazy_tokenize true \
    --deepspeed 'zero3' \
    --gradient_checkpointing true \
    --output_dir "$OUTPUT_DIR" \
    --logging_dir "$LOG_DIR" \
    --add_version False \
    --ignore_args_error True \
    2>&1 | tee "${LOG_DIR}/train.log"

echo "✅ Training command finished."