# set -x
# bash sync.sh
# bash sync.sh
# export CUDA_LAUNCH_BLOCKING=1

# conda create -n torch240 python=3.10.13
# conda activate torch240
# conda install pytorch==2.4.0 torchvision==0.19.0 torchaudio==2.4.0 pytorch-cuda=12.1 -c pytorch -c nvidia
# pip install openrlhf --no-deps
# pip install openrlhf
# pip install wandb accelerate bitsandbytes deepspeed==0.15.0 isort jsonlines loralib optimum peft tensorboard torchmetrics transformers-stream-generator
# pip install antlr4-python3-runtime==4.11.0

NCCL_DEBUG=INFO

# 输出传入参数的数量
echo "传入了 $# 个参数"

# 如果没有参数，给出提示
if [ $# -eq 0 ]; then
    echo "没有提供任何参数"
    exit 1
elif [ $# -lt 4]; then
    echo "提供参数个数 $# 小于程序所需参数个数 4"
    exit 2
else
    # 打印所有参数
    echo "以下是传入的参数："
    for arg in "$@"; do
        echo "$arg"
    done
fi
# train_model_type=llama
train_model_type=$5

current_dir=$(pwd)
if [[ "$current_dir" == *"apdcephfs_sh8"* ]]; then
    use_h20=2
    echo "use a100 pro"
else
    use_h20=0
    echo "use a100"
fi

HOST_FILE=./hostfile_2gpu_1
master_addr=11.215.34.238
NUM_NODES=1
MAX_SEQ_LENGTH=3072
DATASET_NAME=Math
if [ "$use_h20" = "0" ]; then
    micro_train_batch_size=1
    zero_stage=2

    WORK_DIR=/root/workspace/self-improvement
elif [ "$use_h20" = "2" ]; then
    # micro_train_batch_size=4
    # zero_stage=2
    micro_train_batch_size=1
    zero_stage=3
    WORK_DIR=/apdcephfs_sh8/share_300719895/mask/workspace/code/EfficientPlanningTuning
else
    # micro_train_batch_size=4
    # zero_stage=1
    micro_train_batch_size=2
    zero_stage=3
    WORK_DIR=/apdcephfs_gy2_302625456/share_302625456/user/mask/workspace/code/EfficientPlanningTuning
fi

NUM_NODES=1
MAX_SEQ_LENGTH=3072
if [ "$NUM_NODES" = "2" ]; then
    echo "training on 2 nodes"
    HOST_FILE=./hostfile_2gpu
    master_addr=30.159.160.180

    # HOST_FILE=./hostfile_2gpu_2
    # master_addr=11.215.34.238
elif [ "$NUM_NODES" = "4" ]; then
    echo "training on 4 nodes"
    HOST_FILE=./hostfile_4gpu
    master_addr=30.159.160.35
else
    echo "training on 1 nodes"
fi

if [ "$use_h20" = "0" ]; then
#    judge_llm_path=/apdcephfs_qy3/share_301812049/mask/workspace/hf_models/KbsdJames/Omni-Judge
    judge_llm_path=/root/workspace/hf_models/KbsdJames/Omni-Judge
elif [ "$use_h20" = "1" ]; then
    judge_llm_path=/apdcephfs_gy2_302625456/share_302625456/user/mask/workspace/models/KbsdJames/Omni-Judge
else
    judge_llm_path=/apdcephfs_sh8/share_300719895/mask/workspace/hf_models/KbsdJames/Omni-Judge
fi
if [ "$train_model_type" = "llama" ]; then
    if [ "$use_h20" = "0" ]; then
        MODEL_PATH=/apdcephfs_qy3/share_301812049/mask/workspace/hf_models/meta-llama/Llama-3.1-8B-Instruct
        MODEL_PATH=/root/workspace/hf_models/meta-llama/Llama-3.1-8B-Instruct
        # MODEL_PATH=/apdcephfs_qy3/share_301812049/mask/workspace/hf_models/meta-llama/Llama-3.1-8B
        # OUTPUT_DIR=ckpts/Llama-3.1-8B-Instruct/${DATASET_NAME}_sft/random_planning_pruning
        OUTPUT_DIR=ckpts/Llama-3.1-8B-Instruct/${DATASET_NAME}_sft/negative_planning_pruning
    elif [ "$use_h20" = "2" ]; then
    
        MODEL_PATH=/apdcephfs_sh8/share_300719895/mask/workspace/hf_models/meta-llama/Llama-3.1-8B-Instruct
        # MODEL_PATH=/apdcephfs_qy3/share_301812049/mask/workspace/hf_models/meta-llama/Llama-3.1-8B
        OUTPUT_DIR=ckpts/Llama-3.1-8B-Instruct/${DATASET_NAME}_sft/negative_planning_pruning
    else
        MODEL_PATH=/apdcephfs_gy2_302625456/share_302625456/user/mask/workspace/models/meta-llama/Llama-3.1-8B-Instruct
        OUTPUT_DIR=ckpts/Llama-3.1-8B-Instruct/${DATASET_NAME}_sft/negative_planning_pruning
        
    fi
    CHAT_TEMPLATE_NAME=llama-3.1-chat
elif [ "$train_model_type" = "qwen" ]; then
    if [ "$use_h20" = "0" ]; then
#        MODEL_PATH=/apdcephfs_qy3/share_301812049/mask/workspace/hf_models/Qwen/Qwen2.5-Math-7B-Instruct
        MODEL_PATH=/root/workspace/hf_models/Qwen/Qwen2.5-Math-7B-Instruct

        OUTPUT_DIR=ckpts/Qwen2.5-Math-7B-Instruct/${DATASET_NAME}_sft/negative_planning_pruning
    elif [ "$use_h20" = "2" ]; then
        MODEL_PATH=/apdcephfs_sh8/share_300719895/mask/workspace/hf_models/Qwen/Qwen2.5-Math-7B-Instruct
        OUTPUT_DIR=ckpts/Qwen2.5-Math-7B-Instruct/${DATASET_NAME}_sft/negative_planning_pruning
    else
        MODEL_PATH=/apdcephfs_gy2_302625456/share_302625456/user/mask/workspace/models/Qwen/Qwen2.5-Math-7B-Instruct
        OUTPUT_DIR=ckpts/Qwen2.5-Math-7B-Instruct/${DATASET_NAME}_sft/negative_planning_pruning
    fi
    CHAT_TEMPLATE_NAME=qwen-math
elif [ "$train_model_type" = "deepseek_R1" ]; then
    if [ "$use_h20" = "0" ]; then
#        MODEL_PATH=/apdcephfs_qy3/share_301812049/mask/workspace/hf_models/deepseek-ai/DeepSeek-R1-Distill-Qwen-7B
        MODEL_PATH=/root/workspace/hf_models/deepseek-ai/DeepSeek-R1-Distill-Qwen-7B
        OUTPUT_DIR=ckpts/DeepSeek-R1-Distill-Qwen-7B/${DATASET_NAME}_sft/negative_planning_pruning
    elif [ "$use_h20" = "2" ]; then
        MODEL_PATH=/apdcephfs_sh8/share_300719895/mask/workspace/hf_models/deepseek-ai/DeepSeek-R1-Distill-Qwen-7B
        # MODEL_PATH=/apdcephfs_sh8/share_300719895/mask/workspace/hf_models/deepseek-ai/DeepSeek-R1-Distill-Qwen-32B
        
        
        OUTPUT_DIR=ckpts/DeepSeek-R1-Distill-Qwen-7B/${DATASET_NAME}_sft/negative_planning_pruning
    fi
    CHAT_TEMPLATE_NAME=deepseek_R1
else
    echo "模型类型错误"
    exit 3
fi
    

# DATASET_LENGTH=60W_self_training
# DATASET_LENGTH=60W
# DATASET_LENGTH=1W
# DATASET_LENGTH=1W_negative
# DATASET_LENGTH=Qwen25_1W
# DATASET_LENGTH=Qwen25_1W_negative

DATASET_LENGTH=$4
save_only=0
# if [ "$save_only" == "1" ]; then
#     DATASET_LENGTH=1W
# fi
data_ratio=-1
if [ "$DATASET_LENGTH" = "1W" ]; then
    DATA_PATH="data/Math/Llama-3.1-8B-Instruct_self_training_positive.jsonl"
    data_ratio=-1
elif [ "$DATASET_LENGTH" = "Qwen25_1W" ]; then
    DATA_PATH="data/Math/Qwen2.5-Math-7B-Instruct_one_shot_step_baseline.jsonl"
    data_ratio=-1
elif [ "$DATASET_LENGTH" = "1W_negative" ]; then
    DATA_PATH="data/Math/Llama-3.1-8B-Instruct_self_training_positive_negative_11860.jsonl"
    data_ratio=-1
elif [ "$DATASET_LENGTH" = "1W_noposterior" ]; then
    DATA_PATH="rollout_outputs/noposterior_sync_-1_rollout_dataset_llama1W_prefix_length_0_n_32_data_length_11484.json"
    data_ratio=-1
elif [ "$DATASET_LENGTH" = "1W_noposterior_boxed" ]; then
    DATA_PATH="rollout_outputs/sync_-1_rollout_dataset_1W_prefix_length_0_n_1_data_length_11484_add_prompt.json"
    data_ratio=-1

elif [ "$DATASET_LENGTH" = "Qwen25_1W_noposterior" ]; then
    DATA_PATH="rollout_outputs/sync_-1_rollout_dataset_Qwen25_1W_prefix_length_0_n_1_data_length_11405.json"
    data_ratio=-1
elif [ "$DATASET_LENGTH" = "Qwen25_1W_negative" ]; then
    DATA_PATH="data/Math/Qwen2.5-Math-7B-Instruct_self_training_positive_negative_11994.jsonl"
    data_ratio=-1
elif [ "$DATASET_LENGTH" = "60W" ]; then
    DATA_PATH="data/Math/open_math_instruct_2.train.jsonl.dedup"
elif [ "$DATASET_LENGTH" = "60W_self_training" ]; then
    DATA_PATH="data/sampling/OpenMathInstruct2_self_training_v2.jsonl"
elif [ "$DATASET_LENGTH" = "60W_self_training_t0_negative" ]; then
    DATA_PATH="data/sampling/OpenMathInstruct2_self_training_v2_t0.jsonl"

elif [ "$DATASET_LENGTH" = "60W_self_training_no_posterior" ]; then
    DATA_PATH="data/sampling/OpenMathInstruct2_self_training_v3_noposterior.jsonl"
elif [ "$DATASET_LENGTH" = "60W_self_training_no_posterior_whole" ]; then
    DATA_PATH="data/sampling/OpenMathInstruct2_self_training_v4_noposterior_whole.jsonl"
elif [ "$DATASET_LENGTH" = "Qwen_60W_self_training_no_posterior_whole" ]; then
    DATA_PATH="data/sampling/60W_open_math_instruct_2.train.Qwen2.5-Math-7B-Instruct.t0.6.n1_v2.jsonl"
elif [ "$DATASET_LENGTH" = "100W" ]; then
    DATA_PATH="/root/workspace/hf_datasets/nvidia/OpenMathInstruct-1/correct_solutions/train.jsonl"
elif [ "$DATASET_LENGTH" = "1W_subset_positive" ]; then
    DATA_PATH="data/evaluated_data/noposterior_sync_-1_rollout_dataset_llama1W_prefix_length_0_n_32_data_length_11484_after_eval_generation_positive.jsonl"
    data_ratio=-1
elif [ "$DATASET_LENGTH" = "1W_subset_negative" ]; then
    DATA_PATH="data/evaluated_data/noposterior_sync_-1_rollout_dataset_llama1W_prefix_length_0_n_32_data_length_11484_after_eval_generation_negative.jsonl"
    data_ratio=-1
elif [ "$DATASET_LENGTH" = "1W_subset_positive_boxed" ]; then
    DATA_PATH="data/evaluated_data/sync_-1_rollout_dataset_1W_prefix_length_0_n_1_data_length_11484_add_prompt_after_eval_generation_positive.jsonl"
    data_ratio=-1
elif [ "$DATASET_LENGTH" = "1W_subset_negative_boxed" ]; then
    DATA_PATH="data/evaluated_data/sync_-1_rollout_dataset_1W_prefix_length_0_n_1_data_length_11484_add_prompt_after_eval_generation_negative.jsonl"
    data_ratio=-1
elif [ "$DATASET_LENGTH" = "zhiwei_hard_boxed_noposterior" ]; then
    DATA_PATH="data/Math/train.annotated_subset_356748_genrerations.jsonl"
    data_ratio=-1
elif [ "$DATASET_LENGTH" = "zhiwei_hard_boxed_noposterior_check" ]; then
    DATA_PATH="data/Math/train.annotated_subset_356748_genrerations.jsonl"
    data_ratio=-1
elif [ "$DATASET_LENGTH" = "zhiwei_hard_boxed_noposterior_check_v2" ]; then
    DATA_PATH="rollout_outputs/hard_322101_v2.json"
    data_ratio=-1
elif [ "$DATASET_LENGTH" = "llama_zhiwei_hard_boxed_noposterior_check_v2" ]; then
    DATA_PATH="rollout_outputs/hard_299019_llama_v2.json"
    data_ratio=-1
elif [ "$DATASET_LENGTH" = "deepseek_zhiwei_hard_boxed_noposterior_check_v2" ]; then
    DATA_PATH="rollout_outputs/hard_356720_deepseek_v2.json"
    data_ratio=-1
elif [ "$DATASET_LENGTH" = "deepseek_zhiwei_hard_random_boxed_noposterior_check_v3" ]; then
    # DATA_PATH="rollout_outputs/hard_random_349750_deepseek_v3.json"
    DATA_PATH="rollout_outputs/hard_random_337325_deepseek_max_length_3072_v3.json"
    data_ratio=-1
elif [ "$DATASET_LENGTH" = "deepseek_32B_zhiwei_hard_random_boxed_noposterior" ]; then
    # DATA_PATH="rollout_outputs/hard_random_349750_deepseek_v3.json"
    DATA_PATH="data/Math/v2hq.train.deepseek-r1-distil-qwen-32b.one.jsonl"
    data_ratio=-1
elif [ "$DATASET_LENGTH" = "deepseek_7B_zhiwei_hard_boxed_noposterior_check_v1" ]; then
    # DATA_PATH="rollout_outputs/hard_356690_deepseek_r1_qwen_max_length_4096_v1.json"
    DATA_PATH="rollout_outputs/hard_348159_deepseek_r1_qwen_max_length_3072_v1.json"
    data_ratio=-1
elif [ "$DATASET_LENGTH" = "deepseek_7B_zhiwei_hard_boxed_noposterior_check_v1_8192" ]; then
    DATA_PATH="rollout_outputs/hard_356551_deepseek_r1_qwen_max_length8192_v1.jsonl"
    data_ratio=-1
elif [ "$DATASET_LENGTH" = "deepseek_7B_zhiwei_hard_boxed_noposterior_check_v1_8192_01" ]; then
    DATA_PATH="rollout_outputs/hard_356551_deepseek_r1_qwen_max_length8192_v1_ratio_01.jsonl"
    data_ratio=-1
elif [ "$DATASET_LENGTH" = "limo" ]; then
    DATA_PATH="data/Math/limo.jsonl"
    data_ratio=-1
elif [ "$DATASET_LENGTH" = "limo_deepseek_r1_7B_noposterior" ]; then
    DATA_PATH="rollout_outputs/limo_817_deepseek_r1_qwen_max_length16384_v1.jsonl"
    data_ratio=-1
elif [ "$DATASET_LENGTH" = "limo_qwen_7B_noposterior" ]; then
    DATA_PATH="rollout_outputs/limo_whole_solution_qwen_-1.json"
    data_ratio=-1
else
    exit 4
fi
echo $DATA_PATH

if [[ "$DATASET_LENGTH" == *"1W"* ]]; then
    SAVE_STEPS=90
elif [[ "$DATASET_LENGTH" == *"limo"* ]]; then
    SAVE_STEPS=-1
else
    SAVE_STEPS=500
fi

# ADD_SFT_PROMPT=" Please wrap the final answer in $\\boxed{}$ tag. Let's think step by step."
ADD_SFT_PROMPT="1"
ADD_PLANNING_PROMPT=" Please provide the initial step towards resolving the question. This step may serve as a foundation but might not encompass the entire solution."

without_ass_token=0


planning_pruning=$1
planning_pruning_ratio=0.90
planning_pruning_token=1

DEBUG=0

planning_pruning_mode=$3
# max_epochs=5
max_epochs=1

if echo "$DATASET_LENGTH" | grep -q -- "1W"; then
    max_epochs=2
fi

if echo "$DATASET_LENGTH" | grep -q -- "limo"; then
    max_epochs=2
    MAX_SEQ_LENGTH=4096
fi

if [ "$train_model_type" = "deepseek_R1" ]; then
    MAX_SEQ_LENGTH=16384
fi

if [ "$save_only" == "1" ]; then
    max_epochs=1
fi

load_checkpoint=1

compute_q_loss=$6
# compute_q_loss=1
# negative_mode="negative_whole"
# negative_mode="negative_planning"
# negative_mode_list="negative_whole"
if echo "$DATASET_LENGTH" | grep -q -- "negative"; then
    TASK=negative
    data_type="negative"
    negative_mode_list=("negative_whole")
else
    TASK=sft
    data_type="positive"
    negative_mode_list=("negative_none")
fi

if echo "$DATASET_LENGTH" | grep -q -- "posterior"; then
    TASK=noposterior
    data_type="negative"
    negative_mode_list=("noposterior")
fi
#for lr in 1e-5 5e-6 1e-6 2e-5; do \
for negative_mode in "${negative_mode_list[@]}"; do \
    for prefix_length in $2; do \
        for mpm_p in 0.5; do \
            planning_prefix_tuning_length=${prefix_length}
            planning_suffix_tuning_length=0
            lr=1e-6
            # lr=2e-5
            lr_s=constant_with_warmup
            mpm_enable=0
            mpm_p=${mpm_p}
            mpm_mode="91"
            mpm_ratio=1
            mpm_prefix_length=0
            mpm_suffix=mpm_enable${mpm_enable}_p_${mpm_p}_mode_${mpm_mode}_mpm_ratio_${mpm_ratio}_prefix_l_${mpm_prefix_length}
            planning_suffix=planning_pruning_${planning_pruning}_ratio_${planning_pruning_ratio}_planning_pruning_token_${planning_pruning_token}
            SUFFIX=${data_type}_epoch_${max_epochs}_lr${lr}_ratio_${data_ratio}_${DATASET_LENGTH}_${CHAT_TEMPLATE_NAME}_pl_${planning_prefix_tuning_length}
            if [ "$mpm_enable" = "1" ]; then
                SUFFIX=${SUFFIX}_${mpm_suffix}
                planning_pruning=0
            fi
            if [ "$planning_pruning" = "1" ]; then
                SUFFIX=${SUFFIX}_${planning_suffix}_mixture_ppm_${planning_pruning_mode}
            fi
            if [ "$ADD_SFT_PROMPT" = "" ]; then
                if [ "$DATASET_LENGTH" != "Qwen25_1W" ]; then
                    SUFFIX=${SUFFIX}_no_sft_prompt
                    echo "不添加SFT提示"
                fi
            else
                echo "添加SFT提示"
            fi
            if [ "$data_type" = "negative" ]; then
                SUFFIX=${SUFFIX}_${negative_mode}
            fi
            if [ "$DEBUG" = "1" ]; then
                SUFFIX=debug_${SUFFIX}
                rm -rf "${OUTPUT_DIR}/${SUFFIX}"
            fi
            if [ "$compute_q_loss" = "1" ]; then
                SUFFIX=${SUFFIX}_compute_q_loss
                echo "计算 question 损失"
            fi
            if [ "$train_model_type" = "llama" ]; then
                RUN_NAME=Llama-3.1-8B-Instruct_${SUFFIX}
                # RUN_NAME=Llama-3.1-8B_${SUFFIX}
            elif [ "$train_model_type" = "qwen" ]; then
                RUN_NAME=Qwen2.5-Math-7B-Instruct_${SUFFIX}
            elif [ "$train_model_type" = "deepseek_R1" ]; then
                RUN_NAME=DeepSeek-R1-Distill-Qwen-7B_${SUFFIX}
            fi
            # if [ "$planning_pruning" = "0" ]; then
            # if [ -d "${OUTPUT_DIR}/${SUFFIX}" ]; then

            if [ "$use_h20" = "1" ]; then
                # rm -rf ${OUTPUT_DIR}/${SUFFIX}
                echo "使用H20训练"
            fi
            if [ -d "${OUTPUT_DIR}/${SUFFIX}" ]; then
                cp ${OUTPUT_DIR}/${SUFFIX}/training.log ${OUTPUT_DIR}/${SUFFIX}/training.log.bak
            fi

            if [ -e "${OUTPUT_DIR}/${SUFFIX}/model-00001-of-00004.safetensors" ]; then
                echo "训练已完成"
                if [ "$save_only" = "1" ]; then
                    load_checkpoint=1
                fi
            fi
            if [ -e "${OUTPUT_DIR}/${SUFFIX}/model-00001-of-00004.safetensors" ]; then
                echo "达到训练结束条件，退出循环"
                load_checkpoint=0
            fi
            if [[ -d "${OUTPUT_DIR}/${SUFFIX}" && "$load_checkpoint" = "0" ]]; then
                echo "文件夹已经存在: ${OUTPUT_DIR}/${SUFFIX}"
            else
                # mkdir -p ${OUTPUT_DIR}/${SUFFIX}
                echo "文件夹: ${OUTPUT_DIR}/${SUFFIX}"
                mkdir -p ${OUTPUT_DIR}/${SUFFIX}
                
                if [ "$NUM_NODES" = "1" ]; then
                    echo "单节点训练"
                    for i in {0..10};do
                        deepspeed --module train.train_sft \
                        --max_len ${MAX_SEQ_LENGTH} \
                        --dataset ${DATASET_NAME} \
                        --input_key question \
                        --output_key response \
                        --debug ${DEBUG} \
                        --train_batch_size 64 \
                        --micro_train_batch_size ${micro_train_batch_size} \
                        --lr_scheduler ${lr_s} \
                        --max_samples 500000 \
                        --pretrain ${MODEL_PATH} \
                        --save_path ${OUTPUT_DIR}/${SUFFIX} \
                        --save_steps ${SAVE_STEPS} \
                        --logging_steps 1 \
                        --eval_steps -1 \
                        --zero_stage ${zero_stage} \
                        --bf16 \
                        --flash_attn \
                        --max_epochs ${max_epochs} \
                        --learning_rate ${lr} \
                        --load_checkpoint \
                        --gradient_checkpointing \
                        --chat_template_name ${CHAT_TEMPLATE_NAME} \
                        --apply_chat_template \
                        --tasks $TASK \
                        --overlap_comm \
                        --data_ratio ${data_ratio} \
                        --data_path ${DATA_PATH} \
                        --add_prompt "${ADD_SFT_PROMPT}" \
                        --add_planning_prompt "${ADD_PLANNING_PROMPT}" \
                        --planning_pruning_ratio "${planning_pruning_ratio}" \
                        --compute_q_loss ${compute_q_loss} \
                        --negative_mode ${negative_mode} \
                        --dataset_length ${DATASET_LENGTH} \
                        --save_only ${save_only} \
                        --without_ass_token ${without_ass_token} \
                        --mpm_enable ${mpm_enable} \
                        --mpm_p ${mpm_p} \
                        --mpm_mode ${mpm_mode} \
                        --mpm_ratio ${mpm_ratio} \
                        --mpm_prefix_length ${mpm_prefix_length} \
                        --planning_pruning ${planning_pruning} \
                        --planning_pruning_token ${planning_pruning_token} \
                        --planning_pruning_mode ${planning_pruning_mode} \
                        --planning_prefix_tuning_length ${planning_prefix_tuning_length} \
                        --planning_suffix_tuning_length ${planning_suffix_tuning_length} 2>&1 | tee ${OUTPUT_DIR}/${SUFFIX}/training.log
                        if [ -e "${OUTPUT_DIR}/${SUFFIX}/model-00001-of-00004.safetensors" ]; then
                            echo "达到训练结束条件，退出循环"
                            break
                        fi
                    done
                fi
            
                if [ "$NUM_NODES" != "1" ]; then
                    # export NCCL_IB_TIMEOUT=22
                    # export NCCL_DEBUG=INFO
                    # export NCCL_DEBUG_SUBSYS=ALL
                    # export TORCH_DISTRIBUTED_DEBUG=INFO
                    export NCCL_NET_GDR_READ=1
                    export NCCL_IB_TIMEOUT=24
                    export NCCL_IB_GID_INDEX=3
                    export NCCL_IB_SL=3
                    export NCCL_CHECKS_DISABLE=1
                    export NCCL_P2P_DISABLE=0
                    export NCCL_IB_DISABLE=0
                    export NCCL_LL_THRESHOLD=16384
                    export NCCL_IB_CUDA_SUPPORT=1
                    export NCCL_SOCKET_IFNAME=bond1
                    export UCX_NET_DEVICES=bond1
                    export NCCL_IB_HCA=mlx5_bond_1,mlx5_bond_5,mlx5_bond_3,mlx5_bond_7,mlx5_bond_4,mlx5_bond_8,mlx5_bond_2,mlx5_bond_6
                    export NCCL_COLLNET_ENABLE=0
                    export SHARP_COLL_ENABLE_SAT=0
                    export NCCL_NET_GDR_LEVEL=2
                    export NCCL_IB_QPS_PER_CONNECTION=4
                    export NCCL_IB_TC=160
                    export NCCL_PXN_DISABLE=1
                    deepspeed --hostfile ${HOST_FILE} \
                    --num_nodes=${NUM_NODES} \
                    --num_gpus=8 \
                    --master_addr=${master_addr} \
                    --master_port=29500 \
                    --module train.train_sft \
                    --max_len ${MAX_SEQ_LENGTH} \
                    --dataset ${DATASET_NAME} \
                    --input_key question \
                    --output_key response \
                    --debug ${DEBUG} \
                    --train_batch_size 64 \
                    --micro_train_batch_size ${micro_train_batch_size} \
                    --lr_scheduler ${lr_s} \
                    --max_samples 500000 \
                    --pretrain ${MODEL_PATH} \
                    --save_path ${OUTPUT_DIR}/${SUFFIX} \
                    --save_steps ${SAVE_STEPS} \
                    --logging_steps 1 \
                    --eval_steps -1 \
                    --zero_stage ${zero_stage} \
                    --bf16 \
                    --flash_attn \
                    --max_epochs ${max_epochs} \
                    --learning_rate ${lr} \
                    --load_checkpoint \
                    --gradient_checkpointing \
                    --chat_template_name ${CHAT_TEMPLATE_NAME} \
                    --apply_chat_template \
                    --tasks $TASK \
                    --adam_offload \
                    --overlap_comm \
                    --data_ratio ${data_ratio} \
                    --data_path ${DATA_PATH} \
                    --add_prompt "${ADD_SFT_PROMPT}" \
                    --add_planning_prompt "${ADD_PLANNING_PROMPT}" \
                    --planning_pruning_ratio "${planning_pruning_ratio}" \
                    --negative_mode ${negative_mode} \
                    --dataset_length ${DATASET_LENGTH} \
                    --save_only ${save_only} \
                    --without_ass_token ${without_ass_token} \
                    --mpm_enable ${mpm_enable} \
                    --mpm_p ${mpm_p} \
                    --mpm_mode ${mpm_mode} \
                    --mpm_ratio ${mpm_ratio} \
                    --compute_q_loss ${compute_q_loss} \
                    --mpm_prefix_length ${mpm_prefix_length} \
                    --planning_pruning ${planning_pruning} \
                    --planning_pruning_token ${planning_pruning_token} \
                    --planning_pruning_mode ${planning_pruning_mode} \
                    --planning_prefix_tuning_length ${planning_prefix_tuning_length} \
                    --planning_suffix_tuning_length ${planning_suffix_tuning_length} 2>&1 | tee ${OUTPUT_DIR}/${SUFFIX}/training.log
                fi
            fi
            # --packing_samples \
            # ADD_PROMPT=" Please wrap the final answer in $\\boxed{}$ tag. Let's think step by step."
	        ADD_PROMPT=${ADD_SFT_PROMPT}
            GEN_DIR=${WORK_DIR}/gen

            MODEL_NAME=${RUN_NAME}
            EXP_DIR=${GEN_DIR}/$MODEL_NAME

    #        EVAL_DATASETS=(aime gsm8k math500 gpqa arithmetic college_math gaokao2023en minerva_math olympiadbench primary)
            # EVAL_DATASETS=(aime gsm8k math500 gpqa arithmetic college_math gaokao2023en minerva_math olympiadbench primary)
            EVAL_DATASETS=(aime25 aime24 gpqa aime math500 gsm8k)
            
            for split in "${EVAL_DATASETS[@]}"; do
                split=${split}
                MATH_EXP_DIR=${EXP_DIR}/${split}
                EVAL_MODEL=${OUTPUT_DIR}/${SUFFIX}
                if [ -e "${MATH_EXP_DIR}/config.json" ]; then
                    echo "file exists: ${MATH_EXP_DIR}/config.json"
                else
                    echo "file does not exist: $MATH_EXP_DIR/config.json, begin evaluation..."
#                    CUDA_VISIBLE_DEVICES=3,4,6,7
                    python3 ${WORK_DIR}/inference/eval.py --type eval \
                        --base_model ${EVAL_MODEL} \
                        --chat_template_name ${CHAT_TEMPLATE_NAME} \
                        --output_dir $MATH_EXP_DIR \
                        --bf16 False \
                        --split ${split} \
                        --llm_judge True \
                        --add_prompt "${ADD_PROMPT}"
                fi

                cat ${MATH_EXP_DIR}/config.json
                if [ -e "${MATH_EXP_DIR}/result.log" ]; then
                    echo "file exists: ${MATH_EXP_DIR}/result.log"
                else
                    
#                    CUDA_VISIBLE_DEVICES=3,4,6,7
                    python3 $WORK_DIR/inference/eval.py --type judge \
                        --config_file ${MATH_EXP_DIR}/config.json \
                        --judge_llm_path ${judge_llm_path}
                fi
                echo "----------------------- ${split} done -----------------------"
                cat $MATH_EXP_DIR/result.log
            done
        done
    done
done
