#!/bin/bash

# bash recipe/fileagent/run_extracted_bench.sh
set -x

ulimit -n 65535
unset RAY_ADDRESS
export VLLM_USE_RAY=0

# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
# 🔧 环境配置（必须在最前面，确保 Ray worker 继承环境变量）
# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

# LLM Judge 配置
export LLM_JUDGE_IP=${LLM_JUDGE_IP:-"2605:340:cd51:4b00:140c:e92f:caf3:f57a"}
export LLM_JUDGE_PORT=${LLM_JUDGE_PORT:-"18908"}
export LLM_JUDGE_MODEL_NAME=${LLM_JUDGE_MODEL_NAME:-"Qwen2.5-32b-bench-v2-deepseek"}

# 构造 LLM Judge Base URL（处理 IPv6）
if [[ $LLM_JUDGE_IP == *":"* ]] && [[ $LLM_JUDGE_IP != "["* ]]; then
    # IPv6 地址需要用方括号包裹
    export LLM_JUDGE_BASE_URL="http://[${LLM_JUDGE_IP}]:${LLM_JUDGE_PORT}/v1"
else
    # IPv4 或已经包含方括号的地址
    export LLM_JUDGE_BASE_URL="http://${LLM_JUDGE_IP}:${LLM_JUDGE_PORT}/v1"
fi

# FileAgent 工具指标收集（必需！）
export USE_FILEAGENT_TRAINER="true"

# Ray 和日志配置
export VLLM_USE_V1=1
export RAY_JOB_ALLOW_DRIVER_ON_WORKER_NODES=1
export HF_HUB_ENABLE_HF_TRANSFER=1
export VERL_LOGGING_LEVEL="INFO"
export RAY_DEDUP_LOGS=0
export PYTHONUNBUFFERED=1
export VLLM_LOGGING_LEVEL="INFO"
export RAY_LOG_TO_STDERR=1

# Flash Attention 配置（如果遇到兼容性问题，取消注释）
# export DISABLE_FLASH_ATTN=1
# export TRANSFORMERS_NO_FLASH_ATTN=1

# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
# 📁 项目路径配置
# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

PROJECT_DIR="/mnt/bn/fileagent-storage/users/<your_username>/verl"
PROJECT_NAME="fileagent_agentic_rl"
EXP_NAME="qwen25_7b_extracted_bench_$(date +%Y%m%d_%H%M%S)"
SAVE_DIR="/mnt/bn/fileagent-storage/users/<your_username>/verl/checkpoints/extracted_bench/${EXP_NAME}"

# 切换到项目根目录（重要！确保相对路径正确）
cd ${PROJECT_DIR}

# 确保 PYTHONPATH 包含项目目录
export PYTHONPATH="${PROJECT_DIR}:${PYTHONPATH}"

# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
# 🤖 模型和数据配置
# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

MODEL_PATH="/mnt/bn/fileagent-storage/user/<your_username>/LLaMA-Factory/Qwen2.5-7b-bench-v1-deepseek/checkpoint-120"

# 数据路径
DATA_HOME="/mnt/bn/fileagent-storage/users/<your_username>/verl/data/all_pdfs_short_1000_verl"
TRAIN_FILES="['${DATA_HOME}/train.parquet']"
VAL_FILES="['${DATA_HOME}/train.parquet']"  # 使用同一份数据进行验证

# 配置文件路径
cfg_path="${PROJECT_DIR}/recipe/fileagent/config"
cfg_name="extracted_bench_trainer"
tool_cfg_path="recipe/fileagent/config/tool/extracted_bench_tool.yaml"
agent_loop_cfg_path="recipe/fileagent/config/agent_loop.yaml"
new_sp_path="recipe/fileagent/prompts/extracted_bench_sp.md"

# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
# 🔬 算法超参数
# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

# Algorithm
adv_estimator="grpo"
loss_agg_mode="token-mean"
clip_ratio_low=0.2
clip_ratio_high=0.28
clip_ratio_c=10.0

