#!/bin/bash
#
# 合并 LoRA 适配器脚本
# 用法:
#   ./merge_increase_rewriter.sh
#   ./merge_increase_rewriter.sh --adapter_path path/to/adapter --export_dir path/to/export
#

set -e

SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
PROJECT_ROOT="$(cd "$SCRIPT_DIR/.." && pwd)"

# 切换到项目根目录（确保相对路径正确）
cd "$PROJECT_ROOT"

CONFIG_FILE=$(mktemp "${SCRIPT_DIR}/merge_increase_rewriter_XXXXXX.yaml")
export DISABLE_VERSION_CHECK=1

# ============================================================================
# 模型配置 - 只需修改这里即可切换模型
# ============================================================================
BASE_MODEL="Qwen/Qwen3-4B-Instruct-2507"

# 从 BASE_MODEL 自动生成模型标识符（用于路径）
if [[ "$BASE_MODEL" == *"/"* ]]; then
    MODEL_ID=$(echo "$BASE_MODEL" | sed 's/.*\///' | tr '[:upper:]' '[:lower:]')
else
    MODEL_ID=$(echo "$BASE_MODEL" | tr '[:upper:]' '[:lower:]')
fi

# 从 BASE_MODEL 提取模型显示名称（用于日志）
MODEL_DISPLAY_NAME=$(echo "$BASE_MODEL" | sed 's/.*\///')

# ============================================================================
# 合并配置参数
# ============================================================================
DIRECTION="increase"  # "decrease" 或 "increase"
METHOD="sft"  # "sft" | "dpo"
ADAPTER_TYPE="lora"
CHECKPOINT=""  # checkpoint 编号，如 "788" 或 "checkpoint-788"
ADAPTER_PATH=""  # 如果为空，将自动从上述参数构建
TEMPLATE="qwen3"  # 模型模板，如 "qwen3", "llama3" 等
INFER_DTYPE="bfloat16"  # 合并后的数据类型："auto"（自动选择，通常是bfloat16）、"float32"（最高精度）、"float16"、"bfloat16"。建议使用"bfloat16"或"float32"以确保一致性
EXPORT_DIR=""  # 如果为空，将自动生成
EXPORT_SIZE=5  # 每个分片文件的最大大小（GB）。如果模型总大小超过此值，会自动分成多个文件。可根据需要调整（如5、10等）以减少分片数量
EXPORT_DEVICE="auto"  # "cpu" 或 "auto"。只影响计算设备，不影响最终模型权重
EXPORT_LEGACY_FORMAT=false  # true 或 false。只影响保存格式，不影响模型性能

# 解析命令行参数
while [[ $# -gt 0 ]]; do
    case $1 in
        --base_model|--base-model)
            BASE_MODEL="$2"
            # 重新生成 MODEL_ID
            if [[ "$BASE_MODEL" == *"/"* ]]; then
                MODEL_ID=$(echo "$BASE_MODEL" | sed 's/.*\///' | tr '[:upper:]' '[:lower:]')
            else
                MODEL_ID=$(echo "$BASE_MODEL" | tr '[:upper:]' '[:lower:]')
            fi
            MODEL_DISPLAY_NAME=$(echo "$BASE_MODEL" | sed 's/.*\///')
            shift 2
            ;;
        --adapter_path|--adapter-path)
            ADAPTER_PATH="$2"
            shift 2
            ;;
        --direction)
            DIRECTION="$2"
            shift 2
            ;;
        --method)
            METHOD="$2"
            shift 2
            ;;
        --adapter-type|--adapter_type)
            ADAPTER_TYPE="$2"
            shift 2
            ;;
        --checkpoint)
            CHECKPOINT="$2"
            shift 2
            ;;
        --template)
            TEMPLATE="$2"
            shift 2
            ;;
        --infer_dtype|--infer-dtype)
            INFER_DTYPE="$2"
            shift 2
            ;;
        --export_dir|--export-dir)
            EXPORT_DIR="$2"
            shift 2
            ;;
        --export_size|--export-size)
            EXPORT_SIZE="$2"
            shift 2
            ;;
        --export_device|--export-device)
            EXPORT_DEVICE="$2"
            shift 2
            ;;
        --export_legacy_format|--export-legacy-format)
            EXPORT_LEGACY_FORMAT="$2"
            shift 2
            ;;
        *)
            echo "Unknown option: $1"
            echo "Usage: $0 [--base_model MODEL] [--adapter_path PATH] [--direction DIRECTION] [--method METHOD] [--checkpoint CHECKPOINT] [--export_dir DIR] [--template TEMPLATE] [--infer_dtype DTYPE] [--export_size SIZE] [--export_device DEVICE] [--export_legacy_format true/false]"
            exit 1
            ;;
    esac
