#!/bin/bash
set -x
ulimit -n 65535

# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
# 📋 实验元数据
# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
PROJECT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/../.." && pwd)"  # 自动获取项目根目录
PROJECT_NAME="fileagent_agentic_rl"
EXP_NAME="qwen25_7b_extracted_bench_$(date +%Y%m%d_%H%M%S)"
SAVE_DIR="/mnt/bn/fileagent-storage/users/<your_username>/verl/checkpoints/extracted_bench/${EXP_NAME}"

echo "📁 项目根目录: ${PROJECT_DIR}"

# 切换到项目根目录（重要！确保相对路径正确）
cd ${PROJECT_DIR}

# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
# ☁️  Ray & 集群配置
# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
NNODES=1
NGPUS_PER_NODE=8

# Ray 集群地址（如果使用远程 Ray 集群，修改此地址）
RAY_ADDRESS="${RAY_ADDRESS:-auto}"  # auto 表示本地 Ray
PWD_NOW="$(pwd)"
WORKING_DIR="${WORKING_DIR:-${PWD_NOW}}"
RUNTIME_ENV="${RUNTIME_ENV:-recipe/fileagent/runtime_env_extracted_bench.yaml}"

# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
# 🤖 模型和数据配置
# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
MODEL_PATH="/mnt/bn/fileagent-storage/user/<your_username>/LLaMA-Factory/Qwen2.5-7b-bench-v1-deepseek/checkpoint-120"

# 数据路径
DATA_HOME="/mnt/bn/fileagent-storage/users/<your_username>/verl/data/extracted_bench"
TRAIN_FILES="['${DATA_HOME}/train.parquet']"
VAL_FILES="['${DATA_HOME}/train.parquet']"

# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
# ⚙️  配置文件路径（使用绝对路径）
# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
cfg_path="/mnt/bn/fileagent-storage/users/<your_username>/verl/recipe/fileagent/config"
cfg_name="extracted_bench_trainer"
tool_cfg_path="recipe/fileagent/config/tool/extracted_bench_tool.yaml"
agent_loop_cfg_path="recipe/fileagent/config/agent_loop.yaml"
new_sp_path="recipe/fileagent/prompts/extracted_bench_sp.md"

# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
# 🔬 算法超参数
# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
# Algorithm
adv_estimator="grpo"
loss_agg_mode="token-mean"
clip_ratio_low=0.2
clip_ratio_high=0.28
clip_ratio_c=10.0

use_kl_in_reward=False
kl_coef=0.0
use_kl_loss=False
kl_loss_coef=0.0

# Batch Size
train_bsz=128
val_bsz=16
train_mini_bsz=64
train_micro_bsz_per_gpu=4
infer_micro_bsz_per_gpu=8
n_resp_per_prompt=8
max_turns=15

# Sequence Length
max_prompt_len=$((1024 * 2))
max_resp_len=$((1024 * 30))
max_tool_resp_len=$((1024 * 20))

# Performance Related Parameter
offload=False  # 如果内存不够，设为 True
train_sp_size=4
infer_tp_size=4
use_dynamic_bsz=True
actor_ppo_max_token_len=$((max_prompt_len + max_resp_len))
infer_ppo_max_token_len=$((max_prompt_len + max_resp_len))
max_num_batched_tokens=$((max_prompt_len + max_resp_len))

# Trainer Schedule & Logging
val_before_train=False
test_freq=1000  # 设为很大的值，基本禁用中间验证
save_freq=20
total_epochs=1
log_val_generations=5

# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
# 🔧 LLM Judge 配置（通过 custom_reward_function.reward_kwargs 传递）
# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
LLM_JUDGE_IP="${LLM_JUDGE_IP:-2605:340:cd51:4b00:140c:e92f:caf3:f57a}"
LLM_JUDGE_PORT="${LLM_JUDGE_PORT:-18908}"
LLM_JUDGE_MODEL_NAME="${LLM_JUDGE_MODEL_NAME:-Qwen2.5-32b-bench-v2-deepseek}"

# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
# 📊 数据预处理
# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
echo "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━"
echo "🔄 数据预处理检查"
echo "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━"

if [ ! -f "${DATA_HOME}/train.parquet" ]; then
    echo "⚠️  训练数据不存在，开始转换..."
    python3 /mnt/bn/fileagent-storage/users/<your_username>/verl/recipe/fileagent/convert_extracted_bench.py \
        --input /mnt/bn/fileagent-storage/users/<your_username>/data/extracted_bench-v2.json \
        --output ${DATA_HOME}/train.parquet \
        --system_prompt /mnt/bn/fileagent-storage/users/<your_username>/verl/recipe/fileagent/prompts/extracted_bench_sp.md
    
    if [ $? -ne 0 ]; then
        echo "❌ 数据转换失败"
        exit 1
    fi
    echo "✅ 数据转换完成"
else
    echo "✅ 训练数据已存在: ${DATA_HOME}/train.parquet"
fi

# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
# 🎯 计算训练步数
# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
DATASET_SIZE=$(python3 -c "import pandas as pd; print(len(pd.read_parquet('${DATA_HOME}/train.parquet')))" 2>/dev/null || echo "0")
if [ "$DATASET_SIZE" -gt 0 ]; then
    TOTAL_STEPS=$(( (DATASET_SIZE + train_bsz - 1) / train_bsz ))
else
    TOTAL_STEPS="Unknown"
