# sparsities=(0.8 0.9 0.95)
# deltas=(0 0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9 1)
deltas=(0.7)
sparsities=(0.8)
update_interval=100
pruning_method="ri"
remove_method="weight_magnitude_soft"
sparsity_distribution="uniform"
lr=3e-3
zeta=0.1
iterative_warmup_steps=10

for sparsity in ${sparsities[@]}
do

    for delta in ${deltas[@]}
    do
        # # CHTs + az
        CUDA_VISIBLE_DEVICES=0,1,2,8,4,5,6,7 python3 -m torch.distributed.run --standalone --nproc_per_node 8 torchrun_main.py --run_name "llama60m" --model_config configs/llama_60m.json --dataset_name openwebtext --lr $lr --batch_size 64 --total_batch_size 512 --num_training_steps 10000 --warmup_steps 1000 --weight_decay 0 --dtype bfloat16 --eval_every 1000 --optimizer adam --iterative_warmup_steps $iterative_warmup_steps --update_interval $update_interval --sparsity $sparsity --only_save_last --dst_scheduler --remove_method $remove_method --regrow_method CH2_L3n_soft --zeta $zeta --adaptive_zeta --WS3 --delta $delta --log_to_file --save_dir checkpoints/ --only_save_last --degree_dist uniform --start_T 1 --end_T 9 --no_log --evolution --evolution_strategy lora --evolution_every 1 --sltrain

        # # s-shape
        # CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python3 -m torch.distributed.run --standalone --nproc_per_node 8 torchrun_main.py --run_name "llama60m" --model_config configs/llama_60m.json --dataset_name c4 --lr $lr --batch_size 64 --total_batch_size 512 --num_training_steps 10000 --warmup_steps 1000 --weight_decay 0 --dtype bfloat16 --eval_every 1000 --optimizer adam --iterative_warmup_steps $iterative_warmup_steps --update_interval $update_interval --sparsity $sparsity --only_save_last --dst_scheduler --remove_method $remove_method --regrow_method CH2_L3n_soft --granet --granet_init_sparsity 0.5 --sparsity_distribution $sparsity_distribution --pruning_method $pruning_method --pruning_scheduler s_shape --zeta $zeta --adaptive_zeta --WS3 --delta $delta --pruning_T_end 8000 --log_to_file --save_dir checkpoints/ --only_save_last --degree_dist uniform --start_T 1.0 --end_T 9.0

        # # granet
        # CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python3 -m torch.distributed.run --standalone --nproc_per_node 8 torchrun_main.py --run_name "llama60m" --model_config configs/llama_60m.json --dataset_name c4 --lr $lr --batch_size 64 --total_batch_size 512 --num_training_steps 10000 --warmup_steps 1000 --weight_decay 0 --dtype bfloat16 --eval_every 1000 --optimizer adam --iterative_warmup_steps $iterative_warmup_steps --update_interval $update_interval --sparsity $sparsity --only_save_last --dst_scheduler --remove_method $remove_method --regrow_method CH2_L3n_soft --granet --granet_init_sparsity 0.5 --sparsity_distribution $sparsity_distribution --pruning_method $pruning_method --pruning_scheduler granet --zeta $zeta --adaptive_zeta --WS3 --delta $delta --pruning_T_end 8000 --log_to_file --save_dir checkpoints/ --only_save_last --degree_dist uniform --start_T 1 --end_T 9
    done


    # for delta in ${deltas[@]}
    # do
    #     # # CHTs + az
    #     CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python3 -m torch.distributed.run --standalone --nproc_per_node 8 torchrun_main.py --run_name "llama60m" --model_config configs/llama_60m.json --dataset_name c4 --lr $lr --batch_size 64 --total_batch_size 512 --num_training_steps 10000 --warmup_steps 1000 --weight_decay 0 --dtype bfloat16 --eval_every 1000 --optimizer adam --iterative_warmup_steps $iterative_warmup_steps --update_interval $update_interval --sparsity $sparsity --only_save_last --dst_scheduler --remove_method $remove_method --regrow_method CH2_L3n_soft --zeta $zeta --adaptive_zeta --WS --ws_beta $delta --log_to_file --save_dir checkpoints/ --only_save_last --start_T 1.0 --end_T 9.0

    #     # s-shape
    #     CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python3 -m torch.distributed.run --standalone --nproc_per_node 8 torchrun_main.py --run_name "llama60m" --model_config configs/llama_60m.json --dataset_name c4 --lr $lr --batch_size 64 --total_batch_size 512 --num_training_steps 10000 --warmup_steps 1000 --weight_decay 0 --dtype bfloat16 --eval_every 1000 --optimizer adam --iterative_warmup_steps $iterative_warmup_steps --update_interval $update_interval --sparsity $sparsity --only_save_last --dst_scheduler --remove_method $remove_method --regrow_method CH2_L3n_soft --granet --granet_init_sparsity 0.5 --sparsity_distribution $sparsity_distribution --pruning_method $pruning_method --pruning_scheduler s_shape --zeta $zeta --adaptive_zeta --WS --ws_beta $delta --pruning_T_end 8000 --log_to_file --save_dir checkpoints/ --only_save_last --start_T 1.0 --end_T 9.0

    #     # # granet
    #     # CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python3 -m torch.distributed.run --standalone --nproc_per_node 8 torchrun_main.py --run_name "llama60m" --model_config configs/llama_60m.json --dataset_name c4 --lr $lr --batch_size 64 --total_batch_size 512 --num_training_steps 10000 --warmup_steps 1000 --weight_decay 0 --dtype bfloat16 --eval_every 1000 --optimizer adam --iterative_warmup_steps $iterative_warmup_steps --update_interval $update_interval --sparsity $sparsity --only_save_last --dst_scheduler --remove_method $remove_method --regrow_method CH2_L3n_soft --granet --granet_init_sparsity 0.5 --sparsity_distribution $sparsity_distribution --pruning_method $pruning_method --pruning_scheduler granet --zeta $zeta --adaptive_zeta --WS --ws_beta $delta --pruning_T_end 8000 --log_to_file --save_dir checkpoints/ --only_save_last --start_T 1 --end_T 9
    # done
done