# Set common variables
export CUDA_VISIBLE_DEVICES=2
export EXPERIMENT_NAME="ABLATION"
export MODEL_PATH="stabilityai/stable-diffusion-2-1-base"
export CLASS_DIR="data/class-person"
export BASE_OUTPUT_DIR="dreambooth-outputs/"
# Array of subject IDs to process

SUBJECTS=(
"n000050" "n000061" "n000076" "n000088" "n000097" "n000104" "n000138" "n000145" "n000154" "n000170" "n000176" "n000181" "n000187" "n000215" "n000221" "n000228" "n000238"
"n000057" "n000063" "n000080" "n000089" "n000098" "n000105" "n000139" "n000146"
)
# SUBJECTS=(
# "n000161" "n000171" "n000179" "n000184" "n000188" "n000217" "n000223" "n000234" "n000243"
# "n000058" "n000068" "n000087" "n000090" "n000103" "n000110" "n000142" "n000150" "n000164" "n000172" "n000180" "n000185" "n000190" "n000220" "n000225" "n000236"
# )

DATASET_ROOT=""
subdir="promptDist_1e-3"

for SUBJECT_ID in "${SUBJECTS[@]}"; do
    echo "Processing subject: $SUBJECT_ID"
    
    export CLEAN_TRAIN_DIR="$DATASET_ROOT/$SUBJECT_ID/set_A"
    export REF_MODEL_PATH="outputs/PRE_EXP/${SUBJECT_ID}_REFERENCE"

    echo "Training DreamBooth model on set A for $SUBJECT_ID"
    accelerate launch train_dreambooth.py \
        --pretrained_model_name_or_path=$MODEL_PATH \
        --enable_xformers_memory_efficient_attention \
        --train_text_encoder \
        --instance_data_dir=$CLEAN_TRAIN_DIR \
        --class_data_dir=$CLASS_DIR \
        --output_dir=$REF_MODEL_PATH \
        --with_prior_preservation \
        --prior_loss_weight=1.0 \
        --instance_prompt="a photo of sks person" \
        --class_prompt="a photo of person" \
        --inference_prompt="a photo of sks person;a dslr portrait of sks person" \
        --resolution=512 \
        --train_batch_size=1 \
        --gradient_accumulation_steps=1 \
        --learning_rate=5e-7 \
        --lr_scheduler="constant" \
        --lr_warmup_steps=0 \
        --num_class_images=200 \
        --max_train_steps=1000 \
        --checkpointing_steps=1000 \
        --center_crop \
        --mixed_precision=bf16 \
        --prior_generation_precision=bf16 \
        --sample_batch_size=16 \
        --gradient_checkpointing \
        --set_grads_to_none

    echo "Training FSMG on set A for $SUBJECT_ID"
    export CLEAN_ADV_DIR="$DATASET_ROOT/$SUBJECT_ID/set_A"
    export OUTPUT_DIR="outputs/$EXPERIMENT_NAME/${SUBJECT_ID}_ADVERSARIAL/$subdir"
    mkdir -p $OUTPUT_DIR
    
    accelerate launch attacks/fsmg_cE_promptDist.py \
        --pretrained_model_name_or_path="$REF_MODEL_PATH/checkpoint-1000" \
        --enable_xformers_memory_efficient_attention \
        --train_text_encoder \
        --instance_data_dir=$CLEAN_ADV_DIR \
        --output_dir=$OUTPUT_DIR \
        --instance_prompt="a photo of sks person" \
        --resolution=512 \
        --gradient_accumulation_steps=1 \
        --max_train_steps=200 \
        --checkpointing_steps=50 \
        --center_crop \
        --pgd_alpha=5e-3 \
        --inner_iters=10 \
        --mixed_precision=bf16 \
        --gradient_checkpointing

    # Set the instance directory for current iteration
    export INSTANCE_DIR="outputs/$EXPERIMENT_NAME/${SUBJECT_ID}_ADVERSARIAL/set_C/$subdir"    

    # ! remember to change path in code when change dataset
    python add_perturbation.py \
        --subject_id=$SUBJECT_ID \
        --perturbation_path="$OUTPUT_DIR/final/universal_perturbation.pt" \
        --output_dir=$INSTANCE_DIR
    
    # Set output directory for current iteration
    export DREAMBOOTH_OUTPUT_DIR="$BASE_OUTPUT_DIR/${SUBJECT_ID}/${subdir}"

    # Create output directory if it doesn't exist
    mkdir -p "$DREAMBOOTH_OUTPUT_DIR"
    
    # Run the training command
    accelerate launch train_dreambooth.py \
        --pretrained_model_name_or_path=$MODEL_PATH \
        --enable_xformers_memory_efficient_attention \
        --train_text_encoder \
        --instance_data_dir=$INSTANCE_DIR \
        --class_data_dir=$CLASS_DIR \
        --output_dir=$DREAMBOOTH_OUTPUT_DIR \
        --with_prior_preservation \
        --prior_loss_weight=1.0 \
        --instance_prompt="a photo of sks person" \
        --class_prompt="a photo of person" \
        --inference_prompt="a photo of sks person;a dslr portrait of sks person" \
        --resolution=512 \
        --train_batch_size=1 \
        --gradient_accumulation_steps=1 \
        --learning_rate=5e-7 \
        --lr_scheduler="constant" \
        --lr_warmup_steps=0 \
        --num_class_images=200 \
        --max_train_steps=1000 \
        --checkpointing_steps=1000 \
        --center_crop \
        --mixed_precision=bf16 \
        --prior_generation_precision=bf16 \
        --sample_batch_size=16 \
        --gradient_checkpointing \
        --set_grads_to_none

    echo "Running inference for $SUBJECT_ID"
    python infer.py \
        --model_path $DREAMBOOTH_OUTPUT_DIR/checkpoint-1000 \
        --output_dir $DREAMBOOTH_OUTPUT_DIR/results \
        --inference_prompt="a photo of sks person;a dslr portrait of sks person;a photo of sks person looking at the mirror"
    
    echo "deleting model for saving storage"
    rm -r $DREAMBOOTH_OUTPUT_DIR/checkpoint-1000
    rm -r $DREAMBOOTH_OUTPUT_DIR/logs

    echo "Completed processing for $SUBJECT_ID"
    echo "----------------------------------------"
done

echo "All subjects have been processed"