
# # # Part1: 60m, c4, DST-[CHTs(BRF, r=0.25), RigL, MEST, SET, GMP, GraNet], lr=3e-3

# # sparsities=(0.7)
# # # sparsities=(0.8)

# # granet_init_sparsity=0.5
# # # sparsities=(0.7)

# # delta=0.25
# # update_interval=100
# # pruning_method="ri"
# # remove_method="weight_magnitude_soft"
# # sparsity_distribution="uniform"
# # zeta=0.1
# # iterative_warmup_steps=20

# # config_file="configs/llama_60m.json"
# # num_training_steps=10000
# # warmup_steps=1000
# # lr=3e-3

# # for dataset in "c4"
# # do

# # for sparsity in "${sparsities[@]}"
# # do
# #   # CHTs + az
# #   HF_HOME="/data/hf_cache" HF_DATASETS_OFFLINE=1 CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python -m torch.distributed.run --standalone --nproc_per_node 8 torchrun_main.py --run_name "CHTs_a_nosl" --model_config $config_file --dataset_name $dataset --lr $lr --batch_size 64 --total_batch_size 512 --num_training_steps $num_training_steps --warmup_steps $warmup_steps --weight_decay 0 --dtype bfloat16 --eval_every 1000 --optimizer adam --iterative_warmup_steps $iterative_warmup_steps --update_interval $update_interval --sparsity $sparsity --only_save_last --dst_scheduler --remove_method $remove_method --regrow_method CH2_L3n_soft --zeta $zeta --adaptive_zeta --WS3 --delta $delta --degree_dist uniform --no_log --log_to_file --save_dir checkpoints/ --only_save_last --history_weights --new_history_weights --start_T 1 --end_T 3

# #   # RigL
# #   HF_HOME="/data/hf_cache" HF_DATASETS_OFFLINE=1 CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python -m torch.distributed.run --standalone --nproc_per_node 8 torchrun_main.py --run_name "rigl_nosl" --model_config $config_file --dataset_name $dataset --lr $lr --batch_size 64 --total_batch_size 512 --num_training_steps $num_training_steps --warmup_steps $warmup_steps --weight_decay 0 --dtype bfloat16 --eval_every 1000 --optimizer adam --iterative_warmup_steps $iterative_warmup_steps --update_interval $update_interval --sparsity ${sparsity} --zeta $zeta --only_save_last --dst_scheduler --remove_method weight_magnitude --regrow_method gradient --adaptive_zeta --no_log --log_to_file --save_dir checkpoints --only_save_last

# #   # MEST
# #   HF_HOME="/data/hf_cache" HF_DATASETS_OFFLINE=1 CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python -m torch.distributed.run --standalone --nproc_per_node 8 torchrun_main.py --run_name "MEST_nosl" --model_config $config_file --dataset_name $dataset --lr $lr --batch_size 64 --total_batch_size 512 --num_training_steps $num_training_steps --warmup_steps $warmup_steps --weight_decay 0 --dtype bfloat16 --eval_every 1000 --optimizer adam --iterative_warmup_steps $iterative_warmup_steps --update_interval $update_interval --sparsity ${sparsity} --zeta $zeta --only_save_last --dst_scheduler --remove_method MEST --regrow_method random --EM_S --no_log --log_to_file --save_dir checkpoints --only_save_last

# #   # SET
# #   HF_HOME="/data/hf_cache" HF_DATASETS_OFFLINE=1 CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python -m torch.distributed.run --standalone --nproc_per_node 8 torchrun_main.py --run_name "SET_nosl" --model_config $config_file --dataset_name $dataset --lr $lr --batch_size 64 --total_batch_size 512 --num_training_steps $num_training_steps --warmup_steps $warmup_steps --weight_decay 0 --dtype bfloat16 --eval_every 1000 --optimizer adam --iterative_warmup_steps $iterative_warmup_steps --update_interval $update_interval --sparsity ${sparsity} --zeta $zeta --only_save_last --dst_scheduler --remove_method weight_magnitude --regrow_method random --no_log --log_to_file --save_dir checkpoints --only_save_last

