

# Define parameters for each array task
MODEL_NAMES=("pythia-6.9b") 
MAX_ITERS_LIST=(6573397)
LEARNING_RATE_LIST=(1e-4)
FLOPS=4e22

# Select parameters based on the array index
MODEL_NAME=${MODEL_NAMES[$SLURM_ARRAY_TASK_ID]}
MAX_ITERS=${MAX_ITERS_LIST[$SLURM_ARRAY_TASK_ID]}
LEARNING_RATE=${LEARNING_RATE_LIST[$SLURM_ARRAY_TASK_ID]}


# Static parameters
DATA_DIR="/data/nhird_hf_rope"
EVAL_ITERS=200
LOG_INTERVAL=50
BATCH_SIZE=64
MICRO_BATCH_SIZE=2
WEIGHT_DECAY=1e-1
BETA1=0.9
BETA2=0.95
GRAD_CLIP=1.0
DECAY_LR=true
WARMUP_RATIO=0.1
STABLE_RATIO=0.8
DECAY_RATIO=0.1
DEVICES=8
NODES=1
PRECISION="bf16-mixed"
SEED=1337

# Set dynamic WANDB name and paths
export WANDB_NAME="Foundation Model Pretraining ${MODEL_NAME} seq = 2048 WSD flop=${FLOPS} single_node"
OUT_DIR="/pretrained_ckpt/flop_${FLOPS}/${MODEL_NAME}"

# Print info for debugging
echo "Running $MODEL_NAME with MAX_ITERS=$MAX_ITERS, LEARNING_RATE=$LEARNING_RATE, MICRO_BATCH_SIZE=$MICRO_BATCH_SIZE"

# Run the training script
srun python3 pretrain/nhird_flop.py \
    --model_name $MODEL_NAME \
    --out_dir $OUT_DIR \
    --data_dir $DATA_DIR \
    --eval_iters $EVAL_ITERS \
    --log_interval $LOG_INTERVAL \
    --learning_rate $LEARNING_RATE \
    --batch_size $BATCH_SIZE \
    --micro_batch_size $MICRO_BATCH_SIZE \
    --max_iters $MAX_ITERS \
    --weight_decay $WEIGHT_DECAY \
    --beta1 $BETA1 \
    --beta2 $BETA2 \
    --grad_clip $GRAD_CLIP \
    --decay_lr $DECAY_LR \
    --warmup_ratio $WARMUP_RATIO \
    --stable_ratio $STABLE_RATIO \
    --decay_ratio $DECAY_RATIO \
    --devices $DEVICES \
    --nodes $NODES \
    --precision $PRECISION \
    --seed $SEED
