# @package _global_
model:
  name: CDVAE
  train_wrep_encoder: False
  wrep_encoder:
    _target_: src.models.cdvae.WRep_encoder
    activation: leaky_relu
    fc_hidden_units: 14
    z_latent_dim: 12
    context_latent_dim: 12
    br_size: 12
    weighting_method: overlap
    num_layer: 1
    dropout_rate: 0.1
    min_timestep: 10

    balancing: null
    # rnn_type: gru
    tune_hparams: False

    batch_size: 256

    optimizer:
      lr_scheduler: False
      optimizer_cls: adam
      learning_rate: 0.001
      weight_decay: 0.000

  cdvae:
    _target_: src.models.cdvae.CDVAE
    activation: leaky_relu
    fc_hidden_units: 14
    z_latent_dim: 12
    context_latent_dim: 12
    br_size: 12
    use_deviance: False
    weighting_method: overlap
    num_layer: 1
    dropout_rate: 0.1
    y_dist_type: continuous
    min_timestep: 10
    percentage_steps_ipm: 0.1

    y_scale_require_grad: True

    gmm_prior:
      n_clusters: 5
      cov_type_p_z_given_c: full
      to_fix_pi_p_c: False
      init_type_p_z_given_c: gmm

    lambda_ipm: 0.1
    lambda_mm: 0.1
    kld_weight: 0.01
    lambda_y: 1.0

    balancing: null
    # rnn_type: gru
    tune_hparams: False

    batch_size: 512

    optimizer:
      lr_scheduler: False
      non_treatment_head:
        optimizer_cls: rmsprop
        learning_rate: 0.001
        weight_decay: 0.0000
        momentum: 0.9

      treatment_head:
        optimizer_cls: adamw
        learning_rate: 0.001
        # momentum: 0.9
        weight_decay: 0.0000

  tune_hparams: False                 # Hparam tuning
  tune_range: 50
  hparams_grid:
  resources_per_trial:

exp:
  weights_ema: False

callbacks:
  alpha_raise:
    rate: exp
  early_stopping:
    patience: 60
