#!/bin/bash
# Multi-Puzzle DSR SFT Training for OLMo-3-7B-Instruct-SFT (v2)
#
# v2: Fixed {"response": "..."} JSON wrapper in answer tags (rsft_trainer.py fix).
# Training data now has clean grids/arrays in <answer> tags instead of JSON wrappers.
# Expected improvement: xmlcount baseline ~90%+ (was 70-85% in v1).
#
# OLMo-3-7B-Instruct-SFT published baselines: AIME24 6.7%, MATH-500 65.1%
# Uses same ChatML template (<|im_start|>/<|im_end|>) as Qwen — no adaptation needed.
# Falls through to default Qwen2.5-style format in rsft_trainer.py
# (system prompt + <reasoning>/<answer> tags).
#
# Datasets (same as Qwen3 DSR v3, ~3,374 examples):
# - anon-neurips26/bridges_5x5de_dsr_intformat_json (887)
# - anon-neurips26/galaxies_3x3de_dsr_intformat_json (191)
# - anon-neurips26/galaxies_4x4de_dsr_intformat_json (727)
# - anon-neurips26/pattern_3x3_dsr (83)
# - anon-neurips26/pattern_4x4_dsr (587)
# - anon-neurips26/undead_3x3de_dsr (899)
#
# Key differences from Qwen3 v3:
# - LR: 5e-5 (conservative, not 2e-4 like Qwen3)
# - No ignore_input_ids_mismatch (Qwen3-specific)
# - No puzzle_mode flag (not relevant for OLMo)
#
# Usage:
#   ./train/verl_sft/multi_puzzle_dsr_olmo3.sh [N_GPUS]

set -e
set -x

# ============================================
# Python environment detection
# ============================================
if python3 -c "import verl; import vllm" 2>/dev/null; then
    echo "VERL+vLLM found in system Python (Docker image mode)"
    python3 -c "import verl; print(f'VERL version: {verl.__version__}')" 2>/dev/null || true
    python3 -c "import vllm; print(f'vLLM version: {vllm.__version__}')" 2>/dev/null || true
    python3 -c "import torch; print(f'PyTorch version: {torch.__version__}')" 2>/dev/null || true
elif [ -z "$VIRTUAL_ENV" ]; then
    if [ -f $HOME/verl-vllm012/bin/activate ]; then
        source $HOME/verl-vllm012/bin/activate
        echo "Activated verl-vllm012 venv"
    elif [ -f $HOME/verl-latest/bin/activate ]; then
        source $HOME/verl-latest/bin/activate
        echo "Activated verl-latest venv"
    elif [ -f ~/src/kernels/verl-latest/.venv/bin/activate ]; then
        source ~/src/kernels/verl-latest/.venv/bin/activate
        echo "Activated verl-latest venv (VERL 0.7.0 from PyPI)"
    else
        echo "ERROR: No Python environment with VERL found"
        exit 1
    fi
else
    echo "Using existing venv: $VIRTUAL_ENV"
fi

# WandB configuration
export WANDB_API_KEY=${WANDB_API_KEY:-"${WANDB_API_KEY}"}
export WANDB_MODE=${WANDB_MODE:-"online"}
export WANDB_CONSOLE=off

# Ensure Python output is flushed immediately (critical for managed-cluster log visibility)
export PYTHONUNBUFFERED=1

# HuggingFace token for private datasets
export HF_TOKEN="${HF_TOKEN}"
export VERL_HF_TOKEN="$HF_TOKEN"
export HF_HOME="${HF_HOME:-/tmp/hf_cache}"
export HF_HUB_CACHE="${HF_HUB_CACHE:-/tmp/hf_cache}"

# Use all 4 GPUs
export CUDA_VISIBLE_DEVICES=0,1,2,3

# Note: expandable_segments is needed for training (OOM without it on 7B + 32k seq).
# vLLM V1's CuMemAllocator may conflict on B200, but works on H100.
# On managed-cluster/B200, unset before eval phase if needed.
export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True

# Enable vLLM V1 engine (needed for offline eval phase)
export VLLM_USE_V1=1

# B200/SM100a required env vars (needed for offline eval phase with vLLM)
export VLLM_USE_TRTLLM_ATTENTION=0
export VLLM_ATTENTION_BACKEND=FLASH_ATTN

