#!/usr/bin/env bash

set -xeuo pipefail

# ==================== Project Configuration ====================
project_name='Medical-OnePO'


exp_name='run_onestage'
echo "Project: ${project_name}, Experiment: ${exp_name}"


PROJECT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
VERL_DIR="verl"

MODEL_PATH="Qwen3-8B-Base"

REWARD_MODEL_PATH=""

# Reward model server
REWARD_MODEL_SERVER="${REWARD_MODEL_SERVER:-localhost:30000}"


TRAIN_FILE=""


TEST_FILE=""

CKPTS_DIR="${PROJECT_DIR}/checkpoints/${project_name}/${exp_name}"

external_api_enable=True
external_api_url=""
external_api_model="deepseek-v3.2-thinking"
external_api_key=""
n_external_per_prompt=1    
external_api_temperature=0.7
external_api_max_workers=30
external_api_timeout=180   


external_api_min_prob=0.1  # Floor probability for old_log_prob of external responses (will be log-transformed)
external_api_min_log_prob=$(python3 -c "import math; print(math.log(${external_api_min_prob}))")

external_api_min_log_prob=$(python3 -c "print(-${external_api_min_log_prob})")

# ==================== DAPO Algorithm Configuration ====================
adv_estimator=grpo
use_kl_in_reward=False
kl_coef=0.0
use_kl_loss=False
kl_loss_coef=0.0
clip_ratio_low=0.2
clip_ratio_high=0.28


loss_agg_mode="token-mean"
enable_filter_groups=True
filter_groups_metric=acc
max_num_gen_batches=15

require_thinking=True

max_prompt_length=4000
max_response_length=8000

# ==================== Overlong Buffer Configuration ====================
enable_overlong_buffer=True
overlong_buffer_len=4000
overlong_penalty_factor=0.5

# ==================== Batch Configuration ====================
train_prompt_bsz=128
gen_prompt_bsz=192
n_resp_per_prompt=8
train_prompt_mini_bsz=16


# ==================== Sampling Configuration ====================
temperature=1.0
top_p=1.0
top_k=-1
val_top_p=0.7

# ==================== GPU Configuration ====================
NNODES=1
N_GPUS_PER_NODE=8
gen_tp=1
sp_size=1

# ==================== Performance Configuration ====================
use_dynamic_bsz=True
actor_ppo_max_token_len=12000  # Increased for thinking mode
infer_ppo_max_token_len=12000  # Increased for thinking mode
offload=True

ulimit -n 65535 2>/dev/null || true
ulimit -u 65535 2>/dev/null || true

export PYTHONPATH=""
export VERL_LOGGING_LEVEL="INFO"
export NCCL_DEBUG="WARN"
export VLLM_LOGGING_LEVEL="WARN"
export TOKENIZERS_PARALLELISM="true"
export SWANLAB_LOG_DIR="${PROJECT_DIR}/swanlog"

# Export reward model server for reward function
export REWARD_MODEL_SERVER
export REWARD_MODEL_PATH

# Export external API key
export EXTERNAL_API_KEY="${external_api_key}"
export EXTERNAL_API_URL="${external_api_url}"
export EXTERNAL_API_MODEL="${external_api_model}"

# Export reward configuration
export REQUIRE_THINKING="${require_thinking}"

# ==================== Print Configuration ====================
echo "=========================================="
echo "Starting One-Stage DAPO Training with Thinking Mode"
echo "Model: ${MODEL_PATH}"
echo "External API: ${external_api_model} (${n_external_per_prompt} per prompt, reasoning auto-detected)"
echo "Reward Model Server: ${REWARD_MODEL_SERVER}"
echo "Train data: ${TRAIN_FILE}"
echo "Test data: ${TEST_FILE}"
echo "Output: ${CKPTS_DIR}"
echo ""
echo "Key Configuration:"
echo "  - Data types: open-end (rubric) + MC (answer matching)"
echo "  - Response distribution: $((n_resp_per_prompt - n_external_per_prompt)) model + ${n_external_per_prompt} external"
echo "  - max_response_length: ${max_response_length} (increased for thinking)"
echo "  - overlong_buffer: ${enable_overlong_buffer}"
echo "  - Dynamic Sampling: ${enable_filter_groups}"
echo "  - Training GPUs: ${N_GPUS_PER_NODE}"
echo "=========================================="


cd "${VERL_DIR}"

