# @package _global_

defaults:
    - override /ds: celeba
    - override /split: celeba/random/no_nonblond_females
    - override /labeller: gt
    - override /ae_arch: resnet/rn50_256_pre
    - _self_

ae:
    lr: 1.e-5
    zs_dim: 6
    zs_transform: none

alg:
    use_amp: true
    pred:
        lr: ${ ae.lr }
    log_freq: ${ alg.steps }
    val_freq: 200
    num_disc_updates: 5
    # enc_loss_w: 0.0001
    enc_loss_w: 1
    disc_loss_w: 0.03
    # prior_loss_w: 0.01
    prior_loss_w: null
    pred_y_loss_w: 1
    pred_s_loss_w: 0
    pred_y:
        num_hidden: 1  # for decoding the pre-trained RN50 output
        dropout_prob: 0.1
    s_pred_with_bias: false
    s_as_zs: false

disc:
    lr: 1.e-4

# disc_arch:
#     dropout_prob: 0.1

dm:
    stratified_sampler: exact
    num_workers: 4
    batch_size_tr: 10
    batch_size_te: 20

eval:
    batch_size: 10
    balanced_sampling: true
    hidden_dim: null
    num_hidden: 1
    steps: 10000
    opt:
        lr: 1.e-4
        scheduler_cls: torch.optim.lr_scheduler.CosineAnnealingLR
        scheduler_kwargs:
            T_max: ${ eval.steps }
            eta_min: 5e-7