# #   # # GMP
# #   # HF_HOME="/data/hf_cache" HF_DATASETS_OFFLINE=1 CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python -m torch.distributed.run --standalone --nproc_per_node 8 torchrun_main.py --run_name "gmp_nosl" --model_config $config_file --dataset_name $dataset --lr $lr --batch_size 64 --total_batch_size 512 --num_training_steps $num_training_steps --warmup_steps $warmup_steps --weight_decay 0 --dtype bfloat16 --eval_every 1000 --optimizer adam --iterative_warmup_steps $iterative_warmup_steps --update_interval $update_interval --sparsity ${sparsity} --zeta $zeta --only_save_last --dst_scheduler --gmp --granet_init_sparsity $granet_init_sparsity --sparsity_distribution uniform --pruning_method weight_magnitude --pruning_scheduler granet --history_weights --no_log --log_to_file --save_dir checkpoints --only_save_last

# #   # # GraNet
# #   # HF_HOME="/data/hf_cache" HF_DATASETS_OFFLINE=1 CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python -m torch.distributed.run --standalone --nproc_per_node 8 torchrun_main.py --run_name "granet_nosl" --model_config $config_file --dataset_name $dataset --lr $lr --batch_size 64 --total_batch_size 512 --num_training_steps $num_training_steps --warmup_steps $warmup_steps --weight_decay 0 --dtype bfloat16 --eval_every 1000 --optimizer adam --iterative_warmup_steps $iterative_warmup_steps --update_interval $update_interval --sparsity ${sparsity} --zeta $zeta --only_save_last --dst_scheduler --remove_method weight_magnitude --regrow_method gradient --granet --granet_init_sparsity $granet_init_sparsity --sparsity_distribution uniform --pruning_method weight_magnitude --pruning_scheduler granet --adaptive_zeta --no_log --log_to_file --save_dir checkpoints --only_save_last


# # done
# # done




# # # Part2: 130m, [openwebtext, c4], DST-[ CHTs(BRF, r=0.25), RigL, MEST, SET, GMP, GraNet], lr=3e-3

# # sparsities=(0.7)
# # # sparsities=(0.8)

# # granet_init_sparsity=0.5
# # # sparsities=(0.7)

# # delta=0.25
# # update_interval=100
# # pruning_method="ri"
# # remove_method="weight_magnitude_soft"
# # sparsity_distribution="uniform"
# # zeta=0.1
# # iterative_warmup_steps=20

# # config_file="configs/llama_130m.json"
# # num_training_steps=20000
# # warmup_steps=2000
# # lr=3e-3

# # for dataset in "c4"
# # do

# # for sparsity in "${sparsities[@]}"
# # do
# # #   # CHTs + az
# # #   HF_HOME="/data/hf_cache" HF_DATASETS_OFFLINE=1 CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python -m torch.distributed.run --standalone --nproc_per_node 8 torchrun_main.py --run_name "CHTs_a_nosl" --model_config $config_file --dataset_name $dataset --lr $lr --batch_size 64 --total_batch_size 512 --num_training_steps $num_training_steps --warmup_steps $warmup_steps --weight_decay 0 --dtype bfloat16 --eval_every 1000 --optimizer adam --iterative_warmup_steps $iterative_warmup_steps --update_interval $update_interval --sparsity $sparsity --only_save_last --dst_scheduler --remove_method $remove_method --regrow_method CH2_L3n_soft --zeta $zeta --adaptive_zeta --WS3 --delta $delta --degree_dist uniform --no_log --log_to_file --save_dir checkpoints/ --only_save_last --history_weights --new_history_weights --start_T 1 --end_T 3

# # #   # RigL
# # #   HF_HOME="/data/hf_cache" HF_DATASETS_OFFLINE=1 CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python -m torch.distributed.run --standalone --nproc_per_node 8 torchrun_main.py --run_name "rigl_nosl" --model_config $config_file --dataset_name $dataset --lr $lr --batch_size 64 --total_batch_size 512 --num_training_steps $num_training_steps --warmup_steps $warmup_steps --weight_decay 0 --dtype bfloat16 --eval_every 1000 --optimizer adam --iterative_warmup_steps $iterative_warmup_steps --update_interval $update_interval --sparsity ${sparsity} --zeta $zeta --only_save_last --dst_scheduler --remove_method weight_magnitude --regrow_method gradient --adaptive_zeta --no_log --log_to_file --save_dir checkpoints --only_save_last

