#!/bin/bash

testing_scripts=${1:-""}

partition="train"
suffix=""
regularizer=None
if [ ${testing_scripts} = "do_test" ]; then
  partition="validation"
elif [ ${testing_scripts} = "fisher" ]; then
  suffix="_fisher"
  regularizer="fisher"
elif [ ${testing_scripts} = "l2" ]; then
  suffix="_l2"
  regularizer="l2"
elif [ ${testing_scripts} = "lens" ]; then
  suffix="_lens"
  regularizer="lens"
else
  suffix="_$testing_scripts"
  regularizer=$testing_scripts
fi
# python -m prune.run_prune \
# python -m prune.run_prune \
# dataset=wiki
# for model_id in mistralai/Mistral-7B-Instruct-v0.3; do # meta-llama/Llama-2-7b-chat-hf mistralai/Mistral-7B-Instruct-v0.3 
#   if [ ${model_id} = "mistralai/Mistral-7B-Instruct-v0.2" ]; then
#     output_dir=./outputs/mistral-7b/${dataset}_structural_sparsity
#     model_type=mistral
#   elif [ ${model_id} = "mistralai/Mistral-7B-Instruct-v0.3" ]; then
#     output_dir=./outputs/mistralv3-7b/${dataset}_structural_sparsity
#     model_type=mistral
#   elif [ ${model_id} = "Qwen/Qwen3-8B" ]; then
#     output_dir=./outputs/qwen3-8b/${dataset}_structural_sparsity
#     model_type=qwen3
#   elif [ ${model_id} = "Qwen/Qwen3-1.7B" ]; then
#     output_dir=./outputs/qwen3-8b/${dataset}_structural_sparsity
#     model_type=qwen3
#   elif [ ${model_id} = "meta-llama/Llama-3.1-8B-Instruct" ]; then
#     output_dir=./outputs/llama-8b/${dataset}_structural_sparsity
#     model_type=llama
#   elif [ ${model_id} = "meta-llama/Llama-3.2-1B-Instruct" ]; then
#     output_dir=./outputs/llama-1b/${dataset}_structural_sparsity
#     model_type=llama
#   elif [ ${model_id} = "meta-llama/Llama-3.2-3B-Instruct" ]; then
#     output_dir=./outputs/llama-3b/${dataset}_structural_sparsity
#     model_type=llama
#   elif [ ${model_id} = "meta-llama/Llama-2-7b-hf" ]; then
#     output_dir=./outputs/llama-7b/${dataset}_structural_sparsity
#     model_type=llama
#   fi

#   if [ ${dataset} = 'wiki' ]; then
#     task=wikitext
#   elif [ ${dataset} = 'c4' ]; then
#     task=allenai/c4
#   fi

#   for pruning_ratio in 0.0; do # 0.6 0.7 0.8 0.9; do
#     for num_extra_neurons in 1; do # 1 2 4 8; do
#       CUDA_VISIBLE_DEVICES=2,3,4,5,6,7 torchrun --nnode 1 --nproc_per_node 6 -m prune.run_prune \
#           --model_name_or_path $model_id \
#           --dataset_name $task \
#           --dataset_config_name wikitext-2-raw-v1 \
#           --output_dir ${output_dir}_${num_extra_neurons}_${pruning_ratio} \
#           --partition $partition \
#           --do_train \
#           --bf16 \
#           --learning_rate 1e-5 \
#           --num_train_epochs 10 \
#           --per_device_train_batch_size 8 \
#           --per_device_eval_batch_size 8 \
#           --logging_strategy steps \
#           --logging_steps 100 \
#           --eval_strategy epoch \
#           --save_strategy epoch \
#           --load_best_model_at_end False \
#           --save_total_limit 3 \
#           --max_grad_norm  1.0 \
#           --seed 1337 \
#           --block_size 128 \
#           --model_type ${model_type} \
#           --num_extra_neurons $num_extra_neurons \
#           --pruning_ratio $pruning_ratio \
#           --input_dependent False \
#           --use_peft False \
#           --regularizer $regularizer 
#       wait
#     done
#   done
# done
# for rank in 4 8 16
# do
#   CUDA_VISIBLE_DEVICES=0,1,2,3,4,5 torchrun --nnode 1 --nproc_per_node 6 -m prune.run_prune \
#       --model_name_or_path meta-llama/Llama-3.1-8B-Instruct \
#       --dataset_name wikitext \
#       --dataset_config_name wikitext-2-raw-v1 \
#       --output_dir ./outputs/lora-llama-8b/wiki_rank_${rank} \
#       --partition train \
#       --do_train \
#       --bf16 \
#       --learning_rate 1e-5 \
#       --num_train_epochs 10 \
#       --per_device_train_batch_size 8 \
#       --per_device_eval_batch_size 8 \
#       --logging_strategy steps \
#       --logging_steps 100 \
#       --eval_strategy epoch \
#       --save_strategy epoch \
#       --load_best_model_at_end False \
#       --save_total_limit 3 \
#       --max_grad_norm  1.0 \
#       --seed 1337 \
#       --block_size 128 \
#       --model_type llama \
#       --num_extra_neurons $rank \
#       --pruning_ratio 0.5 \
#       --input_dependent False \
#       --lora_finetuning True \
#       --regularizer $regularizer 
# done

