#!/usr/bin/env bash
set -euo pipefail

# =============================================================================
# DCKL Qwen3-8B STASTIC KL Training Script
# =============================================================================

# Project and experiment configuration
project_name='DCKL-Qwen3-8B'
TIMESTAMP=$(date +%Y%m%d_%H%M%S)
# 静态kl
exp_name="stastic_kl_from_scratch_${TIMESTAMP}"

# Environment variables
      # 启用IB网络


export HYDRA_FULL_ERROR=1
export CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES:-"0,1,2,3,4,5,6,7"}

# Data paths
TRAIN_FILE="/primus_datasets/primus_data/clpo_SKYRTP/DAPO-Math-17k/data/dapo-math-17k.parquet"
VAL_FILE="/primus_datasets/primus_data/aime_2B4pCq/train-00000-of-00001-fixed.parquet"

# Model path
MODEL_PATH="/primus_datasets/primus_data/Qwen3_rNrLUi/Qwen3-8B "

# Output directory
CKPTS_DIR="/primus_oss/_checkpoint/0918-Qwen3-8B-AIME2024/stastic_kl_from_scratch"

# Data configuration
max_prompt_length=1024
max_response_length=12000
train_batch_size=64
val_batch_size=32
truncation="error"
filter_overlong_prompts=true
dataloader_num_workers=4

# Algorithm configuration
adv_estimator=grpo
use_kl_in_reward=false
kl_in_reward_coef_hard=1.0
kl_in_reward_coef_nonhard=1.0

# CLPO specific configuration
clpo_hard_acc_upper=0.3
clpo_med_acc_lower=0.3
clpo_med_acc_upper=0.7
kl_loss_coef_hard_scale=1.0
kl_loss_coef_nonhard_scale=1.0

# CLPO rewrite data saving configuration
clpo_save_rewrite_data=true
clpo_rewrite_save_path="/primus_oss/_checkpoint/0918-Qwen3-8B-AIME2024/stastic_kl_from_scratch/rewrite_data.json"
clpo_hard_rewrite_save_path="/primus_oss/_checkpoint/0918-Qwen3-8B-AIME2024/stastic_kl_from_scratch/hard_rewrite_data.json"
clpo_medium_rewrite_save_path="/primus_oss/_checkpoint/0918-Qwen3-8B-AIME2024/stastic_kl_from_scratch/medium_rewrite_data.json"

# Model configuration
enable_gradient_checkpointing=true
use_remove_padding=true

# Actor configuration
actor_lr=1e-6
actor_lr_warmup_steps=10
warmup_style=constant
ppo_mini_batch_size=8
ppo_micro_batch_size_per_gpu=1
use_kl_loss=true
kl_loss_coef=0.001
entropy_coeff=0
param_offload=false
optimizer_offload=false

# Rollout configuration
rollout_name=vllm
n_resp_per_prompt=4
tensor_model_parallel_size=1
gpu_memory_utilization=0.5
log_prob_micro_batch_size_per_gpu=1
max_model_len=$((max_prompt_length + max_response_length))
max_num_batched_tokens=$((max_prompt_length + max_response_length))
val_rollout_n=32
val_do_sample=true
val_temperature=1.0
val_top_k=-1
val_top_p=0.7


# Trainer configuration
total_epochs=1
critic_warmup=0
test_freq=1
save_freq=10
val_before_train=true
NGPUS_PER_NODE=${NGPUS_PER_NODE:-8}
NNODES=${NNODES:-1}

echo "==== DCKL Qwen3-8B STASTIC KL TRAINING CONFIGURATION ===="
echo "Project: $project_name"
echo "Experiment: $exp_name"
echo "Train Data: $TRAIN_FILE"
echo "Val Data: $VAL_FILE"
echo "Model Path: $MODEL_PATH"
echo "Output Dir: $CKPTS_DIR"
echo "CLPO Hard Acc Upper: $clpo_hard_acc_upper"
echo "CLPO Med Acc Lower: $clpo_med_acc_lower"
echo "CLPO Med Acc Upper: $clpo_med_acc_upper"
echo "CLPO Save Rewrite Data: $clpo_save_rewrite_data"
echo "CLPO Rewrite Save Path: $clpo_rewrite_save_path"
echo "CLPO Hard Rewrite Save Path: $clpo_hard_rewrite_save_path"
echo "CLPO Medium Rewrite Save Path: $clpo_medium_rewrite_save_path"

