#!/bin/bash
# Multi-Puzzle GSPO for OLMo-3-7B — from v2 SFT ep3 (4 GPUs)
#
# Fresh GSPO from v2 SFT ep3 merged checkpoint (AIME24 23.3%, val_loss=0.383).
# A/B comparison with ep5 run to test GSPO from earlier SFT checkpoint with more headroom.
#
# Key changes vs ep5 script (multi_puzzle_gspo_olmo3_v2_sft.sh):
#   - Base model: v2 SFT ep3 merged (earlier checkpoint for A/B comparison)
#   - 4 GPUs: 4 independent vLLM engines, FSDP across 4
#   - n_gen=4: matched to GPU count (same per-GPU load as ep5's 8-GPU/8-gen)
#   - MAX_RESPONSE_LENGTH=28000 (vs 22000): more token budget for reasoning
#   - KL beta=0.001: kept same as ep5 for clean A/B comparison
#   - format_weight=0.1: kept same as ep5 for clean A/B comparison
#   - All other hyperparameters identical to ep5 for clean comparison
#
# Data: ~9,000 examples → ~70 steps at batch=128 (1 epoch)
# Total responses/step: 128 prompts x 4 gens = 512 → 16 PPO updates/step (safe with GSPO clips)
#
# Usage:
#   Local:  ./train/verl_grpo/multi_puzzle_gspo_olmo3_v2_sft_ep3.sh

set -e
set -x

# ============================================
# Python environment: vLLM 0.12.0 venv (B200)
# ============================================
VLLM012_VENV="${VLLM_VENV_PATH:-$HOME/verl-vllm012}"
if [ ! -f "$VLLM012_VENV/bin/activate" ]; then
    echo "ERROR: vLLM 0.12.0 venv not found at $VLLM012_VENV"
    echo "Run: bash scripts/setup_vllm012_venv.sh"
    exit 1
fi
source "$VLLM012_VENV/bin/activate"
echo "Activated vLLM 0.12.0 venv: $VIRTUAL_ENV"
python3 -c "import vllm; print(f'vLLM version: {vllm.__version__}')"
python3 -c "import torch; print(f'PyTorch version: {torch.__version__}')"
python3 -c "import verl; print(f'VERL: installed')" 2>/dev/null || echo "WARNING: VERL not importable"

# Detect project directory
if [ -d /opt/ml/code ]; then
    PROJECT_DIR="/opt/ml/code"
else
    PROJECT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/../.." && pwd)"
fi
cd "$PROJECT_DIR"
echo "Project directory: $PROJECT_DIR"

# Use 4 GPUs
export CUDA_VISIBLE_DEVICES=0,1,2,3

# Enable vLLM V1 engine (required by VERL 0.7.0+)
export VLLM_USE_V1=1

# B200/SM100a required env vars
export VLLM_USE_TRTLLM_ATTENTION=0
export VLLM_ATTENTION_BACKEND=FLASH_ATTN

# Fix vLLM 0.12.0 LoRA PDL crash on B200 (SM100a)
UTILS_FILE=$(python3 -c "import vllm; import os; print(os.path.join(os.path.dirname(vllm.__file__), 'lora/ops/triton_ops/utils.py'))" 2>/dev/null)
if [ -n "$UTILS_FILE" ] && [ -f "$UTILS_FILE" ]; then
    if grep -q "has_device_capability(90)" "$UTILS_FILE" && ! grep -q "has_device_capability(100)" "$UTILS_FILE"; then
        echo "Patching vLLM LoRA PDL: disabling for SM100a (B200)..."
        sed -i 's/return current_platform.is_cuda() and current_platform.has_device_capability(90)/return current_platform.is_cuda() and current_platform.has_device_capability(90) and not current_platform.has_device_capability(100)/' "$UTILS_FILE"
        echo "PDL disabled for SM100a. LoRA Triton kernels will compile without gdc_wait()."
    else
        echo "vLLM LoRA PDL patch already applied or not needed."
    fi
