# from launchers/sd15.sh
export MODEL_NAME="runwayml/stable-diffusion-v1-5"
export DATASET_NAME="yuvalkirstain/pickapic_v2"

input_arg="$1"

# system
num_process=2

# path
cache_dir="PATH_FOR_TRAINING_DATASET"
base_data_dir="PATH_FOR_TRAINING_DATA" 
base_output_dir="PATH_FOR_OUTPUT_DIR" 
image_to_reward_path="${base_data_dir}/PATH_FOR_REWARD_OF_EACH_IMAGE"

# train hyperparam
batch_size=8 

# schedular
lr_scheduler="piecewise_constant"
rule="1:200,0.25:400,0.1" 
warmup_step=10
lr=1e-7
accum_step=8

# steps
max_train_steps=500 #1000
checkpointing_steps=100 


gpu_ids="0,1"
port_num=25600

train_file_list=(
  "TRAIN_FILE_NAME"
                  )
resume_from_checkpoint="" # "checkpoint-10000"


for train_file in "${train_file_list[@]}" ; do
  train_data_path="${base_data_dir}/${train_file}.json"
  model_type="sd15_${train_file}"
  output_dir="${base_output_dir}/${model_type}"

  CUDA_VISIBLE_DEVICES=${gpu_ids} accelerate launch --mixed_precision="fp16" --main_process_port=${port_num} --num_processes=${num_process} --gpu_ids=${gpu_ids} train.py \
    --pretrained_model_name_or_path=$MODEL_NAME \
    --dataset_name=$DATASET_NAME \
    --train_batch_size=${batch_size} \
    --dataloader_num_workers=16 \
    --gradient_accumulation_steps=${accum_step} \
    --max_train_steps=${max_train_steps} \
    --lr_scheduler=${lr_scheduler} \
    --learning_rate=${lr} --scale_lr \
    --checkpointing_steps=${checkpointing_steps} \
    --beta_dpo 5000 \
    --cache_dir=${cache_dir} \
    --output_dir=${output_dir} \
    --train_data_path=$train_data_path \
    --resume_from_checkpoint=${resume_from_checkpoint} \
    --image_to_reward_path=${image_to_reward_path} \
    --lr_warmup_steps=${warmup_step} \
    --lr_scheduler_rule=${rule} 
    
done