## 60M num_training_steps=10000, warmup_steps=1000 -model_config configs/llama_60m.json
ws_beta=0
update_interval=100
pruning_method="ri"
remove_method="weight_magnitude_soft"
sparsity_distribution="uniform"
zeta=0.1
iterative_warmup_steps=20

model_config="configs/llama_60m.json"
num_training_steps=10000
warmup_steps=1000

datasets=("openwebtext")

lrs=(3e-3)
las=(0)

sparsities=(0.975 0.925 0.875 0.825 0.775 0.725)
ranks=(88 72 56 40 24 8)

alws=(0.5)


for dataset in "${datasets[@]}"
do
    for lr in "${lrs[@]}"
    do
        for la in "${las[@]}"
        do
            for ((i=0; i<${#sparsities[@]}; i++)); do
                for alw in "${alws[@]}"
                do
                    # CHTs+az
                    HF_HOME="/data/hf_cache" HF_DATASETS_OFFLINE=1 CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python -m torch.distributed.run --standalone --nproc_per_node 8 torchrun_main.py --run_name "CHTs_a_sp${sparsities[i]}+r${ranks[i]}+la$la" --model_config $model_config --dataset_name $dataset --lr $lr --batch_size 64 --total_batch_size 512 --num_training_steps $num_training_steps --warmup_steps $warmup_steps --weight_decay 0 --dtype bfloat16 --eval_every 1000 --optimizer adam --iterative_warmup_steps $iterative_warmup_steps --update_interval $update_interval --sparsity ${sparsities[i]} --only_save_last --dst_scheduler --remove_method $remove_method --regrow_method CH2_L3n_soft --zeta $zeta --adaptive_zeta --WS --ws_beta $ws_beta --no_log --log_to_file --save_dir checkpoints/ --only_save_last --sltrain --rank ${ranks[i]} --start_T 1 --end_T 3 --lora_alpha $la --act_lora --aloss "alignment_wx" --aloss_weight $alw --no_compute_similarity
                done
            done
        done
    done

done


## 60M num_training_steps=10000, warmup_steps=1000 -model_config configs/llama_60m.json
ws_beta=0
update_interval=100
pruning_method="ri"
remove_method="weight_magnitude_soft"
sparsity_distribution="uniform"
zeta=0.1
iterative_warmup_steps=20

model_config="configs/llama_60m.json"
num_training_steps=10000
warmup_steps=1000

datasets=("c4")

lrs=(3e-3)
las=(0)

sparsities=(0.975 0.925 0.875 0.825 0.775 0.725)
ranks=(88 72 56 40 24 8)

alws=(0.3)


for dataset in "${datasets[@]}"
do
    for lr in "${lrs[@]}"
    do
        for la in "${las[@]}"
        do
            for ((i=0; i<${#sparsities[@]}; i++)); do
                for alw in "${alws[@]}"
                do
                    # CHTs+az
                    HF_HOME="/data/hf_cache" HF_DATASETS_OFFLINE=1 CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python -m torch.distributed.run --standalone --nproc_per_node 8 torchrun_main.py --run_name "CHTs_a_sp${sparsities[i]}+r${ranks[i]}+la$la" --model_config $model_config --dataset_name $dataset --lr $lr --batch_size 64 --total_batch_size 512 --num_training_steps $num_training_steps --warmup_steps $warmup_steps --weight_decay 0 --dtype bfloat16 --eval_every 1000 --optimizer adam --iterative_warmup_steps $iterative_warmup_steps --update_interval $update_interval --sparsity ${sparsities[i]} --only_save_last --dst_scheduler --remove_method $remove_method --regrow_method CH2_L3n_soft --zeta $zeta --adaptive_zeta --WS --ws_beta $ws_beta --no_log --log_to_file --save_dir checkpoints/ --only_save_last --sltrain --rank ${ranks[i]} --start_T 1 --end_T 3 --lora_alpha $la --act_lora --aloss "alignment_wx" --aloss_weight $alw --no_compute_similarity
                done
            done
        done
    done

done


## 130M num_training_steps=20000, warmup_steps=2000 -model_config configs/llama_130m.json
ws_beta=0
update_interval=100
pruning_method="ri"
remove_method="weight_magnitude_soft"
sparsity_distribution="uniform"
zeta=0.1
iterative_warmup_steps=20

model_config="configs/llama_130m.json"
num_training_steps=20000
warmup_steps=2000

datasets=("openwebtext" "c4")

lrs=(3e-3)
las=(0)

sparsities=(0.975 0.925 0.875 0.825 0.775 0.725)
ranks=(132 108 84 60 36 12)

alws=(0.5)


for dataset in "${datasets[@]}"
do
    for lr in "${lrs[@]}"
    do
        for la in "${las[@]}"
        do
            for ((i=0; i<${#sparsities[@]}; i++)); do
                for alw in "${alws[@]}"
                do
                    # CHTs+az
                    HF_HOME="/data/hf_cache" HF_DATASETS_OFFLINE=1 CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python -m torch.distributed.run --standalone --nproc_per_node 8 torchrun_main.py --run_name "CHTs_a_sp${sparsities[i]}+r${ranks[i]}+la$la" --model_config $model_config --dataset_name $dataset --lr $lr --batch_size 64 --total_batch_size 512 --num_training_steps $num_training_steps --warmup_steps $warmup_steps --weight_decay 0 --dtype bfloat16 --eval_every 1000 --optimizer adam --iterative_warmup_steps $iterative_warmup_steps --update_interval $update_interval --sparsity ${sparsities[i]} --only_save_last --dst_scheduler --remove_method $remove_method --regrow_method CH2_L3n_soft --zeta $zeta --adaptive_zeta --WS --ws_beta $ws_beta --no_log --log_to_file --save_dir checkpoints/ --only_save_last --sltrain --rank ${ranks[i]} --start_T 1 --end_T 3 --lora_alpha $la --act_lora --aloss "alignment_wx" --aloss_weight $alw --no_compute_similarity
                done
            done
        done
    done

done