#!/bin/bash
#
# DPO 训练脚本 - 基于 LLaMA-Factory
# 用法: ./train_dpo_llamafactory.sh
#

set -e

# ============================================================================
# 模型配置 - 只需修改这里即可切换模型
# ============================================================================
BASE_MODEL="Qwen/Qwen3-4B-Instruct-2507"

# 从 BASE_MODEL 自动生成模型标识符（用于路径）
if [[ "$BASE_MODEL" == *"/"* ]]; then
    MODEL_ID=$(echo "$BASE_MODEL" | sed 's/.*\///' | tr '[:upper:]' '[:lower:]')
else
    MODEL_ID=$(echo "$BASE_MODEL" | tr '[:upper:]' '[:lower:]')
fi

# 从 BASE_MODEL 提取模型显示名称（用于日志）
MODEL_DISPLAY_NAME=$(echo "$BASE_MODEL" | sed 's/.*\///')

echo "========================================================================"
echo "$MODEL_DISPLAY_NAME DPO Training (LLaMA-Factory)"
echo "========================================================================"
echo ""

# 配置参数
MODEL_NAME="$BASE_MODEL"
DIRECTION="decrease"  # "decrease" 或 "increase"
SFT_MODEL_PATH="model/${MODEL_ID}/lora/${DIRECTION}/sft"
TRAIN_DATA="data/pku_saferlhf_filtered.jsonl"
DATA_FORMAT="jsonl"  # "json" 或 "jsonl"
OUTPUT_DIR="model/${MODEL_ID}/lora/${DIRECTION}/dpo"
NUM_EPOCHS=4
BATCH_SIZE=4
GRAD_ACCUM=4
LEARNING_RATE=5e-5
BETA=0.1  # DPO beta 参数
LORA_R=8
LORA_ALPHA=16
MAX_SEQ_LENGTH=2048
EVAL_SPLIT=0.1  # 从训练集中拆分验证集的比例（10%）

# DeepSpeed 配置（用于减少单个 GPU 显存）
USE_DEEPSPEED="z2"  # 设置为 "none" 可禁用 DeepSpeed，使用标准 DDP

# 禁用 FP8（如果 GPU 不支持 sm90）
export NCCL_FP8_DISABLE=1
export TORCH_ALLOW_TF32_CUBLAS_OVERRIDE=0

# 跳过版本检查（如果 datasets 版本不兼容）
export DISABLE_VERSION_CHECK=1
# 抑制分布式训练中断后的警告信息
export PYTHONWARNINGS="ignore::UserWarning"
export NCCL_DEBUG=WARN  # 只显示警告级别以上的信息

echo "Configuration:"
echo "  Base model: $MODEL_NAME"
if [ -n "$SFT_MODEL_PATH" ]; then
    echo "  SFT model path: $SFT_MODEL_PATH"
else
    echo "  SFT model path: (none - training from base model)"
fi
echo "  Data: $TRAIN_DATA"
echo "  Data format: $DATA_FORMAT"
echo "  Output: $OUTPUT_DIR"
echo "  Direction: $DIRECTION"
echo "    - Chosen: $DIRECTION direction (preferred)"
echo "    - Rejected: opposite direction"
echo "    - Generating up to 3 pairs per sample from increase/unchange/decrease fields"
echo "  Epochs: $NUM_EPOCHS"
echo "  Batch size: $BATCH_SIZE"
echo "  Gradient accumulation: $GRAD_ACCUM"
echo "  Learning rate: $LEARNING_RATE"
echo "  DPO beta: $BETA"
echo "  LoRA r: $LORA_R, alpha: $LORA_ALPHA"
echo "  Eval split: $EVAL_SPLIT (from training set)"
echo ""

# 获取脚本目录和项目根目录，并切换到项目根目录（确保相对路径正确）
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
PROJECT_ROOT="$(cd "$SCRIPT_DIR/.." && pwd)"
cd "$PROJECT_ROOT"

# 检查数据文件
if [ ! -f "$TRAIN_DATA" ]; then
    echo "❌ Error: Training data not found: $TRAIN_DATA"
    exit 1
fi