fi

echo "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━"
echo "📋 训练配置总结"
echo "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━"
echo "🤖 模型: ${MODEL_PATH}"
echo "📊 数据集大小: ${DATASET_SIZE} 样本"
echo "🎯 训练批次大小: ${train_bsz}"
echo "📈 总训练步数: ${TOTAL_STEPS}"
echo "🔄 最大轮次: ${max_turns}"
echo "☁️  Ray 地址: ${RAY_ADDRESS}"
echo "🖥️  节点数: ${NNODES}"
echo "🎮 GPU/节点: ${NGPUS_PER_NODE}"
echo "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━"
echo "🔧 LLM Judge 配置"
echo "   IP: ${LLM_JUDGE_IP}"
echo "   Port: ${LLM_JUDGE_PORT}"
echo "   Model: ${LLM_JUDGE_MODEL_NAME}"
echo "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━"
echo "🚀 开始提交训练任务到 Ray..."
echo "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━"

# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
# 🚀 提交 Ray Job
# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
ray job submit --no-wait \
    --runtime-env="${RUNTIME_ENV}" \
    --working-dir "${WORKING_DIR}" \
    --address "${RAY_ADDRESS}" \
    -- python3 -m verl.trainer.main_ppo \
    --config-path=${cfg_path} \
    --config-name=${cfg_name} \
    algorithm.adv_estimator=${adv_estimator} \
    algorithm.use_kl_in_reward=${use_kl_in_reward} \
    algorithm.kl_ctrl.kl_coef=${kl_coef} \
    data.train_batch_size=${train_bsz} \
    data.val_batch_size=${val_bsz} \
    data.max_prompt_length=${max_prompt_len} \
    data.max_response_length=${max_resp_len} \
    data.filter_overlong_prompts=True \
    data.truncation="error" \
    data.return_raw_chat=True \
    data.train_files=${TRAIN_FILES} \
    data.val_files=${VAL_FILES} \
    +data.replace_system_prompt=True \
    +data.new_sp_path=${new_sp_path} \
    actor_rollout_ref.model.path=${MODEL_PATH} \
    actor_rollout_ref.model.use_remove_padding=True \
    actor_rollout_ref.model.enable_gradient_checkpointing=True \
    actor_rollout_ref.actor.optim.lr=1e-6 \
    actor_rollout_ref.actor.ppo_mini_batch_size=${train_mini_bsz} \
    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=${train_micro_bsz_per_gpu} \
    actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \
    actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \
    actor_rollout_ref.actor.kl_loss_type=low_var_kl \
    actor_rollout_ref.actor.entropy_coeff=0.0 \
    actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \
    actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \
    actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \
    actor_rollout_ref.actor.ulysses_sequence_parallel_size=${train_sp_size} \
    actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \
    actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \
    actor_rollout_ref.actor.clip_ratio_c=${clip_ratio_c} \
    actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \
    actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \
    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=${infer_micro_bsz_per_gpu} \
    actor_rollout_ref.rollout.name=vllm \
    actor_rollout_ref.rollout.mode=async \
    actor_rollout_ref.rollout.gpu_memory_utilization=0.8 \
    actor_rollout_ref.rollout.n=${n_resp_per_prompt} \
    actor_rollout_ref.rollout.tensor_model_parallel_size=${infer_tp_size} \
    actor_rollout_ref.rollout.multi_turn.max_assistant_turns=${max_turns} \
    actor_rollout_ref.rollout.multi_turn.max_parallel_calls=1 \
    actor_rollout_ref.rollout.multi_turn.max_tool_response_length=${max_tool_resp_len} \
    actor_rollout_ref.rollout.multi_turn.tool_config_path=${tool_cfg_path} \
    actor_rollout_ref.rollout.agent.agent_loop_config_path=${agent_loop_cfg_path} \
    actor_rollout_ref.rollout.agent.num_workers=8 \
    actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \
    actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \
    actor_rollout_ref.rollout.max_num_batched_tokens=${max_num_batched_tokens} \
    actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \
    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=${infer_micro_bsz_per_gpu} \
    actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \
    actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \
    trainer.critic_warmup=0 \
    trainer.val_before_train=${val_before_train} \
    trainer.logger='["console","wandb"]' \
    trainer.project_name=${PROJECT_NAME} \
    trainer.experiment_name=${EXP_NAME} \
    trainer.n_gpus_per_node=${NGPUS_PER_NODE} \
    trainer.nnodes=${NNODES} \
    trainer.save_freq=${save_freq} \
    trainer.test_freq=${test_freq} \
    trainer.log_val_generations=${log_val_generations} \
    trainer.total_epochs=${total_epochs} \
    trainer.default_local_dir=${SAVE_DIR} \
    +trainer.log_train_freq=5 \
    trainer.resume_mode=auto \
    reward_model.reward_manager=batch \
    custom_reward_function.reward_kwargs.llm_judge_ip="${LLM_JUDGE_IP}" \
    custom_reward_function.reward_kwargs.llm_judge_port="${LLM_JUDGE_PORT}" \
    custom_reward_function.reward_kwargs.llm_judge_model_name="${LLM_JUDGE_MODEL_NAME}" \
    $@

echo "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━"
echo "✅ 任务已提交到 Ray！"
echo "   查看任务状态: ray job status"
echo "   查看任务日志: ray job logs"
echo "   停止任务: ray job stop <job-id>"
echo "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━"

