CVD=(${CUDA_VISIBLE_DEVICES//,/ })
if [ -z "$CUDA_VISIBLE_DEVICES" ]; then
    CVD=($(seq 0 $(($(nvidia-smi -L | wc -l) - 1))))
fi
NUM_GPUS=${#CVD[@]}

# if [ "$NUM_GPUS" -le "1" ]; then
#     echo "Requires at least 2 GPUs for unlearning."
#     exit 1
# elif [ $((NUM_GPUS % 2)) -ne 0 ]; then
#     echo "Requires GPU count to be even for unlearning."
#     exit 1
# fi

# export TRAIN_LLM=true

# cmd="accelerate launch \
#     --config_file=default_config.yaml \
#     --main_process_port 3142$1 \
#     --num_processes $NUM_GPUS ./train_peft.py"

idx=$1
lr_pwr=$2
max_processes=3
# use_optimizer=$3

cmd=""

cmd="$cmd \
    --eval_on_start "

cmd="$cmd \
    --use_wandb "

cmd="$cmd \
    --freeze_batchnorm "

dataset_name="uoft-cs/cifar10"

prompt_field="img"
response_field="label"

split_ratio=0.1 # 1 out of 10 classes to forget

proj_name="cifar"
output_dir_prefix=""

# classes=("0" "1" "2" "3" "4" "5" "6" "7" "8" "9") #
classes=("0") #
cmd="$cmd \
    --dataset_name=$dataset_name \
    --forget_dataset_name=$dataset_name \
    --duplicate_dataset_name=$dataset_name \
    --split_ratio=$split_ratio "

cmd="$cmd \
    --num_train_epochs=50 \
    --dataset_repeat=50 \
    --per_device_train_batch_size 5000 \
    --per_device_eval_batch_size 10000 "

cmd="$cmd \
    --shuffle_dataset \
    --dataset_split=train \
    --duplicate_dataset_split=test \
    --forget_dataset_split=train \
    --dataset_prompt_field=$prompt_field \
    --forget_dataset_prompt_field=$prompt_field \
    --duplicate_dataset_prompt_field=$prompt_field \
    --dataset_response_field=$response_field \
    --forget_dataset_response_field=$response_field \
    --duplicate_dataset_response_field=$response_field "

cmd="$cmd \
    --eval_on_subsets \
    --logging_strategy=steps \
    --logging_steps 1 \
    --eval_strategy=steps \
    --eval_steps=0.02 \
    --save_strategy=steps \
    --save_steps=0.02 \
    --torch_empty_cache_steps 5 \
    --wandb_project=unlearn_$proj_name "

cmd="$cmd \
    --interlace_forget \
    --use_custom_optim \
    --use_lr \
    --max_grad_norm 1.0 "

# cmd="$cmd --full_grad"
# cmd="$cmd --distribute_B"
cmd="$cmd --distribute_B_gr_gf_norm"


postfix=""

i=0
for seed in 41 42 43
do
    for hard_probability in 0.0 0.25 0.5 0.75 1.0
    do
        if [ "$(echo "$hard_probability == 1.0" | bc)" -eq 1 ]; then
            curr_classess=("")
        else
            curr_classess=("${classes[@]}")
        fi
    for forget_class in "${curr_classess[@]}"
    do
    for lr in 1e-$lr_pwr
    do
    for use_optimizer in 0 1
    do
    if [ "$use_optimizer" == "1" ]; then
        cmd="$cmd --use_optimizer"
    fi
        for B in 0 0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9 # 
        do
            for update_type in "dual" ""
            do
                if [ "$update_type" = "dual" ]; then
                    if [ $B = 0 ] ; then
                        echo "Skipping dual update with B=0"
                        continue
                    fi
                    curr_postfix="${postfix}_dual"
                    curr_B=$(awk "BEGIN {print 1 - $B}")
                else
                    curr_postfix="${postfix}"
                    curr_B=$B
                fi
                curr_B=$(awk "BEGIN {print $lr * $curr_B}")
                final_cmd="$cmd \
                    --dataset_subset=$forget_class \
                    --forget_dataset_subset=$forget_class \
                    --duplicate_dataset_subset='' \
                    --shuffle_seed $seed \
                    --seed $seed \
                    --hard_probability $hard_probability "

                final_cmd="$final_cmd \
                    --learning_rate $lr \
                    --B $curr_B \
                    --model_name=trained_cifar_best_accuracy/seed_$seed/forget_0_0.0/full "

                final_cmd="$final_cmd \
                    --wandb_run_name=${seed}_${forget_class}_${hard_probability}_lr${lr}_B${curr_B}${curr_postfix} "
                # final_cmd="$final_cmd \
                #     --output_dir=${output_dir_prefix}unlearned_$proj_name/seed_$seed/$dataset_subset/forget_${forget_class}_${hard_probability}/${split} \
                #     --final_model_output_dir=unlearned_$proj_name/seed_${seed}/forget_${forget_class}_${hard_probability}/${split}"
                if [ "$update_type" = "dual" ]; then
                    final_cmd="$final_cmd --dual_update"
                fi

                # echo "Running unlearning with seed=$seed, lr=$lr, B=$curr_B, forget_class=$forget_class, hard_probability=$hard_probability, update_type=$update_type"

                
                if [ "$NUM_GPUS" = "1" ]; then
                    python_cmd="python ./train_peft.py"
                else
                    port=1${idx}${lr_pwr}$(($i % $max_processes))$use_optimizer
                    # echo "Using port: $port"
                    python_cmd="accelerate launch \
                        --config_file=default_config.yaml \
                        --num_processes $NUM_GPUS \
                        --main_process_port $port ./train_peft.py"
                fi

                # stagger startup with min(i*15, 60) seconds
                if [ $i -gt 0 ] ; then
                    sleep $((i * 15 < 60 ? i * 15 : 60))
                fi
                ((i++))

                final_cmd="$python_cmd $final_cmd"
                final_cmd=$(echo $final_cmd | tr -s ' ')

                # Check running jobs. If count >= max_processes, wait a bit.
                while [ "$(jobs -r | wc -l)" -ge "$max_processes" ]; do
                    sleep 1
                done

                # Execute the command in the background
                if [ "$max_processes" -eq "1" ]; then
                    # echo "$final_cmd"
                    eval $final_cmd
                    # :
                else
                    # echo "$final_cmd" &
                    eval $final_cmd &
                    # :
                fi
            done
        done
    done
    done
    done
    done
done
wait