# @package _global_

# usage: +experiment=nih/sm/sb10_sm_winter_pine

defaults:
    - override /ds: nih/gender_infiltration
    - override /labeller: gt
    - override /split: nih/sb10_gender_infiltration_art
    # - override /ae_arch: vqgan/imagnet
    - override /ae_arch: artifact/nih/gender_infiltration_sb10_0_vqgan
    - override /disc_arch: set
    - override /scorer: nih/sm
    - _self_

seed: 0
dm:
    batch_size_tr: 4
    batch_size_te: 64
    num_samples_per_group_per_bag: 1

alg:
    steps: 20000
    ga_steps: 4
    num_disc_updates: 5
    twoway_disc_loss: false
    prior_loss_w: 0
    log_freq: ${alg.steps} # only log on the final iteration
    val_freq: 0.2
    pred_y_loss_w: 0
    pred_s_loss_w: 0
    warmup_steps: 0
    max_grad_norm: 5
    disc_loss_w: 0.05

ae:
    recon_loss: l1
    lr: 0.0000018348295456200725
    zs_transform: none
    optimizer_cls: ADAM
    weight_decay: 0
    zs_dim: 1

disc:
    lr: 0.000027080358427157673
    criterion: LOGISTIC_NS
    optimizer_cls: ADAM

disc_arch:
    hidden_dim_pre: null
    hidden_dim_post: null
    num_hidden_pre: 1
    num_hidden_post: ${disc_arch.num_hidden_pre}
    agg_input_dim: null
    input_norm: true
    final_bias: true

    head_dim: null
    num_heads: 1
    num_blocks: 0
    mean_query: true

eval:
    batch_size: 12
    balanced_sampling: true
    steps: 10000
    num_hidden: 1
    hidden_dim: null
    opt:
        lr: 1.e-4
        scheduler_cls: torch.optim.lr_scheduler.CosineAnnealingLR
        scheduler_kwargs:
            T_max: ${ eval.steps }
            eta_min: 5e-7

wandb:
    group: nih_sb10_sm_winter_pine
