# change from o4
model:
  target: image_synthesis.modeling.models.dalle.DALLE
  params:
    content_info: {key: image}
    condition_info: {key: text}
    content_codec_config: 
    # target: image_synthesis.modeling.codecs.image_codec.openai_dvae.OpenAIDiscreteVAE
    # target: image_synthesis.modeling.codecs.image_codec.taming_gumbel_vqvae.TamingGumbelVQVAE
      target: image_synthesis.modeling.codecs.image_codec.ema_vqvae.PatchVQVAE
      params:
        trainable: False
        token_shape: [32, 32]
        # return_logits: True
    condition_codec_config:
      target: image_synthesis.modeling.codecs.text_codec.tokenize.Tokenize
      params:
       context_length: 77     ############# 77 for clip and 256 for dalle
       add_start_and_end: True
       with_mask: True
       pad_value: 0 # 0 for clip embedding and -100 for others
       clip_embedding: False     ############################   if we use clip embedding 
       tokenizer_config:
        target: image_synthesis.modeling.modules.clip.simple_tokenizer.SimpleTokenizer   #########
        params:
          end_idx: 49152                              ###################
    diffusion_config:      
    # target: image_synthesis.modeling.transformers.gpt_like_transformer.GPTLikeTransformer
      target: image_synthesis.modeling.transformers.diffusion_transformer.DiffusionTransformer
      params:
        diffusion_step: 100
        alpha_init_type: 'alpha1'       # init_type = fix or cos or linear 
        auxiliary_loss_weight: 5.0e-4
        adaptive_auxiliary_loss: True
        mask_weight: [1, 1]    # the loss weight on mask region and non-mask region
        learnable_cf: true
        transformer_config:
          target: image_synthesis.modeling.transformers.transformer_utils.Text2ImageTransformer
          params:
            attn_type: 'selfcross'
            n_layer: 36
            condition_seq_len: 77    ###### 77 for clip and 256 for dalle
            content_seq_len: 1024  # 32 x 32
            content_spatial_size: [32, 32]
            n_embd: 1408 # the dim of embedding dims
            condition_dim: 512
            n_head: 16 
            attn_pdrop: 0.0
            resid_pdrop: 0.0
            block_activate: GELU2
            timestep_type: 'adalayernorm'    # adainsnorm or adalayernorm and abs
            mlp_hidden_times: 4
            checkpoint: True
        condition_emb_config:
          target: image_synthesis.modeling.embeddings.clip_text_embedding.CLIPTextEmbedding
          params:
            clip_name: 'ViT-B/32'
            num_embed: 49408 # 49152+256
            normalize: True
            pick_last_embedding: False   # if True same as clip but we need embedding of each word
            keep_seq_len_dim: False
            additional_last_embedding: False
            embed_dim: 512
        content_emb_config:
          target: image_synthesis.modeling.embeddings.dalle_mask_image_embedding.DalleMaskImageEmbedding
          params:
            num_embed: 4096
            spatial_size: !!python/tuple [32, 32]
            embed_dim: 1408
            trainable: True
            pos_emb_type: embedding

solver:
  base_lr: 3.0e-6
  adjust_lr: none # not adjust lr according to total batch_size
  max_epochs: 50
  save_epochs: 1
  save_iterations: 1000
  validation_epochs: 20
  sample_iterations: 2000  # epoch #30000      # how many iterations to perform sampling once ?
  print_specific_things: True

  # config for ema
  ema:
    decay: 0.99
    update_interval: 25
    device: cpu

  clip_grad_norm:
    target: image_synthesis.engine.clip_grad_norm.ClipGradNorm
    params:
      start_iteration: 0
      end_iteration: 5000
      max_norm: 0.5
  optimizers_and_schedulers: # a list of configures, so we can config several optimizers and schedulers
  - name: none # default is None
    optimizer:
      target: torch.optim.AdamW
      params: 
        betas: !!python/tuple [0.9, 0.96]
        weight_decay: 4.5e-2
            # target: ZeroRedundancyOptimizer
            # optimizer_class: torch.optim.AdamW
            # params:
            # betas: !!python/tuple [0.9, 0.96]
            # weight_decay: 4.5e-2
    scheduler:
      step_iteration: 1
      target: image_synthesis.engine.lr_scheduler.ReduceLROnPlateauWithWarmup
      params:
        factor: 0.5
        patience: 50000
        min_lr: 1.0e-6
        threshold: 5.0e-2
        threshold_mode: rel
        warmup_lr: 1.0e-4 # the lr to be touched after warmup   4.5e-4 for train from scratch
        warmup: 7000 

dataloader:
  # data_root: data
  data_root: dataset
  batch_size: 8
  num_workers: 8
  train_datasets: # a list of configures, so we can combine several schedulers
    - target: image_synthesis.data.cub200_dataset.Cub200Dataset
      params:
        data_root: dataset/CUB-200
        phase: train
        im_preprocessor_config:
          target: image_synthesis.data.utils.image_preprocessor.DalleTransformerPreprocessor
          params:
            size: 256
            phase: train
  validation_datasets:
    - target: image_synthesis.data.cub200_dataset.Cub200Dataset
      params:
        data_root: dataset/CUB-200
        phase: test
        im_preprocessor_config:
          target: image_synthesis.data.utils.image_preprocessor.DalleTransformerPreprocessor
          params:
            size: 256
            phase: train

