#!/usr/bin/env bash
set -euo pipefail

# =============================================================================
# CLPO V2 Training Script for GSM8K
# =============================================================================
# 
# CLPO V2 (Curriculum Learning Policy Optimization V2) - 重构版本
# 
# 主要改进：
#   1. 无需自定义采样器 - 直接在trainer内部处理
#   2. 动态难度分类 - 基于训练过程中的reward动态调整
#   3. 智能数据缓冲区 - 5类缓冲区管理
#   4. 过滤替换策略 - Easy≤10%, Medium≈60%, Hard+Rewritten≈30%
#
# 使用方法:
#   # 完整CLPO V2 (默认)
#   bash run_clpo_v2_gsm8k_single.sh
#   
#   # 消融实验
#   ENABLE_CL=false bash run_clpo_v2_gsm8k_single.sh  # 禁用课程学习
#   ENABLE_QR=false bash run_clpo_v2_gsm8k_single.sh  # 禁用查询重写
# =============================================================================

      # 启用IB网络


echo "🚀 ===== CLPO V2 Single Machine Training ====="

# 基本信息
echo "==== SYSTEM INFO ===="
echo "HOSTNAME=$(hostname)"
echo "CURRENT_DIR=$(pwd)"
echo "USER=$(whoami)"

# GPU信息检查
echo "==== GPU INFO ===="
if command -v nvidia-smi >/dev/null 2>&1; then
    nvidia-smi -L || echo "⚠️  nvidia-smi failed"
    NGPUS=$(nvidia-smi -L 2>/dev/null | wc -l || echo 1)
    echo "Available GPUs: $NGPUS"
else
    echo "⚠️  nvidia-smi not found, assuming 1 GPU"
    NGPUS=1
fi

# 数据文件路径
TRAIN_DATA="/primus_datasets/primus_data/iclr_gsm8k_Intity/train.parquet"
VAL_DATA="/primus_datasets/primus_data/iclr_gsm8k_Intity/test.parquet"

echo "==== DATA FILES ===="
echo "TRAIN_DATA: $TRAIN_DATA"
echo "VAL_DATA: $VAL_DATA"

# 设置环境变量
export HYDRA_FULL_ERROR=1
export CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES:-"0,1,2,3,4,5,6,7"}

echo "==== TRAINING CONFIGURATION ===="
echo "CUDA_VISIBLE_DEVICES: $CUDA_VISIBLE_DEVICES"
echo "Number of GPUs to use: $NGPUS"
echo "Output directory: /primus_oss/_checkpoint/0903-Qwen3-0.6B-CLPO-V2"
echo ""

# 生成时间戳用于实验名称
TIMESTAMP=$(date +%Y%m%d_%H%M%S)

# 根据配置生成消融实验名称
ENABLE_CL=${ENABLE_CL:-true}           # Curriculum Learning
ENABLE_QR=${ENABLE_QR:-false}          # Query Rewriting (V2默认关闭)
ENABLE_DR=${ENABLE_DR:-false}          # Dynamic Rewards (V2暂未实现)

echo "==== CLPO V2 COMPONENTS ===="
echo "Curriculum Learning (CL): $ENABLE_CL"
echo "Query Rewriting (QR): $ENABLE_QR"
echo "Dynamic Rewards (DR): $ENABLE_DR"

# 构建消融实验名称
if [ "$ENABLE_CL" = "true" ] && [ "$ENABLE_QR" = "false" ] && [ "$ENABLE_DR" = "false" ]; then
    # 标准CLPO V2配置
    EXPERIMENT_NAME="gsm8k_CLPO_V2_${TIMESTAMP}"
elif [ "$ENABLE_CL" = "true" ] && [ "$ENABLE_QR" = "true" ] && [ "$ENABLE_DR" = "false" ]; then
    # 完整CLPO V2配置
    EXPERIMENT_NAME="gsm8k_CLPO_V2_Full_${TIMESTAMP}"
