#!/usr/bin/env bash
set -euo pipefail

# =============================================================================
# CLPO Ablation Study Script for GSM8K
# =============================================================================
# 
# CLPO (Curriculum Learning Policy Optimization) supports ablation studies:
# 
# Usage examples:
#   # Full CLPO model (default): all components enabled
#   bash run_clpo_gsm8k_single.sh
#   
#   # CLPO w/o Curriculum Learning
#   ENABLE_CL=false bash run_clpo_gsm8k_single.sh
#   
#   # CLPO w/o Query Rewriting
#   ENABLE_QR=false bash run_clpo_gsm8k_single.sh
#   
#   # CLPO w/o Dynamic Rewards
#   ENABLE_DR=false bash run_clpo_gsm8k_single.sh
#   
#   # CLPO w/o Curriculum Learning & Query Rewriting
#   ENABLE_CL=false ENABLE_QR=false bash run_clpo_gsm8k_single.sh
#
# Components:
#   - CL: Curriculum Learning
#   - QR: Query Rewriting  
#   - DR: Dynamic Rewards
# =============================================================================

      # 启用IB网络


echo "🚀 ===== CLPO Single Machine Training ====="

# 基本信息
echo "==== SYSTEM INFO ===="
echo "HOSTNAME=$(hostname)"
echo "CURRENT_DIR=$(pwd)"
echo "USER=$(whoami)"

# GPU信息检查
echo "==== GPU INFO ===="
if command -v nvidia-smi >/dev/null 2>&1; then
    nvidia-smi -L || echo "⚠️  nvidia-smi failed"
    NGPUS=$(nvidia-smi -L 2>/dev/null | wc -l || echo 1)
    echo "Available GPUs: $NGPUS"
else
    echo "⚠️  nvidia-smi not found, assuming 1 GPU"
    NGPUS=1
fi

# 数据文件路径
TRAIN_DATA="/primus_datasets/primus_data/iclr_gsm8k_Intity/train.parquet"
VAL_DATA="/primus_datasets/primus_data/iclr_gsm8k_Intity/test.parquet"

echo "==== DATA FILES ===="
echo "TRAIN_DATA: $TRAIN_DATA"
echo "VAL_DATA: $VAL_DATA"


# 设置环境变量
export HYDRA_FULL_ERROR=1
export CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES:-"0,1,2,3,4,5,6,7"}

echo "==== TRAINING CONFIGURATION ===="
echo "CUDA_VISIBLE_DEVICES: $CUDA_VISIBLE_DEVICES"
echo "Number of GPUs to use: $NGPUS"
echo "Output directory: /primus_oss/_checkpoint/0903-Qwen3-0.6B-CLPO"
echo ""
echo "==== CLPO COMPONENTS ===="
echo "Curriculum Learning (CL): $ENABLE_CL"
echo "Query Rewriting (QR): $ENABLE_QR"
echo "Dynamic Rewards (DR): $ENABLE_DR"

# 生成时间戳用于实验名称
TIMESTAMP=$(date +%Y%m%d_%H%M%S)

# 根据配置生成消融实验名称
ENABLE_CL=${ENABLE_CL:-true}           # Curriculum Learning
ENABLE_QR=${ENABLE_QR:-true}           # Query Rewriting  
ENABLE_DR=${ENABLE_DR:-true}           # Dynamic Rewards

# 构建消融实验名称
if [ "$ENABLE_CL" = "true" ] && [ "$ENABLE_QR" = "true" ] && [ "$ENABLE_DR" = "true" ]; then
    # 完整的CLPO模型
    EXPERIMENT_NAME="gsm8k_CLPO_${TIMESTAMP}"
