#!/usr/bin/env bash
set -euo pipefail

# KBQA-R1 SFT launcher from rejection-sampling dumps (no CLI args required)
# Defaults mirror the style of train_kbqa_sexpr_generation.sh

# -----------------------------
# Repo and dataset defaults
# -----------------------------
REPO_ROOT="/ossfs/workspace/kbqa-r1"
export WANDB_MODE=offline
export WANDB_API_KEY=${WANDB_API_KEY:-""}
PROJECT_NAME=${PROJECT_NAME:-"KBQA-R1-SFT"}

# Dataset type to shape paths (webqsp|grailqa|...)
DATASET_TYPE=${DATASET_TYPE:-webqsp}

# SFT train parquet generated by rejection sampling pipeline
# Default follows scripts/data_process/run_rejection_sampling.sh output convention
#   data/${DATASET_TYPE}_rl_dataset_sft/train_sft.parquet
TRAIN_PARQUET=${TRAIN_PARQUET:-"${REPO_ROOT}/data/${DATASET_TYPE}_rl_dataset_sft/train_sft_new.parquet"}

# Optional: validation parquet (not required by trainer; only echoed)
VAL_PARQUET=${VAL_PARQUET:-"${REPO_ROOT}/data/${DATASET_TYPE}_rl_dataset/test.parquet"}

# Base model (can override via env BASE_MODEL)
BASE_MODEL=${BASE_MODEL:-"/ossfs/workspace/aml2/aml_ri/fengyi/Llama-3.1-8B-Instruct"}

# Total training steps (set via env TOTAL_STEPS, default 1000)
TOTAL_STEPS=${TOTAL_STEPS:-5000}

# Max tokens per GPU used by dynamic batcher (must be >= longest sequence length)
# Defaults to 16k to match data.max_length override; adjust based on memory.
MAX_TOKEN_LEN_PER_GPU=${MAX_TOKEN_LEN_PER_GPU:-16384}

# -----------------------------
# GPU detection and config
# -----------------------------
detect_gpu_count() {
  if command -v nvidia-smi &>/dev/null; then
    GPU_COUNT=$(nvidia-smi --list-gpus | wc -l)
    echo "Detected ${GPU_COUNT} GPUs via nvidia-smi"
    return 0
  fi
  if command -v python &>/dev/null; then
    GPU_COUNT=$(python - <<'PY'
import os
try:
    import torch
    print(torch.cuda.device_count())
except Exception:
    print(0)
PY
    )
    echo "Detected ${GPU_COUNT} GPUs via PyTorch"
    return 0
  fi
  if command -v python3 &>/dev/null; then
    GPU_COUNT=$(python3 - <<'PY'
import os
try:
    import torch
    print(torch.cuda.device_count())
except Exception:
    print(0)
PY
    )
    echo "Detected ${GPU_COUNT} GPUs via PyTorch (python3)"
    return 0
  fi
  echo "Warning: Could not detect GPU count, defaulting to 8"
  GPU_COUNT=8
  return 1
}

detect_gpu_count

if [[ "${GPU_COUNT}" -eq 16 ]]; then
  export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15
  NGPUS=16
  echo "Using 16 GPUs"
elif [[ "${GPU_COUNT}" -eq 8 ]]; then
  export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
  NGPUS=8
  echo "Using 8 GPUs"
elif [[ "${GPU_COUNT}" -eq 0 ]]; then
  # CPU-only fallback: run single process without CUDA devices
  unset CUDA_VISIBLE_DEVICES
  NGPUS=1
  echo "No GPUs detected, falling back to single process"
else
  export CUDA_VISIBLE_DEVICES=$(seq -s, 0 $((GPU_COUNT-1)) | sed 's/,$//')
  NGPUS=${GPU_COUNT}
  echo "Using ${GPU_COUNT} GPUs"
fi

# -----------------------------
# Experiment naming & logs
# -----------------------------
TIMESTAMP=$(date +"%Y%m%d_%H%M%S")
# Default experiment name encodes dataset to separate checkpoints by dataset
EXP_NAME=${EXPERIMENT_NAME:-"${DATASET_TYPE}-sft-from-rs/${TIMESTAMP}"}
LOG_DIR=./logs
mkdir -p "${LOG_DIR}"

