#!/bin/bash

# adjust these parameters
export CUDA_VISIBLE_DEVICES=0,1,2,3
experiment_NAME="pgd_FARE"                 # don't add model or dataset name in this name 
dataset_NAME="imagenet"                    # imagenet or imagenet100
model_NAME="dinov2_vits14_reg_lc"          # dinov2_vits14, dinov2_vitb14, dinov2_vitl14, dinov2_vitg14
mode="final"                               # setup or final   
loss_type="l2"                             # l2 , ce , kl

if [ "$mode" == "setup" ]; then
    wandb=False
elif [ "$mode" == "final" ]; then
    wandb=True
else
    echo "who are you?"
    exit 1
fi

# Set batch size based on model type
if [[ "$model_NAME" == "dinov2_vitg14" || "$model_NAME" == "dinov2_vitl14" ]]; then
    per_device_batch_size=32
else
    per_device_batch_size=128
fi

# Dataset root paths
if [ "$dataset_NAME" == "imagenet" ]; then
    dataset_root="/YOUR_ROOT_PATH/data/ILSVRC/Data/CLS-LOC"
elif [ "$dataset_NAME" == "imagenet100" ]; then
    dataset_root="/YOUR_ROOT_PATH/data/imagenet100/data"
else
    echo "Unknown dataset name: $dataset_NAME"
    exit 1
fi

echo 'Start training with DinoV2!'

rho_values=(8)

for rho in "${rho_values[@]}"
do
    EXPERIMENT_NAME="${experiment_NAME}_eps${rho}"

    echo "Starting run with eps = ${rho}"
    echo "Experiment name: ${EXPERIMENT_NAME}"

    python3 /YOUR_ROOT_PATH/src/train/dino_training/adversarial_FARE_DinoV2.py \
        --dinov2_model $model_NAME \
        --pretrained True \
        --dataset $dataset_NAME \
        --imagenet_root $dataset_root \
        --template std \
        --output_normalize False \
        --steps 21000 \
        --warmup 1925 \
        --per_device_batch_size $per_device_batch_size \
        --loss $loss_type \
        --loss_clean $loss_type \
        --clean_weight 0. \
        --inner_loss $loss_type \
        --opt adamw \
        --lr 1e-5 \
        --wd 1e-4 \
        --attack pgd \
        --norm linf \
        --rho 0.15 \
        --eps ${rho} \
        --iterations_adv 10 \
        --stepsize_adv 1 \
        --wandb $wandb \
        --output_dir /YOUR_ROOT_PATH/checkpoints/$mode/$model_NAME-$dataset_NAME \
        --experiment_name ${EXPERIMENT_NAME} \
        --log_freq 10 \
        --eval_freq 10 \
        --online True \
        --k_iter 5 \
        --lambda_net linear_mlp \
        --anchor_option orig \
        --grad_norm 1.0 \
        --lambda_lr 5e-4 \
        --rho_scheduler const \
        --return_clean_dis_freq 40 \
        --lagrangian_type scalar \
        --checkpoint_dir_head /YOUR_ROOT_PATH/checkpoints/final/dino_final/$model_NAME/clean/dino_finetuned_${dataset_NAME}_head.pth

    echo "Finished run with eps = ${rho}"
    echo "-----------------------------------"
done

echo 'All DinoV2 runs completed successfully'