# Disable torch dynamo/compile to avoid triton autotuning CUDA errors on B200
export TORCHDYNAMO_DISABLE=1

# Use fresh triton cache
export TRITON_CACHE_DIR="/tmp/triton_cache_sft_$$"

# Fix vLLM 0.12.0 LoRA PDL crash on B200 (SM100a)
UTILS_FILE=$(python3 -c "import vllm; import os; print(os.path.join(os.path.dirname(vllm.__file__), 'lora/ops/triton_ops/utils.py'))" 2>/dev/null || true)
if [ -n "$UTILS_FILE" ] && [ -f "$UTILS_FILE" ]; then
    if grep -q "has_device_capability(90)" "$UTILS_FILE" && ! grep -q "has_device_capability(100)" "$UTILS_FILE"; then
        echo "Patching vLLM LoRA PDL: disabling for SM100a (B200)..."
        sed -i 's/return current_platform.is_cuda() and current_platform.has_device_capability(90)/return current_platform.is_cuda() and current_platform.has_device_capability(90) and not current_platform.has_device_capability(100)/' "$UTILS_FILE"
        echo "PDL disabled for SM100a. LoRA Triton kernels will compile without gdc_wait()."
    else
        echo "vLLM LoRA PDL patch already applied or not needed."
    fi
fi

# Configuration
N_GPUS=${1:-4}
NUM_EPOCHS=${NUM_EPOCHS:-8}
STEPS_PER_EPOCH=68  # ~3374 examples / batch_size 32 / 4 GPUs, rounded up

# Learning rate: 2e-4 (matching Qwen2.5 and Qwen3 DSR SFT runs)
LEARNING_RATE=2e-4

# Training hyperparameters
TRAIN_BATCH_SIZE=32
MICRO_BATCH_SIZE=1
EVAL_BATCH_SIZE=16
MAX_SEQ_LENGTH=32000  # After filtering to <=28k tokens
LORA_RANK=64
LORA_ALPHA=64

# Base model - OLMo-3-7B-Instruct-SFT (SFT-only, no RL/DPO/RLVR)
BASE_MODEL="allenai/OLMo-3-7B-Instruct-SFT"

# Experiment name
EXP_NAME="olmo3_7b_multi_puzzle_dsr_v2"

# Evaluation frequency: -1 means "after_each_epoch"
EVAL_FREQ=-1

# HuggingFace repo for uploading best checkpoint
HF_REPO="anon-neurips26/${EXP_NAME}"

# Offline eval settings
EVAL_MAX_NEW_TOKENS=32000
PARALLEL_EVAL=true

# Per-puzzle prompt templates
PROMPT_GALAXIES="prompts/galaxies_intformat.txt"
PROMPT_PATTERN="prompts/pattern_formatted.txt"
PROMPT_UNDEAD="prompts/undead_formatted.txt"
PROMPT_BRIDGES="prompts/bridges_intformat.txt"

# Weighted sampling: equal weight across puzzle types
SAMPLING_WEIGHTS="galaxies:0.25,pattern:0.25,undead:0.25,bridges:0.25"

# Training datasets - DeepSeek R1 traces (same as Qwen3 v3)
TRAIN_DATASETS=(
    "anon-neurips26/galaxies_3x3de_dsr_intformat_json"
    "anon-neurips26/galaxies_4x4de_dsr_intformat_json"
    "anon-neurips26/pattern_3x3_dsr"
    "anon-neurips26/pattern_4x4_dsr"
    "anon-neurips26/undead_3x3de_dsr"
    "anon-neurips26/bridges_5x5de_dsr_intformat_json"
)

# Evaluation datasets (same as Qwen3 v3)
EVAL_DATASETS=(
    "anon-neurips26/galaxies_3x3de_rsft_1k_intformat_json"
    "anon-neurips26/galaxies_4x4de_rsft_1k_intformat_json"
    "anon-neurips26/pattern_3x3_rsft_1k"
    "anon-neurips26/pattern_4x4_rsft_1k"
    "anon-neurips26/undead_3x3de_rsft_1k"
    "anon-neurips26/bridges_5x5de_test200_intformat_json"
)

# Detect project directory (works on both managed-cluster /opt/ml/code and local)
if [ -d /opt/ml/code ]; then
    PROJECT_DIR="/opt/ml/code"
