name: cat
manual_seed: 0
mixed_precision: fp16
gradient_accumulation_steps: 1

# dataset and data loader settings
datasets:
  train:
    name: LoraDataset
    concept_list: ./datasets/data_cfgs/MixofShow/single-concept/objects/real/cat.json
    use_caption: true
    use_mask: false
    instance_transform:
      - { type: HumanResizeCropFinalV3, size: 512, crop_p: 0.5 }
      - { type: ToTensor }
      - { type: Normalize, mean: [ 0.5 ], std: [ 0.5 ] }
      - { type: ShuffleCaption, keep_token_num: 1 }
      - { type: EnhanceText, enhance_type: object }
    replace_mapping:
      cat: <cat1> <cat2>
    batch_size_per_gpu: 1
    dataset_enlarge_ratio: 250

  val_vis:
    name: PromptDataset
    prompts: ./datasets/validation_prompts/single-concept/objects/test_cat.txt
    num_samples_per_prompt: 4
    latent_size: [ 4,64,64 ]
    replace_mapping:
      cat: <cat1> <cat2>
    batch_size_per_gpu: 4

models:
  pretrained_path: ./sd_15
  enable_edlora: true  # true means ED-LoRA, false means vallina LoRA
  finetune_cfg:
    text_embedding:
      enable_tuning: true
      lr: !!float 1e-3
    text_encoder:
      enable_tuning: true
      lora_cfg:
        rank: 4
        alpha: 1.0
        where: CLIPAttention
      lr: !!float 1e-5
    unet:
      enable_tuning: true
      lora_cfg:
        rank: 4
        alpha: 1.0
        where: Attention
      lr: !!float 1e-4
  new_concept_token: <cat1>+<cat2>
  initializer_token: <rand-0.011>+cat
  noise_offset: 0.01
  attn_reg_weight: 0.01
  reg_full_identity: true
  use_mask_loss: false
  gradient_checkpoint: false
  enable_xformers: true

# path
path:
  models: ./models

# training settings
train:
  optim_g:
    type: AdamW
    lr: !!float 0.0 # no use since we define different component lr in model
    weight_decay: 0.01
    betas: [ 0.9, 0.999 ] # align with taming

  # dropkv
  unet_kv_drop_rate: 0
  scheduler: linear
  emb_norm_threshold: !!float 5.5e-1

# validation settings
val:
  val_during_save: true
  compose_visualize: true
  alpha_list: [0, 0.7, 1.0] # 0 means only visualize embedding (without lora weight)
  sample:
    num_inference_steps: 50
    guidance_scale: 7.5

# logging settings
logger:
  print_freq: 10
  save_checkpoint_freq: !!float 10000
