wandb:
  entity: null

experiment:
    project: "TinyMuse"
    name: "Instruction-tuning"
    pretrain_dir: "ckpt/pretrain_stage" 
    output_dir: "ckpt/instruct_stage/generator"
    phi2_output_dir: "ckpt/instruct_stage/phi2"
    max_train_examples: 7800000 
    max_eval_examples: 35000
    save_every: 1000
    eval_every: 1000
    generate_every: 1000
    log_every: 50
    log_grad_norm_every: 500
    resume_from_checkpoint: False
    resume_lr_scheduler: True

model:
    vq_model:
        type: "vqgan"
        pretrained: "ckpt/VQGAN"

    text_encoder:
        type: "phi2"
        pretrained: "ckpt/Phi2"
        pretrain_token:  "ckpt/Phi2" 

    transformer:
        vocab_size: 8256 
        hidden_size: 1024
        intermediate_size: 4096
        num_hidden_layers: 6
        num_attention_heads: 16
        max_position_embeddings: 256
        in_channels: 768 
        block_out_channels:
            - 768
        num_res_blocks: 3
        patch_size: 1
        encoder_hidden_size: 2560 
        add_cross_attention: True
        project_encoder_hidden_states: False
        codebook_size: 8192
        num_vq_tokens: 256
        initializer_range: 0.02
        norm_type: "rmsnorm"
        layer_norm_eps: 1e-6
        use_normformer: False
        use_encoder_layernorm: True
        use_bias: False
        hidden_dropout: 0.0
        attention_dropout: 0.0
        use_codebook_size_for_output: True
        use_empty_embeds_for_uncond: True
        add_cond_embeds: True
        add_micro_cond_embeds: True
        micro_cond_embed_dim: 1280
        cond_embed_dim: 2560 

    gradient_checkpointing: True
    enable_xformers_memory_efficient_attention: True


dataset:
    type: "text2image"
    params:
        train_shards_path_or_url: "dataset/stage2/{0000..1208}.tar"
        eval_shards_path_or_url: "dataset/stage2/1209.tar"
        validation_prompts_file: "validation_prompts/pretrain_prompts.txt"
        negative_prompts_file: "negative_prompts/negative_prompts.txt"
        batch_size: ${training.batch_size}
        shuffle_buffer_size: 1000
        num_workers: 32
        resolution: 256
        pin_memory: True
        persistent_workers: True
    preprocessing:
        max_seq_length: 200 
        resolution: 256
        center_crop: False
        random_flip: False


optimizer:
    name: adamw
    params: 
        learning_rate: 1e-5
        scale_lr: False
        beta1: 0.9
        beta2: 0.999
        weight_decay: 0.01
        epsilon: 1e-8


lr_scheduler:
    scheduler: "cosine_with_restarts"
    params:
        learning_rate: ${optimizer.params.learning_rate}
        warmup_steps: 2000


training:
    gradient_accumulation_steps: 2
    batch_size: 24
    mixed_precision: "bf16" 
    enable_tf32: True
    use_ema: False
    seed: 9345104
    max_train_steps: 500000
    overfit_one_batch: False
    cond_dropout_prob: 0.1
    min_masking_rate: 0.0
    label_smoothing: 0.0
    max_grad_norm: null
    guidance_scale: 4.0
    generation_timesteps: 12
    use_soft_code_target: False
    use_stochastic_code: False
    soft_code_temp: 1.0