#!/bin/bash
#
# SFT Training Script - Based on LLaMA-Factory
# Usage: ./train_sft_llamafactory.sh
#

set -e

# ============================================================================
# Model Configuration - Modify here to switch models
# ============================================================================
BASE_MODEL="Qwen/Qwen3-4B-Instruct-2507"

# Auto-generate model identifier from BASE_MODEL (for paths)
# Example: "Qwen/Qwen3-4B-Instruct-2507" -> "qwen3-4b-instruct-2507"
if [[ "$BASE_MODEL" == *"/"* ]]; then
    MODEL_ID=$(echo "$BASE_MODEL" | sed 's/.*\///' | tr '[:upper:]' '[:lower:]')
else
    MODEL_ID=$(echo "$BASE_MODEL" | tr '[:upper:]' '[:lower:]')
fi

# Extract model display name from BASE_MODEL (for logging)
MODEL_DISPLAY_NAME=$(echo "$BASE_MODEL" | sed 's/.*\///')

echo "========================================================================"
echo "$MODEL_DISPLAY_NAME SFT Fine-tuning (LLaMA-Factory)"
echo "========================================================================"
echo ""

# Configuration parameters
MODEL_NAME="$BASE_MODEL"
TRAIN_DATA="data/pku_saferlhf_paired.jsonl"
DATA_FORMAT="jsonl"  # "json" or "jsonl"
DIRECTION="increase"  # "decrease" or "increase"
OUTPUT_DIR="model/${MODEL_ID}/lora/${DIRECTION}/sft"
NUM_EPOCHS=5
BATCH_SIZE=16
GRAD_ACCUM=4
LEARNING_RATE=5e-4
LORA_R=8
LORA_ALPHA=16
MAX_SEQ_LENGTH=2048
EVAL_SPLIT=0.1  # Proportion of validation set split from training set (10%)

# DeepSpeed configuration (for reducing single GPU memory)
# Options: "none" (not used), "z2" (ZeRO-2), "z3" (ZeRO-3)
# ZeRO-2: Optimizer state + gradient sharding, memory reduced by ~8x, faster
# ZeRO-3: Optimizer state + gradient + parameter sharding, maximum memory reduction (N times, N=GPU count), slightly slower
# Note: Using DeepSpeed requires installing deepspeed library: pip install deepspeed
USE_DEEPSPEED="z2"  # Set to "none" to disable DeepSpeed, use standard DDP

# Disable FP8 (if GPU does not support sm90)
export NCCL_FP8_DISABLE=1
export TORCH_ALLOW_TF32_CUBLAS_OVERRIDE=0

# Skip version check (if datasets version is incompatible)
# LLaMA-Factory requires datasets>=2.16.0,<=4.0.0, but newer version may be installed
export DISABLE_VERSION_CHECK=1
# Suppress warning messages after distributed training interruption
export PYTHONWARNINGS="ignore::UserWarning"
export NCCL_DEBUG=WARN  # Only show warnings and above

echo "Configuration:"
echo "  Model: $MODEL_NAME"
echo "  Data: $TRAIN_DATA"
echo "  Data format: $DATA_FORMAT"
echo "  Output: $OUTPUT_DIR"
echo "  Direction: $DIRECTION"
echo "  Epochs: $NUM_EPOCHS"
echo "  Batch size: $BATCH_SIZE"
echo "  Gradient accumulation: $GRAD_ACCUM"
echo "  Learning rate: $LEARNING_RATE"
echo "  LoRA r: $LORA_R, alpha: $LORA_ALPHA"
echo "  Eval split: $EVAL_SPLIT (from training set)"
echo ""

# Get script directory and project root, and switch to project root (ensure relative paths are correct)
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
PROJECT_ROOT="$(cd "$SCRIPT_DIR/.." && pwd)"
cd "$PROJECT_ROOT"

# Check data file
if [ ! -f "$TRAIN_DATA" ]; then
    echo "❌ Error: Training data not found: $TRAIN_DATA"
    exit 1
