experiment:
    project: "maskgen_vq_l_stage2"
    name: "maskgen_vq_l_stage2_run1"
    output_dir: "maskgen_vq_l_stage2_run1"
    init_weight: "maskgen_vq_l_stage1.bin"
    max_train_examples: 13_000_000
    save_every: 50_000
    eval_every: 50_000
    generate_every: 5_000
    log_every: 50
    log_grad_norm_every: 1_000
    resume: True

model:
    vq_model:
        quantize_mode: vq
        codebook_size: 8192
        token_size: 64
        use_l2_norm: False
        commitment_cost: 0.25
        clustering_vq: True
        # vit arch
        vit_enc_model_size: "base"
        vit_dec_model_size: "large"
        vit_enc_patch_size: 16
        vit_dec_patch_size: 16
        num_latent_tokens: 128
        finetune_decoder: False
        is_legacy: False
    maskgen:
        decoder_embed_dim: 1024
        decoder_depth: 16
        decoder_num_heads: 16
        micro_condition: True
        micro_condition_embed_dim: 256
        text_drop_prob: 0.1
        condition_num_classes: 1000
        cfg: 12.0
        num_iter: 16
        temperature: 2.0
        sample_aesthetic_score: 6.5

losses:
  label_smoothing: 0.1
  loss_weight_unmasked_token: 0.1

dataset:
    params:
        train_shards_path_or_url: "datacomp6+::laion-art::laion-pop::journeydb::dalle3"
        eval_shards_path_or_url: "coco"
        pretokenization: "true"
        num_workers_per_gpu: 12
        dataset_with_class_label: False
        dataset_with_text_label: True
    preprocessing:
        resize_shorter_edge: 256
        crop_size: 256
        random_crop: True
        random_flip: True
        res_ratio_filtering: True

optimizer:
    name: adamw 
    params:
        learning_rate: 1e-4
        beta1: 0.9
        beta2: 0.96
        weight_decay: 0.03

lr_scheduler:
    scheduler: "cosine"
    params:
        learning_rate: ${optimizer.params.learning_rate}
        warmup_steps: 10_000
        end_lr: 1e-5

training:
    gradient_accumulation_steps: 1
    per_gpu_batch_size: 128
    mixed_precision: "fp16"
    enable_tf32: True
    enable_wandb: True
    use_ema: True
    seed: 42
    max_train_steps: 250_000
    num_generated_images: 2
    max_grad_norm: 1.0