else
    # 消融实验
    ABLATION_PARTS=""
    [ "$ENABLE_CL" = "false" ] && ABLATION_PARTS="${ABLATION_PARTS}CL"
    [ "$ENABLE_QR" = "false" ] && ABLATION_PARTS="${ABLATION_PARTS}QR"  
    [ "$ENABLE_DR" = "false" ] && ABLATION_PARTS="${ABLATION_PARTS}DR"
    
    if [ ${#ABLATION_PARTS} -gt 2 ]; then
        # 多个组件被移除，用&连接
        FORMATTED_ABLATION=$(echo "$ABLATION_PARTS" | sed 's/\(..\)/\1\&/g' | sed 's/&$//')
    else
        FORMATTED_ABLATION="$ABLATION_PARTS"
    fi
    
    EXPERIMENT_NAME="gsm8k-CLPO w/o ${FORMATTED_ABLATION}_${TIMESTAMP}"
fi

echo "Experiment name: $EXPERIMENT_NAME"
echo ""

# 启动训练
echo "🎯 ===== STARTING CLPO TRAINING ====="
echo "Command will be:"
echo "python3 -m recipe.clpo.main_clpo \\"


python3 -m recipe.clpo.main_clpo \
    clpo.enable_curriculum_learning="$ENABLE_CL" \
    clpo.enable_query_rewriting="$ENABLE_QR" \
    clpo.enable_dynamic_rewards="$ENABLE_DR" \
    algorithm.adv_estimator=grpo \
    algorithm.use_kl_in_reward=false \
    data.train_files="$TRAIN_DATA" \
    data.val_files="$VAL_DATA" \
    data.train_batch_size=64 \
    data.val_batch_size=32 \
    data.max_prompt_length=512 \
    data.max_response_length=1024 \
    data.filter_overlong_prompts=true \
    data.truncation=error \
    data.dataloader_num_workers=0 \
    actor_rollout_ref.model.path=/primus_datasets/primus_data/Qwen3_06B_RlhksV \
    actor_rollout_ref.model.enable_gradient_checkpointing=true \
    actor_rollout_ref.model.use_remove_padding=true \
    actor_rollout_ref.actor.optim.lr=1e-6 \
    actor_rollout_ref.actor.ppo_mini_batch_size=32 \
    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=8 \
    actor_rollout_ref.actor.use_kl_loss=false \
    actor_rollout_ref.actor.kl_loss_coef=0.001 \
    actor_rollout_ref.actor.kl_loss_type=low_var_kl \
    actor_rollout_ref.actor.entropy_coeff=0 \
    actor_rollout_ref.actor.fsdp_config.param_offload=false \
    actor_rollout_ref.actor.fsdp_config.optimizer_offload=false \
    actor_rollout_ref.rollout.name=vllm \
    actor_rollout_ref.rollout.n=4 \
    actor_rollout_ref.rollout.tensor_model_parallel_size=1 \
    actor_rollout_ref.rollout.gpu_memory_utilization=0.7 \
    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=8 \
    actor_rollout_ref.rollout.max_model_len=1536 \
    actor_rollout_ref.rollout.max_num_batched_tokens=1536 \
    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=8 \
    actor_rollout_ref.ref.fsdp_config.param_offload=false \
    trainer.logger='["console", "swanlab"]' \
    trainer.project_name=CLPO-GSM8K \
    trainer.experiment_name="'$EXPERIMENT_NAME'" \
    trainer.total_epochs=1 \
    trainer.critic_warmup=0 \
    trainer.test_freq=10 \
    trainer.save_freq=50 \
    trainer.val_before_train=true \
    trainer.n_gpus_per_node="$NGPUS" \
    trainer.nnodes=1 \
    trainer.default_local_dir=/primus_oss/_checkpoint/0903-Qwen3-0.6B-CLPO
    "$@"

TRAINING_EXIT_CODE=$?

echo ""
echo "🏁 ===== TRAINING COMPLETED ====="

if [ $TRAINING_EXIT_CODE -eq 0 ]; then
    echo "✅ Training completed successfully!"
    echo "📁 Checkpoints saved to: /primus_oss/_checkpoint/0903-Qwen3-0.6B-CLPO"
    echo "📊 Experiment name: $EXPERIMENT_NAME"
    
    if [ -d "/primus_oss/_checkpoint/0903-Qwen3-0.6B-CLPO" ]; then
        echo ""
        echo "📋 Output directory contents:"
        ls -la "/primus_oss/_checkpoint/0903-Qwen3-0.6B-CLPO" || true
    fi
    
    echo ""
    echo "🎉 CLPO single machine training finished successfully!"
else
    echo "❌ Training failed with exit code: $TRAINING_EXIT_CODE"
    echo "💡 Please check the logs above for error details"
    echo "🔧 Common issues:"
    echo "   - Check if data files exist and are readable"
    echo "   - Verify GPU memory is sufficient"
    echo "   - Ensure all dependencies are installed"
    echo "   - Check CUDA_VISIBLE_DEVICES setting"
fi

exit $TRAINING_EXIT_CODE