fi

# Check if llamafactory-cli is available
if ! command -v llamafactory-cli &> /dev/null; then
    echo "❌ Error: llamafactory-cli not found"
    echo "Please install LLaMA-Factory first:"
    echo "  git clone https://github.com/hiyouga/LLaMA-Factory.git"
    echo "  cd LLaMA-Factory"
    echo "  pip install -e .[metrics]"
    exit 1
fi

# Create output directory
mkdir -p "$OUTPUT_DIR"

# Create temporary config file
CONFIG_FILE="${OUTPUT_DIR}/train_config.yaml"
echo "Creating LLaMA-Factory config file: $CONFIG_FILE"

# Check GPU count
if [ -n "$SLURM_GPUS_ON_NODE" ]; then
    NUM_GPUS=$SLURM_GPUS_ON_NODE
elif [ -n "$SLURM_GPUS" ]; then
    NUM_GPUS=$SLURM_GPUS
else
    NUM_GPUS=$(nvidia-smi --list-gpus 2>/dev/null | wc -l || echo "1")
fi
echo "Detected $NUM_GPUS GPU(s)"
echo ""

# Determine DeepSpeed configuration (if enabled)
DEEPSPEED_LINE=""
if [ "$USE_DEEPSPEED" != "none" ]; then
    # Find LLaMA-Factory path
    LLAMAFACTORY_DIR="${LLAMAFACTORY_DIR:-}"
    if [ -z "$LLAMAFACTORY_DIR" ] && [ -n "$SCRATCH" ] && [ -d "$SCRATCH/LLaMA-Factory" ]; then
        LLAMAFACTORY_DIR="$SCRATCH/LLaMA-Factory"
    elif [ -z "$LLAMAFACTORY_DIR" ] && [ -d "$HOME/LLaMA-Factory" ]; then
        LLAMAFACTORY_DIR="$HOME/LLaMA-Factory"
    elif [ -z "$LLAMAFACTORY_DIR" ] && command -v llamafactory-cli &> /dev/null; then
        # 尝试从 llamafactory-cli 的位置推断
        CLI_PATH=$(which llamafactory-cli)
        if [ -d "$(dirname "$CLI_PATH")/../LLaMA-Factory" ]; then
            LLAMAFACTORY_DIR="$(cd "$(dirname "$CLI_PATH")/../LLaMA-Factory" && pwd)"
        fi
    fi
    
    if [ -n "$LLAMAFACTORY_DIR" ] && [ -d "$LLAMAFACTORY_DIR/examples/deepspeed" ]; then
        # 确定 DeepSpeed 配置文件
        if [ "$USE_DEEPSPEED" = "z2" ]; then
            DS_CONFIG="examples/deepspeed/ds_z2_config.json"
        elif [ "$USE_DEEPSPEED" = "z3" ]; then
            DS_CONFIG="examples/deepspeed/ds_z3_config.json"
        else
            echo "⚠️  Warning: Unknown DeepSpeed stage '$USE_DEEPSPEED', using ZeRO-3"
            DS_CONFIG="examples/deepspeed/ds_z3_config.json"
        fi
        
        # 检查配置文件是否存在，使用绝对路径
        DS_CONFIG_ABS="$LLAMAFACTORY_DIR/$DS_CONFIG"
        if [ -f "$DS_CONFIG_ABS" ]; then
            DEEPSPEED_LINE="deepspeed: ${DS_CONFIG_ABS}"
            echo "  Using DeepSpeed ZeRO-${USE_DEEPSPEED:1} for memory optimization"
            echo "  DeepSpeed config: $DS_CONFIG_ABS"
        else
            echo "⚠️  Warning: DeepSpeed config file not found: $DS_CONFIG_ABS"
            echo "  Continuing without DeepSpeed"
        fi
    else
        echo "⚠️  Warning: LLaMA-Factory directory not found, cannot use DeepSpeed"
        echo "  Continuing without DeepSpeed"
    fi
