# @package _global_

# usage: +experiment=celeba/sm/pt

defaults:
    - override /ds: celeba/gender_smiling
    - override /labeller: gt
    - override /split: celeba/artifact/base
    - override /ae_arch: vqgan/imagnet
    - override /disc_arch: set
    - override /scorer: none

artifact_name: ${split.artifact_name}_in_vqgan_pt
seed: 0
dm:
    batch_size_tr: 2
    batch_size_te: 64
    num_samples_per_group_per_bag: 1

alg:
    steps: 50000
    ga_steps: 1
    num_disc_updates: 0
    twoway_disc_loss: false
    prior_loss_w: 0
    log_freq: ${alg.steps} # only log on the final iteration
    val_freq: 1.0
    pred_y_loss_w: 0
    pred_s_loss_w: 0
    warmup_steps: 0
    max_grad_norm: 5
    disc_loss_w: 0

ae:
    recon_loss: l1
    lr: 1.e-4
    zs_transform: none
    optimizer_cls: ADAM
    weight_decay: 0
    zs_dim: 1

eval:
    batch_size: 12
    balanced_sampling: true
    steps: 10000
    num_hidden: 1
    hidden_dim: null
    opt:
        lr: 1.e-4
        scheduler_cls: torch.optim.lr_scheduler.CosineAnnealingLR
        scheduler_kwargs:
            T_max: ${ eval.steps }
            eta_min: 5e-7

wandb:
    group: ${split.artifact_name}_ae_pretraining
