export MODEL_NAME="/home4/jiaxin/ckpt/Wan2.1-Fun-1.3B-Control"
export DATASET_NAME="/home4/jiaxin/data/video/"
export DATASET_META_NAME="/home4/jiaxin/data/video/trajectory_control/dec_fps_24.jsonl,/home4/jiaxin/data/video/trajectory_control_2.0/dec_anime_fps_24.jsonl,/home4/jiaxin/data/video/trajectory_control_pexels/dec_fps_24.jsonl"
export OUTPUT_DIR="/home4/jiaxin/exp/wan2.1_fun_causal_control_animal"
export NCCL_IB_DISABLE=1
export NCCL_P2P_DISABLE=1
NCCL_DEBUG=INFO

accelerate launch --mixed_precision="bf16" --num_machines=1 --num_processes=6 --gpu_ids=2,3,4,5,6,7 --main_process_port=25010 \
  scripts/wan2.1_fun/my_train_control.py \
  --model_type="causal" \
  --denoising_step_list="(1000,757,522,0)" \
  --config_path="config/wan2.1/wan_civitai.yaml" \
  --pretrained_model_name_or_path=$MODEL_NAME \
  --train_data_dir=$DATASET_NAME \
  --train_data_meta=$DATASET_META_NAME \
  --caption_column="prompt" \
  --video_column="video" \
  --video_sample_stride="(3,3,3)" \
  --train_batch_size=20 \
  --video_repeat=1 \
  --gradient_accumulation_steps=1 \
  --dataloader_num_workers=16 \
  --num_train_epochs=40 \
  --checkpointing_steps=200 \
  --learning_rate=5e-05 \
  --lr_scheduler="constant_with_warmup" \
  --lr_warmup_steps=10 \
  --seed=42 \
  --output_dir=$OUTPUT_DIR \
  --gradient_checkpointing \
  --mixed_precision="bf16" \
  --adam_weight_decay=3e-2 \
  --adam_epsilon=1e-10 \
  --max_grad_norm=0.05 \
  --uniform_sampling \
  --train_mode="control_ref" \
  --trainable_modules "." \
  --validation_path="config/wan2.1/val_control.json" \
  --validation_steps=100 \
  --use_ema \
  --validate_before_train

# # Training command for T2V
# export MODEL_NAME="models/Diffusion_Transformer/Wan2.1-Fun-14B-InP"
# export DATASET_NAME="datasets/internal_datasets/"
# export DATASET_META_NAME="datasets/internal_datasets/metadata.json"
# export NCCL_IB_DISABLE=1
# export NCCL_P2P_DISABLE=1
# NCCL_DEBUG=INFO

# accelerate launch --mixed_precision="bf16" scripts/wan2.1_fun/train.py \
#   --config_path="config/wan2.1/wan_civitai.yaml" \
#   --pretrained_model_name_or_path=$MODEL_NAME \
#   --train_data_dir=$DATASET_NAME \
#   --train_data_meta=$DATASET_META_NAME \
#   --image_sample_size=1024 \
#   --video_sample_size=256 \
#   --token_sample_size=512 \
#   --video_sample_stride=2 \
#   --video_sample_n_frames=81 \
#   --train_batch_size=1 \
#   --video_repeat=1 \
#   --gradient_accumulation_steps=1 \
#   --dataloader_num_workers=8 \
#   --num_train_epochs=100 \
#   --checkpointing_steps=50 \
#   --learning_rate=2e-05 \
#   --lr_scheduler="constant_with_warmup" \
#   --lr_warmup_steps=100 \
#   --seed=42 \
#   --output_dir="output_dir" \
#   --gradient_checkpointing \
#   --mixed_precision="bf16" \
#   --adam_weight_decay=3e-2 \
#   --adam_epsilon=1e-10 \
#   --vae_mini_batch=1 \
#   --max_grad_norm=0.05 \
#   --random_hw_adapt \
#   --training_with_video_token_length \
#   --enable_bucket \
#   --uniform_sampling \
#   --low_vram \
#   --train_mode="normal" \
#   --trainable_modules "."