#!/bin/bash
#SBATCH --nodes=1
#SBATCH --ntasks-per-node=8
#SBATCH --gres=gpu:8
#SBATCH --cpus-per-task=8
#SBATCH --account=test

nvidia-smi

sleep 1

model_name="mamba2-36-768"
model_name="mamba2-780m"
# model_name="mamba2-1.3b"
# model_name="mamba2-370m"
# model_name="mamba2-130m"
# model_name="mamba2-48-768"
# model_name="mamba2-768-12"
# model_name="mamba2-256-8"

model_config="configs/model/mamba2/${model_name}.json"  # Only used when rand_init == "1"
# model_config="../../ckpts/mamba/mamba2-780m/config.json"

pretrained_path="../../ckpts/mamba/${model_name}"  # Only used when rand_init != "1"
# pretrained_path="../../long-rnn/long-rnn/ckpts/mamba2-36-768/T8192_B4_GA2_P1_SR1_RD0_RI1_lr0.0003/ckpt_80000"  # Only used when rand_init != "1"
data_config="configs/data/redpajama_4k.json"

batch_size="1"
max_length="8192"
grad_accum="1"
packing_count="8"
repeat_data="0"
state_reset_interval="64"
lr="5e-4"
n_train_steps="600000"  # 600B tokens
n_drop_steps="60000"  # 60B tokens
n_warmup_steps="5000"  # 500M tokens
rand_init="0"
save_interval="2000"  # per 1B tokens
grad_ckpt="layer"

resume_path="output/mamba2-780m/mamba2-780m_lr0.0002_T8192_B1_GA2_P8_SR16_RD0_RI0/ckpt_42000"
resume_path="output/mamba2-130m/mamba2-130m_lr0.0005_T16384_B4_GA1_P1_SR1_RD0_RI0/ckpt_100000"
load_start_step="100000"

cmd="accelerate launch"
cmd+=" --num_processes 8"
cmd+=" --num_machines 1"
cmd+=" --mixed_precision no"
cmd+=" --dynamo_backend no"

cmd+=" train.py"
cmd+=" --model ${model_name}"
cmd+=" --pretrained_path ${pretrained_path}"
cmd+=" --model_config ${model_config}"
cmd+=" --data_config ${data_config}"
cmd+=" --batch_size ${batch_size}"
cmd+=" --max_length ${max_length}"
cmd+=" --grad_accum ${grad_accum}"
cmd+=" --packing_count ${packing_count}"
cmd+=" --repeat_data ${repeat_data}"
cmd+=" --state_reset_interval ${state_reset_interval}"
cmd+=" --lr ${lr}"
cmd+=" --n_train_steps ${n_train_steps}"
cmd+=" --n_drop_steps ${n_drop_steps}"
cmd+=" --n_warmup_steps ${n_warmup_steps}"
cmd+=" --rand_init ${rand_init}"
cmd+=" --save_interval ${save_interval}"
cmd+=" --grad_ckpt ${grad_ckpt}"
cmd+=" --resume_path ${resume_path}"
cmd+=" --load_start_step ${load_start_step}"

cmd+=" --output_dir ./output"
cmd+=" --tensorboard ./tensorboard"

echo "=================="
echo $cmd
echo "=================="

$cmd
