# @package _global_
model:
  name: Causal_CPCPretained
  activation: selu
  pretrain_rep_encoder: True

  rep_encoder:
    _target_: src.models.causal_cpc_ed.Causal_CPC_Encoder
    downsampling_factor: 1
    subsample_win_ratio: 0.1
    dropout_rate: 0.1              # Dropout of LSTM hidden layers + output layers
    num_layer: 1
    br_size: 14
    fc_hidden_units: 14                 # fc_hidden_units <= br_size  used in buiding  treatment  and outcome heads
    genc_hidden: 12                     # hidden dim of local features genc_hidden
    context_latent_dim: ${model.rep_encoder.br_size}               # dim of features summary

    alpha_recons: 1
    alpha_infonce: 1
    alpha_mse: 10
    balancing: mutual_info
    label_smoothing: 0
    use_spectral_norm: True
    use_instance_noise: False
    activation: selu
    cpc_lb: infonce
    infomax_lb: infonce
    rnn_type: gru
    weighting: 'uniform'

    static_features_in_CPC: False
    dim_random_vitals: 0

    use_causalconv: False
    input_channels: 1
    hidden_channels: 16
    kernel_size: 4
    dilation: 1

    batch_size: 256
    optimizer:
      lr_scheduler: False
      non_treatment_head:
        optimizer_cls: adamw
        learning_rate: 0.001
        weight_decay: 0.0000
        momentum: 0.0000

      treatment_head:
        optimizer_cls: sgd
        learning_rate: 0.001
        momentum: 0.9
        weight_decay: 0.0000

    use_attention: False

    tune_hparams: False                 # Hparam tuning
    tune_range: 50
    hparams_grid:
    resources_per_trial:

  train_head: True
  est_head:
    _target_: src.models.causal_cpc_ed.RNNEstHead
    seq_hidden_units: ${model.rep_encoder.br_size}           # rnn_hidden_units in the original terminology should be equal to encoder.br_size
    br_size: ${model.rep_encoder.br_size}                         # preferabvle to be smaller than that of encoder
    dropout_rate:  0.1                  # Dropout of LSTM hidden layers + output layers
    num_layer: 1
    y_dist_type: "continuous"
    teacher_forcing: False
    treat_hidden_dim: 6

    finetune_rep_encoder: True
    retrain_rep_encoder: False
    fc_hidden_units:  14
    rnn_type: gru

    batch_size: 128
    likelihood_training: True
    historical_avg: False

    random_indices: True
    percentage_to_keep: 0.1

    balancing: mutual_info
    label_smoothing: 0
    use_spectral_norm: True
    use_instance_noise: False
    activation: selu
    alpha_recons: 0
    alpha_infonce: 0
    alpha_mse: 10                   # not used when likelihood_training

    step_mse_loss_weights_type: avg

    optimizer:
      lr_scheduler: False
      non_treatment_head:
        optimizer_cls: adamw
        learning_rate: 0.005
        weight_decay: 0.0000
        rep_encoder:
          learning_rate: 0.0005

      treatment_head:
        optimizer_cls: sgd
        learning_rate: 0.01
        momentum: 0.9
        weight_decay: 0.0000


  train_context_decoder: False
  context_decoder:
    activation: selu
    batch_size: 32
    enc_hidden_dim: ${model.rep_encoder.context_latent_dim}

    optimizer:
      lr_scheduler: False
      context_decoder:
        optimizer_cls: adamw
        learning_rate: 0.0005
        weight_decay: 0.0000


    tune_hparams: False                   # Hparam tuning
    tune_range: 30
    hparams_grid:
    resources_per_trial:


exp:
  eval_only: True
  weights_ema: False
  balancing: ${model.rep_encoder.balancing}
  alpha: 1
  update_alpha: False
