wandb:
  entity: null
  resume: 'auto'

experiment:
    project: "mmada-training-stage4"
    name: "mmada-training-stage4-llada-instruct"
    output_dir: "mmada-training-stage4-llada-instruct"
    max_train_examples_t2i: 40000000    # 
    max_train_examples_mmu: 40000000    #
    save_every: 10000
    eval_every: 2500
    generate_every: 1000
    log_every: 50
    log_grad_norm_every: 100
    resume_from_checkpoint: "latest"
    val_every: 50
    max_val_examples_t2i: 2000

model:
    vq_model:
        type: "magvitv2"
        vq_model_name: "showlab/magvitv2"

    mmada:
        tokenizer_path: "GSAI-ML/LLaDA-8B-Instruct"
        pretrained_model_path: "mmada-training-stage3-llada-instruct-512-cot-uni/checkpoint-210000/unwrapped_model"
        w_clip_vit: False
        new_vocab_size: 134656
        llm_vocab_size: 126464
        codebook_size: 8192
        num_vq_tokens: 1024
        num_new_special_tokens: 0
        tie_word_embeddings: False

    gradient_checkpointing: True

dataset:
    gen_type: "t2i"
    und_type: "captioning"
    combined_loader_mode: "max_size_cycle"
    params:
        train_t2i_shards_path_or_url: [ "/path/to/your/datasets/JourneyDB/train/imgs/data/train/imgs/{000..199}.tgz",
                                        "/path/to/your/datasets/laion-aesthetics-12m-filter/{00000..00999}.tar",
                                        # "/path/to/your/datasets/text-to-image-2M/data_512_2M/data_{000000..000046}.tar"
        ]
        train_mmu_shards_path_or_url: [ "/path/to/your/datasets/multimodal_cot/ai2d/new_images.tar",
                                        "/path/to/your/datasets/multimodal_cot/clevr/images.tar",  
                                        "/path/to/your/datasets/multimodal_cot/docvqa/images.tar",  
                                        "/path/to/your/datasets/multimodal_cot/geo/images.tar",
        ]
        train_lm_shards_path_or_url: "/path/to/your/datasets/falcon-refinedweb/data/data/*.parquet" 
        train_instruct_shards_path_or_url: "/path/to/your/datasets/stage4_instruct/*.parquet" 
        add_caption_prompt: True
        external_caption_path: "/path/to/your/datasets/SAM-LLaVA-Captions10M"
        external_journeydb_caption_path: "/path/to/your/datasets/journeydb_anno/train_journeydb_anno.json"
        external_laion12m_caption_path: "/path/to/your/datasets/laion-aesthetics-12m-images-2"
        external_cc12m_caption_path: "/path/to/your/datasets/cc12m/new_captions"
        external_text_to_image_2M_512_caption_path: "/path/to/your/datasets/text-to-image-2M/data_512_2M_captions"
        external_ai2d_caption_path: "/path/to/your/datasets/multimodal_cot/ai2d/new_metadata.csv"
        external_clevr_caption_path: "/path/to/your/datasets/multimodal_cot/clevr/metadata.csv"
        external_docvqa_caption_path: "/path/to/your/datasets/multimodal_cot/docvqa/metadata.csv"
        external_geo_caption_path: "/path/to/your/datasets/multimodal_cot/geo/metadata.csv"
        external_vqa_caption_path: "/path/to/your/datasets/LLaVA-Instruct-150K/llava_v1_5_mix665k.json"
        external_clevr2_caption_path: "/path/to/your/datasetsClevr_CoGenT_TrainA_70K_Complex/captions.json"
        external_geo170k_caption_path: "/path/to/your/datasets/Geo170K/Geo170K/all.json"
        vqa_images_path: "/path/to/your/datasets/LLaVA-Instruct-150K-images"
        clevr2_images_path: "/path/to/your/datasets/Clevr_CoGenT_TrainA_70K_Complex/images"
        geo170k_images_path: "/path/to/your/datasets/Geo170K/Geo170K/images"
        validation_prompts_file: "validation_prompts/text2image_prompts.txt"
        mmu_image_root: "/path/to/your/datasets/MMaDA/mmu_validation"
        mmu_validation_prompts_file: "/path/to/your/datasets/MMaDA/mmu_validation/prompts_with_vqa.json"
        lm_chat_validation_jsonl: "/path/to/your/datasets/MMaDA/lm_chat_validation/questions.jsonl"
        shuffle_buffer_size: 1000
        num_workers: 16
        resolution: 512
        pin_memory: True
        persistent_workers: True

    preprocessing:
        max_seq_length: 512 # for text tokens in t2i & mmu
        max_lm_text_length: 1536 # for text tokens in lm/lm_chat
        resolution: 512
        center_crop: False
        random_flip: False

optimizer:
    name: adamw
    params: # default adamw params
        learning_rate: 5e-5
        scale_lr: False # scale learning rate by total batch size
        beta1: 0.9
        beta2: 0.999
        weight_decay: 0.01
        epsilon: 1e-8

lr_scheduler:
    scheduler: "cosine"
    params:
        learning_rate: ${optimizer.params.learning_rate}
        warmup_steps: 5000
        min_lr_scale: 0.1

training:
    gradient_accumulation_steps: 4  # 4
    noise_type: "mask"
    batch_size_t2i: 1
    batch_size_lm: 2
    batch_size_mmu: 1
    mixed_precision: "bf16"
    enable_tf32: True
    seed: 10086
    max_train_steps: 1000000
    overfit_one_batch: False
    cond_dropout_prob: 0.1
    min_masking_rate: 0.0
    label_smoothing: 0.0
    max_grad_norm: 1
    guidance_scale: 5
    generation_timesteps: 20
    t2i_coeff: 0.05
    lm_coeff: 0.6
    mmu_coeff: 0.4
    cot_in_mmu_coeff: 3.5
    vqa_in_mmu_coeff: 5.5
    clevr2_in_mmu_coeff: 0.5
    geo170k_in_mmu_coeff: 0.5
    base_in_lm_coeff: 0.02
    instruct_in_lm_coeff: 0.98

validation:
    quantative_prompts_file: "validation_prompts/quantative.txt" 
    quantative_batch_size: 8