# 检查 SFT 模型路径（如果指定了）
if [ -n "$SFT_MODEL_PATH" ]; then
    if [ ! -d "$SFT_MODEL_PATH" ]; then
        echo "❌ Error: SFT model not found: $SFT_MODEL_PATH"
        echo "  Please train SFT model first using train_sft_llamafactory.sh"
        echo "  Or set SFT_MODEL_PATH to empty string to train from base model"
        exit 1
    fi
else
    echo "ℹ️  No SFT model specified, training from base model"
fi

# 检查 llamafactory-cli 是否可用
if ! command -v llamafactory-cli &> /dev/null; then
    echo "❌ Error: llamafactory-cli not found"
    echo "Please install LLaMA-Factory first:"
    echo "  git clone https://github.com/hiyouga/LLaMA-Factory.git"
    echo "  cd LLaMA-Factory"
    echo "  pip install -e .[metrics]"
    exit 1
fi

# 创建输出目录
mkdir -p "$OUTPUT_DIR"

# 创建临时配置文件
CONFIG_FILE="${OUTPUT_DIR}/train_config.yaml"
echo "Creating LLaMA-Factory config file: $CONFIG_FILE"

# 检查GPU数量
if [ -n "$SLURM_GPUS_ON_NODE" ]; then
    NUM_GPUS=$SLURM_GPUS_ON_NODE
elif [ -n "$SLURM_GPUS" ]; then
    NUM_GPUS=$SLURM_GPUS
else
    NUM_GPUS=$(nvidia-smi --list-gpus 2>/dev/null | wc -l || echo "1")
fi
echo "Detected $NUM_GPUS GPU(s)"
echo ""

# 确定 DeepSpeed 配置（如果启用）
DEEPSPEED_LINE=""
if [ "$USE_DEEPSPEED" != "none" ]; then
    # 查找 LLaMA-Factory 的路径
    LLAMAFACTORY_DIR="${LLAMAFACTORY_DIR:-}"
    if [ -z "$LLAMAFACTORY_DIR" ] && [ -n "$SCRATCH" ] && [ -d "$SCRATCH/LLaMA-Factory" ]; then
        LLAMAFACTORY_DIR="$SCRATCH/LLaMA-Factory"
    elif [ -z "$LLAMAFACTORY_DIR" ] && [ -d "$HOME/LLaMA-Factory" ]; then
        LLAMAFACTORY_DIR="$HOME/LLaMA-Factory"
    elif [ -z "$LLAMAFACTORY_DIR" ] && command -v llamafactory-cli &> /dev/null; then
        CLI_PATH=$(which llamafactory-cli)
        if [ -d "$(dirname "$CLI_PATH")/../LLaMA-Factory" ]; then
            LLAMAFACTORY_DIR="$(cd "$(dirname "$CLI_PATH")/../LLaMA-Factory" && pwd)"
        fi
    fi
    
    if [ -n "$LLAMAFACTORY_DIR" ] && [ -d "$LLAMAFACTORY_DIR/examples/deepspeed" ]; then
        if [ "$USE_DEEPSPEED" = "z2" ]; then
            DS_CONFIG="examples/deepspeed/ds_z2_config.json"
        elif [ "$USE_DEEPSPEED" = "z3" ]; then
            DS_CONFIG="examples/deepspeed/ds_z3_config.json"
        else
            echo "⚠️  Warning: Unknown DeepSpeed stage '$USE_DEEPSPEED', using ZeRO-3"
            DS_CONFIG="examples/deepspeed/ds_z3_config.json"
        fi
        
        DS_CONFIG_ABS="$LLAMAFACTORY_DIR/$DS_CONFIG"
        if [ -f "$DS_CONFIG_ABS" ]; then
            DEEPSPEED_LINE="deepspeed: ${DS_CONFIG_ABS}"
            echo "  Using DeepSpeed ZeRO-${USE_DEEPSPEED:1} for memory optimization"
        else
            echo "⚠️  Warning: DeepSpeed config file not found: $DS_CONFIG_ABS"
            echo "  Continuing without DeepSpeed"
        fi
    else
        echo "⚠️  Warning: LLaMA-Factory directory not found, cannot use DeepSpeed"
        echo "  Continuing without DeepSpeed"
    fi
fi

# 创建数据集配置目录（放在 llamafactory 目录中）
DATASET_CONFIG_DIR="${SCRIPT_DIR}/llamafactory/data"
mkdir -p "$DATASET_CONFIG_DIR"

