CVD=(${CUDA_VISIBLE_DEVICES//,/ })
if [ -z "$CUDA_VISIBLE_DEVICES" ]; then
    CVD=($(seq 0 $(($(nvidia-smi -L | wc -l) - 1))))
fi
NUM_GPUS=${#CVD[@]}

# if [ "$NUM_GPUS" = "1" ]; then
#     cmd="python ./train_peft.py"
# else
#     cmd="accelerate launch --config_file=default_config.yaml --num_processes $NUM_GPUS --main_process_port 3142$1 ./train_peft.py"
# fi
cmd=""

model_name="resnet20-cifar"
dataset_name="uoft-cs/cifar10"

prompt_field="img"
response_field="label"

split_ratio=0.1 # 1 out of 10 classes to forget
learning_rate=1e-3

proj_name="cifar"
output_dir_prefix=""

classes=("0") # "1" "2" "3" "4" "5" "6" "7" "8" "9") #

max_processes=2

cmd="$cmd \
    --model_name=$model_name \
    --dataset_name=$dataset_name \
    --forget_dataset_name=$dataset_name \
    --duplicate_dataset_name=$dataset_name \
    --split_ratio=$split_ratio "

cmd="$cmd \
    --learning_rate $learning_rate \
    --num_train_epochs=50 \
    --dataset_repeat=50 \
    --per_device_train_batch_size 5000 \
    --per_device_eval_batch_size 10000 "

cmd="$cmd \
    --shuffle_dataset \
    --dataset_split=train \
    --duplicate_dataset_split=test \
    --forget_dataset_split=train \
    --dataset_prompt_field=$prompt_field \
    --forget_dataset_prompt_field=$prompt_field \
    --duplicate_dataset_prompt_field=$prompt_field \
    --dataset_response_field=$response_field \
    --forget_dataset_response_field=$response_field \
    --duplicate_dataset_response_field=$response_field "

cmd="$cmd \
    --eval_on_subsets \
    --logging_strategy=steps \
    --logging_steps 1 \
    --eval_strategy=steps \
    --eval_steps=0.1 \
    --save_strategy=steps \
    --save_steps=0.1 \
    --torch_empty_cache_steps 5 \
    --wandb_project=train_$proj_name "

i=0
for seed in 41 42 43
do
    for hard_probability in 0.0 0.25 0.5 0.75 1.0
    do
        if [ "$(echo "$hard_probability == 1.0" | bc)" -eq 1 ]; then
            curr_classess=("")
        else
            curr_classess=("${classes[@]}")
        fi
        for forget_class in "${curr_classess[@]}"
        do
            # Train full, retrained model
            for split in retain full
            do
                final_cmd="$cmd \
                    --dataset_subset=$forget_class \
                    --forget_dataset_subset=$forget_class \
                    --duplicate_dataset_subset='' \
                    --shuffle_seed $seed \
                    --seed $seed \
                    --hard_probability $hard_probability "
                final_cmd="$final_cmd \
                    --use_wandb \
                    --wandb_run_name=${seed}_${forget_class}_${hard_probability}_${split} "
                final_cmd="$final_cmd \
                    --output_dir=${output_dir_prefix}trained_$proj_name/seed_${seed}/forget_${forget_class}_${hard_probability}/${split} \
                    --final_model_output_dir=trained_$proj_name/seed_${seed}/forget_${forget_class}_${hard_probability}/${split}"
                if [ "$split" = "full" ]; then
                    final_cmd="$final_cmd \
                        --add_forget_to_train"
                fi

                echo "$subset $seed $forget_class $split"

                final_cmd=$(echo $final_cmd | tr -s ' ')

                if [ "$NUM_GPUS" = "1" ]; then
                    python_cmd="python ./train_peft.py"
                else
                    port=314$(($i % $max_processes))
                    # echo "Using port: $port"
                    python_cmd="accelerate launch \
                        --config_file=default_config.yaml \
                        --num_processes $NUM_GPUS \
                        --main_process_port $port ./train_peft.py"
                fi

                final_cmd="$python_cmd $final_cmd"
                final_cmd=$(echo $final_cmd | tr -s ' ')

                # stagger startup with min(i*15, 60) seconds
                if [ $i -gt 0 ] ; then
                    sleep $((i * 15 < 60 ? i * 15 : 60))
                fi
                ((i++))

                # Check running jobs. If count >= max_processes, wait a bit.
                while [ "$(jobs -r | wc -l)" -ge "$max_processes" ]; do
                    sleep 1
                done

                # Execute the command in the background
                if [ "$max_processes" -eq "1" ]; then
                    echo "$final_cmd"
                    eval $final_cmd
                else
                    # echo "$final_cmd" &
                    eval $final_cmd &
                fi
            done
        done
    done
done
wait