fi

# Disable torch dynamo/compile to avoid triton autotuning CUDA errors
export TORCHDYNAMO_DISABLE=1

# Use fresh triton cache
export TRITON_CACHE_DIR="/tmp/triton_cache_grpo_$$"

# Prevent expandable_segments conflict with vLLM V1's CuMemAllocator
unset PYTORCH_CUDA_ALLOC_CONF

# Ray memory settings
export RAY_object_store_memory=150000000000
export RAY_DISABLE_DASHBOARD=1

# Ensure Python output is flushed immediately
export PYTHONUNBUFFERED=1

# WandB configuration
export WANDB_API_KEY="${WANDB_API_KEY:-${WANDB_API_KEY}}"
export WANDB_MODE="${WANDB_MODE:-online}"
export WANDB_CONSOLE=off

# HuggingFace configuration
export HF_TOKEN="${HF_TOKEN}"
export VERL_HF_TOKEN="$HF_TOKEN"
export HF_HOME="${HF_HOME:-/tmp/hf_cache}"
export HF_HUB_CACHE="${HF_HUB_CACHE:-/tmp/hf_cache}"

# ============================================
# Configuration
# ============================================

# Model — v2 SFT ep3 merged (AIME24 23.3%, val_loss 0.383)
# This serves as BOTH the init weights AND the KL reference model.
MERGED_CHECKPOINT="${PROJECT_DIR}/checkpoints/olmo3_7b_multi_puzzle_dsr_v2/merged_ep3_fp32"

# LoRA config
LORA_RANK=64
LORA_ALPHA=128          # 2x scaling
LORA_TARGET_MODULES="all-linear"

# Training parameters
LEARNING_RATE=5e-5      # Standard (exploration fix value)
TRAIN_BATCH_SIZE=128    # Same as ep5
NUM_GENERATIONS=4       # 4 samples/prompt (4 GPUs, same per-GPU load as ep5's 8-GPU/8-gen)
TEMPERATURE=0.8
TOP_P=1.0
BETA=0.001              # Same as ep5 for clean A/B comparison
NUM_EPOCHS=1            # Single epoch, no data recycling

# Sequence lengths
MAX_PROMPT_LENGTH=3000          # Bridges 7x7/8x8 prompts
MAX_RESPONSE_LENGTH=28000       # More token budget for reasoning (vs 22K in ep5)

# PPO mini/micro batch sizes
PPO_MINI_BATCH_SIZE=128
MICRO_BATCH_SIZE=2              # Chunked entropy avoids OOM

# Forward-only micro batch sizes
REF_LOG_PROB_MICRO_BATCH=4
ROLLOUT_LOG_PROB_MICRO_BATCH=4

# Dynamic batch sizing
USE_DYNAMIC_BSZ=true
PPO_MAX_TOKEN_LEN_PER_GPU=32000  # Must be >= MAX_PROMPT_LENGTH + MAX_RESPONSE_LENGTH = 31000

# Evaluation
EVAL_EVERY=5

# GPU config — 4 GPUs
N_GPUS=4

# Prompt templates (per puzzle)
BRIDGES_PROMPT="${PROJECT_DIR}/prompts/bridges_intformat.txt"
GALAXIES_PROMPT="${PROJECT_DIR}/prompts/galaxies_intformat.txt"
PATTERN_PROMPT="${PROJECT_DIR}/prompts/pattern_formatted.txt"
UNDEAD_PROMPT="${PROJECT_DIR}/prompts/undead_formatted.txt"

# ============================================
# Training datasets — adapted to v2 SFT feasibility
# ============================================

# --- Bridges ---
# Primary: 7x7dm (22% accuracy, strong partial)
BRIDGES_7x7DM_TRAIN="anon-neurips26/bridges_7x7dm_grpo_5k_intformat_json"
# Secondary: 7x7dh (10% accuracy, partial=0.30)
BRIDGES_7x7DH_TRAIN="anon-neurips26/bridges_7x7dh_grpo_5k_intformat_json"
# Stabilizer: 5x5dm (96%, prevents regression)
BRIDGES_5x5DM_STAB="anon-neurips26/bridges_5x5dm_grpo_train_5k_intformat_json"