# # #   # MEST
# # #   HF_HOME="/data/hf_cache" HF_DATASETS_OFFLINE=1 CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python -m torch.distributed.run --standalone --nproc_per_node 8 torchrun_main.py --run_name "MEST_nosl" --model_config $config_file --dataset_name $dataset --lr $lr --batch_size 64 --total_batch_size 512 --num_training_steps $num_training_steps --warmup_steps $warmup_steps --weight_decay 0 --dtype bfloat16 --eval_every 1000 --optimizer adam --iterative_warmup_steps $iterative_warmup_steps --update_interval $update_interval --sparsity ${sparsity} --zeta $zeta --only_save_last --dst_scheduler --remove_method MEST --regrow_method random --EM_S --no_log --log_to_file --save_dir checkpoints --only_save_last

# #   # SET
# #   HF_HOME="/data/hf_cache" HF_DATASETS_OFFLINE=1 CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python -m torch.distributed.run --standalone --nproc_per_node 8 torchrun_main.py --run_name "SET_nosl" --model_config $config_file --dataset_name $dataset --lr $lr --batch_size 64 --total_batch_size 512 --num_training_steps $num_training_steps --warmup_steps $warmup_steps --weight_decay 0 --dtype bfloat16 --eval_every 1000 --optimizer adam --iterative_warmup_steps $iterative_warmup_steps --update_interval $update_interval --sparsity ${sparsity} --zeta $zeta --only_save_last --dst_scheduler --remove_method weight_magnitude --regrow_method random --no_log --log_to_file --save_dir checkpoints --only_save_last

# #   # # GMP
# #   # HF_HOME="/data/hf_cache" HF_DATASETS_OFFLINE=1 CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python -m torch.distributed.run --standalone --nproc_per_node 8 torchrun_main.py --run_name "gmp_nosl" --model_config $config_file --dataset_name $dataset --lr $lr --batch_size 64 --total_batch_size 512 --num_training_steps $num_training_steps --warmup_steps $warmup_steps --weight_decay 0 --dtype bfloat16 --eval_every 1000 --optimizer adam --iterative_warmup_steps $iterative_warmup_steps --update_interval $update_interval --sparsity ${sparsity} --zeta $zeta --only_save_last --dst_scheduler --gmp --granet_init_sparsity $granet_init_sparsity --sparsity_distribution uniform --pruning_method weight_magnitude --pruning_scheduler granet --history_weights --no_log --log_to_file --save_dir checkpoints --only_save_last

# #   # # GraNet
# #   # HF_HOME="/data/hf_cache" HF_DATASETS_OFFLINE=1 CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python -m torch.distributed.run --standalone --nproc_per_node 8 torchrun_main.py --run_name "granet_nosl" --model_config $config_file --dataset_name $dataset --lr $lr --batch_size 64 --total_batch_size 512 --num_training_steps $num_training_steps --warmup_steps $warmup_steps --weight_decay 0 --dtype bfloat16 --eval_every 1000 --optimizer adam --iterative_warmup_steps $iterative_warmup_steps --update_interval $update_interval --sparsity ${sparsity} --zeta $zeta --only_save_last --dst_scheduler --remove_method weight_magnitude --regrow_method gradient --granet --granet_init_sparsity $granet_init_sparsity --sparsity_distribution uniform --pruning_method weight_magnitude --pruning_scheduler granet --adaptive_zeta --no_log --log_to_file --save_dir checkpoints --only_save_last


# # done
# # done


# # Part3: 130m, openwebtext and c4, Full (lr=1e-3)

# sparsities=(0.9 0.8 0.7)
# # sparsities=(0.8)

# granet_init_sparsity=0.5
# # sparsities=(0.7)

