export OMP_NUM_THREADS=1
# uncomment the following line in case the base SD model is not retrievable
# export HF_ENDPOINT=https://hf-api.gitee.com

#DEFAULT_NUM_GPUS=8
#DEFAULT_BATCH_SIZE=8
DEFAULT_BATCH_SIZE=1
DEFAULT_NUM_GPUS=1
##################################################################################
DEFAULT_SG_LEARNING_RATE=0.000004
DEFAULT_G_LEARNING_RATE=0.000006
DEFAULT_CONCEPT_TO_FORGET="brad pitt"
DEFAULT_CONCEPT_TO_OVERRIDE="a middle aged man"
DEFAULT_FORGET_DATA_PROMPT_TEXT="./prompts/brad_pitt_prompts.txt"
DEFAULT_FORGET_DATA_PROMPT_TEXT_VAL="./prompts/brad_pitt_prompts_val.txt"
DEFAULT_OVERRIDE_DATA_PROMPT_TEXT=
DEFAULT_POS_DATA_PROMPT_TEXT_VAL=
DEFAULT_NEG_DATA_PROMPT_TEXT_VAL=
DEFAULT_SG_REMAIN_COEF=1.0
DEFAULT_SG_FORGET_COEF=1.0
DEFAULT_G_REMAIN_COEF=1.0
DEFAULT_G_FORGET_COEF=1.0
DEFAULT_SID_W_NEG=0
DEFAULT_USE_NEG=0,0,1
DEFAULT_SG_W_OVERRIDE=0

NUM_GPUS=${NUM_GPUS:-$DEFAULT_NUM_GPUS}
BATCH_SIZE=${BATCH_SIZE:-$DEFAULT_BATCH_SIZE}
BATCH_SIZE_PER_GPU=$((BATCH_SIZE / NUM_GPUS))
SG_LEARNING_RATE=${SG_LEARNING_RATE:-$DEFAULT_SG_LEARNING_RATE}
G_LEARNING_RATE=${G_LEARNING_RATE:-$DEFAULT_G_LEARNING_RATE}
CONCEPT_TO_FORGET=${CONCEPT_TO_FORGET-$DEFAULT_CONCEPT_TO_FORGET}
CONCEPT_TO_OVERRIDE=${CONCEPT_TO_OVERRIDE-$DEFAULT_CONCEPT_TO_OVERRIDE}
FORGET_DATA_PROMPT_TEXT=${FORGET_DATA_PROMPT_TEXT-$DEFAULT_FORGET_DATA_PROMPT_TEXT}
FORGET_DATA_PROMPT_TEXT_VAL=${FORGET_DATA_PROMPT_TEXT_VAL-$DEFAULT_FORGET_DATA_PROMPT_TEXT_VAL}
OVERRIDE_DATA_PROMPT_TEXT=${OVERRIDE_DATA_PROMPT_TEXT:-$DEFAULT_OVERRIDE_DATA_PROMPT_TEXT}
POS_DATA_PROMPT_TEXT=${POS_DATA_PROMPT_TEXT:-$DEFAULT_POS_DATA_PROMPT_TEXT}
NEG_DATA_PROMPT_TEXT=${NEG_DATA_PROMPT_TEXT:-$DEFAULT_NEG_DATA_PROMPT_TEXT}
SG_REMAIN_COEF=${SG_REMAIN_COEF:-$DEFAULT_SG_REMAIN_COEF}
SG_FORGET_COEF=${SG_FORGET_COEF:-$DEFAULT_SG_FORGET_COEF}
G_REMAIN_COEF=${G_REMAIN_COEF:-$DEFAULT_G_REMAIN_COEF}
G_FORGET_COEF=${G_FORGET_COEF:-$DEFAULT_G_FORGET_COEF}
SID_W_NEG=${SID_W_NEG:-$DEFAULT_SID_W_NEG}
USE_NEG=${USE_NEG:-$DEFAULT_USE_NEG}
SG_W_OVERRIDE=${SG_W_OVERRIDE:-$DEFAULT_SG_W_OVERRIDE}