# --- Pattern ---
# DEFERRED: 5x5de (12% at 24K tokens, causes 51% truncation at 28K — add in curriculum stage)
# Primary: 4x4 (42% accuracy, excellent signal)
PATTERN_4x4_TRAIN="anon-neurips26/pattern_4x4_grpo_train"
# Stabilizer: 3x3 from mixed 3x3+4x4 dataset
PATTERN_3x3_4x4_TRAIN="anon-neurips26/pattern_3x3_4x4_grpo_train"

# --- Undead ---
# Primary: 4x4de (10% accuracy, partial improved after reward fix)
UNDEAD_4x4DE_TRAIN="anon-neurips26/undead_4x4de_grpo_5k"
# Stabilizer: 3x3 (83%, prevents regression)
UNDEAD_3x3_STAB="anon-neurips26/undead_3x3_grpo_train"

# --- Galaxies ---
# Primary + stabilizer: 3x3+4x4 mix (~1875 4x4 + ~53 3x3)
GALAXIES_TRAIN="anon-neurips26/galaxies_3x3_4x4_grpo_train_intformat_json"

# Subsampling targets
BRIDGES_7x7DM_N=2000
BRIDGES_7x7DH_N=1500
BRIDGES_5x5DM_N=200       # Stabilizer
PATTERN_4x4_N=1166        # All available 4x4 examples
PATTERN_3x3_STAB_N=200    # Stabilizer (filter 3x3 from mixed dataset)
UNDEAD_4x4DE_N=2000
UNDEAD_3x3_N=200           # Stabilizer
GALAXIES_N=1928            # All available (mostly 4x4)
# Total: ~9,194 examples → ~71 steps at batch=128

# ============================================
# Eval datasets — 5 datasets
# ============================================
BRIDGES_7x7DE_EVAL="anon-neurips26/bridges_7x7de_test200_intformat_json"
PATTERN_4x4_EVAL="anon-neurips26/pattern_4x4_test200"
UNDEAD_4x4DE_EVAL="anon-neurips26/undead_4x4de_test200"
GALAXIES_3x3_EVAL="anon-neurips26/galaxies_3x3_test200_intformat_json"

ALL_EVAL_DATASETS=(
    "$BRIDGES_7x7DE_EVAL"
    "$PATTERN_4x4_EVAL"
    "$UNDEAD_4x4DE_EVAL"
    "$GALAXIES_3x3_EVAL"
)

# Build eval dataset names for best checkpoint tracking
EVAL_DATASET_NAMES=""
for dataset in "${ALL_EVAL_DATASETS[@]}"; do
    dataset_slug=$(echo "$dataset" | sed 's|/|_|g' | sed 's|-|_|g')
    if [ -n "$EVAL_DATASET_NAMES" ]; then
        EVAL_DATASET_NAMES="${EVAL_DATASET_NAMES},${dataset_slug}"
    else
        EVAL_DATASET_NAMES="${dataset_slug}"
    fi
done
export VERL_EVAL_DATASETS="$EVAL_DATASET_NAMES"

# Reward function
REWARD_FN_PATH="${PROJECT_DIR}/reward_function/all.py"
REWARD_METHOD="partial_v2_plus_format"

# Output
PROJECT_NAME="olmo3-puzzle-grpo"
EXPERIMENT_NAME="multi_puzzle_gspo_olmo3_v2_sft_ep3"
CHECKPOINT_DIR="${PROJECT_DIR}/checkpoints/${PROJECT_NAME}/${EXPERIMENT_NAME}"
TENSORBOARD_DIR="${CHECKPOINT_DIR}/tensorboard"

# HuggingFace upload config
export VERL_HF_REPO_ID="anon-neurips26/${PROJECT_NAME}-${EXPERIMENT_NAME}"