python3 -m verl.trainer.main_ppo \
    algorithm.adv_estimator="${adv_estimator}" \
    algorithm.norm_adv_by_std_in_grpo=true \
    algorithm.use_kl_in_reward="${use_kl_in_reward}" \
    algorithm.kl_in_reward_coef_hard="${kl_in_reward_coef_hard}" \
    algorithm.kl_in_reward_coef_nonhard="${kl_in_reward_coef_nonhard}" \
    data.train_files="${TRAIN_FILE}" \
    data.val_files="${VAL_FILE}" \
    data.train_batch_size="${train_batch_size}" \
    data.val_batch_size="${val_batch_size}" \
    data.max_prompt_length="${max_prompt_length}" \
    data.max_response_length="${max_response_length}" \
    data.filter_overlong_prompts="${filter_overlong_prompts}" \
    data.truncation="${truncation}" \
    data.dataloader_num_workers="${dataloader_num_workers}" \
    data.clpo_hard_acc_upper="${clpo_hard_acc_upper}" \
    data.clpo_medium_acc_lower="${clpo_med_acc_lower}" \
    data.clpo_medium_acc_upper="${clpo_med_acc_upper}" \
    data.clpo_save_rewrite_data="${clpo_save_rewrite_data}" \
    data.clpo_rewrite_save_path="${clpo_rewrite_save_path}" \
    data.clpo_hard_rewrite_save_path="${clpo_hard_rewrite_save_path}" \
    data.clpo_medium_rewrite_save_path="${clpo_medium_rewrite_save_path}" \
    actor_rollout_ref.model.path="${MODEL_PATH}" \
    actor_rollout_ref.model.enable_gradient_checkpointing="${enable_gradient_checkpointing}" \
    actor_rollout_ref.model.use_remove_padding="${use_remove_padding}" \
    actor_rollout_ref.actor.optim.lr="${actor_lr}" \
    actor_rollout_ref.actor.optim.warmup_style="${warmup_style}" \
    actor_rollout_ref.actor.optim.lr_warmup_steps="${actor_lr_warmup_steps}" \
    actor_rollout_ref.actor.ppo_mini_batch_size="${ppo_mini_batch_size}" \
    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu="${ppo_micro_batch_size_per_gpu}" \
    actor_rollout_ref.actor.use_kl_loss="${use_kl_loss}" \
    actor_rollout_ref.actor.kl_loss_coef="${kl_loss_coef}" \
    actor_rollout_ref.actor.kl_loss_coef_hard_scale="${kl_loss_coef_hard_scale}" \
    actor_rollout_ref.actor.kl_loss_coef_nonhard_scale="${kl_loss_coef_nonhard_scale}" \
    actor_rollout_ref.actor.entropy_coeff="${entropy_coeff}" \
    actor_rollout_ref.actor.fsdp_config.param_offload="${param_offload}" \
    actor_rollout_ref.actor.fsdp_config.optimizer_offload="${optimizer_offload}" \
    actor_rollout_ref.rollout.name="${rollout_name}" \
    actor_rollout_ref.rollout.n="${n_resp_per_prompt}" \
    actor_rollout_ref.rollout.tensor_model_parallel_size="${tensor_model_parallel_size}" \
    actor_rollout_ref.rollout.gpu_memory_utilization="${gpu_memory_utilization}" \
    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu="${log_prob_micro_batch_size_per_gpu}" \
    actor_rollout_ref.rollout.max_model_len="${max_model_len}" \
    actor_rollout_ref.rollout.max_num_batched_tokens="${max_num_batched_tokens}" \
    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu="${log_prob_micro_batch_size_per_gpu}" \
    actor_rollout_ref.ref.fsdp_config.param_offload="${param_offload}" \
    actor_rollout_ref.rollout.val_kwargs.n="${val_rollout_n}" \
    actor_rollout_ref.rollout.val_kwargs.do_sample="${val_do_sample}" \
    actor_rollout_ref.rollout.val_kwargs.temperature="${val_temperature}" \
    actor_rollout_ref.rollout.val_kwargs.top_k="${val_top_k}" \
    actor_rollout_ref.rollout.val_kwargs.top_p="${val_top_p}" \
    trainer.logger='["console", "swanlab"]' \
    trainer.project_name="${project_name}" \
    trainer.experiment_name="${exp_name}" \
    trainer.total_epochs="${total_epochs}" \
    trainer.critic_warmup="${critic_warmup}" \
    trainer.test_freq="${test_freq}" \
    trainer.save_freq="${save_freq}" \
    trainer.val_before_train="${val_before_train}" \
    trainer.n_gpus_per_node="${NGPUS_PER_NODE}" \
    trainer.nnodes="${NNODES}" \
    trainer.default_local_dir="${CKPTS_DIR}" \
    trainer.task=clpo \
    "$@"

TRAINING_EXIT_CODE=$?

echo ""
echo "🏁 ===== TRAINING COMPLETED ====="

if [ $TRAINING_EXIT_CODE -eq 0 ]; then
    echo "✅ DCKL Qwen3-8B STASTIC KL training completed successfully!"
    echo "📁 Checkpoints saved to: ${CKPTS_DIR}"
    echo "📊 Experiment name: ${exp_name}"
    
    if [ -d "${CKPTS_DIR}" ]; then
        echo ""
        echo "📋 Output directory contents:"
        ls -la "${CKPTS_DIR}" || true
    fi
    
    echo ""
    echo "🎉 DCKL Qwen3-8B STASTIC KL training finished successfully!"
else
    echo "❌ Training failed with exit code: $TRAINING_EXIT_CODE"
    echo "💡 Please check the logs above for error details"
    echo "🔧 Common issues:"
    echo "   - Check if data files exist and are readable"
    echo "   - Verify GPU memory is sufficient"
    echo "   - Ensure all dependencies are installed"
    echo "   - Check CUDA_VISIBLE_DEVICES setting"
fi

exit $TRAINING_EXIT_CODE
