!/bin/bash

#SBATCH --mem=224G
#SBATCH --cpus-per-task=4
#SBATCH --gres=gpu:nvidia_l40s:1
#SBATCH --output=/home/%u/slurm_logs/slurm-%x-%j/logs.out
#SBATCH --error=/home/%u/slurm_logs/slurm-%x-%j/logs.out

###############################################################################
# Robust training launch script (hardened)
# Submit with: sbatch --job-name=run-$(date +%Y%m%d-%H%M%S) speedrun.sh
# Example: sbatch --job-name=run-$(date +%Y%m%d-%H%M%S) /home/***/UT/run.sh pile llada 70 mlp
###############################################################################



# This script is the "Best ChatGPT clone that $100 can buy",
# It is designed to run in ~4 hours on 8XH100 node at $3/GPU/hour.

# 1) Example launch (simplest):
# bash speedrun.sh
# 2) Example launch in a screen session (because the run takes ~4 hours):
# screen -L -Logfile speedrun.log -S speedrun bash speedrun.sh
# 3) Example launch with wandb logging, but see below for setting up wandb first:
# WANDB_RUN=speedrun screen -L -Logfile speedrun.log -S speedrun bash speedrun.sh

SCRIPT_ARGS="$@"

# Default intermediate artifacts directory is in ~/.cache/nanochat
export OMP_NUM_THREADS=1
if [ -f $(dirname "$0")/.env ]; then
	set -a
	# shellcheck disable=SC1091
    . $(dirname "$0")/.env
	set +a
fi
export NANOCHAT_BASE_DIR="${NANOCHAT_BASE_DIR:-$HOME/.cache/nanochat}"
mkdir -p $NANOCHAT_BASE_DIR

# -----------------------------------------------------------------------------
# Python venv setup with uv

# # install uv (if not already installed)
# command -v uv &> /dev/null || curl -LsSf https://astral.sh/uv/install.sh | sh
# # create a .venv local virtual environment (if it doesn't exist)
# [ -d ".venv" ] || uv venv
# # install the repo dependencies
# uv sync --extra gpu
# # activate venv so that `python` uses the project's venv instead of system python
# source .venv/bin/activate

pwd

MODELS=("kla_recurrent" "kla_torch" "kla_triton" "mamba_fused" "mamba_torch" "mamba_triton_fused" "mamba_triton")


# for forward-only in [True, False]:
#   for seq-len in [2**i for i in range(7, 14)]:
#   for model in MODELS:
#     if forward-only:
#       python -m benchmark.bench --model $model --forward-only --batch-size 1
#     else:
#       python -m benchmark.bench --model $model

for forward_only in True False; do
  for seq_len_exp in {7..13}; do
    seq_len=$((2**seq_len_exp))
    for model in "${MODELS[@]}"; do
      if [ "$forward_only" = True ]; then
        echo "Running forward-only benchmark for model $model with seq_len $seq_len"
        python benchmark/bench.py --model "$model" --forward-only --batch-size 1 --seq-len "$seq_len" $SCRIPT_ARGS
      else
        echo "Running training benchmark for model $model with seq_len $seq_len"
        python benchmark/bench.py --model "$model" --seq-len "$seq_len" $SCRIPT_ARGS
      fi
    done
  done
done
