# @package _global_

# usage: +experiment=nih/sb10_gender_infiltration_pt

defaults:
    - override /ds: nih/gender_infiltration
    - override /labeller: gt
    - override /split: nih/sb10_gender_infiltration_art
    - override /ae_arch: vqgan/imagnet
    - override /disc_arch: set
    - override /scorer: none
    - _self_

seed: 0
dm:
    batch_size_tr: 12
    batch_size_te: 64
    num_samples_per_group_per_bag: 1

alg:
    steps: 10000
    ga_steps: 1
    num_disc_updates: 0
    twoway_disc_loss: false
    prior_loss_w: 0
    log_freq: ${alg.steps} # only log on the final iteration
    val_freq: 1.0
    pred_y_loss_w: 0
    pred_s_loss_w: 0
    warmup_steps: 0
    max_grad_norm: 5
    disc_loss_w: 0
    artifact_name: ${split.artifact_name}_in_vqgan_pt

ae:
    recon_loss: l1
    lr: 1.e-4
    zs_transform: none
    optimizer_cls: ADAM
    weight_decay: 0
    zs_dim: 1

eval:
    batch_size: 12
    balanced_sampling: true
    steps: 10000
    num_hidden: 1
    hidden_dim: null
    opt:
        lr: 1.e-4
        scheduler_cls: torch.optim.lr_scheduler.CosineAnnealingLR
        scheduler_kwargs:
            T_max: ${ eval.steps }
            eta_min: 5e-7

wandb:
    group: ${split.artifact_name}_ae_pretraining
