#!/bin/bash
#SBATCH -p gpu22
#SBATCH -c 4
#SBATCH --gres=gpu:1
#SBATCH --mem=16G
#SBATCH -t 02:00:00
#SBATCH -J mamba
#SBATCH -o mamba.log

set -euo pipefail

# -----------------------
# Tunables (env overrides)
# -----------------------
CU_INDEX="${CU_INDEX:-cu118}"

TMP_BASE="${TMP_BASE:-/tmp/$USER/mamba}"
VENV_DIR="${VENV_DIR:-$TMP_BASE/venv}"
PIP_CACHE_DIR_LOCAL="${PIP_CACHE_DIR:-$TMP_BASE/pip-cache}"
TMPDIR_LOCAL="${TMPDIR:-$TMP_BASE/pip-tmp}"

WORKDIR="${WORKDIR:-$PWD}"
DATA_DIR="${DATA_DIR:-$WORKDIR/data/mm_T100s}"
ALPHABET="${ALPHABET:-pm1}"

PYFILE="${PYFILE:-$WORKDIR/train_mamba.py}"

MAX_JOBS="${MAX_JOBS:-4}"
export MAX_JOBS

# CUDA robustness
export CUDA_MODULE_LOADING=LAZY

# -----------------------
# setup
# -----------------------
echo "[1/10] Prepare /tmp dirs: $TMP_BASE"
mkdir -p "$TMP_BASE" "$PIP_CACHE_DIR_LOCAL" "$TMPDIR_LOCAL"

echo "[2/10] Create/reuse venv: $VENV_DIR"
if [[ ! -x "$VENV_DIR/bin/python" ]]; then
  python3 -m venv "$VENV_DIR"
fi
# shellcheck disable=SC1091
source "$VENV_DIR/bin/activate"

echo "[3/10] Force pip cache/tmp to /tmp"
export PIP_CACHE_DIR="$PIP_CACHE_DIR_LOCAL"
export TMPDIR="$TMPDIR_LOCAL"

echo "[4/10] Upgrade pip tooling + build tools"
python -m pip install --no-cache-dir -U pip setuptools wheel packaging ninja cmake

echo "[5/10] Install torch (cu118 pinned) if missing"
if python - <<'PY'
import importlib, sys
try:
    importlib.import_module("torch")
    import torch
    print("torch already installed:", torch.__version__)
    sys.exit(0)
except Exception as e:
    print("torch not installed:", e)
    sys.exit(1)
PY
then
  echo "[5/10] torch present"
else
  echo "[5/10] Installing torch from cu118 wheels..."
  python -m pip install --no-cache-dir \
    --index-url "https://download.pytorch.org/whl/${CU_INDEX}" \
    torch==2.5.1 torchvision==0.20.1 torchaudio==2.5.1
fi

echo "[5.5/10] Verify torch"
python - <<'PY'
import torch, os
print("torch:", torch.__version__)
print("torch.version.cuda:", torch.version.cuda)
print("CUDA_VISIBLE_DEVICES:", os.environ.get("CUDA_VISIBLE_DEVICES"))
print("cuda available:", torch.cuda.is_available())
print("device count:", torch.cuda.device_count())
if torch.cuda.is_available() and torch.cuda.device_count() > 0:
    print("device0:", torch.cuda.get_device_name(0))
PY

echo "[6/10] Install deps (numpy/tqdm)"
python -m pip install --no-cache-dir -U numpy tqdm

echo "[7/10] Install Mamba deps (no build isolation; low-memory friendly)"
python -m pip install --no-cache-dir --no-build-isolation "causal-conv1d==1.6.0" || true
python -m pip install --no-cache-dir --no-build-isolation "mamba-ssm==2.3.0" || true

if ! python -c "import mamba_ssm" >/dev/null 2>&1; then
  echo "[fallback] try mamba-ssm[causal-conv1d]==2.3.0 (no-build-isolation)"
  python -m pip install --no-cache-dir --no-build-isolation "mamba-ssm[causal-conv1d]==2.3.0"
fi

echo "[7.5/10] Verify mamba imports"
python - <<'PY'
import torch, os
print("torch:", torch.__version__, "cuda:", torch.cuda.is_available(), "torch.version.cuda:", torch.version.cuda)
print("CUDA_VISIBLE_DEVICES:", os.environ.get("CUDA_VISIBLE_DEVICES"))
import causal_conv1d
print("causal_conv1d ok:", getattr(causal_conv1d, "__version__", "unknown"))
import mamba_ssm
print("mamba_ssm ok:", getattr(mamba_ssm, "__version__", "unknown"))
from mamba_ssm.modules.mamba_simple import Mamba
print("Mamba class:", Mamba)
PY

# -----------------------
# NEW: CUDA sanity tests
# -----------------------
echo "[8/10] CUDA sanity tests (nvidia-smi + cuBLAS matmul)"
echo "==== nvidia-smi ===="
nvidia-smi || true
echo "CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES:-<unset>}"

python - <<'PY'
import os, torch
print("torch:", torch.__version__)
print("torch.version.cuda:", torch.version.cuda)
print("cuda available:", torch.cuda.is_available())
print("device count:", torch.cuda.device_count())
print("CUDA_VISIBLE_DEVICES:", os.environ.get("CUDA_VISIBLE_DEVICES"))

if not torch.cuda.is_available() or torch.cuda.device_count() == 0:
    raise SystemExit("[FAIL] No CUDA device visible. This job is not on a GPU or GPU allocation failed.")

torch.cuda.init()
name = torch.cuda.get_device_name(0)
free, total = torch.cuda.mem_get_info()
print("device0:", name)
print(f"mem free={free/1e9:.2f}GB total={total/1e9:.2f}GB")

# cuBLAS handle / GEMM test
a = torch.randn(1024, 1024, device="cuda")
b = torch.randn(1024, 1024, device="cuda")
c = a @ b
torch.cuda.synchronize()
print("[OK] cuBLAS matmul works:", c.shape, c.dtype)
PY

echo "[9/10] Run training"
CMD=(
  python3 -u train_mamba.py
  --data_dir "$DATA_DIR"
  --alphabet "$ALPHABET"
  --target_mode multiclass
  --cuda --amp --amp_dtype bf16
  --d_model 256 --layers 4 --dropout 0.1
  --mamba_d_state 16 --mamba_d_conv 4 --mamba_expand 2
  --batch_size 128
  --num_workers 0
  --lr 3e-4 --weight_decay 1e-3 --grad_clip 1.0
  --aux_w 0.1
  --max_steps 30000 --epochs 200 --patience 30 --early_stop acc
  --save_path ckpt_mamba_query_pm1.pt
)

echo "Command: ${CMD[*]}"
"${CMD[@]}"

echo "[10/10] Done. venv=$VENV_DIR"
