# !/bin/bash

# Check if exactly two arguments are provided
if [ "$#" -ne 1 ]; then
    echo "Usage: $0 <num_gpus>"
    echo "Example: $0 2"
    exit 1
fi

NUM_GPUS="$1"

# Generate CUDA_VISIBLE_DEVICES as a range from 0 to NUM_GPUS-1
CUDA_VISIBLE_DEVICES=$(seq -s, 0 $((NUM_GPUS-1)))
export CUDA_VISIBLE_DEVICES

export WANDB_PROJECT="name" 
export WANDB_ENTITY="name"

for tpb in 524288 131071;
do

    TOKENS_PER_BATCH=$tpb

    MAX_LENGTH=8192
    BATCH_SIZE=8

    TOTAL_BATCH_SIZE=$(($TOKENS_PER_BATCH/$MAX_LENGTH))
    GRADIENT_ACC_STEPS=$(($TOTAL_BATCH_SIZE/$NUM_GPUS/$BATCH_SIZE))

    # echo $TOKENS_PER_BATCH $MAX_LENGTH $BATCH_SIZE $TOTAL_BATCH_SIZE $GRADIENT_ACC_STEPS

    # exit

    torchrun --nproc_per_node=$NUM_GPUS --nnodes=1 --master-port=12345 routing/train_routing.py \
            --model_name_or_path outputs/fineweb/pythia-410m-routing_from_scratch-tbp_${TOKENS_PER_BATCH} \
            --dataset_name HuggingFaceFW/fineweb-edu \
            --dataset_config_name sample-10BT \
            --ddp_timeout 1000000 \
            --per_device_train_batch_size ${BATCH_SIZE}  \
            --per_device_eval_batch_size ${BATCH_SIZE} \
            --do_train True \
            --do_eval True \
            --do_predict False \
            --model_max_position_embeddings $MAX_LENGTH \
            --output_dir outputs/fineweb/pythia-410m-routing_from_scratch-tbp_${TOKENS_PER_BATCH}-tuned_${MAX_LENGTH}_${TOKENS_PER_BATCH} \
            --gradient_accumulation_steps ${GRADIENT_ACC_STEPS} \
            --evaluation_strategy "steps" \
            --eval_steps 1000 \
            --save_steps 5000 \
            --num_train_epochs 1 \
            --seed=2222 \
            --run_name nips \
            --save_total_limit 3 \
            --learning_rate 4e-4 \
            --weight_decay 0.01 \
            --lr_scheduler_type cosine \
            --warmup_ratio 0.1 \
            --local_window_size 64 --num_levels 5 --beam_width 4 --causal
done