# for rank in 4 8 16; do #llama-8b 
#     model_name="./neural-pruning/outputs/lora-llama-8b/wiki_rank_${rank}"
#     CUDA_VISIBLE_DEVICES=1 python -m eval.eval \
#         --model_name "${model_name}" \
#         --performance_dir ./performance/ \
#         --task_name piqa \
#         --num_extra_neurons $rank \
#         --pruning_ratio 0.5 \
#         --lora True \ &
#     CUDA_VISIBLE_DEVICES=2 python -m eval.eval \
#         --model_name "${model_name}" \
#         --performance_dir ./performance/ \
#         --task_name winogrande \
#         --num_extra_neurons $rank \
#         --pruning_ratio 0.5 \
#         --lora True \ &
#     CUDA_VISIBLE_DEVICES=3 python -m eval.eval \
#         --model_name "${model_name}" \
#         --performance_dir ./performance/ \
#         --task_name boolq \
#         --num_extra_neurons $rank \
#         --pruning_ratio 0.5 \
#         --lora True \ &
#     # CUDA_VISIBLE_DEVICES=5 python -m eval.eval \
#     #     --model_name "${model_name}" \
#     #     --task_name hellaswag \
#     #     --num_extra_neurons 1 \
#     #     --pruning_ratio 0.6 \
#     #     --baseline False &
#     CUDA_VISIBLE_DEVICES=6 python -m eval.eval \
#         --model_name "${model_name}" \
#         --performance_dir ./performance/ \
#         --task_name arc_easy \
#         --num_extra_neurons $rank \
#         --pruning_ratio 0.5 \
#         --lora True \ &
#     CUDA_VISIBLE_DEVICES=0 python -m eval.eval \
#         --model_name "${model_name}" \
#         --performance_dir ./performance/ \
#         --task_name arc_challenge \
#         --num_extra_neurons $rank \
#         --pruning_ratio 0.5 \
#         --lora True \ &
#     wait
        
# done
# for model in llama-7b; do # mistral-7b llama-8b mistralv3-7b 
#     model_name="./neural-pruning/outputs/${model}/wiki_structural_sparsity_1_0.5"
#     for task in openbookqa piqa winogrande boolq sciq hellaswag arc_easy arc_challenge; do # sciq hellaswag arc_easy arc_challenge
#         for num_extra_neurons in 1; do
#             python -m eval.eval \
#                 --model_name "${model_name}" \
#                 --task_name ${task} \
#                 --num_extra_neurons ${num_extra_neurons} \
#                 --baseline False 
#         done
#     done
# done

