set -x

export VLLM_ATTENTION_BACKEND=XFORMERS
# Default values
MODEL_PATH="$HOME/DeepScaleR-1.5B-Preview"
# Possible values: aime, amc, math, minerva, olympiad_bench
DATATYPES=("aime")
OUTPUT_DIR="$HOME"  # Add default output directory
N_PASSES=16  # Add default number of passes
MAX_LENGTH=32768  # Default max response length
TP_SIZE=1  # Default tensor parallel size

# Initialize Ray cluster if not already running
if ! ray status >/dev/null 2>&1; then
    echo "Starting Ray cluster..."
    ray start --head --port=6379
    RAY_STARTED=true
else
    echo "Ray cluster already running"
    RAY_STARTED=false
fi

# Function to cleanup Ray on exit
cleanup() {
    if [ "$RAY_STARTED" = true ]; then
        echo "Stopping Ray cluster..."
        ray stop
    fi
}
trap cleanup EXIT

# Parse named arguments
while [[ $# -gt 0 ]]; do
    case $1 in
        --model)
            MODEL_PATH="$2"
            shift 2
            ;;
        --datasets)
            # Convert space-separated arguments into array
            shift
            DATATYPES=()
            while [[ $# -gt 0 && ! $1 =~ ^-- ]]; do
                DATATYPES+=("$1")
                shift
            done
            ;;
        --output-dir)
            OUTPUT_DIR="$2"
            shift 2
            ;;
        --n)
            N_PASSES="$2"
            shift 2
            ;;
        --max-length)
            MAX_LENGTH="$2"
            shift 2
            ;;
        --tp)
            TP_SIZE="$2"
            shift 2
            ;;
        *)
            echo "Unknown argument: $1"
            echo "Usage: $0 --model <model_path> --datasets dataset1 dataset2 ... --output-dir <output_directory> --n <number_of_passes> --max-length <max_response_length> --tp <tensor_parallel_size>"
            exit 1
            ;;
    esac
done

# Echo the values for verification
echo "Model Path: ${MODEL_PATH}"
echo "Datasets: ${DATATYPES[@]}"
echo "Output Directory: ${OUTPUT_DIR}"
echo "Number of Passes: ${N_PASSES}"
echo "Max Response Length: ${MAX_LENGTH}"
echo "Tensor Parallel Size: ${TP_SIZE}"


# Loop through all datatypes
for DATA_TYPE in "${DATATYPES[@]}"; do
    python3 -m verl.trainer.main_generation \
        trainer.nnodes=1 \
        trainer.n_gpus_per_node=8 \
        data.path=/fsx/training/data/user/l1_data/no_budget/${DATA_TYPE}.parquet \
        data.output_path=${OUTPUT_DIR}/${DATA_TYPE}.parquet \
        data.n_samples=${N_PASSES} \
        data.batch_size=1024 \
        model.path=${MODEL_PATH} \
        rollout.temperature=0.6 \
        rollout.response_length=${MAX_LENGTH} \
        rollout.top_k=-1 \
        rollout.top_p=0.95 \
        rollout.gpu_memory_utilization=0.8 \
        rollout.tensor_model_parallel_size=${TP_SIZE}
done

##        +rollout.repetition_penalty=1.1 \
