experiment:
    project: "rar_generation"
    name: "rar-b"
    max_train_examples: 1_281_167
    save_every: 50_000
    eval_every: 50_000
    generate_every: 5_000
    log_every: 50
    log_grad_norm_every: 1_000
    resume: True

model:
    vq_model:
        codebook_size: 1024
        token_size: 256
        num_latent_tokens: 256
        finetune_decoder: False
        pretrained_tokenizer_weight: "rar/models/rar/maskgit-vqgan-imagenet-f16-256.bin"
    
    generator:
        hidden_size: 768
        num_hidden_layers: 24
        num_attention_heads: 16
        intermediate_size: 3072
        dropout: 0.1
        attn_drop: 0.1
        class_label_dropout: 0.1
        image_seq_len: 256
        condition_num_classes: 1000

        # sampling hyper-params for RAR-B
        randomize_temperature: 1.0
        guidance_scale: 16.0
        guidance_scale_pow: 2.75
        use_checkpoint: False # True to save memory

        randomness_anneal_start: 125000 # 200 epoch
        randomness_anneal_end: 187500 # 300 epoch

dataset:
    params:
        # use pretokenized dataset for speed-up
        pretokenization: "maskgitvq.jsonl"
        train_shards_path_or_url: "imagenet_sharded/train/imagenet-train-{0000..0252}.tar"
        eval_shards_path_or_url: "imagenet_sharded/val/imagenet-val-{0000..0009}.tar"
        num_workers_per_gpu: 12
    preprocessing:
        resize_shorter_edge: 256
        crop_size: 256
        random_crop: False
        random_flip: True

optimizer:
    name: adamw 
    params:
        learning_rate: 4e-4
        beta1: 0.9
        beta2: 0.96
        weight_decay: 0.03


lr_scheduler:
    scheduler: "cosine"
    params:
        learning_rate: ${optimizer.params.learning_rate}
        warmup_steps: 62_500 # 100 epochs with bsz 2048
        end_lr: 1e-5

training:
    gradient_accumulation_steps: 1
    per_gpu_batch_size: 64 # 32 GPU, total batch size 2048
    mixed_precision: "bf16"
    enable_tf32: True
    enable_wandb: True
    use_ema: False
    seed: 42
    max_train_steps: 250_000
    max_grad_norm: 1.0