else
    # 消融实验
    ABLATION_PARTS=""
    [ "$ENABLE_CL" = "false" ] && ABLATION_PARTS="${ABLATION_PARTS}CL"
    [ "$ENABLE_QR" = "true" ] && ABLATION_PARTS="${ABLATION_PARTS}QR"  # V2中QR是额外功能
    [ "$ENABLE_DR" = "true" ] && ABLATION_PARTS="${ABLATION_PARTS}DR"
    
    if [ ${#ABLATION_PARTS} -gt 2 ]; then
        # 多个组件，用&连接
        FORMATTED_ABLATION=$(echo "$ABLATION_PARTS" | sed 's/\(..\)/\1\&/g' | sed 's/&$//')
    else
        FORMATTED_ABLATION="$ABLATION_PARTS"
    fi
    
    if [ "$ENABLE_CL" = "false" ]; then
        EXPERIMENT_NAME="gsm8k_CLPO_V2_wo_${FORMATTED_ABLATION}_${TIMESTAMP}"
    else
        EXPERIMENT_NAME="gsm8k_CLPO_V2_with_${FORMATTED_ABLATION}_${TIMESTAMP}"
    fi
fi

echo "Experiment name: $EXPERIMENT_NAME"
echo ""

# 启动训练
echo "🎯 ===== STARTING CLPO V2 TRAINING ====="
echo "Command will be:"
echo "python3 -m recipe.clpo.main_clpo_v2 \\"


python3 -m recipe.clpo.main_clpo_v2 \
    clpo.enable_curriculum_learning="$ENABLE_CL" \
    clpo.enable_query_rewriting="$ENABLE_QR" \
    clpo.enable_dynamic_rewards="$ENABLE_DR" \
    clpo.easy_threshold=0.7 \
    clpo.hard_threshold=0.3 \
    clpo.target_easy_ratio=0.1 \
    clpo.target_medium_ratio=0.6 \
    clpo.buffer_size_per_category=1000 \
    algorithm.adv_estimator=grpo \
    algorithm.norm_adv_by_std_in_grpo=true \
    algorithm.use_kl_in_reward=false \
    data.train_files="$TRAIN_DATA" \
    data.val_files="$VAL_DATA" \
    data.train_batch_size=64 \
    data.val_batch_size=32 \
    data.gen_batch_size=64 \
    data.max_prompt_length=512 \
    data.max_response_length=1024 \
    data.filter_overlong_prompts=true \
    data.truncation=error \
    data.dataloader_num_workers=4 \
    data.shuffle=true \
    actor_rollout_ref.model.path=/primus_datasets/primus_data/Qwen3_06B_RlhksV \
    actor_rollout_ref.model.enable_gradient_checkpointing=true \
    actor_rollout_ref.model.use_remove_padding=true \
    actor_rollout_ref.actor.optim.lr=1e-6 \
    actor_rollout_ref.actor.ppo_mini_batch_size=32 \
    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=8 \
    actor_rollout_ref.actor.use_kl_loss=false \
    actor_rollout_ref.actor.kl_loss_coef=0.001 \
    actor_rollout_ref.actor.kl_loss_type=low_var_kl \
    actor_rollout_ref.actor.entropy_coeff=0 \
    actor_rollout_ref.actor.fsdp_config.param_offload=false \
    actor_rollout_ref.actor.fsdp_config.optimizer_offload=false \
    actor_rollout_ref.rollout.name=vllm \
    actor_rollout_ref.rollout.n=4 \
    actor_rollout_ref.rollout.tensor_model_parallel_size=1 \
    actor_rollout_ref.rollout.gpu_memory_utilization=0.7 \
    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=8 \
    actor_rollout_ref.rollout.max_model_len=1536 \
    actor_rollout_ref.rollout.max_num_batched_tokens=1536 \
    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=8 \
    actor_rollout_ref.ref.fsdp_config.param_offload=false \
    trainer.logger='["console", "swanlab"]' \
    trainer.project_name=CLPO-GSM8K \
    trainer.experiment_name="'$EXPERIMENT_NAME'" \
    trainer.total_epochs=3 \
    trainer.critic_warmup=0 \
    trainer.test_freq=10 \
    trainer.save_freq=50 \
    trainer.val_before_train=true \
    trainer.n_gpus_per_node="$NGPUS" \
    trainer.nnodes=1 \
    trainer.default_local_dir=/primus_oss/_checkpoint/0903-Qwen3-0.6B-CLPO-V2 \
    critic.enable=false \
    reward_model.enable=false \
    reward_model.reward_manager=naive \
    "$@"

TRAINING_EXIT_CODE=$?

echo ""
echo "🏁 ===== TRAINING COMPLETED ====="

if [ $TRAINING_EXIT_CODE -eq 0 ]; then
    echo "✅ CLPO V2 training completed successfully!"
    echo "📁 Checkpoints saved to: /primus_oss/_checkpoint/0903-Qwen3-0.6B-CLPO-V2"
    echo "📊 Experiment name: $EXPERIMENT_NAME"
    
    if [ -d "/primus_oss/_checkpoint/0903-Qwen3-0.6B-CLPO-V2" ]; then
        echo ""
        echo "📋 Output directory contents:"
        ls -la "/primus_oss/_checkpoint/0903-Qwen3-0.6B-CLPO-V2" || true
    fi
    
    echo ""
    echo "🎉 CLPO V2 single machine training finished successfully!"
    echo "📈 Key improvements over V1:"
    echo "   - No custom sampler required"
    echo "   - Dynamic difficulty classification" 
    echo "   - Intelligent data buffering"
    echo "   - Faster startup time"
else
    echo "❌ CLPO V2 training failed with exit code: $TRAINING_EXIT_CODE"
    echo "💡 Please check the logs above for error details"
    echo "🔧 Common issues:"
    echo "   - Check if data files exist and are readable"
    echo "   - Verify GPU memory is sufficient"
    echo "   - Ensure all dependencies are installed"
    echo "   - Check CUDA_VISIBLE_DEVICES setting"
    echo "   - Verify main_clpo_v2.py exists"
fi

exit $TRAINING_EXIT_CODE