fi

# 创建数据集配置目录（放在 llamafactory 目录中）
DATASET_CONFIG_DIR="${SCRIPT_DIR}/llamafactory/data"
mkdir -p "$DATASET_CONFIG_DIR"

# 创建 YAML 配置文件
cat > "$CONFIG_FILE" << EOF
### model
model_name_or_path: ${MODEL_NAME}
trust_remote_code: true

### method
stage: sft
do_train: true
finetuning_type: lora
lora_target: all
lora_rank: ${LORA_R}
lora_alpha: ${LORA_ALPHA}
lora_dropout: 0.05
$(if [ -n "$DEEPSPEED_LINE" ]; then echo "$DEEPSPEED_LINE"; fi)

### dataset
dataset: rewrite_dataset
dataset_dir: ${DATASET_CONFIG_DIR}
template: gpt_oss
cutoff_len: ${MAX_SEQ_LENGTH}
overwrite_cache: false
preprocessing_num_workers: 32
val_size: ${EVAL_SPLIT}

### output
output_dir: ${OUTPUT_DIR}
logging_steps: 10
save_strategy: epoch
plot_loss: true
overwrite_output_dir: true

### train
per_device_train_batch_size: ${BATCH_SIZE}
per_device_eval_batch_size: ${BATCH_SIZE}
gradient_accumulation_steps: ${GRAD_ACCUM}
learning_rate: ${LEARNING_RATE}
num_train_epochs: ${NUM_EPOCHS}
lr_scheduler_type: cosine
warmup_ratio: 0.1
weight_decay: 0.01
bf16: true
fp16: false
gradient_checkpointing: true
ddp_find_unused_parameters: false

### eval
eval_strategy: epoch
EOF

# 创建数据集配置文件
# LLaMA-Factory 使用 data/dataset_info.json 作为数据集配置
DATASET_CONFIG_FILE="${DATASET_CONFIG_DIR}/dataset_info.json"
echo "Creating dataset config file: $DATASET_CONFIG_FILE"

# 转换数据格式（如果需要）
# LLaMA-Factory 需要特定格式，我们需要转换数据
CONVERTED_DATA_DIR="${SCRIPT_DIR}/llamafactory/data"
mkdir -p "$CONVERTED_DATA_DIR"
CONVERTED_DATA="${CONVERTED_DATA_DIR}/rewrite_dataset_${DIRECTION}.jsonl"

echo "Converting data format for LLaMA-Factory..."
CONVERT_SCRIPT="${SCRIPT_DIR}/convert_data_for_sft.py"
if [ ! -f "$CONVERT_SCRIPT" ]; then
    echo "❌ Error: convert_data_for_sft.py not found at $CONVERT_SCRIPT"
    exit 1
fi
echo "  Using script: $CONVERT_SCRIPT"
python3 "$CONVERT_SCRIPT" \
    --input "$TRAIN_DATA" \
    --output "$CONVERTED_DATA" \
    --direction "$DIRECTION" \
    --data_format "$DATA_FORMAT"

# 更新数据集配置文件
# LLaMA-Factory 使用 data/dataset_info.json 或 data/rewrite_dataset.json
cat > "$DATASET_CONFIG_FILE" << DATASET_EOF
{
  "rewrite_dataset": {
    "file_name": "rewrite_dataset_${DIRECTION}.jsonl",
    "file_dir": "${CONVERTED_DATA_DIR}",
    "formatting": "sharegpt",
    "columns": {
      "messages": "messages"
    }
  }
}
DATASET_EOF

echo ""
echo "Starting training with LLaMA-Factory..."
echo ""

# 运行训练
llamafactory-cli train "$CONFIG_FILE"

echo ""
echo "========================================================================"
echo "✅ Training completed!"
echo "========================================================================"
echo "Model saved to: $OUTPUT_DIR"
echo ""
echo "To test the model, run:"
echo "  ./test_trained_model.sh --test-file data_generation/pku_saferlhf_test_final.jsonl"

