flux_path: "black-forest-labs/FLUX.1-dev"
dtype: "bfloat16"

model:
  independent_condition: false

train:
  accumulate_grad_batches: 1
  dataloader_workers: 5
  save_interval: 1000
  sample_interval: 100
  max_steps: -1
  gradient_checkpointing: true # (Turn off for faster training)
  save_path: "runs"

  # Specify the type of condition to use. 
  condition_type: "token_intergration"
  dataset:
    type: "img"
    urls:
      # (Uncomment the following lines to use more data)
      # - "https://huggingface.co/datasets/jackyhate/text-to-image-2M/resolve/main/data_512_2M/data_000040.tar"
      # - "https://huggingface.co/datasets/jackyhate/text-to-image-2M/resolve/main/data_512_2M/data_000041.tar"
      # - "https://huggingface.co/datasets/jackyhate/text-to-image-2M/resolve/main/data_512_2M/data_000042.tar"
      # - "https://huggingface.co/datasets/jackyhate/text-to-image-2M/resolve/main/data_512_2M/data_000043.tar"
      # - "https://huggingface.co/datasets/jackyhate/text-to-image-2M/resolve/main/data_512_2M/data_000044.tar"
      - "https://huggingface.co/datasets/jackyhate/text-to-image-2M/resolve/main/data_512_2M/data_000045.tar"
      - "https://huggingface.co/datasets/jackyhate/text-to-image-2M/resolve/main/data_512_2M/data_000046.tar"
    cache_name: "data_512_2M"
    condition_size: 
      - 512
      - 512
    target_size: 
      - 512
      - 512
    drop_text_prob: 0.1
    drop_image_prob: 0.1


  wandb:
    project: "OminiControl"

  lora_config:
    r: 4
    lora_alpha: 4
    init_lora_weights: "gaussian"
    target_modules: "(.*x_embedder|.*(?<!single_)transformer_blocks\\.[0-9]+\\.norm1\\.linear|.*(?<!single_)transformer_blocks\\.[0-9]+\\.attn\\.to_k|.*(?<!single_)transformer_blocks\\.[0-9]+\\.attn\\.to_q|.*(?<!single_)transformer_blocks\\.[0-9]+\\.attn\\.to_v|.*(?<!single_)transformer_blocks\\.[0-9]+\\.attn\\.to_out\\.0|.*(?<!single_)transformer_blocks\\.[0-9]+\\.ff\\.net\\.2|.*single_transformer_blocks\\.[0-9]+\\.norm\\.linear|.*single_transformer_blocks\\.[0-9]+\\.proj_mlp|.*single_transformer_blocks\\.[0-9]+\\.proj_out|.*single_transformer_blocks\\.[0-9]+\\.attn.to_k|.*single_transformer_blocks\\.[0-9]+\\.attn.to_q|.*single_transformer_blocks\\.[0-9]+\\.attn.to_v|.*single_transformer_blocks\\.[0-9]+\\.attn.to_out)"
    # (Uncomment the following lines to train less parameters while keeping the similar performance)
    # target_modules: "(.*(?<!single_)transformer_blocks\\.[0-9]+\\.norm1\\.linear|.*(?<!single_)transformer_blocks\\.[0-9]+\\.attn\\.to_k|.*(?<!single_)transformer_blocks\\.[0-9]+\\.attn\\.to_q|.*single_transformer_blocks\\.[0-9]+\\.norm\\.linear|.*single_transformer_blocks\\.[0-9]+\\.attn.to_k|.*single_transformer_blocks\\.[0-9]+\\.attn.to_q)"

  optimizer:
    type: "Prodigy"
    params:
      lr: 1
      use_bias_correction: true
      safeguard_warmup: true
      weight_decay: 0.01

  # (To use AdamW Optimizer, uncomment the following lines)
  # optimizer:
  #   type: AdamW
  #   lr: 1e-4
  #   weight_decay: 0.001