#!/bin/bash
set -x

# AUTHOR_NAME removed for anonymity
mkdir -p $WANDB_DIR
export CUDA_DEVICE_MAX_CONNECTIONS=1
export HYDRA_FULL_ERROR=1
export VLLM_USE_V1=0

DATA_DIR="../../data/math_curriculum_sampled"
FULL_DATA_DIR="../../data/math_curriculum"
if [ $? -ne 0 ]; then
    echo "ERROR: Data preparation failed!"
    exit 1
fi

# Use the generated files
TRAIN_FILE="${DATA_DIR}/<your-train-file>"
TEST_FILE="${FULL_DATA_DIR}/math__full_difficulty_ordered_test_500.parquet"

# Original validation data directory
SHARED_DATA_PATH=../../data/guru_verl
VAL_DATA_DIR=${SHARED_DATA_PATH}/online_eval/

# Training data - single file format
train_files="['${TRAIN_FILE}']"

# Validation files
val_files="['${TEST_FILE}', '${VAL_DATA_DIR}/math__math_500.parquet', '${VAL_DATA_DIR}/logic__zebra_puzzle_dataset_200.parquet', '${VAL_DATA_DIR}/stem__supergpqa_200.parquet', '${VAL_DATA_DIR}/codegen__humaneval_164.parquet', '${VAL_DATA_DIR}/aime2024.parquet', '${VAL_DATA_DIR}/gsm8k.parquet', '${VAL_DATA_DIR}/amc2023.parquet']"


train_prompt_bsz=512
n_resp_per_prompt=8 
train_prompt_mini_bsz=128  


RESULTS_DIR=<your-results-dir>
CHECKPOINT_DIR=${RESULTS_DIR}/checkpoints
mkdir -p ${CHECKPOINT_DIR}      

# =================== Model Configuration ===================
MODEL_NAME=<your-model-name>
BASE_MODEL=<your-model-path>



# =================== GRPO Training Parameters ===================
# Algorithm settings - GRPO specific
adv_estimator=grpo  # GRPO estimator

# KL settings for GRPO
use_kl_in_reward=False  # GRPO doesn't use KL in reward
use_kl_loss=True  # GRPO uses KL loss instead
kl_loss_coef=0.001  # Standard GRPO KL coefficient
kl_loss_type=low_var_kl  # Low variance KL for GRPO

# PPO clipping (still used in GRPO)
clip_ratio_low=0.2
clip_ratio_high=0.2

# Sequence length limits
max_prompt_length=2048
max_response_length=4096

# Hardware Platform
num_nodes=<your-num-nodes>
n_gpus_per_node=<your-n-gpus-per-node>


EPOCHS=<data_repeated_num>

# Dynamic batch size configuration
use_dynamic_bsz=True

# Calculate max sequence length from input/output settings
max_seq_length=$((max_prompt_length + max_response_length))

actor_seq_multiplier=<your-actor-seq-multiplier>
rollout_seq_multiplier=<your-rollout-seq-multiplier>
# Calculate token limits for GRPO (no critic needed)
actor_ppo_max_token_len=$((max_seq_length * actor_seq_multiplier)) 
rollout_log_prob_max_token_len=$((max_seq_length * rollout_seq_multiplier))  # Same as actor

# Sampling parameters
temperature=1.0


#validation parameters
val_temperature=0.7


# Model parallelism settings
gen_tp=<your-gen-tp>
sp_size=<your-sp-size>

# Memory optimization
offload=<your-offload-choice>
gpu_memory_utilization=<your-gpu-memory-utilization>  # Reduced from 0.65 for stability with mixed domains



# =================== Start GRPO Training ===================

python3 -m verl.trainer.main_ppo \
    hydra.run.dir=<your-run-dir> \  
    hydra.sweep.dir=<your-sweep-dir> \
    hydra.job.chdir=False \
    algorithm.adv_estimator=${adv_estimator} \
    algorithm.use_kl_in_reward=${use_kl_in_reward} \
    algorithm.gamma=1.0 \
    algorithm.lam=0.95 \
    data.train_files="${train_files}" \
    data.val_files="${val_files}" \
    data.prompt_key=prompt \
    data.truncation='right' \
    data.max_prompt_length=${max_prompt_length} \
    data.max_response_length=${max_response_length} \
    data.train_batch_size=${train_prompt_bsz} \
    data.filter_overlong_prompts=True \
    data.shuffle=False \
    data.trust_remote_code=True \
    actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \
    actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \
    actor_rollout_ref.actor.kl_loss_type=${kl_loss_type} \
    actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \
    actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \
    actor_rollout_ref.actor.strategy="fsdp" \
    actor_rollout_ref.actor.optim.lr=1e-6 \
    actor_rollout_ref.actor.optim.lr_warmup_steps=0 \
    actor_rollout_ref.actor.optim.weight_decay=0.1 \
    actor_rollout_ref.actor.optim.warmup_style=constant \
    actor_rollout_ref.actor.optim.min_lr_ratio=0. \
    actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \
    actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \
    actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \
    actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \
    actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \
    actor_rollout_ref.actor.entropy_coeff=0 \
    actor_rollout_ref.actor.grad_clip=1.0 \
    actor_rollout_ref.actor.fsdp_config.fsdp_size=-1 \
    actor_rollout_ref.rollout.name=vllm \
    actor_rollout_ref.rollout.n=${n_resp_per_prompt} \
    actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \
    actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${rollout_log_prob_max_token_len} \
    actor_rollout_ref.rollout.gpu_memory_utilization=${gpu_memory_utilization} \
    actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \
    actor_rollout_ref.rollout.temperature=${temperature} \
    actor_rollout_ref.rollout.val_kwargs.temperature=${val_temperature} \
    actor_rollout_ref.rollout.val_kwargs.n=1 \
    actor_rollout_ref.rollout.val_kwargs.do_sample=True \
    actor_rollout_ref.rollout.enforce_eager=False \
    actor_rollout_ref.rollout.free_cache_engine=False \
    actor_rollout_ref.rollout.max_num_seqs=64 \
    actor_rollout_ref.model.path=${BASE_MODEL} \
    actor_rollout_ref.model.trust_remote_code=True \
    actor_rollout_ref.model.use_remove_padding=True \
    actor_rollout_ref.model.enable_gradient_checkpointing=True \
    trainer.critic_warmup=0 \
    trainer.logger='["console", "wandb"]' \
    trainer.project_name=<your-wandb-project> \
    trainer.experiment_name=<your-wandb-experiment-name> \
    trainer.val_before_train=True \
    trainer.n_gpus_per_node=${n_gpus_per_node} \
    trainer.test_freq=1 \
    trainer.save_freq=10 \
    trainer.total_epochs=${EPOCHS} \
    trainer.max_actor_ckpt_to_keep=1 \
    trainer.resume_mode=auto \
    trainer.default_local_dir=<your-checkpoint-dir> $@