#!/bin/bash

# ==============================================================================
# Generate SFT Dataset from Solver's Natural Code Generation Failures
# ==============================================================================
#
# This script collects buggy code solutions from a Solver model attempting
# to solve coding problems. These "natural" failures can be used to warmstart
# a Generator model for Generator-Solver joint RL training.
#
# The data format is:
#   Input: _build_bug_generator_prompt(problem, correct_solution)
#   Output: Solver's buggy code (natural failure)
#
# Usage:
#   bash generate_sft_dataset.sh
#
#   # Or with custom settings:
#   SOLVER_BASE_URL=http://localhost:30000/v1 \
#   DATASET=bigcodebench \
#   NUM_SAMPLES=1000 \
#   VAL_RATIO=0.1 \
#   bash generate_sft_dataset.sh
#
# Outputs:
#   - solver_failures_sft_train.parquet (training data)
#   - solver_failures_sft_val.parquet (validation data)
# ==============================================================================

set -e
set -x

# Load conda environment
source /data/user/miniconda3/etc/profile.d/conda.sh
conda activate rllm2
cd /data/user/rllm

# Load any environment variables (e.g., for API keys)
if [ -f /data/user/LiveCodeBench/.env ]; then
    set -a
    . /data/user/LiveCodeBench/.env
    set +a
fi

# ==============================================================================
# Configuration
# ==============================================================================

# Run name and directories
RUN_NAME="${RUN_NAME:-generator-sft-data}"
RLLM_DIR="$(pwd -P)"
RUN_DIR="${RUN_DIR:-$RLLM_DIR/runs/$RUN_NAME}"
mkdir -p "$RUN_DIR"

# Solver model settings
SOLVER_MODEL="${SOLVER_MODEL:-Qwen/Qwen2.5-Coder-7B-Instruct}"
SOLVER_BASE_URL="${SOLVER_BASE_URL:-http://server:30001/v1}"
SOLVER_TEMPERATURE="${SOLVER_TEMPERATURE:-0.6}"
SOLVER_TOP_P="${SOLVER_TOP_P:-0.95}"

# Dataset settings
DATASET="${DATASET:-bigcodebench}"
SPLIT="${SPLIT:-train}"
NUM_SAMPLES="${NUM_SAMPLES:-0}"  # 0 = all samples
SAMPLES_PER_PROBLEM="${SAMPLES_PER_PROBLEM:-20}"
N_PARALLEL="${N_PARALLEL:-128}"
MAX_PROMPT_LENGTH="${MAX_PROMPT_LENGTH:-8192}"
MAX_RESPONSE_LENGTH="${MAX_RESPONSE_LENGTH:-8192}"

# Output settings
SFT_TRAIN_PATH="${SFT_TRAIN_PATH:-$RUN_DIR/solver_failures_sft_train.parquet}"
SFT_VAL_PATH="${SFT_VAL_PATH:-$RUN_DIR/solver_failures_sft_val.parquet}"
VAL_RATIO="${VAL_RATIO:-0.1}"

# HuggingFace Hub settings (optional)
PUSH_TO_HF="${PUSH_TO_HF:-true}"
HF_REPO_ID="${HF_REPO_ID:-anonymous/qwen7b-failures-bigcodebench}"
HF_PRIVATE="${HF_PRIVATE:-false}"

# ==============================================================================
# Environment Setup
# ==============================================================================

unset ROCR_VISIBLE_DEVICES ROCM_VISIBLE_DEVICES HIP_VISIBLE_DEVICES
export PYTORCH_CUDA_ALLOC_CONF="expandable_segments:True"
export CUDA_DEVICE_ORDER=PCI_BUS_ID
export TOKENIZERS_PARALLELISM=true

# HuggingFace temp directories (for uploads)

# Print info
echo "=============================================================="
echo "Generate SFT Dataset from Solver Failures"
echo "=============================================================="
echo "Solver Model: $SOLVER_MODEL"
echo "Solver URL: $SOLVER_BASE_URL"
echo "Dataset: $DATASET / $SPLIT"
echo "Num Samples: $NUM_SAMPLES (0 = all)"
echo "Samples per Problem: $SAMPLES_PER_PROBLEM"
echo "N Parallel: $N_PARALLEL"
echo "Val Ratio: $VAL_RATIO"
echo "Output Train: $SFT_TRAIN_PATH"
echo "Output Val: $SFT_VAL_PATH"
if [[ "$PUSH_TO_HF" == "true" ]]; then
    echo "Push to HF: $HF_REPO_ID (private=$HF_PRIVATE)"
fi
echo "=============================================================="

# Print GPU info (if available)
nvidia-smi -L 2>/dev/null || echo "No GPU detected (running on CPU)"

# ==============================================================================
# Generate SFT Data
# ==============================================================================

# Build HF push arguments
HF_ARGS=""
if [[ "$PUSH_TO_HF" == "true" && -n "$HF_REPO_ID" ]]; then
    HF_ARGS="--push_to_hf --hf_repo_id $HF_REPO_ID"
    if [[ "$HF_PRIVATE" == "false" ]]; then
        HF_ARGS="$HF_ARGS --hf_public"
    fi
fi

python -m examples.bugs.sft_generator_on_solver_failures \
    --mode generate \
    --solver_model "$SOLVER_MODEL" \
    --solver_base_url "$SOLVER_BASE_URL" \
    --solver_temperature "$SOLVER_TEMPERATURE" \
    --solver_top_p "$SOLVER_TOP_P" \
    --dataset "$DATASET" \
    --split "$SPLIT" \
    --num_samples "$NUM_SAMPLES" \
    --samples_per_problem "$SAMPLES_PER_PROBLEM" \
    --n_parallel "$N_PARALLEL" \
    --max_prompt_length "$MAX_PROMPT_LENGTH" \
    --max_response_length "$MAX_RESPONSE_LENGTH" \
    --val_ratio "$VAL_RATIO" \
    --output_train "$SFT_TRAIN_PATH" \
    --output_val "$SFT_VAL_PATH" \
    $HF_ARGS

echo "=============================================================="
echo "Done! SFT datasets saved to:"
echo "  Train: $SFT_TRAIN_PATH"
echo "  Val: $SFT_VAL_PATH"
echo "=============================================================="