# delta=0.25
# update_interval=100
# pruning_method="ri"
# remove_method="weight_magnitude_soft"
# sparsity_distribution="uniform"
# zeta=0.1
# iterative_warmup_steps=20

# config_file="configs/llama_130m.json"
# num_training_steps=20000
# warmup_steps=2000
# lr=1e-3

# for dataset in "openwebtext" "c4"
# do

#   # Dense
#   HF_HOME="/data/hf_cache" HF_DATASETS_OFFLINE=1 CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python -m torch.distributed.run --standalone --nproc_per_node 8 torchrun_main.py --run_name "dense" --model_config $config_file --dataset_name $dataset --lr $lr --batch_size 64 --total_batch_size 512 --num_training_steps $num_training_steps --warmup_steps $warmup_steps --weight_decay 0 --dtype bfloat16 --eval_every 1000 --optimizer adam --iterative_warmup_steps $iterative_warmup_steps --update_interval $update_interval --only_save_last --no_log --log_to_file --save_dir checkpoints/ --only_save_last --scheduler cosine

# done


# Part4: 60m, c4, Full (lr=1e-3)

sparsities=(0.9 0.8 0.7)
# sparsities=(0.8)

granet_init_sparsity=0.5
# sparsities=(0.7)

delta=0.25
update_interval=100
pruning_method="ri"
remove_method="weight_magnitude_soft"
sparsity_distribution="uniform"
zeta=0.1
iterative_warmup_steps=20

config_file="configs/llama_60m.json"
num_training_steps=10000
warmup_steps=1000
lr=1e-3

for dataset in "c4"
do

  # Dense
  HF_HOME="/data/hf_cache" HF_DATASETS_OFFLINE=1 CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python -m torch.distributed.run --standalone --nproc_per_node 8 torchrun_main.py --run_name "dense" --model_config $config_file --dataset_name $dataset --lr $lr --batch_size 64 --total_batch_size 512 --num_training_steps $num_training_steps --warmup_steps $warmup_steps --weight_decay 0 --dtype bfloat16 --eval_every 1000 --optimizer adam --iterative_warmup_steps $iterative_warmup_steps --update_interval $update_interval --only_save_last --no_log --log_to_file --save_dir checkpoints/ --only_save_last --scheduler cosine


done















# ### 130M num_training_steps=40000, warmup_steps=10000 -model_config configs/llama130m.json
# ws_beta=0
# update_interval=100
# pruning_method="ri"
# remove_method="weight_magnitude_soft"
# sparsity_distribution="uniform"
# zeta=0.1
# iterative_warmup_steps=20

# model_config="configs/llama_130m.json"
# num_training_steps=20000
# warmup_steps=2000

# datasets=("c4")



# lrs=(3e-3)
# las=(0)

# sparsities=(0.95 0.95 0.9 0.85 0.95 0.9 0.85 0.8 0.75)
# ranks=(24 72 48 24 120 96 72 48 24)

# for dataset in "${datasets[@]}"
# do
#     for lr in "${lrs[@]}"
#     do
#         for la in "${las[@]}"
#         do
#             for ((i=0; i<${#sparsities[@]}; i++)); do
#                 # static
#                 HF_HOME="/modelopsnas/modelops/463248/hf_cache" HF_DATASETS_OFFLINE=1 CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python -m torch.distributed.run --standalone --nproc_per_node 8 torchrun_main.py --run_name "static_sp${sparsities[i]}+r${ranks[i]}+la$la" --model_config $model_config --dataset_name $dataset --lr $lr --batch_size 64 --total_batch_size 512 --num_training_steps $num_training_steps --warmup_steps $warmup_steps --weight_decay 0 --dtype bfloat16 --eval_every 1000 --optimizer adam --iterative_warmup_steps $iterative_warmup_steps --update_interval $update_interval --sparsity ${sparsities[i]} --only_save_last --dst_scheduler --static_dst --no_log --log_to_file --save_dir checkpoints/ --only_save_last --sltrain --rank ${ranks[i]} --lora_alpha $la --no_compute_similarity

#             done
#         done
#     done

# done
