pretrained_path=/root/paddlejob/workspace/env_run/huitingfeng/models/llama-2.7b-hf
base_path=/root/paddlejob/workspace/env_run/huitingfeng/MoE/output
all_name_or_path=/root/paddlejob/workspace/env_run/huitingfeng/MoE/output/sheared-steps

numbers=(12200 13176 14640 15616)
i=1

for num in "${numbers[@]}"
do
    output_path="output/sheared-steps/moe-gate-$i"
    data_path="data/$num.json"
    lora_name_or_path="output/sheared-steps/checkpoint-$num"

    torchrun --nproc_per_node=8 --nnodes=1 --master_port=4741 train_moe_gate.py \
        --deepspeed ds_config/ds_z2_config.json \
        --model_name_or_path $pretrained_path \
        --data_path $data_path \
        --lora_name_or_path $lora_name_or_path \
        --all_name_or_path $all_name_or_path \
        --bf16 True \
        --model_max_length 2048 \
        --output_dir $output_path \
        --num_train_epochs 4 \
        --per_device_train_batch_size 1 \
        --gradient_accumulation_steps 16 \
        --evaluation_strategy "no" \
        --save_strategy "epoch" \
        --logging_strategy "steps" \
        --lr_scheduler_type "constant" \
        --save_total_limit 1 \
        --overwrite_output_dir \
        --logging_steps 1 \
        --learning_rate 2e-5 \
        --weight_decay 0.0 \
        --warmup_steps 0 \
        --tf32 True
    
    ((i++))
    
    find "$base_path" -type d -name 'global*' -exec rm -r {} +
done