# !/bin/bash

# Check if exactly two arguments are provided
if [ "$#" -ne 1 ]; then
    echo "Usage: $0 <num_gpus>"
    echo "Example: $0 2"
    exit 1
fi

NUM_GPUS="$1"

# Generate CUDA_VISIBLE_DEVICES as a range from 0 to NUM_GPUS-1
CUDA_VISIBLE_DEVICES=0,1  # $(seq -s, 0 $((NUM_GPUS-1)))
export CUDA_VISIBLE_DEVICES

export WANDB_PROJECT=
export WANDB_ENTITY=

for tpb in 131071;  
do

    TOKENS_PER_BATCH=$tpb

    MAX_LENGTH=2048
    BATCH_SIZE=8

    TOTAL_BATCH_SIZE=$(($TOKENS_PER_BATCH/$MAX_LENGTH))
    GRADIENT_ACC_STEPS=$(($TOTAL_BATCH_SIZE/$NUM_GPUS/$BATCH_SIZE))

    # echo $TOKENS_PER_BATCH $MAX_LENGTH $BATCH_SIZE $TOTAL_BATCH_SIZE $GRADIENT_ACC_STEPS
    torchrun --nproc_per_node=$NUM_GPUS --nnodes=1 --master-port=12345 routing/train_routing.py \
            --model_name_or_path  \
            --dataset_name HuggingFaceFW/fineweb-edu \
            --dataset_config_name sample-10BT \
            --ddp_timeout 1000000 \
            --per_device_train_batch_size $BATCH_SIZE  \
            --per_device_eval_batch_size $BATCH_SIZE \
            --do_train True \
            --do_eval True \
            --do_predict False \
            --model_max_position_embeddings $MAX_LENGTH \
            --output_dir outputs \
            --gradient_accumulation_steps $GRADIENT_ACC_STEPS \
            --evaluation_strategy "steps" \
            --eval_steps 200000 \
            --save_steps 5000 \
            --num_train_epochs 1 \
            --seed=2222 \
            --run_name \
            --save_total_limit 3 \
            --learning_rate 4e-4 \
            --weight_decay 0.01 \
            --lr_scheduler_type cosine \
            --warmup_ratio 0.1 \
            --fp16 True \
            --num_levels 3 \
            --beam_width 8 \
            --local_window_size 64
done