#!/bin/bash

# ==============================================================================
# Train Generator with SFT on Solver's Failed Code Generation (Buggy Code)
# ==============================================================================
#
# This script trains a bug Generator model using SFT on a parquet file
# containing Solver failures. The dataset should contain:
#   - messages: List of {role, content} dicts (user prompt + assistant response)
#   - task_id, passed_tests, total_tests (optional metadata)
#
# The Generator learns to produce buggy code given (problem, correct_solution),
# serving as a warmstart for Generator-Solver joint RL training.
#
# Usage:
#   SFT_DATA_PATH=/path/to/data.parquet bash sft_generator_on_solver_failures.sh
#
#   # Or use default path:
#   bash sft_generator_on_solver_failures.sh
# ==============================================================================

source /data/user/miniconda3/etc/profile.d/conda.sh
conda activate rllm2
cd /data/user/rllm

# Load env vars (HF token, etc.)
set -a
. /data/user/rllm/.env
set +a

set -x

# ==============================================================================
# Configuration
# ==============================================================================

# Run name and directories
RUN_NAME="sft-qwen7b-bcb-generator-solver-failures-1/8"
RLLM_DIR="$(pwd -P)"
RUN_DIR="${RUN_DIR:-$RLLM_DIR/runs/$RUN_NAME}"
OUTPUT_DIR="${OUTPUT_DIR:-/data/user/rllm/checkpoints/rllm-agent/$RUN_NAME}"
mkdir -p "$RUN_DIR" "$OUTPUT_DIR"

# Path to SFT data parquet files
SFT_DATA_DIR="${SFT_DATA_DIR:-$RLLM_DIR/generator-sft-data}"
SFT_TRAIN_PATH="${SFT_TRAIN_PATH:-$SFT_DATA_DIR/qwen7b_bugbench_human_train.parquet}"
SFT_VAL_PATH="${SFT_VAL_PATH:-$SFT_DATA_DIR/qwen7b_bugbench_human_val.parquet}"

# Training settings
TRAIN_MODEL="Qwen/Qwen2.5-Coder-7B-Instruct"
EPOCHS=3
BATCH_SIZE=64
MICRO_BATCH_SIZE_PER_GPU=1
MAX_LENGTH=8192
LR=1e-4
LORA_RANK=32
LORA_ALPHA=16
NUM_GPUS=4

# Validation and checkpoint frequency (in steps, -1 = disabled)
SAVE_FREQ="${SAVE_FREQ:-100}"
TEST_FREQ="${TEST_FREQ:-50}"

# Wandb settings
export WANDB_PROJECT="rllm-agent"
export WANDB_NAME="$RUN_NAME"
export WANDB_DIR="$RUN_DIR"

# ==============================================================================
# Environment Setup
# ==============================================================================

# GPU/torch settings
unset ROCR_VISIBLE_DEVICES ROCM_VISIBLE_DEVICES HIP_VISIBLE_DEVICES
export PYTORCH_CUDA_ALLOC_CONF="expandable_segments:True"
export CUDA_DEVICE_ORDER=PCI_BUS_ID
export TOKENIZERS_PARALLELISM=true

# HuggingFace settings - use local cache to avoid download issues
export HF_HOME="${HF_HOME:-/data/user/hf_home}"
export HF_HUB_CACHE="${HF_HUB_CACHE:-$HF_HOME/hub}"
mkdir -p "$HF_HOME" "$HF_HUB_CACHE"
echo "HF_HOME: $HF_HOME"
echo "HF_HUB_CACHE: $HF_HUB_CACHE"

# Pre-download model if not cached (run this once before training)
# huggingface-cli download $TRAIN_MODEL --cache-dir $HF_HUB_CACHE
if [ ! -d "$HF_HUB_CACHE/models--${TRAIN_MODEL//\//"--"}" ]; then
    echo "Model not found in cache. Downloading $TRAIN_MODEL..."
    huggingface-cli download "$TRAIN_MODEL" --cache-dir "$HF_HUB_CACHE" || {
        echo "WARNING: Failed to pre-download model. Training may fail if download is slow."
    }
fi

# Print GPU info
echo "GPU Info:"
nvidia-smi -L

# ==============================================================================
# Validate Data Path
# ==============================================================================

if [ ! -f "$SFT_TRAIN_PATH" ]; then
    echo "ERROR: Train data file not found at: $SFT_TRAIN_PATH"
    echo "Please set SFT_TRAIN_PATH to a valid parquet file."
    echo "You can generate one using: scripts/train/bugs/generate_sft_dataset.sh"
    exit 1
fi

if [ ! -f "$SFT_VAL_PATH" ]; then
    echo "ERROR: Val data file not found at: $SFT_VAL_PATH"
    echo "Please set SFT_VAL_PATH to a valid parquet file."
    echo "You can generate one using: scripts/train/bugs/generate_sft_dataset.sh"
    exit 1
fi

echo "Using SFT train data: $SFT_TRAIN_PATH"
echo "Using SFT val data: $SFT_VAL_PATH"

# ==============================================================================
# Train Generator with SFT
# ==============================================================================

echo "=============================================================="
echo "Training Generator with SFT"
echo "=============================================================="
echo "Model: $TRAIN_MODEL"
echo "Train Data: $SFT_TRAIN_PATH"
echo "Val Data: $SFT_VAL_PATH"
echo "Output: $OUTPUT_DIR"
echo "Epochs: $EPOCHS"
echo "Batch Size: $BATCH_SIZE"
echo "Learning Rate: $LR"
echo "LoRA Rank: $LORA_RANK"
echo "Save Freq: $SAVE_FREQ steps"
echo "Test Freq: $TEST_FREQ steps"
echo "=============================================================="

# Run SFT training using verl.trainer.fsdp_sft_trainer
PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True \
torchrun --standalone --nnodes=1 --nproc_per_node="$NUM_GPUS" \
    -m verl.trainer.fsdp_sft_trainer \
    model.partial_pretrain="$TRAIN_MODEL" \
    model.trust_remote_code=true \
    model.enable_gradient_checkpointing=true \
    model.lora_rank="$LORA_RANK" \
    model.lora_alpha="$LORA_ALPHA" \
    model.target_modules=all-linear \
    model.strategy=fsdp \
    trainer.total_epochs="$EPOCHS" \
    data.train_batch_size="$BATCH_SIZE" \
    data.micro_batch_size_per_gpu="$MICRO_BATCH_SIZE_PER_GPU" \
    data.max_length="$MAX_LENGTH" \
    data.truncation=right \
    data.multiturn.enable=true \
    data.multiturn.messages_key=messages \
    data.train_files="$SFT_TRAIN_PATH" \
    data.val_files="$SFT_VAL_PATH" \
    trainer.default_local_dir="$OUTPUT_DIR" \
    trainer.logger='["console", "wandb"]' \
    trainer.project_name="$WANDB_PROJECT" \
    trainer.experiment_name="$WANDB_NAME" \
    trainer.save_freq="$SAVE_FREQ" \
    trainer.test_freq="$TEST_FREQ" \
    optim.lr="$LR"

echo "=============================================================="
echo "Training complete!"
echo "Checkpoints saved to: $OUTPUT_DIR"
echo "=============================================================="
