#!/bin/bash
#SBATCH -p gpu22
#SBATCH -c 4
#SBATCH --gres=gpu:1
#SBATCH --mem=16G
#SBATCH -t 03:00:00
#SBATCH -J deltanet
#SBATCH -o delta.log
set -euo pipefail

# -----------------------
# User knobs (env overrides)
# -----------------------
CU_INDEX="${CU_INDEX:-cu126}"

TMP_BASE="${TMP_BASE:-/tmp/$USER/deltanet_tf_nomod}"
VENV_DIR="${VENV_DIR:-$TMP_BASE/venv}"

DATA_DIR="${DATA_DIR:-$PWD/data/mm_T100_bal_stratT_cap1e6}"
TRAIN_PY="${TRAIN_PY:-$PWD/train_deltanet_binary.py}"   # <-- your deltanet TF-nomod script
SAVE_PATH="${SAVE_PATH:-$PWD/ckpt_deltanet_tf_nomod.pt}"

# model / train hyperparams
D_MODEL="${D_MODEL:-256}"
HEADS="${HEADS:-4}"
LAYERS="${LAYERS:-1}"
DROPOUT="${DROPOUT:-0.2}"

MLP_HIDDEN_MULT="${MLP_HIDDEN_MULT:-4}"
MLP_ACT="${MLP_ACT:-gelu}"

BATCH_SIZE="${BATCH_SIZE:-256}"
EPOCHS="${EPOCHS:-200}"
MAX_STEPS="${MAX_STEPS:-30000}"
LR="${LR:-3e-4}"
WD="${WD:-0.005}"
GRAD_CLIP="${GRAD_CLIP:-1.0}"

PATIENCE="${PATIENCE:-20}"
EARLY_STOP="${EARLY_STOP:-loss}"   # loss / stepAcc / finalAcc (depends on your script)

# tokenization / safety
MAX_LEN="${MAX_LEN:-0}"            # 0 => auto infer (recommended)
STATE_CAP="${STATE_CAP:-0}"        # 0 disables

# perf
NUM_WORKERS="${NUM_WORKERS:-2}"

# AMP
AMP="${AMP:-1}"                    # 1 => enable --amp
AMP_DTYPE="${AMP_DTYPE:-bf16}"     # bf16 or fp16 (must match your script)

# Install fla?
INSTALL_FLA="${INSTALL_FLA:-1}"    # 1 => pip install flash-linear-attention

echo "[1/8] Prepare /tmp: $TMP_BASE"
mkdir -p "$TMP_BASE"

echo "[2/8] Create/reuse venv: $VENV_DIR"
if [[ ! -x "$VENV_DIR/bin/python" ]]; then
  python3 -m venv "$VENV_DIR"
fi
# shellcheck disable=SC1091
source "$VENV_DIR/bin/activate"

echo "[3/8] Pip cache/tmp -> /tmp"
export PIP_CACHE_DIR="$TMP_BASE/pip-cache"
export TMPDIR="$TMP_BASE/pip-tmp"
mkdir -p "$PIP_CACHE_DIR" "$TMPDIR"

echo "[4/8] Upgrade pip tooling"
python -m pip install --no-cache-dir -U pip setuptools wheel

echo "[5/8] Install torch if missing (CUDA wheels: ${CU_INDEX})"
if python - <<'PY'
import importlib, sys
try:
    importlib.import_module("torch")
    sys.exit(0)
except Exception:
    sys.exit(1)
PY
then
  echo "  [skip] torch already installed"
else
  python -m pip install --no-cache-dir --index-url "https://download.pytorch.org/whl/${CU_INDEX}" \
    torch torchvision torchaudio
fi

echo "[6/8] Install deps"
python -m pip install --no-cache-dir -U numpy tqdm

if [[ "$INSTALL_FLA" == "1" ]]; then
  echo "[6.5/8] Install flash-linear-attention (fla)"
  # If your cluster needs a specific version, pin it here.
  python -m pip install --no-cache-dir -U flash-linear-attention
fi

echo "[7/8] Environment"
python - <<'PY'
import sys
import torch, numpy as np
print("Python:", sys.version.split()[0])
print("Torch :", torch.__version__)
print("CUDA  :", torch.version.cuda)
print("CUDA avail:", torch.cuda.is_available())
if torch.cuda.is_available():
    print("GPU:", torch.cuda.get_device_name(0))
print("NumPy :", np.__version__)
try:
    import fla
    print("fla   :", getattr(fla, "__version__", "unknown"))
except Exception as e:
    print("fla   : not importable:", e)
PY

echo "[8/8] Run training"

if [[ ! -f "$TRAIN_PY" ]]; then
  echo "ERROR: TRAIN_PY not found: $TRAIN_PY" >&2
  exit 1
fi
if [[ ! -d "$DATA_DIR" ]]; then
  echo "ERROR: DATA_DIR not found: $DATA_DIR" >&2
  exit 1
fi

CMD=(
python3 -u train_deltanet.py \
  --data_dir data/mm_stepwise_m29_qk0 \
  --alphabet pm1 \
  --cuda --amp --amp_dtype bf16 \
  --d_model 256 --heads 4 --layers 2 \
  --dropout 0.2 \
  --mlp_hidden_mult 4 --mlp_act gelu \
  --deltanet_mode chunk \
  --state_cap 0 \
  --aux_w 0.005 \
  --pos_weight 0.0 \
  --batch_size 256 \
  --lr 1e-4 \
  --weight_decay 0.05 \
  --grad_clip 1.0 \
  --max_steps 30000 \
  --epochs 10 \
  --patience 10 \
  --early_stop loss \
  --num_workers 2 \
  --save_path ckpt_deltanet_tf_stepwise_lr1e4_do02_wd05_aux01.pt




)

# optional args
if [[ "$MAX_LEN" != "0" ]]; then
  CMD+=(--max_len "$MAX_LEN")
fi
if [[ "$STATE_CAP" != "0" ]]; then
  CMD+=(--state_cap "$STATE_CAP")
fi

# AMP flags MUST be inside CMD (this fixes your error)
if [[ "$AMP" == "1" ]]; then
  CMD+=(--amp --amp_dtype "$AMP_DTYPE")
fi

echo "Command: ${CMD[*]}"
"${CMD[@]}"

echo "Done. Saved: $SAVE_PATH"
