model:
  base_learning_rate: 1.0e-5
  target: sgm.models.diffusion.DiffusionEngine
  params:
    scale_factor: 0.13025
    disable_first_stage_autocast: True
    no_cond_log: True

    ckpt_config:
      target: sgm.modules.checkpoint.CheckpointEngine
      params:
        ckpt_path: checkpoints/sd_xl_base_1.0.safetensors
        pre_adapters:
          - target: sgm.modules.checkpoint.Finetuner
            params:
              keys:
                - model\.diffusion_model\.(input_blocks|middle_block|output_blocks)(\.[0-9])?\.[0-9]\.transformer_blocks\.[0-9]\.attn2\.(to_k|to_v)\.weight
          - target: sgm.modules.checkpoint.Pruner
            params:
              keys:
                - model\.diffusion_model\.label_emb\.0\.0\.weight
              slices:
                - ":, :1024"
        print_sd_keys: False
        print_model: False

    scheduler_config:
      target: sgm.lr_scheduler.LambdaLinearScheduler
      params:
        warm_up_steps: [ 1000 ]
        cycle_lengths: [ 10000000000000 ]
        f_start: [ 1.e-6 ]
        f_max: [ 1. ]
        f_min: [ 1. ]

    denoiser_config:
      target: sgm.modules.diffusionmodules.denoiser.DiscreteDenoiser
      params:
        num_idx: 1000

        scaling_config:
          target: sgm.modules.diffusionmodules.denoiser_scaling.EpsScaling

        discretization_config:
          target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization

    network_config:
      target: sgm.modules.diffusionmodules.openaimodel.UNetModel
      params:
        adm_in_channels: 1024 #2816
        num_classes: sequential
        use_checkpoint: True
        in_channels: 4
        out_channels: 4
        model_channels: 320
        attention_resolutions: [ 4, 2 ]
        num_res_blocks: 2
        channel_mult: [ 1, 2, 4 ]
        num_head_channels: 64
        use_linear_in_transformer: True
        transformer_depth: [ 1, 2, 10 ]  # note: the first is unused (due to attn_res starting at 2) 32, 16, 8 --> 64, 32, 16
        context_dim: 1664 #1280
        spatial_transformer_attn_type: softmax-xformers

    conditioner_config:
      target: sgm.modules.GeneralConditioner
      params:
        emb_models:
          # cross atn 
          - is_trainable: False
            input_key: jpg
            target: sgm.modules.encoders.modules.FrozenOpenCLIPImageEmbedder
            params:
              arch: ViT-bigG-14
              version: laion2b_s39b_b160k
              freeze: True
              repeat_to_max_len: False
              output_tokens: True
              only_tokens: True
          # vector cond
          - is_trainable: False
            input_key: original_size_as_tuple
            target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
            params:
              outdim: 256  # multiplied by two
          # vector cond
          - is_trainable: False
            input_key: crop_coords_top_left
            target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
            params:
              outdim: 256  # multiplied by two
          # # vector cond
          # - is_trainable: False
          #   input_key: target_size_as_tuple
          #   target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
          #   params:
          #     outdim: 256  # multiplied by two

    first_stage_config:
      #target: sgm.models.autoencoder.AutoencoderKLInferenceWrapper
      target: sgm.models.autoencoder.AutoencoderKL
      params:
        embed_dim: 4
        monitor: val/rec_loss
        ddconfig:
          attn_type: vanilla-xformers
          double_z: true
          z_channels: 4
          resolution: 256
          in_channels: 3
          out_ch: 3
          ch: 128
          ch_mult: [ 1, 2, 4, 4 ]
          num_res_blocks: 2
          attn_resolutions: [ ]
          dropout: 0.0
        lossconfig:
          target: torch.nn.Identity

    loss_fn_config:
      target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss
      params:
        offset_noise_level: 0.04
        sigma_sampler_config:
          target: sgm.modules.diffusionmodules.sigma_sampling.DiscreteSampling
          params:
            num_idx: 1000

            discretization_config:
              target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization
        loss_weighting_config:
          target: sgm.modules.diffusionmodules.loss_weighting.EpsWeighting

    sampler_config:
      target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler
      params:
        num_steps: 50

        discretization_config:
          target: sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization

        guider_config:
          target: sgm.modules.diffusionmodules.guiders.VanillaCFG
          params:
            scale: 5.0

data:
  target: sgm.data.dataset.StableDataModuleFromConfig
  params:
    train:
      datapipeline:
        urls:
          - s3://stability-west/sddatasets/laiocosplitv1c/
        pipeline_config:
          shardshuffle: 10000
          sample_shuffle: 10000

        preprocessors:
          - target: sdata.filters.SimpleKeyFilter
            params:
              keys: [txt, jpg]
          - target: sdata.filters.AttributeFilter
            params:
              filter_dict:
                SSCD_65: False
                is_spawning: True
                is_getty: True

        decoders:
          - pil

      loader:
        batch_size: 1
        num_workers: 4
        batched_transforms:
          - target: sdata.mappers.MultiAspectCacher
            params:
              batch_size: 16
              debug: False
              crop_coords_key: crop_coords_top_left
              target_size_key: target_size_as_tuple
              original_size_key: original_size_as_tuple
              max_pixels: 262144


lightning:
  strategy:
    target: pytorch_lightning.strategies.DDPStrategy

  modelcheckpoint:
    params:
      every_n_train_steps: 100000

  callbacks:
    metrics_over_trainsteps_checkpoint:
      params:
        every_n_train_steps: 5000

    image_logger:
      target: sgm.modules.loggers.train_logging.SampleLogger
      params:
        disabled: False
        enable_autocast: True
        batch_frequency: 2000
        max_images: 4
        increase_log_steps: True
        log_first_step: False
        log_before_first_step: True
        log_images_kwargs:
          N: 4
          num_steps:
            - 50
          ucg_keys: [ ]

  trainer:
    devices: 0,
    benchmark: False
    num_sanity_val_steps: 0
    accumulate_grad_batches: 1
    max_epochs: 1000
    precision: 16