# Data output directory
DATA_DIR="${PROJECT_DIR}/data/multi_puzzle/gspo_olmo3_v2_sft_ep3"

# ============================================
# Validate merged checkpoint exists
# ============================================
echo "=========================================="
echo "Validating v2 SFT ep3 merged checkpoint"
echo "=========================================="

if [ ! -d "$MERGED_CHECKPOINT" ] || [ ! -f "$MERGED_CHECKPOINT/config.json" ]; then
    echo "ERROR: Merged checkpoint not found at $MERGED_CHECKPOINT"
    echo "Expected: config.json in $MERGED_CHECKPOINT"
    echo ""
    echo "To create it:"
    echo "  python src/verl_helpers/merge_lora.py --base_model allenai/OLMo-3-7B-Instruct-SFT --lora_path checkpoints/olmo3_7b_multi_puzzle_dsr_v2/global_step_204 --output_dir checkpoints/olmo3_7b_multi_puzzle_dsr_v2/merged_ep3_fp32 --torch_dtype float32"
    exit 1
fi
echo "Checkpoint validated: $MERGED_CHECKPOINT"

# ============================================
# Prepare HuggingFace datasets (per puzzle)
# ============================================
echo "=========================================="
echo "Preparing HuggingFace datasets for VERL"
echo "=========================================="

mkdir -p "$DATA_DIR"

# --- Bridges: 7x7dm (primary) ---
echo ""
echo "--- Bridges (7x7dm, primary) ---"
python3 src/verl_helpers/prepare_hf_datasets.py \
    --train_datasets $BRIDGES_7x7DM_TRAIN \
    --eval_datasets $BRIDGES_7x7DE_EVAL \
    --prompt_template "$BRIDGES_PROMPT" \
    --output_dir "$DATA_DIR/bridges_7x7dm" \
    --data_source bridges \
    --system_prompt rsft \
    --no_cache

# --- Bridges: 7x7dh (secondary) ---
echo ""
echo "--- Bridges (7x7dh, secondary) ---"
python3 src/verl_helpers/prepare_hf_datasets.py \
    --train_datasets $BRIDGES_7x7DH_TRAIN \
    --prompt_template "$BRIDGES_PROMPT" \
    --output_dir "$DATA_DIR/bridges_7x7dh" \
    --data_source bridges \
    --system_prompt rsft \
    --no_cache

# --- Bridges: 5x5dm stabilizer ---
echo ""
echo "--- Bridges (5x5dm, stabilizer) ---"
python3 src/verl_helpers/prepare_hf_datasets.py \
    --train_datasets $BRIDGES_5x5DM_STAB \
    --prompt_template "$BRIDGES_PROMPT" \
    --output_dir "$DATA_DIR/bridges_5x5dm_stab" \
    --data_source bridges \
    --system_prompt rsft \
    --no_cache

# --- Pattern: 4x4 (primary, best signal) ---
echo ""
echo "--- Pattern (4x4, primary) ---"
python3 src/verl_helpers/prepare_hf_datasets.py \
    --train_datasets $PATTERN_4x4_TRAIN \
    --eval_datasets $PATTERN_4x4_EVAL \
    --prompt_template "$PATTERN_PROMPT" \
    --output_dir "$DATA_DIR/pattern_4x4" \
    --data_source pattern \
    --system_prompt rsft \
    --no_cache

# --- Pattern: 3x3+4x4 (for 3x3 stabilizer extraction) ---
echo ""
echo "--- Pattern (3x3+4x4, for stabilizer extraction) ---"
python3 src/verl_helpers/prepare_hf_datasets.py \
    --train_datasets $PATTERN_3x3_4x4_TRAIN \
    --prompt_template "$PATTERN_PROMPT" \
    --output_dir "$DATA_DIR/pattern_3x3_4x4" \
    --data_source pattern \
    --system_prompt rsft \
    --no_cache

