# change from o4
model:
  target: image_synthesis.modeling.models.conditional_dalle.C_DALLE
  params:
    content_info: {key: image}
    condition_info: {key: label}
    content_codec_config: 
      target: image_synthesis.modeling.codecs.image_codec.taming_gumbel_vqvae.TamingVQVAE
      params:
        trainable: False
        token_shape: [16, 16]
        config_path: '/checkpoints/pretrained_model/taming_dvae/vqgan_imagenet_f16_16384.yaml'
        ckpt_path: '/checkpoints/pretrained_model/taming_dvae/vqgan_imagenet_f16_16384.pth'
        num_tokens: 16384
        quantize_number: 974
        mapping_path: './help_folder/statistics/taming_vqvae_974.pt'
        # return_logits: True
    diffusion_config:      
      target: image_synthesis.modeling.transformers.diffusion_transformer.DiffusionTransformer
      params:
        diffusion_step: 100
        alpha_init_type: 'alpha1'        
        auxiliary_loss_weight: 1.0e-3
        adaptive_auxiliary_loss: True
        mask_weight: [1, 1]    # the loss weight on mask region and non-mask region

        transformer_config:
          target: image_synthesis.modeling.transformers.transformer_utils.Condition2ImageTransformer
          params:
            attn_type: 'selfcondition'
            n_layer: 24
            class_type: 'adalayernorm'
            class_number: 1000
            content_seq_len: 256  # 16 x 16
            content_spatial_size: [16, 16]
            n_embd: 512 # the dim of embedding dims   # both this and content_emb_config
            n_head: 16 
            attn_pdrop: 0.0
            resid_pdrop: 0.0
            block_activate: GELU2
            timestep_type: 'adalayernorm'    # adainsnorm or adalayernorm and abs
            mlp_hidden_times: 4
            mlp_type: 'conv_mlp'
        condition_emb_config:
          target: image_synthesis.modeling.embeddings.class_embedding.ClassEmbedding
          params:
            num_embed: 1000 # 
            embed_dim: 512
            identity: True
        content_emb_config:
          target: image_synthesis.modeling.embeddings.dalle_mask_image_embedding.DalleMaskImageEmbedding
          params:
            num_embed: 974
            spatial_size: !!python/tuple [32, 32]
            embed_dim: 512
            trainable: True
            pos_emb_type: embedding

solver:
  base_lr: 3.0e-6
  adjust_lr: none # not adjust lr according to total batch_size
  max_epochs: 100
  save_epochs: 2
  validation_epochs: 100
  sample_iterations: epoch  # epoch #30000      # how many iterations to perform sampling once ?
  print_specific_things: True

  # config for ema
  ema:
    decay: 0.99
    update_interval: 25
    device: cpu

  clip_grad_norm:
    target: image_synthesis.engine.clip_grad_norm.ClipGradNorm
    params:
      start_iteration: 0
      end_iteration: 5000
      max_norm: 0.5
  optimizers_and_schedulers: # a list of configures, so we can config several optimizers and schedulers
  - name: none # default is None
    optimizer:
      target: torch.optim.AdamW
      params: 
        betas: !!python/tuple [0.9, 0.96]
        weight_decay: 4.5e-2
    scheduler:
      step_iteration: 1
      target: image_synthesis.engine.lr_scheduler.ReduceLROnPlateauWithWarmup
      params:
        factor: 0.5
        patience: 100000
        min_lr: 1.0e-6
        threshold: 1.0e-1
        threshold_mode: rel
        warmup_lr: 4.5e-4 # the lr to be touched after warmup
        warmup: 5000 

dataloader:
  data_root: /mnt/blob/datasets/ImageNet-2012
  batch_size: 32
  num_workers: 4
  train_datasets: # a list of configures, so we can combine several schedulers
    - target: image_synthesis.data.imagenet_dataset.ImageNetDataset
      params:
        data_root: /mnt/blob/datasets/ImageNet-2012
        phase: train                                          ###########
        input_file: ILSVRC2012_name_train.txt                   ###########
        drop_caption_rate: 0.0
        im_preprocessor_config:
          target: image_synthesis.data.utils.image_preprocessor.DalleTransformerPreprocessor
          params:
            size: 256
            phase: train
  validation_datasets:
    - target: image_synthesis.data.imagenet_dataset.ImageNetDataset
      params:
        data_root: /mnt/blob/datasets/ImageNet-2012
        phase: val
        input_file: ILSVRC2012_name_val.txt
        im_preprocessor_config:
          target: image_synthesis.data.utils.image_preprocessor.DalleTransformerPreprocessor
          params:
            size: 256
            phase: val