echo "NUM_GPUS:"$NUM_GPUS
echo "BATCH_SIZE:"$BATCH_SIZE
echo "BATCH_SIZE_PER_GPU":$BATCH_SIZE_PER_GPU
echo "SG_LEARNING_RATE:"$SG_LEARNING_RATE
echo "G_LEARNING_RATE:"$G_LEARNING_RATE
echo "RESUME_TRAINING:"${RESUME_TRAINING-None}
echo "FORGET_DATA_PROMPT_TEXT:"$FORGET_DATA_PROMPT_TEXT
echo "FORGET_DATA_PROMPT_TEXT_VAL:"$FORGET_DATA_PROMPT_TEXT_VAL
echo "CONCEPT_TO_FORGET:"$CONCEPT_TO_FORGET
echo "CONCEPT_TO_OVERRIDE:"$CONCEPT_TO_OVERRIDE
echo "OVERRIDE_DATA_PROMPT_TEXT:"$OVERRIDE_DATA_PROMPT_TEXT
echo "POS_DATA_PROMPT_TEXT:"$POS_DATA_PROMPT_TEXT
echo "NEG_DATA_PROMPT_TEXT:"$NEG_DATA_PROMPT_TEXT
echo "SG_REMAIN_COEF:"$SG_REMAIN_COEF
echo "SG_FORGET_COEF:"$SG_FORGET_COEF
echo "G_REMAIN_COEF:"$G_REMAIN_COEF
echo "G_FORGET_COEF:"$G_FORGET_COEF
echo "SID_W_NEG:"$SID_W_NEG
echo "USE_NEG:"$USE_NEG
echo "SG_W_OVERRIDE:"$SG_W_OVERRIDE
echo "SEED:"$SEED

torchrun \
--standalone \
--nproc_per_node "${NUM_GPUS}" \
    train.py \
    --outdir 'image_experiment/df2-stage2-train-runs/' \
    --train_mode 1 \
    --cfg_train_fake 4.5 \
    --cfg_eval_fake 4.5 \
    --cfg_eval_real 4.5 \
    --optimizer 'adamw' \
    --data_prompt_text './' \
    --resolution 512 \
    --alpha 1 \
    --init_timestep 625 \
    --batch "${BATCH_SIZE}" \
    --fp16 1 \
    --batch-gpu "${BATCH_SIZE_PER_GPU}" \
    --sd_model 'runwayml/stable-diffusion-v1-5' \
    --tick 1 \
    --snap 5 \
    --dump 5 \
    --lr "${SG_LEARNING_RATE}" \
    --glr "${G_LEARNING_RATE}" \
    --duration 1 \
    --ema 0.0 \
    --sg_remain_coef "${SG_REMAIN_COEF}" \
    --sg_forget_coef "${SG_FORGET_COEF}" \
    --g_remain_coef "${G_REMAIN_COEF}" \
    --g_forget_coef "${G_FORGET_COEF}" \
    --forget_data_prompt_text "${FORGET_DATA_PROMPT_TEXT}" \
    --forget_data_prompt_text_val "${FORGET_DATA_PROMPT_TEXT_VAL}" \
    --concept_to_forget "${CONCEPT_TO_FORGET}" \
    --concept_to_override "${CONCEPT_TO_OVERRIDE}" \
    --from_distill_ema 'batch512_cfg4.54.54.5_t625_8380_v2.pkl' \
    --resume "${RESUME_TRAINING}" \
    --override_data_prompt_text "${OVERRIDE_DATA_PROMPT_TEXT}" \
    --pos_data_prompt_text "${POS_DATA_PROMPT_TEXT}" \
    --neg_data_prompt_text "${NEG_DATA_PROMPT_TEXT}" \
    --sid_w_neg "${SID_W_NEG}" \
    --use_neg "${USE_NEG}" \
    --sg_w_override "${SG_W_OVERRIDE}" \
    --seed 0
#read -p "Press any key to continue..."

