### 350M num_training_steps=60000, warmup_steps=6000 -model_config configs/llama350m.json
ws_beta=0
update_interval=100
pruning_method="ri"
remove_method="weight_magnitude_soft"
sparsity_distribution="uniform"
zeta=0.1
iterative_warmup_steps=20

model_config="configs/llama_350m.json"
num_training_steps=60000
warmup_steps=6000

datasets=("openwebtext")



lrs=(3e-3)
las=(0)

# sparsities=(0.95 0.95 0.9 0.85 0.95 0.9 0.85 0.8 0.75)
# # ranks=(24 72 48 24 120 96 72 48 24)
# ranks=(32 96 64 32 160 128 96 64 32)
sparsities=(0.95 0.9 0.85)
ranks=(32 64 96)
# sparsities=(0.95)
# ranks=(32)

alws=(0.5)


for dataset in "${datasets[@]}"
do
    for lr in "${lrs[@]}"
    do
        for la in "${las[@]}"
        do
            for ((i=0; i<${#sparsities[@]}; i++)); do

                # # CHTs+az
                # HF_HOME="/data/hf_cache" HF_DATASETS_OFFLINE=1 CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python -m torch.distributed.run --standalone --nproc_per_node 8 torchrun_main.py --run_name "CHTs_a_sp${sparsities[i]}+r${ranks[i]}+la$la" --model_config $model_config --dataset_name $dataset --lr $lr --batch_size 64 --total_batch_size 512 --num_training_steps $num_training_steps --warmup_steps $warmup_steps --weight_decay 0 --dtype bfloat16 --eval_every 1000 --optimizer adam --iterative_warmup_steps $iterative_warmup_steps --update_interval $update_interval --sparsity ${sparsities[i]} --only_save_last --dst_scheduler --remove_method $remove_method --regrow_method CH2_L3n_soft --zeta $zeta --adaptive_zeta --WS --ws_beta $ws_beta --no_log --log_to_file --save_dir checkpoints/ --only_save_last --sltrain --rank ${ranks[i]} --start_T 1 --end_T 3 --lora_alpha $la --no_compute_similarity

                # # CHTs+az
                # HF_HOME="/modelopsnas/modelops/463248/hf_cache" HF_DATASETS_OFFLINE=1 CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python -m torch.distributed.run --standalone --nproc_per_node 8 torchrun_main.py --run_name "CHTs_a_sp${sparsities[i]}+r${ranks[i]}+la$la" --model_config $model_config --dataset_name $dataset --lr $lr --batch_size 64 --total_batch_size 512 --num_training_steps $num_training_steps --warmup_steps $warmup_steps --weight_decay 0 --dtype bfloat16 --eval_every 1000 --optimizer adam --iterative_warmup_steps $iterative_warmup_steps --update_interval $update_interval --sparsity ${sparsities[i]} --only_save_last --dst_scheduler --remove_method $remove_method --regrow_method CH2_L3n_soft --zeta $zeta --adaptive_zeta --WS --ws_beta $ws_beta --no_log --log_to_file --save_dir checkpoints/ --only_save_last --sltrain --rank ${ranks[i]} --start_T 1 --end_T 3 --lora_alpha $la --act_lora --no_compute_similarity

                # # CHTss
                # HF_HOME="/modelopsnas/modelops/463248/hf_cache" HF_DATASETS_OFFLINE=1 CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python -m torch.distributed.run --standalone --nproc_per_node 8 torchrun_main.py --run_name "CHTs_s_sp${sparsities[i]}+r${ranks[i]}+la$la" --model_config $model_config --dataset_name $dataset --lr $lr --batch_size 64 --total_batch_size 512 --num_training_steps $num_training_steps --warmup_steps $warmup_steps --weight_decay 0 --dtype bfloat16 --eval_every 1000 --optimizer adam --iterative_warmup_steps $iterative_warmup_steps --update_interval $update_interval --sparsity ${sparsities[i]} --only_save_last --dst_scheduler --remove_method $remove_method --regrow_method CH2_L3n_soft --granet --granet_init_sparsity 0.5 --sparsity_distribution $sparsity_distribution --pruning_method $pruning_method --pruning_scheduler s_shape --zeta $zeta --adaptive_zeta --WS --ws_beta $ws_beta --pruning_T_end 15000 --no_log --log_to_file --save_dir checkpoints/ --only_save_last --sltrain --rank ${ranks[i]} --start_T 1 --end_T 3 --lora_alpha $la --act_lora
                
                for alw in "${alws[@]}"
                do

                    # # CHTs+az
                    # HF_HOME="/modelopsnas/modelops/463248/hf_cache" HF_DATASETS_OFFLINE=1 CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python -m torch.distributed.run --standalone --nproc_per_node 8 torchrun_main.py --run_name "CHTs_a_sp${sparsities[i]}+r${ranks[i]}+la$la" --model_config $model_config --dataset_name $dataset --lr $lr --batch_size 64 --total_batch_size 512 --num_training_steps $num_training_steps --warmup_steps $warmup_steps --weight_decay 0 --dtype bfloat16 --eval_every 1000 --optimizer adam --iterative_warmup_steps $iterative_warmup_steps --update_interval $update_interval --sparsity ${sparsities[i]} --only_save_last --dst_scheduler --remove_method $remove_method --regrow_method CH2_L3n_soft --zeta $zeta --adaptive_zeta --WS --ws_beta $ws_beta --no_log --log_to_file --save_dir checkpoints/ --only_save_last --sltrain --rank ${ranks[i]} --start_T 1 --end_T 3 --lora_alpha $la --act_lora

                    # CHTs+az
                    HF_HOME="/data/hf_cache" HF_DATASETS_OFFLINE=1 CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python -m torch.distributed.run --standalone --nproc_per_node 8 torchrun_main.py --run_name "CHTs_a_sp${sparsities[i]}+r${ranks[i]}+la$la" --model_config $model_config --dataset_name $dataset --lr $lr --batch_size 64 --total_batch_size 512 --num_training_steps $num_training_steps --warmup_steps $warmup_steps --weight_decay 0 --dtype bfloat16 --eval_every 1000 --optimizer adam --iterative_warmup_steps $iterative_warmup_steps --update_interval $update_interval --sparsity ${sparsities[i]} --only_save_last --dst_scheduler --remove_method $remove_method --regrow_method CH2_L3n_soft --zeta $zeta --adaptive_zeta --WS --ws_beta $ws_beta --no_log --log_to_file --save_dir checkpoints/ --only_save_last --sltrain --rank ${ranks[i]} --start_T 1 --end_T 3 --lora_alpha $la --act_lora --aloss "alignment_wx" --aloss_weight $alw --no_compute_similarity

                    # # CHTs+az
                    # HF_HOME="/modelopsnas/modelops/463248/hf_cache" HF_DATASETS_OFFLINE=1 CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python -m torch.distributed.run --standalone --nproc_per_node 8 torchrun_main.py --run_name "CHTs_a_sp${sparsities[i]}+r${ranks[i]}+la$la" --model_config $model_config --dataset_name $dataset --lr $lr --batch_size 64 --total_batch_size 512 --num_training_steps $num_training_steps --warmup_steps $warmup_steps --weight_decay 0 --dtype bfloat16 --eval_every 1000 --optimizer adam --iterative_warmup_steps $iterative_warmup_steps --update_interval $update_interval --sparsity ${sparsities[i]} --only_save_last --dst_scheduler --remove_method $remove_method --regrow_method CH2_L3n_soft --zeta $zeta --adaptive_zeta --WS --ws_beta $ws_beta --no_log --log_to_file --save_dir checkpoints/ --only_save_last --sltrain --rank ${ranks[i]} --start_T 1 --end_T 3 --lora_alpha $la --act_lora --aloss "orthogonal_w" --aloss_weight 0.1

                    # # CHTss
                    # HF_HOME="/modelopsnas/modelops/463248/hf_cache" HF_DATASETS_OFFLINE=1 CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python -m torch.distributed.run --standalone --nproc_per_node 8 torchrun_main.py --run_name "CHTs_s_sp${sparsities[i]}+r${ranks[i]}+la$la" --model_config $model_config --dataset_name $dataset --lr $lr --batch_size 64 --total_batch_size 512 --num_training_steps $num_training_steps --warmup_steps $warmup_steps --weight_decay 0 --dtype bfloat16 --eval_every 1000 --optimizer adam --iterative_warmup_steps $iterative_warmup_steps --update_interval $update_interval --sparsity ${sparsities[i]} --only_save_last --dst_scheduler --remove_method $remove_method --regrow_method CH2_L3n_soft --granet --granet_init_sparsity 0.5 --sparsity_distribution $sparsity_distribution --pruning_method $pruning_method --pruning_scheduler s_shape --zeta $zeta --adaptive_zeta --WS --ws_beta $ws_beta --pruning_T_end 15000 --no_log --log_to_file --save_dir checkpoints/ --only_save_last --sltrain --rank ${ranks[i]} --start_T 1 --end_T 3 --lora_alpha $la --act_lora --aloss "alignment_wx" --aloss_weight $alw
                done
            done
        done
    done

done

# ### 130M num_training_steps=40000, warmup_steps=10000 -model_config configs/llama130m.json
# ws_beta=0
# update_interval=100
# pruning_method="ri"
# remove_method="weight_magnitude_soft"
# sparsity_distribution="uniform"
# zeta=0.1
# iterative_warmup_steps=20

# model_config="configs/llama_130m.json"
# num_training_steps=20000
# warmup_steps=2000

# datasets=("openwebtext" "c4_ant")



# lrs=(3e-3)
# las=(0)

# sparsities=(0.95 0.95 0.9 0.85 0.95 0.9 0.85 0.8 0.75)
# ranks=(24 72 48 24 120 96 72 48 24)

# alws=(0.3 0.7 0.1 1.0)


# for dataset in "${datasets[@]}"
# do
#     for lr in "${lrs[@]}"
#     do
#         for la in "${las[@]}"
#         do
#             for ((i=0; i<${#sparsities[@]}; i++)); do

#                 # # static
#                 # HF_HOME="/modelopsnas/modelops/463248/hf_cache" HF_DATASETS_OFFLINE=1 CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python -m torch.distributed.run --standalone --nproc_per_node 8 torchrun_main.py --run_name "static_sp${sparsities[i]}+r${ranks[i]}+la$la" --model_config $model_config --dataset_name $dataset --lr $lr --batch_size 64 --total_batch_size 512 --num_training_steps $num_training_steps --warmup_steps $warmup_steps --weight_decay 0 --dtype bfloat16 --eval_every 1000 --optimizer adam --iterative_warmup_steps $iterative_warmup_steps --update_interval $update_interval --sparsity ${sparsities[i]} --only_save_last --dst_scheduler --static_dst --no_log --log_to_file --save_dir checkpoints/ --only_save_last --sltrain --rank ${ranks[i]} --lora_alpha $la --no_compute_similarity

#                 # # static
#                 # HF_HOME="/modelopsnas/modelops/463248/hf_cache" HF_DATASETS_OFFLINE=1 CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python -m torch.distributed.run --standalone --nproc_per_node 8 torchrun_main.py --run_name "static_sp${sparsities[i]}+r${ranks[i]}+la$la" --model_config $model_config --dataset_name $dataset --lr $lr --batch_size 64 --total_batch_size 512 --num_training_steps $num_training_steps --warmup_steps $warmup_steps --weight_decay 0 --dtype bfloat16 --eval_every 1000 --optimizer adam --iterative_warmup_steps $iterative_warmup_steps --update_interval $update_interval --sparsity ${sparsities[i]} --only_save_last --dst_scheduler --static_dst --no_log --log_to_file --save_dir checkpoints/ --only_save_last --sltrain --rank ${ranks[i]} --lora_alpha $la --act_lora --no_compute_similarity

#                 # # CHTss
#                 # HF_HOME="/modelopsnas/modelops/463248/hf_cache" HF_DATASETS_OFFLINE=1 CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python -m torch.distributed.run --standalone --nproc_per_node 8 torchrun_main.py --run_name "CHTs_s_sp${sparsities[i]}+r${ranks[i]}+la$la" --model_config $model_config --dataset_name $dataset --lr $lr --batch_size 64 --total_batch_size 512 --num_training_steps $num_training_steps --warmup_steps $warmup_steps --weight_decay 0 --dtype bfloat16 --eval_every 1000 --optimizer adam --iterative_warmup_steps $iterative_warmup_steps --update_interval $update_interval --sparsity ${sparsities[i]} --only_save_last --dst_scheduler --remove_method $remove_method --regrow_method CH2_L3n_soft --granet --granet_init_sparsity 0.5 --sparsity_distribution $sparsity_distribution --pruning_method $pruning_method --pruning_scheduler s_shape --zeta $zeta --adaptive_zeta --WS --ws_beta $ws_beta --pruning_T_end 15000 --no_log --log_to_file --save_dir checkpoints/ --only_save_last --sltrain --rank ${ranks[i]} --start_T 1 --end_T 3 --lora_alpha $la --act_lora
                
#                 for alw in "${alws[@]}"
#                 do

#                     # # CHTs+az
#                     # HF_HOME="/modelopsnas/modelops/463248/hf_cache" HF_DATASETS_OFFLINE=1 CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python -m torch.distributed.run --standalone --nproc_per_node 8 torchrun_main.py --run_name "CHTs_a_sp${sparsities[i]}+r${ranks[i]}+la$la" --model_config $model_config --dataset_name $dataset --lr $lr --batch_size 64 --total_batch_size 512 --num_training_steps $num_training_steps --warmup_steps $warmup_steps --weight_decay 0 --dtype bfloat16 --eval_every 1000 --optimizer adam --iterative_warmup_steps $iterative_warmup_steps --update_interval $update_interval --sparsity ${sparsities[i]} --only_save_last --dst_scheduler --remove_method $remove_method --regrow_method CH2_L3n_soft --zeta $zeta --adaptive_zeta --WS --ws_beta $ws_beta --no_log --log_to_file --save_dir checkpoints/ --only_save_last --sltrain --rank ${ranks[i]} --start_T 1 --end_T 3 --lora_alpha $la --act_lora

#                     # static
#                     HF_HOME="/modelopsnas/modelops/463248/hf_cache" HF_DATASETS_OFFLINE=1 CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python -m torch.distributed.run --standalone --nproc_per_node 8 torchrun_main.py --run_name "static_sp${sparsities[i]}+r${ranks[i]}+la$la" --model_config $model_config --dataset_name $dataset --lr $lr --batch_size 64 --total_batch_size 512 --num_training_steps $num_training_steps --warmup_steps $warmup_steps --weight_decay 0 --dtype bfloat16 --eval_every 1000 --optimizer adam --iterative_warmup_steps $iterative_warmup_steps --update_interval $update_interval --sparsity ${sparsities[i]} --only_save_last --dst_scheduler --static_dst --no_log --log_to_file --save_dir checkpoints/ --only_save_last --sltrain --rank ${ranks[i]} --lora_alpha $la --act_lora --aloss "alignment_wx" --aloss_weight $alw --no_compute_similarity

#                     # # CHTs+az
#                     # HF_HOME="/modelopsnas/modelops/463248/hf_cache" HF_DATASETS_OFFLINE=1 CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python -m torch.distributed.run --standalone --nproc_per_node 8 torchrun_main.py --run_name "CHTs_a_sp${sparsities[i]}+r${ranks[i]}+la$la" --model_config $model_config --dataset_name $dataset --lr $lr --batch_size 64 --total_batch_size 512 --num_training_steps $num_training_steps --warmup_steps $warmup_steps --weight_decay 0 --dtype bfloat16 --eval_every 1000 --optimizer adam --iterative_warmup_steps $iterative_warmup_steps --update_interval $update_interval --sparsity ${sparsities[i]} --only_save_last --dst_scheduler --remove_method $remove_method --regrow_method CH2_L3n_soft --zeta $zeta --adaptive_zeta --WS --ws_beta $ws_beta --no_log --log_to_file --save_dir checkpoints/ --only_save_last --sltrain --rank ${ranks[i]} --start_T 1 --end_T 3 --lora_alpha $la --act_lora --aloss "orthogonal_w" --aloss_weight 0.1

#                     # # CHTss
#                     # HF_HOME="/modelopsnas/modelops/463248/hf_cache" HF_DATASETS_OFFLINE=1 CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python -m torch.distributed.run --standalone --nproc_per_node 8 torchrun_main.py --run_name "CHTs_s_sp${sparsities[i]}+r${ranks[i]}+la$la" --model_config $model_config --dataset_name $dataset --lr $lr --batch_size 64 --total_batch_size 512 --num_training_steps $num_training_steps --warmup_steps $warmup_steps --weight_decay 0 --dtype bfloat16 --eval_every 1000 --optimizer adam --iterative_warmup_steps $iterative_warmup_steps --update_interval $update_interval --sparsity ${sparsities[i]} --only_save_last --dst_scheduler --remove_method $remove_method --regrow_method CH2_L3n_soft --granet --granet_init_sparsity 0.5 --sparsity_distribution $sparsity_distribution --pruning_method $pruning_method --pruning_scheduler s_shape --zeta $zeta --adaptive_zeta --WS --ws_beta $ws_beta --pruning_T_end 15000 --no_log --log_to_file --save_dir checkpoints/ --only_save_last --sltrain --rank ${ranks[i]} --start_T 1 --end_T 3 --lora_alpha $la --act_lora --aloss "alignment_wx" --aloss_weight $alw
#                 done
#             done
#         done
#     done

# done