# --- Undead: 4x4de (primary) ---
echo ""
echo "--- Undead (4x4de, primary) ---"
python3 src/verl_helpers/prepare_hf_datasets.py \
    --train_datasets $UNDEAD_4x4DE_TRAIN \
    --eval_datasets $UNDEAD_4x4DE_EVAL \
    --prompt_template "$UNDEAD_PROMPT" \
    --output_dir "$DATA_DIR/undead_4x4de" \
    --data_source undead \
    --system_prompt rsft \
    --no_cache

# --- Undead: 3x3 stabilizer ---
echo ""
echo "--- Undead (3x3, stabilizer) ---"
python3 src/verl_helpers/prepare_hf_datasets.py \
    --train_datasets $UNDEAD_3x3_STAB \
    --prompt_template "$UNDEAD_PROMPT" \
    --output_dir "$DATA_DIR/undead_3x3_stab" \
    --data_source undead \
    --system_prompt rsft \
    --no_cache

# --- Galaxies: 3x3+4x4 (primary + stabilizer) ---
echo ""
echo "--- Galaxies (3x3+4x4, primary) ---"
python3 src/verl_helpers/prepare_hf_datasets.py \
    --train_datasets $GALAXIES_TRAIN \
    --eval_datasets $GALAXIES_3x3_EVAL \
    --prompt_template "$GALAXIES_PROMPT" \
    --output_dir "$DATA_DIR/galaxies" \
    --data_source galaxies \
    --system_prompt rsft \
    --no_cache

# ============================================
# Concatenate + subsample training data
# ============================================
echo "=========================================="
echo "Concatenating + subsampling training data (~9K v2 SFT ep3 mix)"
echo "=========================================="

python3 -c "
import pandas as pd
import numpy as np

np.random.seed(42)

dfs = []

# Define all dataset components: (name, parquet_path, target_n, filter_fn)
# filter_fn is optional — used to extract 3x3 from mixed 3x3+4x4 dataset
components = [
    # Bridges
    ('bridges_7x7dm (primary)',    '$DATA_DIR/bridges_7x7dm/train_rl.parquet',      $BRIDGES_7x7DM_N, None),
    ('bridges_7x7dh (secondary)',  '$DATA_DIR/bridges_7x7dh/train_rl.parquet',      $BRIDGES_7x7DH_N, None),
    ('bridges_5x5dm (stab)',       '$DATA_DIR/bridges_5x5dm_stab/train_rl.parquet', $BRIDGES_5x5DM_N, None),
    # Pattern (5x5de deferred to curriculum stage)
    ('pattern_4x4 (primary)',      '$DATA_DIR/pattern_4x4/train_rl.parquet',        $PATTERN_4x4_N, None),
    # Undead
    ('undead_4x4de (primary)',     '$DATA_DIR/undead_4x4de/train_rl.parquet',       $UNDEAD_4x4DE_N, None),
    ('undead_3x3 (stab)',          '$DATA_DIR/undead_3x3_stab/train_rl.parquet',    $UNDEAD_3x3_N, None),
    # Galaxies (all)
    ('galaxies_3x3+4x4 (primary)','$DATA_DIR/galaxies/train_rl.parquet',           $GALAXIES_N, None),
]

for name, path, target_n, filter_fn in components:
    print(f'\n--- {name} ---')
    df = pd.read_parquet(path)
    print(f'  Raw: {len(df)}')
    if len(df) > target_n:
        df = df.sample(n=target_n, random_state=42)
    print(f'  Used: {len(df)}')
    dfs.append(df)

# Pattern 3x3 stabilizer — extract 3x3 examples from mixed 3x3+4x4 dataset
# Filter by gridsize in extra_info if available, otherwise by prompt length (3x3 prompts are shorter)
print(f'\n--- pattern_3x3 (stab, from mixed 3x3+4x4) ---')
df_pattern_mix = pd.read_parquet('$DATA_DIR/pattern_3x3_4x4/train_rl.parquet')
print(f'  Mixed dataset raw: {len(df_pattern_mix)}')