#   for layer_idx in 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15; do
#     for pruning_ratio in 0.5; do # 0.6 0.7 0.8 0.9; do
#       for num_extra_neurons in 1; do # 1 2 4 8; do
#         CUDA_VISIBLE_DEVICES=1,2,3,4,6,7 torchrun --nnode 1 --nproc_per_node 6 -m prune.run_prune \
#             --model_name_or_path $model_id \
#             --dataset_name wikitext \
#             --dataset_config_name wikitext-2-raw-v1 \
#             --output_dir ${output_dir}_${num_extra_neurons}_${layer_idx}.down_${pruning_ratio} \
#             --partition $partition \
#             --do_train \
#             --bf16 \
#             --learning_rate 1e-5 \
#             --num_train_epochs 10 \
#             --per_device_train_batch_size 8 \
#             --per_device_eval_batch_size 8 \
#             --logging_strategy steps \
#             --logging_steps 100 \
#             --eval_strategy epoch \
#             --save_strategy epoch \
#             --load_best_model_at_end False \
#             --save_total_limit 3 \
#             --max_grad_norm  1.0 \
#             --seed 1337 \
#             --block_size 128 \
#             --model_type ${model_type} \
#             --num_extra_neurons $num_extra_neurons \
#             --pruning_ratio $pruning_ratio \
#             --input_dependent False \
#             --spontaneous_strategy "layers.${layer_idx}.mlp.down" \
#             --regularizer $regularizer 
#         wait
#       done
#     done
#   done
# done
# allenai/c4 wikitext\
# CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun --nnode 1 --nproc_per_node 4 -m prune.run_prune \
#     --model_name_or_path deepseek-ai/DeepSeek-V2-Lite \
#     --dataset_name wikitext \
#     --dataset_config_name wikitext-2-raw-v1 \
#     --output_dir ./outputs/deepseekv2-spon-wiki-1e-6/ \
#     --partition train \
#     --do_train \
#     --bf16 \
#     --learning_rate 1e-6 \
#     --num_train_epochs 5 \
#     --per_device_train_batch_size 16 \
#     --per_device_eval_batch_size 16 \
#     --logging_strategy steps \
#     --logging_steps 100 \
#     --eval_strategy epoch \
#     --save_strategy epoch \
#     --preprocessing_num_workers  4 \
#     --dataloader_num_workers 4 \
#     --load_best_model_at_end False \
#     --save_total_limit 1 \
#     --max_grad_norm  1.0 \
#     --seed 1337 \
#     --pruning_ratio 0.0 \
#     --block_size 256 \
#     --model_type deepseekv2 \
#     --trust_remote_code True \
#     --num_extra_neurons 1 \
#     --regularizer $regularizer 

# CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun --nnode 1 --nproc_per_node 4 -m prune.run_prune \
#     --model_name_or_path allenai/OLMoE-1B-7B-0125-Instruct \
#     --dataset_name wikitext \
#     --dataset_config_name wikitext-2-raw-v1 \
#     --output_dir ./outputs/olmoe-spon-wiki-1e-6/ \
#     --partition train \
#     --do_train \
#     --bf16 \
#     --learning_rate 1e-6 \
#     --num_train_epochs 5 \
#     --per_device_train_batch_size 16 \
#     --per_device_eval_batch_size 16 \
#     --logging_strategy steps \
#     --logging_steps 100 \
#     --eval_strategy epoch \
#     --save_strategy epoch \
#     --preprocessing_num_workers  4 \
#     --dataloader_num_workers 4 \
#     --load_best_model_at_end False \
#     --save_total_limit 1 \
#     --max_grad_norm  1.0 \
#     --seed 1337 \
#     --pruning_ratio 0.0 \
#     --block_size 512 \
#     --model_type olmoe \
#     --num_extra_neurons 1 \
#     --regularizer $regularizer 
# wait

# CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun --nnode 1 --nproc_per_node 4 -m prune.run_prune \
#     --model_name_or_path deepseek-ai/DeepSeek-V2-Lite \
#     --dataset_name allenai/c4 \
#     --dataset_config_name wikitext-2-raw-v1 \
#     --output_dir ./outputs/deepseekv2-spon-c4/ \
#     --partition train \
#     --do_train \
#     --bf16 \
#     --learning_rate 1e-5 \
#     --num_train_epochs 3 \
#     --per_device_train_batch_size 16 \
#     --per_device_eval_batch_size 16 \
#     --logging_strategy steps \
#     --logging_steps 100 \
#     --eval_strategy epoch \
#     --save_strategy epoch \
#     --preprocessing_num_workers  4 \
#     --dataloader_num_workers 4 \
#     --load_best_model_at_end False \
#     --save_total_limit 1 \
#     --max_grad_norm  1.0 \
#     --seed 1337 \
#     --pruning_ratio 0.0 \
#     --block_size 256 \
#     --model_type deepseekv2 \
#     --trust_remote_code True \
#     --num_extra_neurons 1 \
#     --regularizer $regularizer 


# CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun --nnode 1 --nproc_per_node 4 -m prune.run_prune \
#     --model_name_or_path meta-llama/Llama-3.3-70B-Instruct \
#     --dataset_name wikitext \
#     --dataset_config_name wikitext-2-raw-v1 \
#     --output_dir ./outputs/llama-70b-spon-wiki/ \
#     --partition train \
#     --do_train \
#     --bf16 \
#     --learning_rate 1e-6 \
#     --num_train_epochs 5 \
#     --per_device_train_batch_size 1 \
#     --per_device_eval_batch_size 1 \
#     --logging_strategy steps \
#     --logging_steps 100 \
#     --eval_strategy epoch \
#     --save_strategy steps \
#     --save_steps 3527 \
#     --preprocessing_num_workers  4 \
#     --dataloader_num_workers 4 \
#     --load_best_model_at_end False \
#     --save_total_limit 1 \
#     --max_grad_norm  1.0 \
#     --seed 1337 \
#     --pruning_ratio 0.0 \
#     --block_size 16 \
#     --model_type llama \
#     --trust_remote_code True \
#     --num_extra_neurons 1 \
#     --spontaneous_strategy down \
#     --regularizer $regularizer 

# CUDA_VISIBLE_DEVICES=0,1,2,3 deepspeed --num_gpus 4 --module prune.run_prune \
#     --model_name_or_path Qwen/Qwen3-32B \
#     --dataset_name wikitext \
#     --dataset_config_name wikitext-2-raw-v1 \
#     --output_dir ./outputs/qwen3-32b-spon-wiki/ \
#     --partition train \
#     --do_train \
#     --bf16 \
#     --learning_rate 1e-6 \
#     --num_train_epochs 1 \
#     --per_device_train_batch_size 1 \
#     --per_device_eval_batch_size 1 \
#     --logging_strategy steps \
#     --logging_steps 100 \
#     --eval_strategy epoch \
#     --save_strategy steps \
#     --save_steps 3527 \
#     --preprocessing_num_workers  4 \
#     --dataloader_num_workers 4 \
#     --load_best_model_at_end False \
#     --save_total_limit 1 \
#     --max_grad_norm  1.0 \
#     --seed 1337 \
#     --pruning_ratio 0.0 \
#     --block_size 64 \
#     --model_type qwen3 \
#     --trust_remote_code True \
#     --num_extra_neurons 1 \
#     --regularizer $regularizer \
#     --deepspeed ./deepspeed_config/ds_config.json

export PYTORCH_CUDA_ALLOC_CONF="expandable_segments:True"

CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun --nnode 1 --nproc_per_node 4 -m prune.run_prune \
    --model_name_or_path meta-llama/Llama-3.3-70B-Instruct \
    --dataset_name wikitext \
    --dataset_config_name wikitext-2-raw-v1 \
    --output_dir ./outputs/llama-70b-spon-wiki/ \
    --partition train \
    --do_train \
    --fsdp "full_shard auto_wrap" \
    --fsdp_config '{"fsdp_transformer_layer_cls_to_wrap": "LlamaDecoderLayer", "mixed_precision": {"param_dtype": "bfloat16","reduce_dtype": "bfloat16","buffer_dtype": "bfloat16"}}' \
    --bf16 True \
    --learning_rate 1e-6 \
    --num_train_epochs 1 \
    --per_device_train_batch_size 4 \
    --per_device_eval_batch_size 4 \
    --logging_strategy steps \
    --logging_steps 1 \
    --eval_strategy epoch \
    --save_strategy steps \
    --save_steps 3527 \
    --preprocessing_num_workers 4 \
    --dataloader_num_workers 4 \
    --load_best_model_at_end False \
    --save_total_limit 1 \
    --max_grad_norm 1.0 \
    --seed 1337 \
    --pruning_ratio 0.0 \
    --block_size 64 \
    --model_type llama \
    --trust_remote_code True \
    --num_extra_neurons 1 \
    --regularizer $regularizer