#!/bin/bash
# 简化版 

set -x
ulimit -n 65535

# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
# 基础配置
# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
PROJECT_DIR="$(pwd)"

PROJECT_NAME="agentic_rl_extracted_bench"
EXP_NAME="qwen3_32b_all_2epochs-sft_$(date +%Y%m%d_%H%M%S)"
SAVE_DIR="${SAVE_DIR:-${PROJECT_DIR}/outputs/${EXP_NAME}}"

# Ray 集群地址（根据你的环境修改），默认假设本机 head 节点
RAY_ADDRESS="${RAY_ADDRESS:-http://127.0.0.1:8265}"
PWD_NOW="$(pwd)"
WORKING_DIR="${WORKING_DIR:-${PWD_NOW}}"
RUNTIME_ENV="${RUNTIME_ENV:-runtime_env.yaml}"

# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
# 模型/数据（路径可由环境变量外部注入，否则使用默认）
# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
MODEL_PATH="${MODEL_PATH:-/path/to/your/qwen3-32b-checkpoint}"

DATA_HOME="${DATA_HOME:-/path/to/your/rl_data}"
# TRAIN_FILES="['${DATA_HOME}/ppt_out_rl/train.parquet','${DATA_HOME}/gov_excel_gpt5_rl/train.parquet','${DATA_HOME}/cc_pdf_21_rl/train.parquet','${DATA_HOME}/cc_ppt_26_rl/train.parquet','${DATA_HOME}/cc_doc_21_rl/train.parquet','${DATA_HOME}/ppt_out_rl/test.parquet','${DATA_HOME}/gov_excel_gpt5_rl/test.parquet','${DATA_HOME}/cc_pdf_21_rl/test.parquet','${DATA_HOME}/cc_ppt_26_rl/test.parquet','${DATA_HOME}/cc_doc_21_rl/test.parquet','${DATA_HOME}/wiki-train-rl-600/train.parquet','${DATA_HOME}/wiki-train-rl-600/test.parquet','${DATA_HOME}/cc_pdf_18_rl/train.parquet','${DATA_HOME}/cc_pdf_18_rl/test.parquet']"
TRAIN_FILES="['${DATA_HOME}/ppt_out_rl/train.parquet','${DATA_HOME}/gov_excel_gpt5_rl/train.parquet','${DATA_HOME}/cc_pdf_21_rl/train.parquet','${DATA_HOME}/cc_ppt_26_rl/train.parquet','${DATA_HOME}/cc_doc_21_rl/train.parquet','${DATA_HOME}/ppt_out_rl/test.parquet','${DATA_HOME}/gov_excel_gpt5_rl/test.parquet','${DATA_HOME}/cc_pdf_21_rl/test.parquet','${DATA_HOME}/cc_ppt_26_rl/test.parquet','${DATA_HOME}/cc_doc_21_rl/test.parquet','${DATA_HOME}/wiki-train-rl-600/train.parquet','${DATA_HOME}/wiki-train-rl-600/test.parquet','${DATA_HOME}/cc_pdf_18_rl/train.parquet','${DATA_HOME}/cc_pdf_18_rl/test.parquet','${DATA_HOME}/cc_pdf_18_1_rl/train.parquet','${DATA_HOME}/cc_pdf_18_1_rl/test.parquet']"


VAL_FILES="['${DATA_HOME}/bench-v2/test.parquet','${DATA_HOME}/bench-v2/train.parquet']"


cfg_path="recipe/fileagent/config"
cfg_name="extracted_bench_trainer"
tool_cfg_path="recipe/fileagent/config/tool/extracted_bench_tool.yaml"
agent_loop_cfg_path="recipe/fileagent/config/agent_loop.yaml"
new_sp_path="recipe/fileagent/prompts/extracted_bench_sp.md"

# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
# 算法参数
# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
adv_estimator="grpo"
loss_agg_mode="token-mean"
clip_ratio_low=0.2
clip_ratio_high=0.28
clip_ratio_c=10.0

use_kl_in_reward=False
kl_coef=0.0
use_kl_loss=False
kl_loss_coef=0.0

train_bsz=256
val_bsz=32
train_mini_bsz=32
train_micro_bsz_per_gpu=8
infer_micro_bsz_per_gpu=8
n_resp_per_prompt=16
max_turns=15

# Dynamic Sampling 配置
# gen_bsz=$((train_bsz * 3))  # 生成批次大小（3倍训练批次）
filter_groups_enable=True
filter_groups_metric="score"
filter_groups_max_num_gen_batches=10

max_prompt_len=$((1024 * 2))
max_resp_len=$((1024 * 30))
max_tool_resp_len=$((1024 * 20))

train_sp_size=4
infer_tp_size=4
use_dynamic_bsz=True
actor_ppo_max_token_len=$((max_prompt_len + max_resp_len))
infer_ppo_max_token_len=$((max_prompt_len + max_resp_len))
max_num_batched_tokens=$((max_prompt_len + max_resp_len))