# Try to filter by extra_info gridsize
def is_3x3(row):
    try:
        ei = row.get('extra_info', {})
        if isinstance(ei, str):
            import json
            ei = json.loads(ei)
        gs = ei.get('gridsize', '')
        return '3x3' in str(gs) or '3' == str(gs)
    except:
        return False

# Check if extra_info has gridsize
sample_ei = df_pattern_mix.iloc[0].get('extra_info', {})
if isinstance(sample_ei, str):
    import json
    try:
        sample_ei = json.loads(sample_ei)
    except:
        sample_ei = {}

if 'gridsize' in (sample_ei if isinstance(sample_ei, dict) else {}):
    df_3x3 = df_pattern_mix[df_pattern_mix.apply(is_3x3, axis=1)]
    print(f'  3x3 filtered by gridsize: {len(df_3x3)}')
else:
    # Fallback: take shorter prompts (3x3 grids have shorter prompts)
    # Sort by prompt length and take the shorter half
    prompt_lens = df_pattern_mix['prompt'].apply(lambda p: len(str(p)))
    median_len = prompt_lens.median()
    df_3x3 = df_pattern_mix[prompt_lens <= median_len]
    print(f'  3x3 filtered by prompt length (<= median {median_len:.0f}): {len(df_3x3)}')

target_3x3 = $PATTERN_3x3_STAB_N
if len(df_3x3) > target_3x3:
    df_3x3 = df_3x3.sample(n=target_3x3, random_state=42)
elif len(df_3x3) == 0:
    print(f'  WARNING: No 3x3 examples found, taking {target_3x3} random from mixed')
    df_3x3 = df_pattern_mix.sample(n=min(target_3x3, len(df_pattern_mix)), random_state=42)
print(f'  Used: {len(df_3x3)}')
dfs.append(df_3x3)

# Concatenate and shuffle
combined = pd.concat(dfs, ignore_index=True)
combined = combined.sample(frac=1, random_state=42).reset_index(drop=True)

# Print data_source distribution
print(f'\n=== Combined training set: {len(combined)} examples ===')
print(f'=== Steps per epoch at batch=128: {len(combined) // 128} ===')
print(combined['data_source'].value_counts())

output_path = '$DATA_DIR/train_combined_v2_sft_ep3.parquet'
combined.to_parquet(output_path, index=False)
print(f'\nSaved to {output_path}')
"

if [ $? -ne 0 ]; then
    echo "ERROR: Failed to concatenate/subsample training data"
    exit 1
fi

# ============================================
# Build data file lists
# ============================================
TRAIN_DATA="$DATA_DIR/train_combined_v2_sft_ep3.parquet"

TEST_FILES="["
for dataset in "${ALL_EVAL_DATASETS[@]}"; do
    dataset_slug=$(echo "$dataset" | sed 's|/|_|g' | sed 's|-|_|g')
    # Route each eval dataset to the correct puzzle subdirectory
    case "$dataset" in
        *bridges_7x7*) puzzle_dir="bridges_7x7dm" ;;
        *pattern_4x4*) puzzle_dir="pattern_4x4" ;;
        *undead_4x4*) puzzle_dir="undead_4x4de" ;;
        *galaxies_3x3*) puzzle_dir="galaxies" ;;
    esac
    TEST_FILES+="'$DATA_DIR/${puzzle_dir}/eval_${dataset_slug}.parquet',"
done
TEST_FILES="${TEST_FILES%,}]"

echo "Train data: $TRAIN_DATA"
echo "Test files: $TEST_FILES"

# ============================================
# Create output directories
# ============================================
mkdir -p $CHECKPOINT_DIR
mkdir -p $TENSORBOARD_DIR

