# @package _global_

# usage: +experiment=nih/sb10_sm

defaults:
    - override /ds: nih/gender_infiltration
    - override /labeller: gt
    - override /split: nih/sb10_gender_infiltration_art
    # - override /ae_arch: resnet
    - override /ae_arch: artifact/nih/gender_infiltration_sb10_0_vqgan
    - override /disc_arch: fcn
    - _self_

dm:
    batch_size_tr: 4
    batch_size_te: 64
    num_samples_per_group_per_bag: 1

alg:
    steps: 20000
    num_disc_updates: 5
    prior_loss_w: 0
    log_freq: 1000
    pred_y_loss_w: 0
    pred_s_loss_w: 0
    warmup_steps: 0
    max_grad_norm: 5
    disc_loss_w: 1.0

ae:
    recon_loss: l1
    # lr: 1.e-4
    lr: 0.0000018348295456200725
    zs_transform: none
    optimizer_cls: ADAM
    weight_decay: 0
    zs_dim: 1

# ae_arch:
#     version: RN18
#     latent_dim: 128
#     first_conv: false
#     maxpool1: false

disc:
    # lr: 4.e-4
    lr: 0.000027080358427157673
    optimizer_cls: ADAM

disc_arch:
    hidden_dim: null
    num_hidden: 2
    input_norm: true
    final_bias: true

eval:
    batch_size: 12
    balanced_sampling: true
    steps: 10000
    num_hidden: 1
    opt:
        lr: 1.e-4
        scheduler_cls: torch.optim.lr_scheduler.CosineAnnealingLR
        scheduler_kwargs:
            T_max: ${ alg.steps }
            eta_min: 5e-7

wandb:
    group: nih_sb10_mimin