echo "========================================"
echo "KBQA-R1 SFT from Rejection Sampling"
echo "========================================"
echo "Train parquet : ${TRAIN_PARQUET}"
echo "Val parquet   : ${VAL_PARQUET}"
echo "Base model    : ${BASE_MODEL}"
echo "Project       : ${PROJECT_NAME}"
echo "Experiment    : ${EXP_NAME}"
echo "GPUs          : ${NGPUS} (CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES})"
echo "Total steps   : ${TOTAL_STEPS}"
echo "MaxTok/GPU    : ${MAX_TOKEN_LEN_PER_GPU}"
echo "========================================"

# -----------------------------
# torch.distributed launcher
# -----------------------------
TORCHRUN=${TORCHRUN:-torchrun}
MASTER_PORT=${MASTER_PORT:-$(python3 - <<'PY'
import random; print(random.randint(20000,29999))
PY
)}

# -----------------------------
# Launch SFT training (choose trainer via TRAINER env: engine|fsdp)
# Default to engine (verl.trainer.sft_trainer) which is sufficient for 8B on 8x80GB
# -----------------------------
TRAINER=${TRAINER:-engine}
echo "Trainer       : ${TRAINER}"

if [[ "${TRAINER}" == "engine" ]]; then
  # sft_trainer (engine config: sft_trainer_engine.yaml)
  ${TORCHRUN} --standalone --nproc_per_node=${NGPUS} --master_port=${MASTER_PORT} -m verl.trainer.sft_trainer \
    data.train_files="['${TRAIN_PARQUET}']" \
    data.pad_mode=no_padding \
    data.max_token_len_per_gpu=${MAX_TOKEN_LEN_PER_GPU} \
    checkpoint.save_contents=['hf_model'] \
    data.max_length=16384 \
    data.truncation=right \
    data.train_batch_size=$((NGPUS * 4)) \
    data.micro_batch_size_per_gpu=4 \
    model.path="${BASE_MODEL}" \
    model.enable_gradient_checkpointing=true \
    model.use_remove_padding=true \
    trainer.default_local_dir="checkpoints/${PROJECT_NAME}/${EXP_NAME}" \
    trainer.project_name="${PROJECT_NAME}" \
    trainer.experiment_name="${EXP_NAME}" \
    trainer.save_freq=after_each_epoch \
    trainer.total_training_steps=${TOTAL_STEPS} \
    trainer.logger=['console','tensorboard'] \
    optim.lr=1e-5 \
    optim.clip_grad=1.0 \
    2>&1 | tee "${LOG_DIR}/${EXP_NAME//\//_}.log"
else
  # fsdp_sft_trainer (legacy path; no data.pad_mode key)
  ${TORCHRUN} --standalone --nproc_per_node=${NGPUS} --master_port=${MASTER_PORT} -m verl.trainer.fsdp_sft_trainer \
    data.train_files="['${TRAIN_PARQUET}']" \
    data.multiturn.enable=true \
    data.multiturn.messages_key=messages \
    data.max_token_len_per_gpu=${MAX_TOKEN_LEN_PER_GPU} \
    trainer.checkpoint.save_contents=['hf_model'] \
    data.max_length=16384 \
    data.truncation=right \
    data.train_batch_size=$((NGPUS * 4)) \
    data.micro_batch_size_per_gpu=4 \
    model.partial_pretrain="${BASE_MODEL}" \
    model.enable_gradient_checkpointing=true \
    use_remove_padding=true \
    trainer.default_local_dir="checkpoints/${PROJECT_NAME}/${EXP_NAME}" \
    trainer.project_name="${PROJECT_NAME}" \
    trainer.experiment_name="${EXP_NAME}" \
    trainer.total_training_steps=${TOTAL_STEPS} \
    trainer.save_freq=after_each_epoch \
    trainer.logger=['console','tensorboard'] \
    optim.lr=1e-5 \
    optim.warmup_steps_ratio=0.1 \
    optim.clip_grad=1.0 \
    2>&1 | tee "${LOG_DIR}/${EXP_NAME//\//_}.log"
fi

echo ""
echo "========================================"
echo "SFT Training Launched"
echo "========================================"
echo "Experiment: ${EXP_NAME}"
echo "Logs: ${LOG_DIR}/${EXP_NAME//\//_}.log"
echo "========================================"

cd /ossfs/workspace
python train.py