echo "=========================================="
echo "Starting VERL GSPO Training (OLMo-3-7B v2 SFT ep3, 4 GPUs)"
echo "=========================================="
echo "Checkpoint: $MERGED_CHECKPOINT (v2 SFT ep3 merged)"
echo "KL reference: same checkpoint (KL starts ~0, grows as policy diverges)"
echo "LoRA: rank=$LORA_RANK, alpha=$LORA_ALPHA"
echo "LR: $LEARNING_RATE (cosine decay, min_ratio=0.3)"
echo "Batch: $TRAIN_BATCH_SIZE prompts x $NUM_GENERATIONS gens = $((TRAIN_BATCH_SIZE * NUM_GENERATIONS)) responses/step"
echo "Max prompt: $MAX_PROMPT_LENGTH tokens"
echo "Max response: $MAX_RESPONSE_LENGTH tokens"
echo "Temperature: $TEMPERATURE"
echo "Beta (KL): $BETA (same as ep5 for A/B comparison)"
echo "Epochs: $NUM_EPOCHS (single epoch, no data recycling)"
echo "Reward: $REWARD_METHOD (exact + partial_v2 + 0.1*xmlcount, power=3.0) [bridges fix: normalize extracted answer]"
echo "Policy loss: GSPO (sequence-level importance ratios)"
echo "GSPO clip: clip_ratio_low=5e-3, clip_ratio_high=8e-3, ppo_epochs=4"
echo "GPUs: $N_GPUS (4 independent vLLM engines)"
echo "Data: ~9K v2 SFT ep3 mix (same datasets as ep5)"
echo "Evals: ${#ALL_EVAL_DATASETS[@]} datasets every $EVAL_EVERY steps"
echo "Checkpoints: $CHECKPOINT_DIR"
echo "WandB: $PROJECT_NAME / $EXPERIMENT_NAME"
echo "=========================================="

