experiment:
    project: "titok_b_64_stage2"
    name: "titok_b_64_stage2_run1"
    output_dir: "titok_b_64_stage2_run1"
    max_train_examples: 1_281_167
    save_every: 50_000
    eval_every: 50_000
    generate_every: 5_000
    log_every: 50
    log_grad_norm_every: 1_000
    resume: True
    init_weight: "titok_b_64_stage1.bin"


model:
    vq_model:
        codebook_size: 4096
        token_size: 12
        use_l2_norm: True
        commitment_cost: 0.0
        # vit arch
        vit_enc_model_size: "base"
        vit_dec_model_size: "base"
        vit_enc_patch_size: 16
        vit_dec_patch_size: 16
        num_latent_tokens: 64
        finetune_decoder: True
        pretrained_tokenizer_weight: "maskgit-vqgan-imagenet-f16-256.bin"

losses:
    discriminator_start: 20_000
    quantizer_weight: 0.0
    discriminator_factor: 1.0
    discriminator_weight: 0.01
    perceptual_loss: "convnext_s"
    perceptual_weight: 0.1
    reconstruction_loss: "l2"
    reconstruction_weight: 1.0
    lecam_regularization_weight: 0.001


dataset:
    params:
        train_shards_path_or_url: "imagenet_sharded/train/imagenet-train-{0000..0252}.tar"
        eval_shards_path_or_url: "imagenet_sharded/val/imagenet-val-{0000..0009}.tar"
        num_workers_per_gpu: 12
    preprocessing:
        resize_shorter_edge: 256
        crop_size: 256
        random_crop: True
        random_flip: True


optimizer:
    name: adamw 
    params:
        learning_rate: 1e-4
        discriminator_learning_rate: 1e-4
        beta1: 0.9
        beta2: 0.999
        weight_decay: 1e-4

lr_scheduler:
    scheduler: "cosine"
    params:
        learning_rate: ${optimizer.params.learning_rate}
        warmup_steps: 5_000
        end_lr: 1e-5

training:
    gradient_accumulation_steps: 1
    per_gpu_batch_size: 32
    mixed_precision: "fp16"
    enable_tf32: True
    enable_wandb: True
    use_ema: True
    seed: 42
    max_train_steps: 500_000
    num_generated_images: 2
    max_grad_norm: 1.0