done

# 自动构建适配器路径（如果未指定）
AUTO_DETECTED_CHECKPOINT=""  # 用于记录自动检测到的 checkpoint 信息
if [ -z "$ADAPTER_PATH" ]; then
    # 先构建基础适配器路径（不包含 checkpoint）
    BASE_ADAPTER_PATH="model/${MODEL_ID}/${ADAPTER_TYPE}/${DIRECTION}/${METHOD}"
    
    # 如果 CHECKPOINT 为空，自动查找最新的 checkpoint
    if [ -z "$CHECKPOINT" ]; then
        if [ -d "$BASE_ADAPTER_PATH" ]; then
            # 优先使用当前训练 run 的 global_step 对应的 checkpoint（避免目录里残留旧 checkpoint 误选）
            TRAINER_STATE_JSON="$BASE_ADAPTER_PATH/trainer_state.json"
            if [ -f "$TRAINER_STATE_JSON" ]; then
                GLOBAL_STEP=$(python3 -c "import json; print(json.load(open('$TRAINER_STATE_JSON','r')).get('global_step',''))" 2>/dev/null || true)
                if [[ "$GLOBAL_STEP" =~ ^[0-9]+$ ]] && [ -d "$BASE_ADAPTER_PATH/checkpoint-$GLOBAL_STEP" ]; then
                    CHECKPOINT_DIRNAME="checkpoint-$GLOBAL_STEP"
                    ADAPTER_PATH="$BASE_ADAPTER_PATH/$CHECKPOINT_DIRNAME"
                    AUTO_DETECTED_CHECKPOINT="$CHECKPOINT_DIRNAME (from trainer_state.json, global_step=$GLOBAL_STEP)"
                else
                    # 回退：按修改时间找最新 checkpoint
                    LATEST_CHECKPOINT=$(ls -td "$BASE_ADAPTER_PATH"/checkpoint-* 2>/dev/null | head -1)
                    if [ -n "$LATEST_CHECKPOINT" ]; then
                        CHECKPOINT_DIRNAME=$(basename "$LATEST_CHECKPOINT")
                        ADAPTER_PATH="$LATEST_CHECKPOINT"
                        AUTO_DETECTED_CHECKPOINT="$CHECKPOINT_DIRNAME (by mtime)"
                    else
                        echo "⚠️  Warning: No checkpoint found in $BASE_ADAPTER_PATH"
                        echo "    Using directory directly (may not work for LoRA merge)"
                        ADAPTER_PATH="$BASE_ADAPTER_PATH"
                        AUTO_DETECTED_CHECKPOINT="(none - using directory)"
                    fi
                fi
            else
                # 没有 trainer_state.json 时，按修改时间找最新 checkpoint
                LATEST_CHECKPOINT=$(ls -td "$BASE_ADAPTER_PATH"/checkpoint-* 2>/dev/null | head -1)
                if [ -n "$LATEST_CHECKPOINT" ]; then
                    CHECKPOINT_DIRNAME=$(basename "$LATEST_CHECKPOINT")
                    ADAPTER_PATH="$LATEST_CHECKPOINT"
                    AUTO_DETECTED_CHECKPOINT="$CHECKPOINT_DIRNAME (by mtime)"
                else
                    echo "⚠️  Warning: No checkpoint found in $BASE_ADAPTER_PATH"
                    echo "    Using directory directly (may not work for LoRA merge)"
                    ADAPTER_PATH="$BASE_ADAPTER_PATH"
                    AUTO_DETECTED_CHECKPOINT="(none - using directory)"
                fi
            fi
        else
            echo "❌ Error: Adapter base path not found: $BASE_ADAPTER_PATH"
            exit 1
        fi
    else
        # 用户指定了 CHECKPOINT，规范化并构建完整路径
        if [[ "$CHECKPOINT" == checkpoint-* ]]; then
            CHECKPOINT_DIRNAME="$CHECKPOINT"
        else
            CHECKPOINT_DIRNAME="checkpoint-$CHECKPOINT"
        fi
        ADAPTER_PATH="$BASE_ADAPTER_PATH/$CHECKPOINT_DIRNAME"
    fi
