# change from o4
model:
  target: image_synthesis.modeling.models.unconditional_dalle.UC_DALLE
  params:
    content_info: {key: image}
    content_codec_config: 
      target: image_synthesis.modeling.codecs.image_codec.taming_gumbel_vqvae.TamingFFHQVQVAE
      params:
        trainable: False
        token_shape: [16, 16]
        config_path: '/checkpoints/pretrained_model/taming_dvae/vqgan_ffhq_f16_1024.yaml'
        ckpt_path: '/checkpoints/pretrained_model/taming_dvae/vqgan_ffhq_f16_1024.pth'
        num_tokens: 1024
        quantize_number: 0
        mapping_path: None
        # return_logits: True
    diffusion_config:      
      target: image_synthesis.modeling.transformers.diffusion_transformer.DiffusionTransformer
      params:
        diffusion_step: 100
        alpha_init_type: 'alpha1'        
        auxiliary_loss_weight: 1.0e-3
        adaptive_auxiliary_loss: True
        mask_weight: [1, 1]    # the loss weight on mask region and non-mask region

        transformer_config:
          target: image_synthesis.modeling.transformers.transformer_utils.UnCondition2ImageTransformer
          params:
            attn_type: 'self'
            n_layer: 20
            content_seq_len: 256  # 32 x 32
            content_spatial_size: [16, 16]
            n_embd: 512 # the dim of embedding dims   # both this and content_emb_config
            n_head: 16 
            attn_pdrop: 0.0
            resid_pdrop: 0.0
            block_activate: GELU2
            timestep_type: 'adalayernorm'    # adainsnorm or adalayernorm and abs
            mlp_hidden_times: 4
            mlp_type: 'conv_mlp'
        content_emb_config:
          target: image_synthesis.modeling.embeddings.dalle_mask_image_embedding.DalleMaskImageEmbedding
          params:
            num_embed: 1024
            spatial_size: !!python/tuple [32, 32]
            embed_dim: 512
            trainable: True
            pos_emb_type: embedding

solver:
  base_lr: 3.0e-6
  adjust_lr: none # not adjust lr according to total batch_size
  max_epochs: 300
  save_epochs: 5
  validation_epochs: 300
  sample_iterations: 3000  # epoch #30000      # how many iterations to perform sampling once ?
  print_specific_things: True

  # config for ema
  ema:
    decay: 0.99
    update_interval: 10
    device: cpu

  clip_grad_norm:
    target: image_synthesis.engine.clip_grad_norm.ClipGradNorm
    params:
      start_iteration: 0
      end_iteration: 5000
      max_norm: 0.5
  optimizers_and_schedulers: # a list of configures, so we can config several optimizers and schedulers
  - name: none # default is None
    optimizer:
      target: torch.optim.AdamW
      params: 
        betas: !!python/tuple [0.9, 0.96]
        weight_decay: 4.5e-2
    scheduler:
      step_iteration: 1
      target: image_synthesis.engine.lr_scheduler.ReduceLROnPlateauWithWarmup
      params:
        factor: 0.5
        patience: 5000
        min_lr: 1.0e-6
        threshold: 1.0e-1
        threshold_mode: rel
        warmup_lr: 4.5e-4 # the lr to be touched after warmup
        warmup: 2000 

dataloader:
  data_root: /mnt/blob/datasets/FFHQ/image256_PIL
  batch_size: 32
  num_workers: 8
  train_datasets: # a list of configures, so we can combine several schedulers
    - target: image_synthesis.data.ffhq_dataset.FFHQDataset
      params:
        data_root: /mnt/blob/datasets/FFHQ/image256_PIL
        drop_caption_rate: 0.0
        im_preprocessor_config:
          target: image_synthesis.data.utils.image_preprocessor.DalleTransformerPreprocessor
          params:
            size: 256
            phase: train
  validation_datasets:
    - target: image_synthesis.data.ffhq_dataset.FFHQDataset
      params:
        data_root: /mnt/blob/datasets/FFHQ/image256_PIL
        im_preprocessor_config:
          target: image_synthesis.data.utils.image_preprocessor.DalleTransformerPreprocessor
          params:
            size: 256
            phase: val
