export MODEL_NAME="models/Wan2.1-Fun-V1.1-14B-Control/PAI/Wan2.1-Fun-V1.1-14B-Control"
export DATASET_NAME="datasets/internal_datasets/"
export DATASET_META_NAME="/xxx.txt"
# NCCL_IB_DISABLE=1 and NCCL_P2P_DISABLE=1 are used in multi nodes without RDMA. 
# export NCCL_IB_DISABLE=1
# export NCCL_P2P_DISABLE=1
NCCL_DEBUG=DEBUG

accelerate launch --mixed_precision="bf16" scripts/4D-STraG_training/train_control_i2f_mae_depth.py \
  --config_path="config/wan2.1/wan_civitai_mae_depth.yaml" \
  --pretrained_model_name_or_path=$MODEL_NAME \
  --train_data_dir=$DATASET_NAME \
  --train_data_meta=$DATASET_META_NAME \
  --image_sample_size=1024 \
  --video_sample_size=456 \
  --token_sample_size=512 \
  --checkpoints_total_limit 5 \
  --video_sample_stride=1 \
  --video_sample_n_frames=49 \
  --train_batch_size=3 \
  --video_repeat=1 \
  --gradient_accumulation_steps=1 \
  --dataloader_num_workers=1 \
  --num_train_epochs=10 \
  --checkpointing_steps=50 \
  --learning_rate=2e-05 \
  --lr_scheduler="constant_with_warmup" \
  --lr_warmup_steps=100 \
  --seed=42 \
  --vae_ckpt_dir="xxx/ckpt" \
  --output_dir="output_dir/" \
  --gradient_checkpointing \
  --mixed_precision="bf16" \
  --adam_weight_decay=3e-2 \
  --adam_epsilon=1e-10 \
  --vae_mini_batch=1 \
  --max_grad_norm=0.05 \
  --training_with_video_token_length \
  --enable_bucket \
  --uniform_sampling \
  --train_mode="control_ref" \
  --control_ref_image="first_frame" \
  --resume_from_checkpoint="latest" \
  --transformer_path="xxx/transformer" \
  --trainable_modules "." \
  --add_full_ref_image_in_self_attention \
  --normalize_track_z \
  --max_sample_dataset 6000 \
  --scene_flow_smoothness_loss \
  --freeze_wan \
  --use_omnimae_guidance \
