CVD=(${CUDA_VISIBLE_DEVICES//,/ })
if [ -z "$CUDA_VISIBLE_DEVICES" ]; then
    CVD=($(seq 0 $(($(nvidia-smi -L | wc -l) - 1))))
fi
NUM_GPUS=${#CVD[@]}

if [ "$NUM_GPUS" -le "1" ]; then
    echo "Requires at least 2 GPUs for unlearning."
    exit 1
elif [ $((NUM_GPUS % 2)) -ne 0 ]; then
    echo "Requires GPU count to be even for unlearning."
    exit 1
fi

export TRAIN_LLM=true

cmd="accelerate launch \
    --config_file=default_config.yaml \
    --main_process_port 2142$1 \
    --num_processes $NUM_GPUS ./train_peft.py"

dataset_name="locuslab/TOFU"
dataset_name="dataset/Glow-AI/WaterDrum-TOFU"
dataset_subset="forget_10"
output_dir_prefix=""

cmd="$cmd \
    --dataset_name=$dataset_name \
    --dataset_subset=$dataset_subset \
    --dataset_split=retain \
    --dataset_prompt_field=question \
    --dataset_response_field=answer \
    --forget_dataset_name=$dataset_name \
    --forget_dataset_subset=$dataset_subset \
    --forget_dataset_split=forget \
    --forget_dataset_prompt_field=question \
    --forget_dataset_response_field=answer \
    --duplicate_dataset_name=$dataset_name \
    --duplicate_dataset_subset=$dataset_subset \
    --duplicate_dataset_split=semantic_duplicate \
    --duplicate_dataset_prompt_field=question \
    --duplicate_dataset_response_field=answer \
    --eval_on_subsets \
    --shuffle_dataset \
        --use_lr \
        --per_device_train_batch_size 50 \
        --per_device_eval_batch_size 200 \
        --eval_strategy no \
        --max_grad_norm 1.0 "
cmd="$cmd --use_wandb"
cmd="$cmd --wandb_project=unlearn_TOFU"
# cmd="$cmd --torch_empty_cache_steps 1 "
cmd="$cmd --eval_on_start"
cmd="$cmd --use_optimizer"

run_cmd() {
    model_type=$1
    seed=$2
    lr=$3
    method=$4
    echo "Running unlearning with seed=$seed, lr=$lr, method=$method, model_type=$model_type"
    final_cmd="$curr_cmd \
        --use_$method \
        --seed $seed \
        --shuffle_seed $seed \
        --model_name=trained_models/seed_$seed/$dataset_subset/$model_type \
        --output_dir=${output_dir_prefix}trained_models/seed_$seed/$dataset_subset/$model_type \
        --learning_rate $lr \
        --wandb_run_name=${seed}_lr${lr}_$method${curr_postfix} "
    if [ "$method" != "gd" ] && [ "$method" != "ga" ]; then
        final_cmd="$final_cmd --interlace_forget"
    fi
    final_cmd="$final_cmd --num_train_epochs 10 "
    
    final_cmd=$(echo $final_cmd | tr -s ' ')
    # echo $final_cmd
    eval $final_cmd
}

for model_type in "full" #"full_semanticdup"
do
    if [ "$model_type" = "full" ]; then
        curr_cmd="$cmd \
        --logging_steps 4 \
        --eval_steps 8 \
        --longer_eval_steps 72 "
        postfix=""
    else
        curr_cmd="$cmd --add_duplicate_to_retain\
        --logging_steps 4 \
        --eval_steps 8 \
        --longer_eval_steps 80 "
        postfix="_sem"
    fi
    for seed in 4$1
    do
        method="gd" "ga" "gdiff" "kl" "scrub"
        for lr in 1e-5 1e-4 1e-3
        do
            run_cmd $model_type $seed $lr $method
        done
    done
done