export http_proxy=http://9.21.0.122:11113
export https_proxy=http://9.21.0.122:11113

export CHECKPOINT_PATH="/apdcephfs_nj3/share_301053287/guanjiechen/dmd2_ckpts" # change this to your own checkpoint folder (this should be a central directory shared across nodes)
export WANDB_ENTITY="2936127220" # change this to your own wandb entity
export WANDB_PROJECT="DMD2_SPO" # change this to your own wandb project
export WANDB_API_KEY="a6c8a30b64e2f56be9823a8d41275a617c1835bd"
# wandb login $WANDB_API_KEY
# wandb online
export RANK=$1
export MASTER_IP="29.76.21.57"
export MASTER_PORT="29500" 
export MACHINE_RANK=$RANK


# NCCL setting (important!)
export NCCL_IB_GID_INDEX=3
export NCCL_IB_SL=3
export NCCL_CHECK_DISABLE=1
export NCCL_P2P_DISABLE=0
export NCCL_IB_DISABLE=0
export NCCL_LL_THRESHOLD=16384
export NCCL_IB_CUDA_SUPPORT=1
export NCCL_SOCKET_IFNAME=bond1
export UCX_NET_DEVICES=bond1
export NCCL_IB_HCA=mlx5_bond_1,mlx5_bond_5,mlx5_bond_3,mlx5_bond_7,mlx5_bond_4,mlx5_bond_8,mlx5_bond_2,mlx5_bond_6
export NCCL_COLLNET_ENABLE=0
export SHARP_COLL_ENABLE_SAT=0
export NCCL_NET_GDR_LEVEL=2
export NCCL_IB_QPS_PER_CONNECTION=4
export NCCL_IB_TC=160
export NCCL_PXN_DISABLE=1
export NCCL_DEBUG=WARN



accelerate launch --config_file ./fsdp_configs/dmd2-no-fsdp/config_rank$RANK.yaml \
--main_process_ip=$MASTER_IP \
--main_process_port=$MASTER_PORT \
main/train_sd_no_fsdp.py  \
--num_workers 8 \
--generator_lora \
--lora_rank 64 \
--lora_alpha 8 \
--generator_lr 5e-5  \
--guidance_lr 5e-5 \
--train_iters 100000000 \
--output_path  ./train_results/dmd2-spo-lora-bszie-32-prefer  \
--batch_size 2 \
--grid_size 1 \
--initialie_generator --log_iters 1000 \
--resolution 1024 \
--latent_resolution 128 \
--seed 10 \
--real_guidance_scale 8 \
--fake_guidance_scale 1.0 \
--max_grad_norm 10.0 \
--model_id /root/cgj/models/SDXL \
--ref_model_id $CHECKPOINT_PATH/SPO-SDXL_4k-p_10ep/unet \
--wandb_iters 100 \
--wandb_entity $WANDB_ENTITY \
--wandb_project $WANDB_PROJECT \
--wandb_name "prefer_spo_lora"  \
--log_loss \
--dfake_gen_update_ratio 5 \
--gradient_checkpointing \
--sdxl \
--use_fp16 \
--max_step_percent 0.98 \
--cls_on_clean_image \
--gen_cls_loss \
--gen_cls_loss_weight 5e-3 \
--guidance_cls_loss_weight 1e-2 \
--diffusion_gan \
--diffusion_gan_max_timestep 1000 \
--denoising \
--num_denoising_step 4 \
--denoising_timestep 1000 \
--backward_simulation \
--train_prompt_path $CHECKPOINT_PATH/captions_laion_score6.25.pkl \
--real_image_path /root/cgj/datasets/sdxl_vae_latents_laion_500k_lmdb_2/ \
--generator_ckpt_path /root/cgj/models/DMD2/dmd2_sdxl_4step_unet_fp16.bin

# spo: $CHECKPOINT_PATH/SPO-SDXL_4k-p_10ep \
# inversion: $CHECKPOINT_PATH/Inversion-DPO-whole \