else
    PROJECT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/../.." && pwd)"
fi

# Detect output directory
if [ -d "/opt/ml/code" ]; then
    OUTPUT_DIR="/opt/ml/checkpoints"
    export SKIP_OFFLINE_EVAL=1
else
    OUTPUT_DIR="${PROJECT_DIR}/checkpoints/${EXP_NAME}"
fi
LOG_FILE="${OUTPUT_DIR}/training_log.txt"

echo "=========================================="
echo "OLMo-3-7B Multi-Puzzle DSR Training"
echo "  (Qwen2.5-style: <reasoning>/<answer>)"
echo "=========================================="
echo "Base Model: ${BASE_MODEL}"
echo "Training Datasets:"
for train_ds in "${TRAIN_DATASETS[@]}"; do
    echo "  - ${train_ds}"
done
echo "Evaluation Datasets:"
for eval_ds in "${EVAL_DATASETS[@]}"; do
    echo "  - ${eval_ds}"
done
echo "Sampling Weights: ${SAMPLING_WEIGHTS}"
echo "Max Seq Length: ${MAX_SEQ_LENGTH}"
echo "Epochs: ${NUM_EPOCHS}"
echo "GPUs: ${N_GPUS}"
echo "Train Batch Size: ${TRAIN_BATCH_SIZE}"
echo "Micro Batch Size: ${MICRO_BATCH_SIZE}"
echo "Gradient Accumulation Steps: $((TRAIN_BATCH_SIZE / (MICRO_BATCH_SIZE * N_GPUS)))"
echo "LoRA Rank: ${LORA_RANK}"
echo "LoRA Alpha: ${LORA_ALPHA}"
echo "Learning Rate: ${LEARNING_RATE}"
echo "Mode: bf16 LoRA (no 4-bit quantization)"
echo "Format: system prompt + <reasoning>/<answer> tags"
echo "=========================================="

mkdir -p "${OUTPUT_DIR}"

# Format datasets as Hydra list
TRAIN_DATASETS_STR=$(IFS=,; echo "${TRAIN_DATASETS[*]}")
TRAIN_DATASETS_HYDRA="[${TRAIN_DATASETS_STR}]"
EVAL_DATASETS_STR=$(IFS=,; echo "${EVAL_DATASETS[*]}")
EVAL_DATASETS_HYDRA="[${EVAL_DATASETS_STR}]"

# Optional: limit total training steps (for sanity checks)
EXTRA_TRAINER_ARGS=""
if [ -n "$SANITY_TOTAL_TRAINING_STEPS" ]; then
    EXTRA_TRAINER_ARGS="trainer.total_training_steps=${SANITY_TOTAL_TRAINING_STEPS}"
    echo "SANITY CHECK MODE: limiting to ${SANITY_TOTAL_TRAINING_STEPS} steps"
fi

cd "$PROJECT_DIR"
export PYTHONPATH="$PROJECT_DIR:$PYTHONPATH"

torchrun --standalone --nnodes=1 --nproc_per_node=$N_GPUS \
    -m train.verl_sft.rsft_trainer \
    +hf_datasets="${TRAIN_DATASETS_HYDRA}" \
    +hf_eval_datasets="${EVAL_DATASETS_HYDRA}" \
    data.multiturn.enable=true \
    +data.ignore_input_ids_mismatch=true \
    data.max_length=${MAX_SEQ_LENGTH} \
    data.micro_batch_size_per_gpu=${MICRO_BATCH_SIZE} \
    data.train_batch_size=${TRAIN_BATCH_SIZE} \
    model.partial_pretrain="${BASE_MODEL}" \
    model.fsdp_config.model_dtype=bf16 \
    model.lora_rank=${LORA_RANK} \
    model.lora_alpha=${LORA_ALPHA} \
    "model.target_modules=[q_proj,k_proj,v_proj,o_proj,gate_proj,up_proj,down_proj]" \
    model.trust_remote_code=true \
    model.enable_gradient_checkpointing=true \
    ++model.strategy=fsdp \
    +model.attn_implementation=flash_attention_2 \
    optim.lr=${LEARNING_RATE} \
    trainer.total_epochs=${NUM_EPOCHS} \
    trainer.save_freq=${EVAL_FREQ} \
    trainer.test_freq=${EVAL_FREQ} \
    trainer.default_local_dir="${OUTPUT_DIR}" \
    trainer.project_name="sft-verl" \
    trainer.experiment_name="${EXP_NAME}" \
    "trainer.logger=[console,wandb]" \
    +max_token_length=28000 \
    "+sampling_weights='${SAMPLING_WEIGHTS}'" \
    +prompt_template_galaxies="${PROMPT_GALAXIES}" \
    +prompt_template_pattern="${PROMPT_PATTERN}" \
    +prompt_template_undead="${PROMPT_UNDEAD}" \
    +prompt_template_bridges="${PROMPT_BRIDGES}" \
    +token_averaging=true \
    +training_mode=rsft \
    +disable_accuracy_eval=true \
    +val_loss_topk=0 \
    +save_epoch_checkpoints=true \
    +max_eval_samples=200 \
    +eval_batch_size=${EVAL_BATCH_SIZE} \
    +eval_max_new_tokens=${EVAL_MAX_NEW_TOKENS} \
    ${EXTRA_TRAINER_ARGS} \
    2>&1 | tee "${LOG_FILE}"