# ============================================
# Run VERL GSPO Training
# ============================================
TENSORBOARD_DIR=$TENSORBOARD_DIR python3 ${PROJECT_DIR}/src/verl_helpers/train_main.py \
    algorithm.adv_estimator=grpo \
    data.train_files="['$TRAIN_DATA']" \
    data.val_files="$TEST_FILES" \
    data.train_batch_size=$TRAIN_BATCH_SIZE \
    data.max_prompt_length=$MAX_PROMPT_LENGTH \
    data.max_response_length=$MAX_RESPONSE_LENGTH \
    data.filter_overlong_prompts=True \
    data.truncation='error' \
    actor_rollout_ref.model.path="$MERGED_CHECKPOINT" \
    actor_rollout_ref.model.lora_rank=$LORA_RANK \
    actor_rollout_ref.model.lora_alpha=$LORA_ALPHA \
    actor_rollout_ref.model.target_modules="$LORA_TARGET_MODULES" \
    actor_rollout_ref.rollout.load_format=safetensors \
    actor_rollout_ref.actor.optim.lr=$LEARNING_RATE \
    actor_rollout_ref.actor.optim.lr_warmup_steps_ratio=0.03 \
    actor_rollout_ref.actor.optim.lr_scheduler_type=cosine \
    actor_rollout_ref.actor.optim.min_lr_ratio=0.3 \
    actor_rollout_ref.model.use_remove_padding=True \
    actor_rollout_ref.actor.ppo_mini_batch_size=$PPO_MINI_BATCH_SIZE \
    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=$MICRO_BATCH_SIZE \
    actor_rollout_ref.actor.use_dynamic_bsz=$USE_DYNAMIC_BSZ \
    actor_rollout_ref.actor.ppo_max_token_len_per_gpu=$PPO_MAX_TOKEN_LEN_PER_GPU \
    actor_rollout_ref.actor.policy_loss.loss_mode=gspo \
    actor_rollout_ref.actor.ppo_epochs=4 \
    actor_rollout_ref.actor.clip_ratio_low=5e-3 \
    actor_rollout_ref.actor.clip_ratio_high=8e-3 \
    actor_rollout_ref.actor.use_kl_loss=True \
    actor_rollout_ref.actor.kl_loss_coef=$BETA \
    actor_rollout_ref.actor.kl_loss_type=low_var_kl \
    actor_rollout_ref.actor.entropy_coeff=0 \
    actor_rollout_ref.actor.entropy_from_logits_with_chunking=True \
    actor_rollout_ref.actor.entropy_checkpointing=True \
    actor_rollout_ref.actor.strategy=fsdp \
    actor_rollout_ref.model.enable_gradient_checkpointing=True \
    actor_rollout_ref.actor.fsdp_config.model_dtype=bfloat16 \
    actor_rollout_ref.actor.fsdp_config.param_offload=False \
    actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \
    actor_rollout_ref.actor.fsdp_config.forward_prefetch=True \
    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=$ROLLOUT_LOG_PROB_MICRO_BATCH \
    actor_rollout_ref.rollout.tensor_model_parallel_size=1 \
    actor_rollout_ref.rollout.name=vllm \
    actor_rollout_ref.rollout.gpu_memory_utilization=0.70 \
    actor_rollout_ref.rollout.enforce_eager=False \
    actor_rollout_ref.rollout.free_cache_engine=True \
    actor_rollout_ref.rollout.enable_chunked_prefill=True \
    actor_rollout_ref.rollout.enable_prefix_caching=True \
    actor_rollout_ref.rollout.n=$NUM_GENERATIONS \
    actor_rollout_ref.rollout.temperature=$TEMPERATURE \
    actor_rollout_ref.rollout.top_p=$TOP_P \
    actor_rollout_ref.rollout.max_num_batched_tokens=31000 \
    actor_rollout_ref.rollout.disable_log_stats=False \
    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=$REF_LOG_PROB_MICRO_BATCH \
    actor_rollout_ref.ref.fsdp_config.param_offload=False \
    trainer.critic_warmup=0 \
    trainer.logger=['console','wandb'] \
    trainer.project_name="$PROJECT_NAME" \
    trainer.experiment_name="$EXPERIMENT_NAME" \
    trainer.default_local_dir="$CHECKPOINT_DIR" \
    +trainer.remove_previous_ckpt_in_save=False \
    trainer.n_gpus_per_node=$N_GPUS \
    trainer.nnodes=1 \
    trainer.save_freq=5 \
    trainer.test_freq=$EVAL_EVERY \
    trainer.total_epochs=$NUM_EPOCHS \
    custom_reward_function.path=$REWARD_FN_PATH \
    custom_reward_function.name=compute_score \
    +custom_reward_function.reward_kwargs.method=$REWARD_METHOD \
    +custom_reward_function.reward_kwargs.power_exponent=3.0 \
    +custom_reward_function.reward_kwargs.format_weight=0.1 \
    +custom_reward_function.reward_kwargs.changed_cell_weight=2.0 \
    +ray_kwargs.ray_init.include_dashboard=False

EXIT_CODE=$?

echo ""
echo "=========================================="
if [ $EXIT_CODE -eq 0 ]; then
    echo "VERL GSPO OLMo-3 v2 SFT ep3 training completed successfully"
    echo "Checkpoints: $CHECKPOINT_DIR"

    # Post-training upload of best checkpoints
    if [ -n "$VERL_HF_REPO_ID" ]; then
        echo ""
        echo "Uploading best checkpoints to HuggingFace (via WandB metrics)"
        python3 src/verl_helpers/upload_best_checkpoints.py \
            --output_dir "$CHECKPOINT_DIR" \
            --repo_id "$VERL_HF_REPO_ID" \
            --token "$VERL_HF_TOKEN" \
            --from_wandb \
            --wandb_entity ${WANDB_ENTITY:-anonymous} \
            --wandb_project "$PROJECT_NAME" \
            --wandb_run_name "$EXPERIMENT_NAME" \
            || echo "Warning: HF upload failed"
    fi
else
    echo "VERL GSPO OLMo-3 v2 SFT ep3 training failed with exit code $EXIT_CODE"
fi
echo "=========================================="

exit $EXIT_CODE
