model:
  base_learning_rate: 1.0e-05
  target: adapters.coadapters.CoAdapter
  params:
    adapter_configs:
      - target: ldm.modules.encoders.adapter.StyleAdapter
        cond_name: style
        pretrained: models/t2iadapter_style_sd14v1.pth
        params:
          width: 1024
          context_dim: 768
          num_head: 8
          n_layes: 3
          num_token: 8

      - target: ldm.modules.encoders.adapter.Adapter
        cond_name: canny
        pretrained: models/t2iadapter_canny_sd14v1.pth
        params:
          cin: 64
          channels: [ 320, 640, 1280, 1280 ]
          nums_rb: 2
          ksize: 1
          sk: True
          use_conv: False

      - target: ldm.modules.encoders.adapter.Adapter
        cond_name: sketch
        pretrained: models/t2iadapter_sketch_sd14v1.pth
        params:
          cin: 64
          channels: [ 320, 640, 1280, 1280 ]
          nums_rb: 2
          ksize: 1
          sk: True
          use_conv: False

      - target: ldm.modules.encoders.adapter.Adapter_light
        cond_name: color
        pretrained: models/t2iadapter_color_sd14v1.pth
        params:
          cin: 192
          channels: [ 320, 640, 1280, 1280 ]
          nums_rb: 4

      - target: ldm.modules.encoders.adapter.Adapter
        cond_name: depth
        pretrained: models/t2iadapter_depth_sd14v1.pth
        params:
          cin: 192
          channels: [ 320, 640, 1280, 1280 ]
          nums_rb: 2
          ksize: 1
          sk: True
          use_conv: False

    coadapter_fuser_config:
      target: ldm.modules.encoders.adapter.CoAdapterFuser
      params:
        unet_channels: [320, 640, 1280, 1280]
        width: 768
        num_head: 8
        n_layes: 3

    noise_schedule: linear
    linear_start: 0.00085
    linear_end: 0.0120
    num_timesteps_cond: 1
    log_every_t: 200
    timesteps: 1000
    first_stage_key: "jpg"
    cond_stage_key: "txt"
    image_size: 64
    channels: 4
    cond_stage_trainable: false   # Note: different from the one we trained before
    conditioning_key: crossattn
    scale_factor: 0.18215
    use_ema: False

    ucg_training:
      txt: 0.5

    unet_config:
      target: ldm.modules.diffusionmodules.openaimodel.UNetModel
      params:
        image_size: 32 # unused
        in_channels: 4
        out_channels: 4
        model_channels: 320
        attention_resolutions: [ 4, 2, 1 ]
        num_res_blocks: 2
        channel_mult: [ 1, 2, 4, 4 ]
        num_heads: 8
        use_spatial_transformer: True
        transformer_depth: 1
        context_dim: 768
        use_checkpoint: True
        legacy: False

    first_stage_config:
      target: ldm.models.autoencoder.AutoencoderKL
      params:
        embed_dim: 4
        monitor: val/rec_loss
        ddconfig:
          double_z: true
          z_channels: 4
          resolution: 256
          in_channels: 3
          out_ch: 3
          ch: 128
          ch_mult:
          - 1
          - 2
          - 4
          - 4
          num_res_blocks: 2
          attn_resolutions: []
          dropout: 0.0
        lossconfig:
          target: torch.nn.Identity

    cond_stage_config: #__is_unconditional__
      target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
      params:
        version: openai/clip-vit-large-patch14

data:
  target: ldm.data.dataset_laion.WebDataModuleFromConfig
  params:
    tar_base: YOUR_DATASET_ROOT  # need to change
    batch_size: 2
    num_workers: 8
    multinode: True
    train:
      shards: YOUR_TAR_LIST  # need to change
      shuffle: 10000
      image_key: jpg
      image_transforms:
      - target: torchvision.transforms.Resize
        params:
          size: 512
          interpolation: 3
      - target: torchvision.transforms.RandomCrop
        params:
          size: 512
      process:
        - target: ldm.data.utils.AddStyle
          params:
            version: openai/clip-vit-large-patch14

        - target: ldm.data.utils.AddCannyRandomThreshold
          params:
            low_threshold: 40
            high_threshold: 110
            shift_range: 10

        - target: ldm.data.utils.AddSpatialPalette
          params:
            downscale_factor: 64

        - target: ldm.data.utils.PILtoTensor

lightning:
  find_unused_parameters: False

  modelcheckpoint:
    params:
      every_n_train_steps: 5000
      save_top_k: -1
      monitor: null
  trainer:
    benchmark: True
    num_sanity_val_steps: 0
    accumulate_grad_batches: 1
    limit_val_batches: 0