program: scripts/sweep_cli_entry.py
method: random
name: sweep_cifar_cnn
project: vqvae_ablation
entity: your_wandb_name
metric:
  name: val_loss
  goal: minimize
parameters:
  vq_decay:
    values: [0.5, 0.8, 0.9, 0.95, 0.98]
  commit_weight:
    values: [0.1, 0.2, 0.5, 0.8, 1.0, 1.2, 1.5, 2.0, 2.5, 3.0, 3.5, 4.0, 4.5, 5.0, 5.5, 6.0]
  threshold_ema_dead_code:
    # depends on global batch_size and codebook_size. expect threshold = constant * (total tokens / codebook size)
    values: [0.008, 0.032, 0.125, 0.5] 
  model:
    values: ["vqvae", "rot_vqvae"]
  codebook_size:
    values: [512, 1024, 2048, 8192]
  codebook_dim:
    values: [256, 128, 64, 32, 16, 8, 6, 4, 3]
  dataset:
    values: ["cifar10"]
  arch:
    values: ["taming"]
  batch_size:
    values: [256]
  gpu:
    values: [0]
  f:
    values: [4]