use_kl_in_reward=False
kl_coef=0.0
use_kl_loss=False
kl_loss_coef=0.0

# Batch Size
train_bsz=128
val_bsz=16  # 降低验证批次大小，减少 LLM Judge 并发
train_mini_bsz=64
train_micro_bsz_per_gpu=4
infer_micro_bsz_per_gpu=8
n_resp_per_prompt=8
max_turns=15  # Extracted Bench 任务可能需要更多轮次

# Sequence Length
max_prompt_len=$((1024 * 2))      # 2KB
max_resp_len=$((1024 * 30))       # 30KB for multi-turn
max_tool_resp_len=$((1024 * 20))  # 20KB for tool responses

# Performance Related Parameter
train_sp_size=4
infer_tp_size=4
use_dynamic_bsz=True
actor_ppo_max_token_len=$((max_prompt_len + max_resp_len))
infer_ppo_max_token_len=$((max_prompt_len + max_resp_len))
max_num_batched_tokens=$((max_prompt_len + max_resp_len))

# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
# 📊 训练调度和日志
# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

# Ray & Cluster Settings
NNODES=1
NGPUS_PER_NODE=8

# Trainer Schedule & Logging
val_before_train=False  # 不在训练前验证，避免初始 LLM Judge 爆发
test_freq=1000          # 设为很大的值，基本禁用中间验证
save_freq=20
total_epochs=1
log_val_generations=5

# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
# 📊 数据预处理
# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

echo "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━"
echo "🔄 数据预处理"
echo "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━"

# 检查数据是否存在，如果不存在则转换
if [ ! -f "${DATA_HOME}/train.parquet" ]; then
    echo "⚠️  训练数据不存在，开始转换..."
    python3 ${PROJECT_DIR}/recipe/fileagent/convert_extracted_bench.py \
        --input /path/to/project/verl/data/all_pdfs_short_1000_verl/dataset_original.json \
        --output ${DATA_HOME}/train.parquet \
        --system_prompt ${PROJECT_DIR}/recipe/fileagent/prompts/extracted_bench_sp.md
    
    if [ $? -ne 0 ]; then
        echo "❌ 数据转换失败"
        exit 1
    fi
    echo "✅ 数据转换完成"
else
    echo "✅ 训练数据已存在: ${DATA_HOME}/train.parquet"
fi

# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
# 🎯 计算训练步数
# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

# 获取数据集大小
DATASET_SIZE=$(python3 -c "import pandas as pd; print(len(pd.read_parquet('${DATA_HOME}/train.parquet')))")
TOTAL_STEPS=$(( (DATASET_SIZE + train_bsz - 1) / train_bsz ))  # 向上取整

echo "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━"
echo "📋 训练配置总结"
echo "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━"
echo "🤖 模型: ${MODEL_PATH}"
echo "📊 数据集大小: ${DATASET_SIZE} 样本"
echo "🎯 训练批次大小: ${train_bsz}"
echo "📈 总训练步数: ${TOTAL_STEPS}"
echo "🔄 最大轮次: ${max_turns}"
echo "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━"
echo "🔧 环境变量"
echo "   USE_FILEAGENT_TRAINER=${USE_FILEAGENT_TRAINER} ✅"
echo "   LLM_JUDGE_BASE_URL=${LLM_JUDGE_BASE_URL}"
echo "   LLM_JUDGE_MODEL_NAME=${LLM_JUDGE_MODEL_NAME}"
echo "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━"
echo "🚀 开始训练..."
echo "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━"

# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
# 🚀 启动训练
# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

cd ${PROJECT_DIR}