python3 -m main_run_onepo \
    data.train_files="${TRAIN_FILE}" \
    data.val_files="${TEST_FILE}" \
    data.prompt_key=prompt \
    data.truncation='left' \
    data.max_prompt_length=${max_prompt_length} \
    data.max_response_length=${max_response_length} \
    data.gen_batch_size=${gen_prompt_bsz} \
    data.train_batch_size=${train_prompt_bsz} \
    data.return_raw_chat=True \
    data.custom_cls.path="${PROJECT_DIR}/dataset_qa_mc.py" \
    data.custom_cls.name=MixedQADataset \
    data.dataloader_num_workers=2 \
    ray_kwargs.ray_init.num_cpus=48 \
    \
    custom_reward_function.path="${PROJECT_DIR}/reward_function_qa_mc.py" \
    custom_reward_function.name=compute_score \
    \
    actor_rollout_ref.rollout.n=${n_resp_per_prompt} \
    \
    algorithm.adv_estimator=${adv_estimator} \
    algorithm.use_kl_in_reward=${use_kl_in_reward} \
    algorithm.kl_ctrl.kl_coef=${kl_coef} \
    \
    actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \
    actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \
    actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \
    actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \
    actor_rollout_ref.actor.clip_ratio_c=10.0 \
    actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \
    actor_rollout_ref.actor.policy_loss.loss_mode=external_adjusted \
    +actor_rollout_ref.actor.policy_loss.external_min_log_prob=${external_api_min_log_prob} \
    actor_rollout_ref.actor.entropy_coeff=0 \
    actor_rollout_ref.actor.grad_clip=1.0 \
    \
    algorithm.filter_groups.enable=${enable_filter_groups} \
    algorithm.filter_groups.max_num_gen_batches=${max_num_gen_batches} \
    algorithm.filter_groups.metric=${filter_groups_metric} \
    \
    actor_rollout_ref.model.path="${MODEL_PATH}" \
    actor_rollout_ref.model.enable_gradient_checkpointing=True \
    actor_rollout_ref.model.use_remove_padding=True \
    +actor_rollout_ref.model.override_config.attn_implementation=flash_attention_2 \
    \
    actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \
    actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \
    actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \
    actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \
    actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \
    actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \
    \
    actor_rollout_ref.actor.optim.lr=2e-6 \
    actor_rollout_ref.actor.optim.lr_warmup_steps=1 \
    actor_rollout_ref.actor.optim.weight_decay=0.1 \
    actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \
    actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \
    actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \
    actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \
    \
    actor_rollout_ref.rollout.name=sglang \
    actor_rollout_ref.hybrid_engine=True \
    actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \
    actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \
    actor_rollout_ref.rollout.enable_chunked_prefill=True \
    actor_rollout_ref.rollout.max_num_batched_tokens=${actor_ppo_max_token_len} \
    actor_rollout_ref.rollout.temperature=${temperature} \
    actor_rollout_ref.rollout.top_p=${top_p} \
    actor_rollout_ref.rollout.top_k="${top_k}" \
    actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \
    actor_rollout_ref.rollout.val_kwargs.top_p=${val_top_p} \
    actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \
    actor_rollout_ref.rollout.val_kwargs.do_sample=True \
    actor_rollout_ref.rollout.val_kwargs.n=1 \
    \
    actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \
    actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \
    actor_rollout_ref.actor.fsdp_config.fsdp_size=-1 \
    \
    reward_model.enable=False \
    reward_model.reward_manager=dapo \
    reward_model.use_reward_loop=True \
    reward_model.overlong_buffer.enable=${enable_overlong_buffer} \
    reward_model.overlong_buffer.len=${overlong_buffer_len} \
    reward_model.overlong_buffer.penalty_factor=${overlong_penalty_factor} \
    reward_model.overlong_buffer.log=True \
    \
    external_api.enable=${external_api_enable} \
    external_api.url="${external_api_url}" \
    external_api.key="${external_api_key}" \
    external_api.model="${external_api_model}" \
    external_api.n_per_prompt=${n_external_per_prompt} \
    external_api.temperature=${external_api_temperature} \
    external_api.max_workers=${external_api_max_workers} \
    external_api.timeout=${external_api_timeout} \
    external_api.max_retries=3 \
    external_api.fallback_on_failure=True \
    external_api.debug=True \
    \
    trainer.logger='["console","swanlab"]' \
    trainer.project_name="${project_name}" \
    trainer.experiment_name="${exp_name}" \
    trainer.n_gpus_per_node=${N_GPUS_PER_NODE} \
    trainer.nnodes=${NNODES} \
    trainer.val_before_train=False \
    trainer.test_freq=5 \
    trainer.save_freq=20 \
    trainer.total_epochs=4 \
    trainer.log_val_generations=20 \
    trainer.default_local_dir="${CKPTS_DIR}" \
    trainer.resume_mode=disable \
    \
    "$@"

echo "=========================================="
echo "Training completed!"
echo "=========================================="