val_before_train=True
test_freq=10
save_freq=10
total_epochs=2
log_val_generations=20

# 验证时的采样参数（可选配置）
val_temperature=0.7         # 0=贪婪采样，>0=随机采样
val_top_p=0.95             # nucleus sampling
val_top_k=-1              # -1表示禁用
val_do_sample=True       # False=贪婪，True=采样
val_n=1                   # 验证时每个prompt生成几个响应

# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
# 提交任务（像样例一样简单）
# ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
echo "RAY_ADDRESS: ${RAY_ADDRESS}"
ray job submit --no-wait \
    --runtime-env="${RUNTIME_ENV}" \
    --working-dir "${WORKING_DIR}" \
    --address "${RAY_ADDRESS}" \
    -- python3 -m verl.trainer.main_ppo \
    --config-path=${cfg_path} \
    --config-name=${cfg_name} \
    hydra.run.dir=${SAVE_DIR}/hydra/${now:%Y-%m-%d}/${now:%H-%M-%S} \
    hydra.sweep.dir=${SAVE_DIR}/hydra/${now:%Y-%m-%d}/${now:%H-%M-%S} \
    algorithm.adv_estimator=${adv_estimator} \
    algorithm.use_kl_in_reward=${use_kl_in_reward} \
    algorithm.kl_ctrl.kl_coef=${kl_coef} \
    data.train_batch_size=${train_bsz} \
    data.val_batch_size=${val_bsz} \
    data.max_prompt_length=${max_prompt_len} \
    data.max_response_length=${max_resp_len} \
    data.filter_overlong_prompts=False \
    data.truncation="error" \
    data.return_raw_chat=True \
    data.custom_cls.path=recipe/fileagent/rl_dataset.py \
    data.custom_cls.name=CustomRLHFDataset \
    data.train_files=${TRAIN_FILES} \
    data.val_files=${VAL_FILES} \
    +data.replace_system_prompt=True \
    +data.new_sp_path=${new_sp_path} \
    +data.allow_heterogeneous_schemas=True \
    actor_rollout_ref.model.path=${MODEL_PATH} \
    actor_rollout_ref.model.use_remove_padding=True \
    actor_rollout_ref.model.enable_gradient_checkpointing=True \
    actor_rollout_ref.actor.optim.lr=1e-6 \
    actor_rollout_ref.actor.ppo_mini_batch_size=${train_mini_bsz} \
    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=${train_micro_bsz_per_gpu} \
    actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \
    actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \
    actor_rollout_ref.actor.kl_loss_type=low_var_kl \
    actor_rollout_ref.actor.entropy_coeff=0.0 \
    actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \
    actor_rollout_ref.actor.ulysses_sequence_parallel_size=${train_sp_size} \
    actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \
    actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \
    actor_rollout_ref.rollout.name=vllm \
    actor_rollout_ref.rollout.mode=async \
    actor_rollout_ref.rollout.gpu_memory_utilization=0.7 \
    actor_rollout_ref.rollout.enforce_eager=True \
    actor_rollout_ref.rollout.n=${n_resp_per_prompt} \
    actor_rollout_ref.rollout.tensor_model_parallel_size=${infer_tp_size} \
    actor_rollout_ref.rollout.multi_turn.max_assistant_turns=${max_turns} \
    actor_rollout_ref.rollout.multi_turn.max_parallel_calls=1 \
    actor_rollout_ref.rollout.multi_turn.max_tool_response_length=${max_tool_resp_len} \
    actor_rollout_ref.rollout.multi_turn.tool_config_path=${tool_cfg_path} \
    actor_rollout_ref.rollout.agent.agent_loop_config_path=${agent_loop_cfg_path} \
    actor_rollout_ref.rollout.agent.num_workers=8 \
    actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \
    actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \
    actor_rollout_ref.rollout.max_num_batched_tokens=${max_num_batched_tokens} \
    actor_rollout_ref.rollout.val_kwargs.temperature=${val_temperature} \
    actor_rollout_ref.rollout.val_kwargs.top_p=${val_top_p} \
    actor_rollout_ref.rollout.val_kwargs.top_k=${val_top_k} \
    actor_rollout_ref.rollout.val_kwargs.do_sample=${val_do_sample} \
    actor_rollout_ref.rollout.val_kwargs.n=${val_n} \
    actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \
    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=${infer_micro_bsz_per_gpu} \
    actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \
    trainer.critic_warmup=0 \
    trainer.val_before_train=${val_before_train} \
    trainer.logger='["console","wandb"]' \
    trainer.project_name=${PROJECT_NAME} \
    trainer.experiment_name=${EXP_NAME} \
    trainer.n_gpus_per_node=8 \
    trainer.nnodes=8 \
    trainer.save_freq=${save_freq} \
    trainer.test_freq=${test_freq} \
    trainer.log_val_generations=${log_val_generations} \
    trainer.total_epochs=${total_epochs} \
    trainer.default_local_dir=${SAVE_DIR} \
    +trainer.log_train_freq=5 \
    trainer.resume_mode=auto
