CVD=(${CUDA_VISIBLE_DEVICES//,/ })
if [ -z "$CUDA_VISIBLE_DEVICES" ]; then
    CVD=($(seq 0 $(($(nvidia-smi -L | wc -l) - 1))))
fi
NUM_GPUS=${#CVD[@]}

idx=$1
lr_pwr=$2
max_processes=3

cmd=""

cmd="$cmd \
    --eval_on_start "

cmd="$cmd \
    --use_wandb "

cmd="$cmd \
    --freeze_batchnorm "

dataset_name="uoft-cs/cifar10"

prompt_field="img"
response_field="label"

split_ratio=0.1 # 1 out of 10 classes to forget

proj_name="cifar"
output_dir_prefix=""

# classes=("0" "1" "2" "3" "4" "5" "6" "7" "8" "9") #
classes=("0") #
cmd="$cmd \
    --dataset_name=$dataset_name \
    --forget_dataset_name=$dataset_name \
    --duplicate_dataset_name=$dataset_name \
    --split_ratio=$split_ratio "

cmd="$cmd \
    --num_train_epochs=50 \
    --dataset_repeat=50 \
    --per_device_train_batch_size 5000 \
    --per_device_eval_batch_size 10000 "

cmd="$cmd \
    --shuffle_dataset \
    --dataset_split=train \
    --duplicate_dataset_split=test \
    --forget_dataset_split=train \
    --dataset_prompt_field=$prompt_field \
    --forget_dataset_prompt_field=$prompt_field \
    --duplicate_dataset_prompt_field=$prompt_field \
    --dataset_response_field=$response_field \
    --forget_dataset_response_field=$response_field \
    --duplicate_dataset_response_field=$response_field "

cmd="$cmd \
    --eval_on_subsets \
    --logging_strategy=steps \
    --logging_steps 1 \
    --eval_strategy=steps \
    --eval_steps=0.02 \
    --save_strategy=steps \
    --save_steps=0.02 \
    --torch_empty_cache_steps 5 \
    --wandb_project=unlearn_$proj_name "

cmd="$cmd \
    --use_lr \
    --max_grad_norm 1.0 "

# cmd="$cmd --full_grad"
# cmd="$cmd --distribute_B"
cmd="$cmd --distribute_B_gr_gf_norm"

postfix=""

i=0
for seed in 41 42 43
do
    for hard_probability in 0.0 0.25 0.5 0.75 1.0
    do
        if [ "$(echo "$hard_probability == 1.0" | bc)" -eq 1 ]; then
            curr_classess=("")
        else
            curr_classess=("${classes[@]}")
        fi
    for forget_class in "${curr_classess[@]}"
    do
    for lr in 1e-$lr_pwr
    do
    for use_optimizer in 0 #1
    do
    if [ "$use_optimizer" == "1" ]; then
        cmd="$cmd --use_optimizer"
    fi
        for method in "kl" "scrub" "gdiff" "gd" "ga" 
        do
                final_cmd="$cmd \
                    --dataset_subset=$forget_class \
                    --forget_dataset_subset=$forget_class \
                    --duplicate_dataset_subset='' \
                    --shuffle_seed $seed \
                    --seed $seed \
                    --hard_probability $hard_probability "

                final_cmd="$final_cmd \
                    --learning_rate $lr \
                    --model_name=trained_cifar_best_accuracy/seed_$seed/forget_0_0.0/full "

                final_cmd="$final_cmd \
                    --wandb_run_name=${seed}_${forget_class}_${hard_probability}_lr${lr}_$method${curr_postfix} "

                final_cmd="$final_cmd \
                    --use_$method "
                if [ "$method" != "gd" ] && [ "$method" != "ga" ]; then
                    final_cmd="$final_cmd --interlace_forget"
                fi

                # echo "Running unlearning with seed=$seed, lr=$lr, B=$curr_B, forget_class=$forget_class, hard_probability=$hard_probability, update_type=$update_type"

                
                if [ "$NUM_GPUS" = "1" ]; then
                    python_cmd="python ./train_peft.py"
                else
                    port=1${idx}${lr_pwr}$(($i % $max_processes))$use_optimizer
                    python_cmd="accelerate launch \
                        --config_file=default_config.yaml \
                        --num_processes $NUM_GPUS \
                        --main_process_port $port ./train_peft.py"
                fi

                # stagger startup with min(i*15, 60) seconds
                if [ $i -gt 0 ] ; then
                    sleep $((i * 15 < 60 ? i * 15 : 60))
                fi
                ((i++))

                final_cmd="$python_cmd $final_cmd"
                final_cmd=$(echo $final_cmd | tr -s ' ')

                # Check running jobs. If count >= max_processes, wait a bit.
                while [ "$(jobs -r | wc -l)" -ge "$max_processes" ]; do
                    sleep 1
                done

                # Execute the command in the background
                if [ "$max_processes" -eq "1" ]; then
                    eval $final_cmd
                else
                    eval $final_cmd &
                fi
        done
    done
    done
    done
    done
done
wait