#!/bin/bash
set -ex


# Parse command line arguments, must processed before conda activate
while [[ $# -gt 0 ]]; do
    case $1 in
        --model-path)
            MODEL_PATH="$2"
            shift 2
            ;;
        --res-length)
            RES_LENGTH="$2"
            shift 2
            ;;
        --train-type)
            TRAIN_TYPE="$2"
            shift 2
            ;;
        --val-before-train)
            VAL_BEFORE_TRAIN="$2"
            shift 2
            ;;
        --use-debug)
            USE_DEBUG="$2"
            shift 2
            ;;
        *)
            # Keep all other arguments for passing to the python command
            break
            ;;
    esac
done

export WORLD_SIZE=${WORLD_SIZE:-1}
export RANK=${RANK:-0}
export VLLM_ATTENTION_BACKEND=XFORMERS
export HYDRA_FULL_ERROR=1
export LIVECODEBENCH_DATA_PATH=/path/to/folder/data/livecodebench_2408_2502
export WANDB_API_KEY="0055fca53232a22a37d4fc9cf90b94ef85608e72"

BATCH_SIZE=1
BATCH_SIZE_ALL=1
PPO_MINI_BATH=$((BATCH_SIZE_ALL / 1))
GROUP_SIZE=4
# Default values
RES_LENGTH=${RES_LENGTH:-$((1024 * 6))}
MODEL_PATH=${MODEL_PATH:-"deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"}
TRAIN_TYPE=${TRAIN_TYPE:-"adversarial_test_case"}
VAL_BEFORE_TRAIN=${VAL_BEFORE_TRAIN:-"False"}
USE_DEBUG=${USE_DEBUG:-"True"}
# Set train_files based on TRAIN_TYPE
DYNAMIC_TEST_CASE=fixed
USE_TEST_CASE=True
if [ "$TRAIN_TYPE" = "data0409" ]; then
    train_files=[\"/path/to/folder/verl_code/verl/verl_data/code/train/0409_without_other/train_code_data0409_del32all0_del1p5all8_public_test.pkl\"]
    USE_TEST_CASE=False
elif [ "$TRAIN_TYPE" = "code_generation" ]; then
    train_files=[\"/path/to/folder/data/taco_for_code_generation.pkl\"]
    USE_TEST_CASE=False
elif [ "$TRAIN_TYPE" = "test_generation" ]; then
    train_files=[\"/path/to/folder/data/taco_for_test_generation_debug.pkl\"]
    USE_TEST_CASE=False
elif [ "$TRAIN_TYPE" = "code_mix" ]; then
    train_files=[\"/path/to/folder/data/taco_for_code_generation_four.pkl\",\"/path/to/folder/data/taco_for_test_generation.pkl\"]
elif [ "$TRAIN_TYPE" = "dynamic_test_case" ]; then
    train_files=[\"/path/to/folder/data/taco_for_code_generation.pkl\"]
    DYNAMIC_TEST_CASE=dynamic
elif [ "$TRAIN_TYPE" = "adversarial_test_case" ]; then
    train_files=[\"/path/to/folder/data/taco_for_code_generation.pkl\"]
    DYNAMIC_TEST_CASE=adversarial
else
    echo "Error: Unknown train_type: $TRAIN_TYPE"
    echo "Supported types: data0409, code_generation, test_generation, code_mix"
    exit 1
fi

test_files=[\"/path/to/folder/data/livecodebench_2408_2502_tagged.parquet\"]

PROJECT_NAME="debug"
EXP_NAME=${TRAIN_TYPE}_L$(($RES_LENGTH / 1024))k_$(basename $MODEL_PATH)_AllBs${BATCH_SIZE_ALL}_GenBs${BATCH_SIZE}_MiniBs${PPO_MINI_BATH}_GS${GROUP_SIZE}_${WORLD_SIZE}Nodes_$(date +"%m%d%H%M")

export PYTHONBREAKPOINT=0

python3 -m verl.trainer.main_ppo \
    algorithm.adv_estimator=grpo \
    trainer.project_name=$PROJECT_NAME \
    trainer.experiment_name=$EXP_NAME \
    trainer.n_gpus_per_node=1 \
    trainer.nnodes=$WORLD_SIZE \
    trainer.total_epochs=30 \
    trainer.save_freq=-1 \
    trainer.test_freq=-1 \
    trainer.val_before_train=$VAL_BEFORE_TRAIN \
    trainer.logger=['console','wandb'] \
    trainer.default_local_dir=$(dirname $(dirname $(realpath $0)))/rl_ckpt/$PROJECT_NAME/$EXP_NAME \
    trainer.default_hdfs_dir=null \
    data.train_files=$train_files \
    data.val_files=$test_files \
    data.train_batch_size=$BATCH_SIZE \
    data.max_response_length=$RES_LENGTH \
    actor_rollout_ref.model.path=$MODEL_PATH \
    actor_rollout_ref.rollout.name=vllm \
    actor_rollout_ref.rollout.n=$GROUP_SIZE \
    actor_rollout_ref.rollout.temperature=0.6 \
    actor_rollout_ref.rollout.val_temperature=0.6 \
    actor_rollout_ref.rollout.gpu_memory_utilization=0.75 \
    actor_rollout_ref.rollout.tensor_model_parallel_size=1 \
    actor_rollout_ref.actor.optim.lr=1e-6 \
    actor_rollout_ref.actor.optim.lr_warmup_steps=10 \
    actor_rollout_ref.actor.clip_higher_ratio=0.00 \
    actor_rollout_ref.actor.ppo_mini_batch_size=$PPO_MINI_BATH \
    actor_rollout_ref.actor.ppo_max_token_len_per_gpu=28000 \
    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=1 \
    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=1 \
    reward_model.reward_manager=yr \
    reward_model.max_test_cases=3 \
    data.test_case.type=$DYNAMIC_TEST_CASE \
    data.test_case.use=$USE_TEST_CASE \
    debug.use_debug=$USE_DEBUG \
    "${@:1}"
