#!/bin/sh

dataset_name=$1
peft=$2
target_folder=datasets

if [ "$dataset_name" = "merve/lego_sets_latest" ]; then
  instance_prompt="a TOK lego set"
  validation_prompt="a TOK lego set of an orange llama eating ramen, in the style of TOK"
  test_path="test_file_lego.txt"
elif [ "$dataset_name" = "linoyts/3d_icon" ]; then
  instance_prompt="a TOK 3d icon"
  validation_prompt="a TOK 3d icon of an orange llama eating ramen, in the style of TOK"
  test_path="test_file_icon.txt"
else
  instance_prompt="a TOK lego set"
  validation_prompt="a TOK lego set of an orange llama eating ramen, in the style of TOK"
  test_path="test_file_lego.txt"
fi

if [ "$peft" = "lora" ]; then
  rank=16
elif [ "$peft" = "oft" ]; then
  rank=16
elif [ "$peft" = "hra" ]; then
  rank=32
elif [ "$peft" = "hoft" ]; then
  rank=16
elif [ "$peft" = "shoft" ]; then
  rank=16
else
  rank=16
fi


python prepare_dataset.py \
  --huggigface_dataset $dataset_name \
  --result_folder $target_folder \
  --caption_prefix "$instance_prompt" \
  --task "dreambooth"


accelerate launch train_dreambooth_sdxl.py \
  --pretrained_model_name_or_path="stabilityai/stable-diffusion-xl-base-1.0" \
  --pretrained_vae_model_name_or_path="madebyollin/sdxl-vae-fp16-fix" \
  --dataset_name="$target_folder/$dataset_name" \
  --instance_prompt="$instance_prompt" \
  --validation_prompt="$validation_prompt" \
  --test_prompts_path=$test_path \
  --num_validation_images="6" \
  --num_test_images="10" \
  --output_dir="out/sdxl" \
  --caption_column="prompt" \
  --mixed_precision="fp16" \
  --resolution=1024 \
  --train_batch_size=4 \
  --report_to="wandb"\
  --gradient_accumulation_steps=4 \
  --gradient_checkpointing \
  --learning_rate="1e-4" \
  --text_encoder_lr="5e-6"\
  --adam_beta2=0.99 \
  --optimizer="AdamW"\
  --train_text_encoder \
  --snr_gamma=5.0 \
  --lr_scheduler="constant" \
  --lr_warmup_steps=0 \
  --rank=$rank \
  --peft_type=$peft \
  --max_train_steps=1000 \
  --checkpointing_steps=2000 \
  --seed="0" 