#!/usr/bin/env bash
set -euo pipefail

export SWANLAB_API_KEY=kUaCfvi1P5G0e6aM7HgJ0
export NCCL_IB_DISABLE=0      # 启用IB网络


echo "🚀 ===== DAPO Single Machine Training (Baseline) ====="

# 基本信息
echo "==== SYSTEM INFO ===="
echo "HOSTNAME=$(hostname)"
echo "CURRENT_DIR=$(pwd)"
echo "USER=$(whoami)"

# GPU信息检查
echo "==== GPU INFO ===="
if command -v nvidia-smi >/dev/null 2>&1; then
    nvidia-smi -L || echo "⚠️  nvidia-smi failed"
    NGPUS=$(nvidia-smi -L 2>/dev/null | wc -l || echo 1)
    echo "Available GPUs: $NGPUS"
else
    echo "⚠️  nvidia-smi not found, assuming 1 GPU"
    NGPUS=1
fi

# 数据文件路径
TRAIN_DATA="/primus_datasets/primus_data/iclr_gsm8k_Intity/train.parquet"
VAL_DATA="/primus_datasets/primus_data/iclr_gsm8k_Intity/test.parquet"

echo "==== DATA FILES ===="
echo "TRAIN_DATA: $TRAIN_DATA"
echo "VAL_DATA: $VAL_DATA"

# 设置环境变量
export HYDRA_FULL_ERROR=1
export CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES:-"0,1,2,3,4,5,6,7"}

echo "==== TRAINING CONFIGURATION ===="
echo "CUDA_VISIBLE_DEVICES: $CUDA_VISIBLE_DEVICES"
echo "Number of GPUs to use: $NGPUS"
echo "Output directory: /primus_oss/_checkpoint/0903-Qwen3-0.6B-DAPO-Baseline"

# 生成时间戳用于实验名称
TIMESTAMP=$(date +%Y%m%d_%H%M%S)
EXPERIMENT_NAME="gsm8k_dapo_baseline_${TIMESTAMP}"

echo "Experiment name: $EXPERIMENT_NAME"
echo ""

# 启动训练
echo "🎯 ===== STARTING DAPO BASELINE TRAINING ====="
echo "Command will be:"
echo "python3 -m recipe.dapo.main_dapo \\"

 - 使用DAPO专用配置
python3 -m recipe.dapo.main_dapo \
    algorithm.adv_estimator=grpo \
    algorithm.use_kl_in_reward=false \
    data.train_files="$TRAIN_DATA" \
    data.val_files="$VAL_DATA" \
    data.train_batch_size=64 \
    data.val_batch_size=32 \
    data.max_prompt_length=512 \
    data.max_response_length=1024 \
    data.filter_overlong_prompts=true \
    data.truncation=error \
    data.dataloader_num_workers=0 \
    actor_rollout_ref.model.path=/primus_datasets/primus_data/Qwen3_06B_RlhksV \
    actor_rollout_ref.model.enable_gradient_checkpointing=true \
    actor_rollout_ref.model.use_remove_padding=true \
    actor_rollout_ref.actor.optim.lr=1e-6 \
    actor_rollout_ref.actor.ppo_mini_batch_size=32 \
    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=8 \
    actor_rollout_ref.actor.use_kl_loss=false \
    actor_rollout_ref.actor.kl_loss_coef=0.0 \
    actor_rollout_ref.actor.clip_ratio_low=0.2 \
    actor_rollout_ref.actor.clip_ratio_high=0.28 \
    actor_rollout_ref.actor.clip_ratio_c=10.0 \
    actor_rollout_ref.actor.entropy_coeff=0 \
    actor_rollout_ref.actor.fsdp_config.param_offload=false \
    actor_rollout_ref.actor.fsdp_config.optimizer_offload=false \
    actor_rollout_ref.rollout.name=vllm \
    actor_rollout_ref.rollout.n=4 \
    actor_rollout_ref.rollout.tensor_model_parallel_size=1 \
    actor_rollout_ref.rollout.gpu_memory_utilization=0.7 \
    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=8 \
    actor_rollout_ref.rollout.max_model_len=1536 \
    actor_rollout_ref.rollout.max_num_batched_tokens=1536 \
    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=8 \
    actor_rollout_ref.ref.fsdp_config.param_offload=false \
    reward_model.reward_manager=dapo \
    trainer.logger='["console", "swanlab"]' \
    trainer.project_name=CLPO-GSM8K \
    trainer.experiment_name="$EXPERIMENT_NAME" \
    trainer.total_epochs=3 \
    trainer.critic_warmup=0 \
    trainer.test_freq=10 \
    trainer.save_freq=50 \
    trainer.val_before_train=true \
    trainer.n_gpus_per_node="$NGPUS" \
    trainer.nnodes=1 \
    trainer.default_local_dir=/primus_oss/_checkpoint/0903-Qwen3-0.6B-DAPO-Baseline \
    "$@"

TRAINING_EXIT_CODE=$?

echo ""
echo "🏁 ===== TRAINING COMPLETED ====="

if [ $TRAINING_EXIT_CODE -eq 0 ]; then
    echo "✅ DAPO Baseline training completed successfully!"
    echo "📁 Checkpoints saved to: /primus_oss/_checkpoint/0903-Qwen3-0.6B-DAPO-Baseline"
    echo "📊 Experiment name: $EXPERIMENT_NAME"
    
    if [ -d "/primus_oss/_checkpoint/0903-Qwen3-0.6B-DAPO-Baseline" ]; then
        echo ""
        echo "📋 Output directory contents:"
        ls -la "/primus_oss/_checkpoint/0903-Qwen3-0.6B-DAPO-Baseline" || true
    fi
    
    echo ""
    echo "🎉 DAPO baseline training finished successfully!"
else
    echo "❌ Training failed with exit code: $TRAINING_EXIT_CODE"
    echo "💡 Please check the logs above for error details"
    echo "🔧 Common issues:"
    echo "   - Check if data files exist and are readable"
    echo "   - Verify GPU memory is sufficient"
    echo "   - Ensure all dependencies are installed"
    echo "   - Check CUDA_VISIBLE_DEVICES setting"
fi

exit $TRAINING_EXIT_CODE
