set -x

export HYDRA_FULL_ERROR=1
# export CUDA_VISIBLE_DEVICES=0,1,2,3

PROJECT_NAME=fortune

## task config
# TASK=formula
TASK=text


# model config
# MODEL=qwen
MODEL=llama
# MODEL=qwen14b


# training config
NODE_NUM=4
GPU_NUM_PER_NODE=2
MINI_BATCH_SIZE=64
MICRO_BATCH_SIZE_PER_GPU=1


EXP_NAME=${TASK}_${MODEL}_all
OUTPUT_DIR=Fortune-private/training_outputs/rl/${PROJECT_NAME}/${EXP_NAME}


set +x
if [ $MODEL = 'qwen' ]; then
    INIT_MODEL=Qwen/Qwen2.5-Coder-7B-Instruct
    GPU_MEMORY_UTILIZATION=0.3
elif [ $MODEL = 'llama' ]; then
    INIT_MODEL=meta-llama/Llama-3.1-8B-Instruct
    GPU_MEMORY_UTILIZATION=0.3
elif [ $MODEL = 'qwen14b' ]; then
    INIT_MODEL=Qwen/Qwen2.5-Coder-14B-Instruct
    GPU_MEMORY_UTILIZATION=0.4
    MODEL=qwen
    NODE_NUM=$((NODE_NUM * 2))
fi

TRAIN_DATASET_LIST="wikitq tabfact finqa hitab multihiertt"
TEST_DATASET_LIST="wikitq tabfact finqa hitab multihiertt aitqa tablebench"
train_files=""
test_files=""

for DATASET in $TRAIN_DATASET_LIST; do
    tmp_train_files=data/processed_data/${TASK}/${MODEL}/${DATASET}/train.parquet
    if [ -z "$train_files" ]; then
        train_files="'$tmp_train_files'"
    else
        train_files="$train_files, '$tmp_train_files'"
    fi
done

for DATASET in $TEST_DATASET_LIST; do
    tmp_test_files=data/processed_data/${TASK}/${MODEL}/${DATASET}/test.parquet
    if [ -z "$test_files" ]; then
        test_files="'$tmp_test_files'"
    else
        test_files="$test_files, '$tmp_test_files'"
    fi
done
set -x

train_files="[$train_files]"
test_files="[$test_files]"



python3 -m verl.trainer.main_ppo \
    data.train_files="$train_files" \
    data.val_files="$test_files" \
    data.train_batch_size=$MINI_BATCH_SIZE \
    data.max_prompt_length=8192\
    data.max_response_length=512 \
    data.shuffle=True \
    data.filter_overlong_prompts=False \
    data.truncation='error' \
    actor_rollout_ref.model.path=$INIT_MODEL \
    actor_rollout_ref.model.enable_gradient_checkpointing=False \
    actor_rollout_ref.actor.optim.lr=1e-6 \
    actor_rollout_ref.model.use_remove_padding=True \
    actor_rollout_ref.actor.ppo_mini_batch_size=$MINI_BATCH_SIZE \
    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=$MICRO_BATCH_SIZE_PER_GPU \
    actor_rollout_ref.model.enable_gradient_checkpointing=True \
    actor_rollout_ref.actor.fsdp_config.param_offload=False \
    actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \
    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=$MICRO_BATCH_SIZE_PER_GPU \
    actor_rollout_ref.rollout.tensor_model_parallel_size=$GPU_NUM_PER_NODE \
    actor_rollout_ref.rollout.name=vllm \
    actor_rollout_ref.rollout.temperature=0.6 \
    actor_rollout_ref.rollout.max_num_batched_tokens=10240 \
    actor_rollout_ref.rollout.gpu_memory_utilization=$GPU_MEMORY_UTILIZATION \
    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=$MICRO_BATCH_SIZE_PER_GPU \
    actor_rollout_ref.ref.fsdp_config.param_offload=True \
    critic.optim.lr=1e-5 \
    critic.model.use_remove_padding=True \
    critic.model.path=$INIT_MODEL \
    critic.model.enable_gradient_checkpointing=False \
    critic.ppo_micro_batch_size_per_gpu=$MICRO_BATCH_SIZE_PER_GPU \
    critic.model.fsdp_config.param_offload=False \
    critic.model.fsdp_config.optimizer_offload=False \
    algorithm.kl_ctrl.kl_coef=0.0001 \
    trainer.val_before_train=True \
    trainer.critic_warmup=0 \
    trainer.logger=['console','wandb'] \
    trainer.project_name=$PROJECT_NAME \
    trainer.experiment_name=$EXP_NAME \
    trainer.n_gpus_per_node=$GPU_NUM_PER_NODE \
    trainer.nnodes=$NODE_NUM \
    trainer.save_freq=-1 \
    trainer.test_freq=20 \
    trainer.max_actor_ckpt_to_keep=20 \
    trainer.max_critic_ckpt_to_keep=1 \
    trainer.default_local_dir=$OUTPUT_DIR \
    trainer.resume_mode=disable \
    trainer.total_epochs=10 $@
