#!/bin/bash

# Default values
DATA_ROOT=""
VECTOR_NUM=1
EPOCHS=1000
OUTPUT_DIR=""
ST_CLS=0
END_CLS=10
GPU_ID=0        

# Parse command line arguments
while [[ "$#" -gt 0 ]]; do
    case $1 in
        --data_root) DATA_ROOT="$2"; shift ;;
        --vector_num) VECTOR_NUM="$2"; shift ;;
        --epochs) EPOCHS="$2"; shift ;;
        --model_name) MODEL_NAME="$2"; shift ;;
        --output_dir) OUTPUT_DIR="$2"; shift ;;
        --placeholder_file) PLACEHOLDER_FILE="$2"; shift ;;
        --initializer_file) INITIALIZER_FILE="$2"; shift ;;
        --st_cls) ST_CLS="$2"; shift ;;         
        --end_cls) END_CLS="$2"; shift ;;       
        --cuda) GPU_ID="$2"; shift ;;       
        *) echo "Unknown parameter passed: $1"; exit 1 ;;
    esac
    shift
done

if [[ -z "$DATA_ROOT" || -z "$OUTPUT_DIR" || -z "$PLACEHOLDER_FILE" || -z "$INITIALIZER_FILE" ]]; then
    echo "Data root, output directory, placeholder file, and initializer file are required."
    exit 1
fi

# Read data from placeholder and initializer files into arrays
readarray -t placeholder_lst < "${PLACEHOLDER_FILE}"
readarray -t initializer_lst < "${INITIALIZER_FILE}"

# Iterate over the classes
cls_lst=( $(ls -1 $DATA_ROOT | sort) )
for (( cls_ind=0; cls_ind<${#cls_lst[@]}; cls_ind++ )); do
    mkdir -p "$OUTPUT_DIR/${cls_lst[$cls_ind]}"
    out_dir="$OUTPUT_DIR/${cls_lst[$cls_ind]}"
    train_dir="$DATA_ROOT/${cls_lst[$cls_ind]}"
    echo initial token: ${initializer_lst[$cls_ind+$ST_CLS]} check      
    echo placeholder token: ${placeholder_lst[$cls_ind+$ST_CLS]} check      
    echo train_dir: $train_dir check        
    SHIFTED_CLS=$(( cls_ind + ST_CLS ))
    echo shifted_cls: $SHIFTED_CLS check    

    # Construct and execute the command
    
    command="CUDA_VISIBLE_DEVICES=$GPU_ID accelerate launch textual_inversion.py\
    --pretrained_model_name_or_path=$MODEL_NAME\
    --train_data_dir=$train_dir\
    --learnable_property='object'\
    --placeholder_token='${placeholder_lst[$SHIFTED_CLS]}'\
    --initializer_token='${initializer_lst[$SHIFTED_CLS]}'\
    --resolution=512\
    --train_batch_size=1\
    --gradient_accumulation_steps=4\
    --max_train_steps=$EPOCHS\
    --learning_rate=5.0e-04\
    --scale_lr\
    --lr_scheduler=constant\
    --lr_warmup_steps=0\
    --output_dir=$out_dir\
    --save_steps=500\
    --num_vectors=$VECTOR_NUM"

    echo "Executing command for ${cls_lst[$cls_ind]}"
    eval $command
done