# change from o4
model:
  target: image_synthesis.modeling.models.dalle.DALLE
  params:
    content_info: {key: image}
    condition_info: {key: text}
    content_codec_config: 
    # target: image_synthesis.modeling.codecs.image_codec.openai_dvae.OpenAIDiscreteVAE
      target: image_synthesis.modeling.codecs.image_codec.taming_gumbel_vqvae.TamingGumbelVQVAE
      params:
        trainable: False
        token_shape: [32, 32]
        config_path: '/checkpoints/pretrained_model/taming_dvae/taming_f8_8192_openimages.yaml'
        ckpt_path: '/checkpoints/pretrained_model/taming_dvae/taming_f8_8192_openimages_last.pth'
        num_tokens: 8192
        quantize_number: 2887
        mapping_path: './help_folder/statistics/taming_vqvae_2887.pt'
        # return_logits: True
    condition_codec_config:
      target: image_synthesis.modeling.codecs.text_codec.tokenize.Tokenize
      params:
       context_length: 77     ############# 77 for clip and 256 for dalle
       add_start_and_end: True
       with_mask: True
       pad_value: 0 # 0 for clip embedding and -100 for others
       clip_embedding: False     ############################   if we use clip embedding 
       tokenizer_config:
        target: image_synthesis.modeling.modules.clip.simple_tokenizer.SimpleTokenizer   #########
        params:
          end_idx: 49152                              ###################
    diffusion_config:      
    # target: image_synthesis.modeling.transformers.gpt_like_transformer.GPTLikeTransformer
      target: image_synthesis.modeling.transformers.diffusion_transformer.DiffusionTransformer
      params:
        diffusion_step: 100
        alpha_init_type: 'alpha1'       # init_type = fix or cos or linear 
        auxiliary_loss_weight: 5.0e-4
        adaptive_auxiliary_loss: True
        mask_weight: [1, 1]    # the loss weight on mask region and non-mask region

        transformer_config:
          target: image_synthesis.modeling.transformers.transformer_utils.Text2ImageTransformer
          params:
            attn_type: 'selfcross'
            n_layer: 19
            condition_seq_len: 77    ###### 77 for clip and 256 for dalle
            content_seq_len: 1024  # 32 x 32
            content_spatial_size: [32, 32]
            n_embd: 1024 # the dim of embedding dims
            condition_dim: 512
            n_head: 16 
            attn_pdrop: 0.0
            resid_pdrop: 0.0
            block_activate: GELU2
            timestep_type: 'adalayernorm'    # adainsnorm or adalayernorm and abs
            mlp_hidden_times: 4
        condition_emb_config:
          target: image_synthesis.modeling.embeddings.clip_text_embedding.CLIPTextEmbedding
          params:
            clip_name: 'ViT-B/32'
            num_embed: 49408 # 49152+256
            normalize: True
            pick_last_embedding: False   # if True same as clip but we need embedding of each word
            keep_seq_len_dim: False
            additional_last_embedding: False
            embed_dim: 512
        content_emb_config:
          target: image_synthesis.modeling.embeddings.dalle_mask_image_embedding.DalleMaskImageEmbedding
          params:
            num_embed: 2887    #should be quantize_number
            spatial_size: !!python/tuple [32, 32]
            embed_dim: 1024
            trainable: True
            pos_emb_type: embedding

solver:
  base_lr: 3.0e-6
  adjust_lr: none # not adjust lr according to total batch_size
  max_epochs: 400
  save_epochs: 30
  validation_epochs: 400
  sample_iterations: epoch  # epoch #30000      # how many iterations to perform sampling once ?
  print_specific_things: True

  # config for ema
  ema:
    decay: 0.99
    update_interval: 25
    device: cpu

  clip_grad_norm:
    target: image_synthesis.engine.clip_grad_norm.ClipGradNorm
    params:
      start_iteration: 0
      end_iteration: 5000
      max_norm: 0.5
  optimizers_and_schedulers: # a list of configures, so we can config several optimizers and schedulers
  - name: none # default is None
    optimizer:
      target: torch.optim.AdamW
      params: 
        betas: !!python/tuple [0.9, 0.96]
        weight_decay: 4.5e-2
            # target: ZeroRedundancyOptimizer
            # optimizer_class: torch.optim.AdamW
            # params:
            # betas: !!python/tuple [0.9, 0.96]
            # weight_decay: 4.5e-2
    scheduler:
      step_iteration: 1
      target: image_synthesis.engine.lr_scheduler.ReduceLROnPlateauWithWarmup
      params:
        factor: 0.5
        patience: 25000
        min_lr: 1.0e-6
        threshold: 1.0e-1
        threshold_mode: rel
        warmup_lr: 4.5e-4 # the lr to be touched after warmup
        warmup: 1000 

dataloader:
  # data_root: data
  data_root: /mnt/blob/datasets/CUB-200
  batch_size: 4
  num_workers: 4
  train_datasets: # a list of configures, so we can combine several schedulers
    - target: image_synthesis.data.cub200_dataset.Cub200Dataset
      params:
        data_root: /mnt/blob/datasets/CUB-200
        phase: train
        drop_caption_rate: 0.0
        im_preprocessor_config:
          target: image_synthesis.data.utils.image_preprocessor.DalleTransformerPreprocessor   # ImageNet
          params:
            size: 256
            phase: train
  validation_datasets:
    - target: image_synthesis.data.cub200_dataset.Cub200Dataset
      params:
        data_root: /mnt/blob/datasets/CUB-200
        phase: test
        im_preprocessor_config:
          target: image_synthesis.data.utils.image_preprocessor.DalleTransformerPreprocessor
          params:
            size: 256
            phase: val