echo "Done OLMo-3-7B Multi-Puzzle DSR Training"
echo ""

if [ "${SKIP_OFFLINE_EVAL}" = "1" ]; then
    echo "Skipping offline evaluation (SKIP_OFFLINE_EVAL=1)"
    exit 0
fi

# ============================================
# Offline Evaluation with vLLM
# ============================================
echo "=========================================="
echo "Starting Offline Evaluation with vLLM"
echo "=========================================="

WANDB_PROJECT="sft-verl"
WANDB_RUN_NAME="${EXP_NAME}"

# Only eval epochs 5-8 (early epochs always worse, saves GPU hours)
EVAL_EPOCHS_START=5
EVAL_STEP_START=$(( STEPS_PER_EPOCH * EVAL_EPOCHS_START ))
EVAL_PATTERN="global_step_{$(seq -s, $EVAL_STEP_START $STEPS_PER_EPOCH $(( STEPS_PER_EPOCH * NUM_EPOCHS )))}"
echo "Eval checkpoint pattern: ${EVAL_PATTERN}"

EVAL_CMD="python tools/eval_lora_checkpoints.py \
    --base_model ${BASE_MODEL} \
    --checkpoint_dir ${OUTPUT_DIR} \
    --checkpoint_pattern '${EVAL_PATTERN}' \
    --eval_datasets ${EVAL_DATASETS[*]} \
    --dump_generations \
    --max_new_tokens ${EVAL_MAX_NEW_TOKENS} \
    --max_model_len ${EVAL_MAX_NEW_TOKENS} \
    --temperature 0.0 \
    --prompt_style rsft \
    --puzzle_templates_json '{\"galaxies\": \"${PROMPT_GALAXIES}\", \"pattern\": \"${PROMPT_PATTERN}\", \"undead\": \"${PROMPT_UNDEAD}\", \"bridges\": \"${PROMPT_BRIDGES}\"}'"

if [ "$PARALLEL_EVAL" = true ]; then
    EVAL_CMD="${EVAL_CMD} --parallel --num_gpus ${N_GPUS}"
    echo "Mode: Parallel evaluation on ${N_GPUS} GPUs"
else
    echo "Mode: Sequential evaluation"
fi

if [ -n "$HF_REPO" ]; then
    EVAL_CMD="${EVAL_CMD} --upload_best_to_hf ${HF_REPO}"
    EVAL_CMD="${EVAL_CMD} --upload_results_to_hf ${HF_REPO}"
    echo "Will upload best checkpoint to: ${HF_REPO}"
fi

EVAL_CMD="${EVAL_CMD} --wandb_project ${WANDB_PROJECT} --wandb_run_name ${WANDB_RUN_NAME}"
echo "WandB logging: ${WANDB_PROJECT}/${WANDB_RUN_NAME}"

echo ""
echo "Running: ${EVAL_CMD}"
echo ""

eval ${EVAL_CMD}

echo ""
echo "=========================================="
echo "All Done!"
echo "=========================================="
echo "Training output: ${OUTPUT_DIR}"
echo "Eval results: ${OUTPUT_DIR}/eval_results.json"
if [ -n "$HF_REPO" ]; then
    echo "Best checkpoint uploaded to: https://huggingface.co/${HF_REPO}"
fi