# 创建 YAML 配置文件
cat > "$CONFIG_FILE" << EOF
### model
model_name_or_path: ${MODEL_NAME}
$(if [ -n "$SFT_MODEL_PATH" ]; then echo "adapter_name_or_path: ${SFT_MODEL_PATH}"; fi)
trust_remote_code: true

### method
stage: dpo
do_train: true
finetuning_type: lora
lora_target: all
lora_rank: ${LORA_R}
lora_alpha: ${LORA_ALPHA}
lora_dropout: 0.05
pref_beta: ${BETA}
pref_loss: sigmoid
ref_model: ${MODEL_NAME}
$(if [ -n "$SFT_MODEL_PATH" ]; then echo "ref_model_adapters: ${SFT_MODEL_PATH}"; fi)
$(if [ -n "$DEEPSPEED_LINE" ]; then echo "$DEEPSPEED_LINE"; fi)

### dataset
dataset: dpo_dataset
dataset_dir: ${DATASET_CONFIG_DIR}
template: llama3
cutoff_len: ${MAX_SEQ_LENGTH}
overwrite_cache: false
preprocessing_num_workers: 16
val_size: ${EVAL_SPLIT}

### output
output_dir: ${OUTPUT_DIR}
logging_steps: 10
save_strategy: epoch
plot_loss: true
overwrite_output_dir: true

### train
per_device_train_batch_size: ${BATCH_SIZE}
per_device_eval_batch_size: ${BATCH_SIZE}
gradient_accumulation_steps: ${GRAD_ACCUM}
learning_rate: ${LEARNING_RATE}
num_train_epochs: ${NUM_EPOCHS}
lr_scheduler_type: cosine
warmup_ratio: 0.1
weight_decay: 0.01
bf16: true
fp16: false
gradient_checkpointing: true
ddp_find_unused_parameters: false

### eval
eval_strategy: epoch
EOF

# 创建数据集配置文件
DATASET_CONFIG_FILE="${DATASET_CONFIG_DIR}/dataset_info.json"
echo "Creating dataset config file: $DATASET_CONFIG_FILE"

# 转换数据格式（如果需要）
CONVERTED_DATA_DIR="${SCRIPT_DIR}/llamafactory/data"
mkdir -p "$CONVERTED_DATA_DIR"
CONVERTED_DATA="${CONVERTED_DATA_DIR}/dpo_dataset_${DIRECTION}.jsonl"

echo "Converting data format for LLaMA-Factory DPO..."
echo "  This will generate up to 3 DPO pairs per sample from increase/unchange/decrease fields"
CONVERT_SCRIPT="${SCRIPT_DIR}/convert_data_for_dpo.py"
if [ ! -f "$CONVERT_SCRIPT" ]; then
    echo "❌ Error: convert_data_for_dpo.py not found at $CONVERT_SCRIPT"
    exit 1
fi
echo "  Using script: $CONVERT_SCRIPT"
python3 "$CONVERT_SCRIPT" \
    --input "$TRAIN_DATA" \
    --output "$CONVERTED_DATA" \
    --direction "$DIRECTION" \
    --data_format "$DATA_FORMAT"

# 更新数据集配置文件
# 根据 LLaMA-Factory 文档，DPO 数据集使用 ShareGPT 格式，并设置 ranking: true
cat > "$DATASET_CONFIG_FILE" << DATASET_EOF
{
  "dpo_dataset": {
    "file_name": "dpo_dataset_${DIRECTION}.jsonl",
    "file_dir": "${CONVERTED_DATA_DIR}",
    "formatting": "sharegpt",
    "ranking": true,
    "columns": {
      "messages": "conversations",
      "chosen": "chosen",
      "rejected": "rejected"
    }
  }
}
DATASET_EOF

echo ""
echo "Starting training with LLaMA-Factory..."
echo ""

# 运行训练
llamafactory-cli train "$CONFIG_FILE"

echo ""
echo "========================================================================"
echo "✅ Training completed!"
echo "========================================================================"
echo "Model saved to: $OUTPUT_DIR"
echo ""
echo "To test the model, run:"
echo "  ./test_trained_model.sh --adapter_path $OUTPUT_DIR"

