# configs/gepc_cifar10.yaml
image_size: 32
data_image_size: 32

# Path to an improved-diffusion checkpoint
model_path: checkpoints/celeba_ema_0.9999_499999.pt

# Explicit UNet/diffusion overrides
improved_args:
  num_channels: 128
  num_res_blocks: 3
  learn_sigma: true
  dropout: 0.3

data_root: ./data
batch_size: 128
device: 0
strict_determinism: true

eval:
  id_train: { name: cifar10, split: train, limit: 2000, download: true }
  id_test:  { name: cifar10, split: test,  limit: 1000, download: true }
  ood:
    # - { name: svhn,     split: test, limit: 1000, download: true }
    - { name: celeba,   split: test, limit: 1000, download: true }
    # - { name: cifar100, split: test, limit: 1000, download: true }

gepc:
  seed: 1337
  verbose: true
  amp: "fp32"

  features: ["gepc_s"] # ,"gepc_s_cos","gepc_s_pair"]
  metric_default: "gepc_s"

  density_mode: "kde"
  bandwidth: 0.0
  fit_batches: 128
  mc_samples: 1

  spatial_pool: "topk"
  topk_rho: 0.3

  group_shifts: true
  shift_px: 1

  t_mode: "snr"
  snr_levels: [0.99997, 0.99990, 0.99790, 0.99690]
  keep_k: 2

  agg_t: "wmean"
  weight_t: "inv_cv"
  agg_feat: "sum"
  vector_mode: "none"

  internal_bs: 64
