#!/usr/bin/env bash

datasets=("MNIST" "EMNIST" "Fashion_MNIST")
regrow_methods=("CH2_L3n_soft") 
REMOVE_METHODS=(weight_magnitude_soft weight_magnitude ri ri_soft)
pruning_methods=("ri")
pruning_schedules=("granet" "s_shape")
sparsity_distribution=("uniform")
sparsitys=(0.99)
seeds=(0 1 2)
gpu_ids=(0 1 2 3 4)
gpu_count=${#gpu_ids[@]}
max_parallel_jobs=75
running_jobs=0
total_jobs=0
deltas=(0 0.25 0.5 0.75 1.0)
degdists=(uniform fixed)

for dataset in "${datasets[@]}"; do
    # Conditionally set dim based on dataset
    if [ "$dataset" = "CIFAR10" ]; then
        dim=1
    elif [ "$dataset" = "MNIST" ]; then
        dim=2
    else
        # Default or add more conditions if necessary
        dim=2
    fi
    for dd in "${degdists[@]}"; do
    for d in "${deltas[@]}"; do
    for seed in "${seeds[@]}"; do
        for regrow_method in "${regrow_methods[@]}"; do
            for remove_method in "${REMOVE_METHODS[@]}"; do

                for sparsity in "${sparsitys[@]}"; do
                    gpu_id=$((total_jobs % gpu_count))
                    echo "Launching: dataset=$dataset, seed=$seed, remove=$remove_method, sparsity=$sparsity on GPU $gpu_id"
                    python run.py \
                        --batch_size 32 \
                        --dataset "$dataset" \
                        --network_structure mlp \
                        --weight_decay 5e-04 \
                        --regrow_method "$regrow_method" \
                        --init_mode swi \
                        --record_anp \
                        --linearlr \
                        --WS3 \
                        --delta "$d" \
                        --degree_dist "$dd" \
                        --epochs 100 \
                        --learning_rate 0.025 \
                        --cuda_device "$gpu_id" \
                        --dim "$dim" \
                        --update_interval 1 \
                        --self_correlated_sparse \
                        --no_log \
                        --check_exist \
                        --chain_removal \
                        --zeta 0.3 \
                        --remove_method "$remove_method" \
                        --seed "$seed" \
                        --sparsity "$sparsity" \
                        --adaptive_zeta \
                        --dst_scheduler &

                    running_jobs=$((running_jobs + 1))
                    total_jobs=$((total_jobs + 1))

                    # If we've launched the maximum tasks at once, wait for them to finish
                    if [ "$running_jobs" -eq $((gpu_count * 10)) ]; then
                        wait
                        running_jobs=0
                    fi
                done
            done
        done
    done
done
done
done

wait 

# for dataset in "${datasets[@]}"; do
#     # Conditionally set dim based on dataset
#     if [ "$dataset" = "CIFAR10" ]; then
#         dim=1
#     elif [ "$dataset" = "MNIST" ]; then
#         dim=2
#     else
#         # Default or add more conditions if necessary
#         dim=2
#     fi

#     for seed in "${seeds[@]}"; do
#         for regrow_method in "${regrow_methods[@]}"; do
#             for remove_method in "${remove_methods[@]}"; do
#                 for sparsity in "${sparsitys[@]}"; do
#                     for ps in "${pruning_schedules[@]}"; do
#                         for sd in "${sparsity_distribution[@]}"; do
#                             gpu_id=$((total_jobs % gpu_count))

#                             python run.py \
#                                 --batch_size 32 \
#                                 --dataset "$dataset" \
#                                 --network_structure mlp \
#                                 --weight_decay 5e-04 \
#                                 --regrow_method "$regrow_method" \
#                                 --init_mode swi \
#                                 --record_anp \
#                                 --linearlr \
#                                 --epochs 100 \
#                                 --learning_rate 0.025 \
#                                 --cuda_device "$gpu_id" \
#                                 --dim "$dim" \
#                                 --update_interval 1 \
#                                 --self_correlated_sparse \
#                                 --no_log \
#                                 --chain_removal \
#                                 --zeta 0.3 \
#                                 --remove_method "$remove_method" \
#                                 --seed "$seed" \
#                                 --sparsity "$sparsity" \
#                                 --adaptive_zeta \
#                                 --granet \
#                                 --check_exist \
#                                 --granet_init_sparsity 0.5 \
#                                 --pruning_method ri \
#                                 --pruning_scheduler "$ps" \
#                                 --sparsity_distribution "$sd" \
#                                 --dst_scheduler &

#                             running_jobs=$((running_jobs + 1))
#                             total_jobs=$((total_jobs + 1))

#                             # If we've launched the maximum tasks at once, wait for them to finish
#                             if [ "$running_jobs" -eq $((gpu_count * 10)) ]; then
#                                 wait
#                                 running_jobs=0
#                             fi
#                         done
#                     done
#                 done
#             done
#         done
#     done
# done
# wait