#!/bin/bash
#SBATCH -p gpu22
#SBATCH --gres=gpu:1
#SBATCH -c 8
#SBATCH --mem=32G
#SBATCH -t 02:00:00
#SBATCH -J rwkv7_train
#SBATCH -o rwkv.log

set -euo pipefail

# MUST be defined before any use (because of -u)
BASE="/tmp/${USER}/rwkv7_train_${SLURM_JOB_ID}"

# -----------------------
# stability / cluster knobs
# -----------------------
export TORCH_COMPILE_DISABLE=1
export TORCHDYNAMO_DISABLE=1
export DISABLE_RWKV7_FUSED_ADDCMUL=1

# keep caches off $HOME (important on NFS)
export TRITON_CACHE_DIR="${BASE}/triton-cache"
export TORCHINDUCTOR_CACHE_DIR="${BASE}/torchinductor-cache"
export TMPDIR="${BASE}/pip-tmp"
export PIP_CACHE_DIR="${BASE}/pip-cache"

# extra “quiet + stable” knobs (optional but recommended)
export HF_HUB_DISABLE_TELEMETRY=1
export TOKENIZERS_PARALLELISM=false
export CUDA_DEVICE_MAX_CONNECTIONS=1
export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
export PATCH_TORCH_LERP_DTYPE=1

mkdir -p "$TRITON_CACHE_DIR" "$TORCHINDUCTOR_CACHE_DIR" "$TMPDIR" "$PIP_CACHE_DIR"

# avoid CPU oversubscription during preload / tokenization
export OMP_NUM_THREADS="${OMP_NUM_THREADS:-8}"
export MKL_NUM_THREADS="${MKL_NUM_THREADS:-8}"

PYTHON_BIN="${PYTHON_BIN:-python3}"
VENV_DIR="${BASE}/venv"

echo "[Host] $(hostname)"
echo "[BASE] $BASE"
echo "[VENV] $VENV_DIR"
echo "[TMPDIR] $TMPDIR"
echo "[PIP_CACHE_DIR] $PIP_CACHE_DIR"
echo "[TRITON_CACHE_DIR] $TRITON_CACHE_DIR"
echo "[TORCHINDUCTOR_CACHE_DIR] $TORCHINDUCTOR_CACHE_DIR"

# -----------------------
# venv
# -----------------------
rm -rf "$VENV_DIR"
$PYTHON_BIN -m venv "$VENV_DIR"
# shellcheck disable=SC1091
source "$VENV_DIR/bin/activate"

# -----------------------
# installs
# -----------------------
python -m pip install --no-cache-dir -U pip setuptools wheel packaging ninja
python -m pip install --no-cache-dir -U numpy einops tqdm

# torch cu118 (this will install triton==3.1.0 as required by torch 2.5.1+cu118)
python -m pip install --no-cache-dir \
  torch==2.5.1+cu118 torchvision==0.20.1+cu118 torchaudio==2.5.1+cu118 \
  --index-url https://download.pytorch.org/whl/cu118

# HF deps required by flash-linear-attention
python -m pip install --no-cache-dir -U transformers accelerate safetensors

# FLA (RWKV7Attention)
python -m pip install --no-cache-dir fla-core==0.4.1
python -m pip install --no-cache-dir --no-deps flash-linear-attention==0.4.1

echo
python - <<'PY'
import torch
import triton
print("torch:", torch.__version__, "cuda:", torch.version.cuda, "cuda_available:", torch.cuda.is_available())
print("triton:", triton.__version__)
if torch.cuda.is_available():
    print("gpu:", torch.cuda.get_device_name(0))
PY

# -----------------------
# run training
# -----------------------
DATA_DIR="${DATA_DIR:-data/mm_T100s}"
SAVE_PATH="${SAVE_PATH:-ckpt_rwkv7_tf_query_pm1_balfr.pt}"

$PYTHON_BIN -u train_rwkv7.py \
  --data_dir "$DATA_DIR" \
  --alphabet pm1 \
  --target_mode multiclass \
  --cuda --amp --amp_dtype bf16 \
  --d_model 256 \
  --rwkv7_head_dim 64 \
  --rwkv7_depth 2 \
  --rwkv7_mode chunk \
  --dropout 0.1 \
  --batch_size 256 \
  --lr 3e-4 \
  --weight_decay 0.001 \
  --grad_clip 1.0 \
  --aux_w 0.1 \
  --max_steps 30000 \
  --early_stop acc --patience 30 \
  --save_path "$SAVE_PATH"
