# @package _global_

# usage: +experiment=nih/sm/sb10_sm

defaults:
    - override /ds: nih/gender_infiltration
    - override /labeller: gt
    - override /split: nih/sb10_gender_infiltration_art
    # - override /ae_arch: resnet
    # - override /ae_arch: artifact
    - override /ae_arch: artifact/nih/gender_infiltration_sb10_0_vqgan
    - override /disc_arch: set
    - override /scorer: nih/sm
    - _self_

seed: 0
dm:
    batch_size_tr: 4
    batch_size_te: 64
    num_samples_per_group_per_bag: 1

alg:
    steps: 10000
    ga_steps: 2
    num_disc_updates: 5
    twoway_disc_loss: true
    prior_loss_w: 0
    log_freq: ${alg.steps} # only log on the final iteration
    val_freq: 0.2
    pred_y_loss_w: 0
    pred_s_loss_w: 0
    warmup_steps: 0
    max_grad_norm: 5

ae:
    recon_loss: l1
    lr: 1.e-5
    zs_transform: none
    optimizer_cls: ADAM
    weight_decay: 0
    zs_dim: 1

# ae_arch:
#     # version: RN18
#     # latent_dim: 128
#     # first_conv: false
#     # maxpool1: false
#     artifact_name: nih_gender_infiltration_sb10_0_ae
#     version: 0

disc:
    lr: 8.e-5
    criterion: LOGISTIC_NS
    optimizer_cls: ADAM

disc_arch:
    hidden_dim_pre: null
    hidden_dim_post: null
    num_hidden_pre: 1
    num_hidden_post: ${disc_arch.num_hidden_pre}
    agg_input_dim: null
    input_norm: true
    final_bias: true
    agg_fn: GATED

eval:
    batch_size: 12
    balanced_sampling: true
    steps: 10000
    num_hidden: 1
    hidden_dim: null
    opt:
        lr: 1.e-4
        scheduler_cls: torch.optim.lr_scheduler.CosineAnnealingLR
        scheduler_kwargs:
            T_max: ${ eval.steps }
            eta_min: 5e-7

wandb:
    group: nih_sb10_sm
