set -e
set -x
VISIBLE_DEVICES="0,1,2,3,4,5,6,7"
export HYDRA_FULL_ERROR=1

ROOT_DIR=$(pwd)
TRAIN_FILE=$HOME/data/deepscaler/train.parquet

VAL_PREFIX=$HOME/data/benchmarks
MATH500_PATH=$VAL_PREFIX/math500.parquet
AIME_PATH=$VAL_PREFIX/aime.parquet
AIME25_PATH=$VAL_PREFIX/aime25.parquet
AMC_PATH=$VAL_PREFIX/amc.parquet
OLYMPIAD_PATH=$VAL_PREFIX/olympiadbench.parquet
MINERVA_PATH=$VAL_PREFIX/minerva.parquet
VAL_FILE_LIST="['$MATH500_PATH', '$AMC_PATH', '$AIME_PATH', '$AIME25_PATH', '$OLYMPIAD_PATH', '$MINERVA_PATH']"


LR=1e-7
BACKBONE="Qwen2.5-Math-7B"
BACKBONE_PATH="your_model_path"
MAX_PROMPT_LENGTH=1024
MAX_GEN_LENGTH=3072
MODEL_ID="qwen2.5_math_7b"
DATE=$(date +"%m%d_%H%M")
TASK="OSFT"
DATASET_NAME="dsr"
ROLLOUT_N=1
EXPERIMENT="CE-${DATASET_NAME}"
ENABLE_TRAIN_TEMP=False
TAU_S=0.6


PROJECT_NAME="formal"

mkdir -p ${ROOT_DIR}/logs
mkdir -p ${ROOT_DIR}/outputs

MODEL="${TASK}-${BACKBONE}"
OUTPUT_DIR="${ROOT_DIR}/outputs/${MODEL}/${TASK}/${EXPERIMENT}/${DATE}"

mkdir -p ${OUTPUT_DIR}

EXP="${TASK}-${MODEL_ID}-${EXPERIMENT}-lr${LR}-TAUS${TAU_S}-rollout${ROLLOUT_N}-ttr${ENABLE_TRAIN_TEMP}-${DATE}"
LOG_FILE="${ROOT_DIR}/logs/${EXP}.log"


export SWANLAB_API_KEY="your_swanlab_api_key"
export SWANLAB_LOG_DIR=${ROOT_DIR}/logs/swanlab/${EXP}
export SWANLAB_MODE=cloud


CUDA_VISIBLE_DEVICES=${VISIBLE_DEVICES} \
python3 -m recipe.osft.main_osft \
    data.train_files=$TRAIN_FILE \
    data.val_files="$VAL_FILE_LIST" \
    data.train_batch_size=128 \
    data.filter_overlong_prompts=True \
    data.max_prompt_length=${MAX_PROMPT_LENGTH} \
    data.max_response_length=${MAX_GEN_LENGTH} \
    actor_rollout_ref.model.path=${BACKBONE_PATH} \
    actor_rollout_ref.model.use_liger=False \
    actor_rollout_ref.model.use_shm=True \
    actor_rollout_ref.model.use_remove_padding=True \
    actor_rollout_ref.model.enable_gradient_checkpointing=True \
    actor_rollout_ref.actor.optim.lr_warmup_steps=10 \
    actor_rollout_ref.actor.optim.lr=${LR} \
    actor_rollout_ref.actor.ppo_mini_batch_size=64 \
    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4 \
    actor_rollout_ref.actor.use_kl_loss=False \
    actor_rollout_ref.actor.fsdp_config.param_offload=False \
    actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \
    actor_rollout_ref.rollout.tensor_model_parallel_size=1 \
    actor_rollout_ref.rollout.name=vllm \
    actor_rollout_ref.rollout.gpu_memory_utilization=0.85 \
    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=4 \
    actor_rollout_ref.rollout.n=${ROLLOUT_N} \
    actor_rollout_ref.rollout.temperature=${TAU_S} \
    actor_rollout_ref.rollout.val_kwargs.temperature=1 \
    actor_rollout_ref.rollout.val_kwargs.n=8 \
    actor_rollout_ref.rollout.val_kwargs.do_sample=True \
    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=4 \
    trainer.enable_train_temperature=${ENABLE_TRAIN_TEMP} \
    trainer.rejection_sampling=False \
    trainer.logger=['console','swanlab'] \
    trainer.project_name=${PROJECT_NAME} \
    trainer.experiment_name=${EXP} \
    trainer.val_before_train=True \
    trainer.default_local_dir=${OUTPUT_DIR} \
    trainer.n_gpus_per_node=8 \
    trainer.default_hdfs_dir=null \
    trainer.nnodes=1 \
    trainer.save_freq=100 \
    trainer.rollout_data_dir=${OUTPUT_DIR}/rollout_data \
    trainer.validation_data_dir=${OUTPUT_DIR}/rollout_eval_data \
    trainer.test_freq=10 \
    +trainer.log_freq=1 \
    trainer.total_epochs=1 | tee ${LOG_FILE}
