# from launchers/sd15.sh
export MODEL_NAME="stabilityai/stable-diffusion-xl-base-1.0"
export VAE="madebyollin/sdxl-vae-fp16-fix"
export DATASET_NAME="yuvalkirstain/pickapic_v2"

input_arg="$1"

# system
num_process=2

# path
cache_dir="PATH_FOR_TRAINING_DATASET"
base_data_dir="PATH_FOR_TRAINING_DATA" 
base_output_dir="PATH_FOR_OUTPUT_DIR" 
image_to_reward_path="${base_data_dir}/PATH_FOR_REWARD_OF_EACH_IMAGE"

# train hyperparam
batch_size=1

# schedular
lr_scheduler="piecewise_constant"
rule="1:200,0.1" 
warmup_step=5
lr_list=(2e-8)
accum_list=(64)

# steps
max_train_steps=100
checkpointing_steps=50



gpu_ids="0,1"
port_num=25600

train_file_list=(
  "TRAIN_FILE_NAME"
  )
resume_from_checkpoint="" # "checkpoint-10000"


for train_file in "${train_file_list[@]}" ; do
  for i in "${!lr_list[@]}"; do
    lr=${lr_list[$i]}
    accum_step=${accum_list[$i]}
    train_data_path="${base_data_dir}/${train_file}.json"
    model_type="sdxl_${train_file}"
    output_dir="${base_output_dir}/${model_type}"


    CUDA_VISIBLE_DEVICES=${gpu_ids} accelerate launch --main_process_port=${port_num} --num_processes=${num_process} train.py \
      --pretrained_model_name_or_path=$MODEL_NAME \
      --pretrained_vae_model_name_or_path=$VAE \
      --dataset_name=$DATASET_NAME \
      --train_batch_size=${batch_size} \
      --dataloader_num_workers=16 \
      --gradient_accumulation_steps=${accum_step} \
      --max_train_steps=${max_train_steps} \
      --lr_scheduler=${lr_scheduler} \
      --learning_rate=${lr} --scale_lr \
      --checkpointing_steps=${checkpointing_steps} \
      --beta_dpo 5000 \
      --cache_dir=${cache_dir} \
      --output_dir=${output_dir} \
      --train_data_path=$train_data_path \
      --resume_from_checkpoint=${resume_from_checkpoint} \
      --image_to_reward_path=${image_to_reward_path} \
      --lr_warmup_steps=${warmup_step} \
      --lr_scheduler_rule=${rule} \
      --sdxl  

  done
done