wandb:
  entity: null

experiment:
    project: "TinyMuse"
    name: "Pretrain_Stage"
    output_dir: "ckpt/pretrain_stage"
    max_train_examples: 9088710 
    max_eval_examples: 8118
    save_every: 1000
    eval_every: 1000
    generate_every: 1000
    log_every: 50
    log_grad_norm_every: 500
    resume_from_checkpoint: None
    resume_lr_scheduler: True

model:
    vq_model:
        type: "vqgan"
        pretrained: "ckpt/VQGAN"
    
    text_encoder:
        type: "phi2"
        pretrained: "ckpt/Phi2"

    transformer:
        vocab_size: 8256 
        hidden_size: 1024
        intermediate_size: 4096
        num_hidden_layers: 22
        num_attention_heads: 16
        max_position_embeddings: 256
        in_channels: 768
        block_out_channels:
            - 768
        num_res_blocks: 3
        patch_size: 1
        encoder_hidden_size: 2560 
        add_cross_attention: True
        project_encoder_hidden_states: False
        codebook_size: 8192
        num_vq_tokens: 256
        initializer_range: 0.02
        norm_type: "rmsnorm"
        layer_norm_eps: 1e-6
        use_normformer: False
        use_encoder_layernorm: True
        use_bias: False
        hidden_dropout: 0.0
        attention_dropout: 0.0
        use_codebook_size_for_output: True
        use_empty_embeds_for_uncond: True
        add_cond_embeds: True
        add_micro_cond_embeds: True
        micro_cond_embed_dim: 1280
        cond_embed_dim: 2560

    gradient_checkpointing: True
    enable_xformers_memory_efficient_attention: True


dataset:
    type: "text2image"
    params:
        train_shards_path_or_url: "dataset/stage1/{00000..01209}.tar"
        eval_shards_path_or_url: "dataset/stage1/01210.tar"
        validation_prompts_file: "validation_prompts/pretrain_prompts.txt"
        negative_prompts_file: "negative_prompts/negative_prompts.txt"
        batch_size: ${training.batch_size}
        shuffle_buffer_size: 1000
        num_workers: 32
        resolution: 256
        pin_memory: True
        persistent_workers: True
    preprocessing:
        max_seq_length: 200 
        resolution: 256
        center_crop: False
        random_flip: False


optimizer:
    name: adamw 
        learning_rate: 1e-4
        scale_lr: False 
        beta1: 0.9
        beta2: 0.999
        weight_decay: 0.01
        epsilon: 1e-8


lr_scheduler:
    scheduler: "cosine_with_restarts" 
    params:
        learning_rate: ${optimizer.params.learning_rate}
        warmup_steps: 2000


training:
    gradient_accumulation_steps: 2
    batch_size: 96
    mixed_precision: "fp16"
    enable_tf32: True
    use_ema: False
    seed: 9345104
    max_train_steps: 677300
    overfit_one_batch: False
    cond_dropout_prob: 0.1
    min_masking_rate: 0.0
    label_smoothing: 0.0
    max_grad_norm: null
    guidance_scale: 4.0
    generation_timesteps: 12
    use_soft_code_target: False
    use_stochastic_code: False
    soft_code_temp: 1.0