fi

# 自动构建导出目录（如果未指定）
if [ -z "$EXPORT_DIR" ]; then
    EXPORT_DIR="model/${DIRECTION}_rewriter/${METHOD}"
fi

# 清理函数：删除临时配置文件
cleanup() {
    if [ -f "$CONFIG_FILE" ]; then
        rm -f "$CONFIG_FILE"
        echo "Cleaned up temporary config file: $CONFIG_FILE"
    fi
}

# 注册清理函数，在脚本退出时执行（包括正常退出和错误退出）
trap cleanup EXIT

echo "========================================================================"
echo "Merging LoRA Adapter - $MODEL_DISPLAY_NAME"
echo "========================================================================"
echo ""

# 检查适配器路径是否存在（在显示配置之前检查，以便在自动检测 checkpoint 时显示信息）
if [ ! -d "$ADAPTER_PATH" ]; then
    echo "❌ Error: Adapter path not found: $ADAPTER_PATH"
    exit 1
fi

echo "Configuration:"
echo "  Base model: $BASE_MODEL"
echo "  Adapter path: $ADAPTER_PATH"
if [ -n "$CHECKPOINT" ]; then
    echo "  Checkpoint: $CHECKPOINT (manual override)"
elif [ -n "$AUTO_DETECTED_CHECKPOINT" ]; then
    echo "  Checkpoint: $AUTO_DETECTED_CHECKPOINT (auto-detected)"
fi
echo "  Template: $TEMPLATE"
echo "  Infer dtype: $INFER_DTYPE (affects merged model precision)"
echo "  Export directory: $EXPORT_DIR"
echo "  Export size: ${EXPORT_SIZE}GB"
echo "  Export device: $EXPORT_DEVICE"
echo "  Export legacy format: $EXPORT_LEGACY_FORMAT"
echo ""

echo "✅ Adapter path found: $ADAPTER_PATH"
echo ""

# 创建临时配置文件
echo "Creating temporary configuration file: $CONFIG_FILE"
cat > "$CONFIG_FILE" << EOF
### Note: DO NOT use quantized model or quantization_bit when merging lora adapters

### model
model_name_or_path: ${BASE_MODEL}
adapter_name_or_path: ${ADAPTER_PATH}
template: ${TEMPLATE}
finetuning_type: lora
infer_dtype: ${INFER_DTYPE}
trust_remote_code: true

### export
export_dir: ${EXPORT_DIR}
export_size: ${EXPORT_SIZE}
export_device: ${EXPORT_DEVICE}  # choices: [cpu, auto]
export_legacy_format: ${EXPORT_LEGACY_FORMAT}
EOF

echo "✅ Configuration file created"
echo ""

# 运行合并命令
echo "Running llamafactory-cli export..."
echo ""

llamafactory-cli export "$CONFIG_FILE"

EXIT_CODE=$?

if [ $EXIT_CODE -eq 0 ]; then
    echo ""
    echo "========================================================================"
    echo "✅ Merge completed successfully!"
    echo "========================================================================"
    echo ""
    echo "Merged model location: $EXPORT_DIR"
    echo ""
else
    echo ""
    echo "❌ Error: Merge failed with exit code $EXIT_CODE"
    exit $EXIT_CODE
fi