python3 -m verl.trainer.main_ppo \
    --config-path=${cfg_path} \
    --config-name=${cfg_name} \
    algorithm.adv_estimator=${adv_estimator} \
    algorithm.use_kl_in_reward=${use_kl_in_reward} \
    algorithm.kl_ctrl.kl_coef=${kl_coef} \
    data.train_batch_size=${train_bsz} \
    data.val_batch_size=${val_bsz} \
    data.max_prompt_length=${max_prompt_len} \
    data.max_response_length=${max_resp_len} \
    data.filter_overlong_prompts=True \
    data.truncation="error" \
    data.return_raw_chat=True \
    data.train_files=${TRAIN_FILES} \
    data.val_files=${VAL_FILES} \
    +data.replace_system_prompt=True \
    +data.new_sp_path=${new_sp_path} \
    actor_rollout_ref.model.path=${MODEL_PATH} \
    actor_rollout_ref.model.use_remove_padding=True \
    actor_rollout_ref.model.enable_gradient_checkpointing=True \
    actor_rollout_ref.actor.optim.lr=1e-6 \
    actor_rollout_ref.actor.ppo_mini_batch_size=${train_mini_bsz} \
    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=${train_micro_bsz_per_gpu} \
    actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \
    actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \
    actor_rollout_ref.actor.kl_loss_type=low_var_kl \
    actor_rollout_ref.actor.entropy_coeff=0.0 \
    actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \
    actor_rollout_ref.actor.fsdp_config.param_offload=False \
    actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \
    actor_rollout_ref.actor.ulysses_sequence_parallel_size=${train_sp_size} \
    actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \
    actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \
    actor_rollout_ref.actor.clip_ratio_c=${clip_ratio_c} \
    actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \
    actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \
    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=${infer_micro_bsz_per_gpu} \
    actor_rollout_ref.rollout.name=vllm \
    actor_rollout_ref.rollout.mode=async \
    actor_rollout_ref.rollout.gpu_memory_utilization=0.8 \
    actor_rollout_ref.rollout.n=${n_resp_per_prompt} \
    actor_rollout_ref.rollout.tensor_model_parallel_size=${infer_tp_size} \
    actor_rollout_ref.rollout.multi_turn.max_assistant_turns=${max_turns} \
    actor_rollout_ref.rollout.multi_turn.max_parallel_calls=1 \
    actor_rollout_ref.rollout.multi_turn.max_tool_response_length=${max_tool_resp_len} \
    actor_rollout_ref.rollout.multi_turn.tool_config_path=${tool_cfg_path} \
    actor_rollout_ref.rollout.agent.agent_loop_config_path=${agent_loop_cfg_path} \
    actor_rollout_ref.rollout.agent.num_workers=8 \
    actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \
    actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \
    actor_rollout_ref.rollout.max_num_batched_tokens=${max_num_batched_tokens} \
    actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \
    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=${infer_micro_bsz_per_gpu} \
    actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \
    actor_rollout_ref.ref.fsdp_config.param_offload=True \
    trainer.critic_warmup=0 \
    trainer.val_before_train=${val_before_train} \
    trainer.logger='["console","wandb"]' \
    trainer.project_name=${PROJECT_NAME} \
    trainer.experiment_name=${EXP_NAME} \
    trainer.n_gpus_per_node=${NGPUS_PER_NODE} \
    trainer.nnodes=${NNODES} \
    trainer.save_freq=${save_freq} \
    trainer.test_freq=${test_freq} \
    trainer.log_val_generations=${log_val_generations} \
    trainer.total_epochs=${total_epochs} \
    trainer.default_local_dir=${SAVE_DIR} \
    trainer.resume_mode=auto \
    reward_model.reward_manager=batch \
    custom_reward_function.reward_kwargs.llm_judge_ip="${LLM_JUDGE_IP}" \
    custom_reward_function.reward_kwargs.llm_judge_port="${LLM_JUDGE_PORT}" \
    custom_reward_function.reward_kwargs.llm_judge_model_name="${LLM_JUDGE_MODEL_NAME}" \
    $@

echo "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━"
echo "✅ 训练完成！"
echo "   检查点保存在: ${SAVE_DIR}"
echo "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━"

