#!/bin/bash
#SBATCH -p gpu22
#SBATCH --gres=gpu:1
#SBATCH -c 8
#SBATCH --mem=32G
#SBATCH -t 02:00:00
#SBATCH -J rwkv7_train
#SBATCH -o rwkv.log
 
set -euo pipefail

PYTHON_BIN="${PYTHON_BIN:-python3}"

BASE="/tmp/${USER}/rwkv7_train_${SLURM_JOB_ID}"
VENV_DIR="${BASE}/venv"
export TMPDIR="${BASE}/pip-tmp"
export PIP_CACHE_DIR="${BASE}/pip-cache"
mkdir -p "$TMPDIR" "$PIP_CACHE_DIR"

echo "[Host] $(hostname)"
echo "[BASE] $BASE"
echo "[VENV] $VENV_DIR"

rm -rf "$VENV_DIR"
$PYTHON_BIN -m venv "$VENV_DIR"
# shellcheck disable=SC1091
source "$VENV_DIR/bin/activate"

# -----------------------
# Pip tooling + basics
# -----------------------
python -m pip install --no-cache-dir -U pip setuptools wheel packaging ninja
python -m pip install --no-cache-dir -U numpy einops tqdm

# -----------------------
# Torch (CUDA 11.8 wheels)
# -----------------------
python -m pip install --no-cache-dir \
  torch==2.5.1+cu118 torchvision==0.20.1+cu118 torchaudio==2.5.1+cu118 \
  --index-url https://download.pytorch.org/whl/cu118

# (Optional) HF utilities
python -m pip install --no-cache-dir -U transformers accelerate safetensors

# -----------------------
# FLA backend (RWKV7Attention)
# -----------------------
# NOTE: keeping your pins; if you hit Triton-related issues, see the env vars below.
python -m pip install --no-cache-dir --no-deps flash-linear-attention==0.4.1
python -m pip install --no-cache-dir fla-core==0.4.1

echo
python - <<'PY'
import torch
print("torch:", torch.__version__, "cuda:", torch.version.cuda, "cuda_available:", torch.cuda.is_available())
if torch.cuda.is_available():
    print("gpu:", torch.cuda.get_device_name(0))
PY

# -----------------------
# RWKV7 stability knobs
# -----------------------
# 1) Disable the flaky Triton fused_addcmul path used by some FLA builds
export DISABLE_RWKV7_FUSED_ADDCMUL="${DISABLE_RWKV7_FUSED_ADDCMUL:-1}"

# 2) Make Triton/Inductor caches local to /tmp (avoid home/NFS issues)
export TRITON_CACHE_DIR="${TRITON_CACHE_DIR:-$BASE/triton-cache}"
export TORCHINDUCTOR_CACHE_DIR="${TORCHINDUCTOR_CACHE_DIR:-$BASE/torchinductor-cache}"
mkdir -p "$TRITON_CACHE_DIR" "$TORCHINDUCTOR_CACHE_DIR"

# (Optional) reduce thread oversubscription on CPU-heavy preload
export OMP_NUM_THREADS="${OMP_NUM_THREADS:-8}"
export MKL_NUM_THREADS="${MKL_NUM_THREADS:-8}"

# -----------------------
# Run training
# -----------------------
DATA_DIR="${DATA_DIR:-data/mm_stepwise_m29_qk0}"
SAVE_PATH="${SAVE_PATH:-ckpt_rwkv7_tf_final_n.pt}"


python3 -u train_rwkv7.py \
  --data_dir data/mm_stepwise_m29_qk0 \
  --m_max 29 --cuda --amp --amp_dtype bf16 \
  --mat_mode signed --state_method embed_sum --t_feature none \
  --d_model 256 --rwkv7_depth 2 --rwkv7_head_dim 64 --rwkv7_mode chunk --dropout 0.2 \
  --batch_size 256 --lr 3e-4 --weight_decay 1e-3 --grad_clip 1.0 \
  --epochs 200 --patience 20 --seed 0 \
  --save_path ckpt_rwkv7_stepwise_m29_qk0.pt
